In [45]:
%matplotlib notebook
import numpy as np
import cupy as cp
import sigpy.plot as pl
from scipy.sparse import csr_matrix
import pywt
# device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [51]:
def wavMask(dims, scale):
    sx, sy = dims
    res = np.ones(dims)
    NM = np.round(np.log2(dims))
    for n in range(int(np.min(NM)-scale+2)//2):
        res[:int(np.round(2**(NM[0]-n))), :int(np.round(2**(NM[1]-n)))] = \
            res[:int(np.round(2**(NM[0]-n))), :int(np.round(2**(NM[1]-n)))]/2
    return res


def imshowWAV(Wim, scale=1):
    plt.imshow(np.abs(Wim)*wavMask(Wim.shape, scale), cmap = plt.get_cmap('gray'))

    
def coeffs2img(LL, coeffs):
    LH, HL, HH = coeffs
    return np.vstack((np.hstack((LL, LH)), np.hstack((HL, HH))))


def unstack_coeffs(Wim):
        L1, L2  = np.hsplit(Wim, 2) 
        LL, HL = np.vsplit(L1, 2)
        LH, HH = np.vsplit(L2, 2)
        return LL, [LH, HL, HH]

    
def img2coeffs(Wim, levels=3):
    LL, c = unstack_coeffs(Wim)
    coeffs = [c]
    for i in range(levels-1):
        LL, c = unstack_coeffs(LL)
        coeffs.insert(0,c)
    coeffs.insert(0, LL)
    return coeffs
    
    
def dwt2(im):
    coeffs = pywt.wavedec2(im, wavelet='db4', mode='per', level=3)
    Wim, rest = coeffs[0], coeffs[1:]
    for levels in rest:
        Wim = coeffs2img(Wim, levels)
    return Wim


def idwt2(Wim):
    coeffs = img2coeffs(Wim, levels=3)
    return pywt.waverec2(coeffs, wavelet='db4', mode='per')

In [48]:
B.shape

(256, 256)

In [57]:
dwt2(cp.asnumpy(B)).shape

(256, 256)

In [14]:
def M_forward(M,c):
    return cp.matmul(M,c)
    

In [15]:
def C_forward(C,m):
    return cp.matmul(m,C)


In [16]:
def M_adjoint(M,x):
    return cp.matmul(cp.conj(M.T),x)

In [17]:
def C_adjoint(C,x):
    return cp.matmul(x,cp.conj(C.T))

In [21]:
#Verification
n=256
tdim=24
M=cp.random.rand(n*tdim,n)+1j*cp.random.rand(n*tdim,n)
C=cp.random.rand(n,n)+1j*cp.random.rand(n,n)
x=cp.random.rand(n*tdim,n)+1j*cp.random.rand(n*tdim,n)

In [22]:
cp.dot(cp.conj(C.T),M_adjoint(M,x))

array([[384901.49535756-399208.87147394j,
        382990.23437094-402400.295204j  ,
        382330.29698954-403121.68916165j, ...,
        384334.97728483-399135.44771367j,
        384052.84207356-400144.17530657j,
        385455.46111125-401772.7642212j ],
       [391807.97307671-399140.65734638j,
        389820.65445562-402337.31794222j,
        389106.41497834-403088.26175801j, ...,
        390980.82184885-399028.70564089j,
        390830.19711142-400030.93854197j,
        392208.65019686-401576.18804863j],
       [405555.55895358-388345.07404408j,
        403536.85841716-391530.36320478j,
        402903.56127761-392345.27271657j, ...,
        404775.3351876 -388152.13913242j,
        404601.90716607-389306.80383992j,
        406217.54607615-390711.59033423j],
       ...,
       [385149.88279043-385503.73138422j,
        383267.96722142-388721.79745187j,
        382483.39938009-389239.87044631j, ...,
        384361.60290664-385370.4587125j ,
        384188.72486721-386433.10269813j,

In [23]:
cp.dot(C_forward(C,M).T.conj(),x)

array([[384901.49535756-399208.87147394j,
        382990.23437094-402400.295204j  ,
        382330.29698954-403121.68916165j, ...,
        384334.97728483-399135.44771367j,
        384052.84207356-400144.17530657j,
        385455.46111125-401772.7642212j ],
       [391807.97307671-399140.65734638j,
        389820.65445562-402337.31794222j,
        389106.41497834-403088.26175801j, ...,
        390980.82184885-399028.70564089j,
        390830.19711142-400030.93854197j,
        392208.65019686-401576.18804863j],
       [405555.55895358-388345.07404408j,
        403536.85841716-391530.36320478j,
        402903.56127761-392345.27271657j, ...,
        404775.3351876 -388152.13913242j,
        404601.90716607-389306.80383992j,
        406217.54607615-390711.59033423j],
       ...,
       [385149.88279043-385503.73138422j,
        383267.96722142-388721.79745187j,
        382483.39938009-389239.87044631j, ...,
        384361.60290664-385370.4587125j ,
        384188.72486721-386433.10269813j,

In [58]:
def soft_thresh_complex_np(x, l):
    return np.sign(abs(x)) * np.maximum(np.abs(x) - l, 0.)*np.exp(1j*np.angle(x))

In [25]:
def soft_thresh_complex(x, l):
    return cp.sign(abs(x)) * np.maximum(cp.abs(x) - l, 0.)*cp.exp(1j*cp.angle(x))

In [60]:
def ista(C1,M1,L,lamda,CS,x,N):
    C0=C1
    M0=M1
    converge = []
    for i in range(N):
        for j in range(100):
            C01 = C0 - (1/L)*M_adjoint(M0,M_forward(M0,C0)-x)
            w = dwt2(cp.asnumpy(C01))
            C00 = soft_thresh_complex_np(w,CS/L)
            C01 = cp.asarray(idwt2(C00))
            converge.append(cp.linalg.norm(M_forward(M0,C01)-x))
            C0 = C01
        for t in range(100):
            w = M0 - (1/L)*C_adjoint(C0,C_forward(C0,M0)-x)
            M01 = soft_thresh_complex(w,lamda/L)
            converge.append(cp.linalg.norm(M_forward(M01,C0)-x))
            M0 = M01
        print(cp.linalg.norm(M_forward(M0,C0)-x))
        cp.save("M0.npy",M0)
        cp.save("C0.npy",C0)
        cp.save("converge.npy",converge)
        
    return M0,C0,converge

In [27]:
import scipy.io 

In [28]:
image_data=scipy.io.loadmat('images.mat')

In [29]:
LPS_image = image_data['LplusS']

In [30]:
pl.ImagePlot(LPS_image)

<IPython.core.display.Javascript object>

<sigpy.plot.ImagePlot at 0x7f207dc577b8>

In [31]:
LPS_image.shape

(256, 256, 24)

In [32]:
import numpy.matlib

In [33]:
x=LPS_image.transpose(2,0,1).reshape(256*24,256)
C=LPS_image[:,:,1]
M=np.matlib.repmat(np.eye(256,256),24, 1)

In [35]:
im = M.dot(C)

In [36]:
C = cp.array(C)
M = cp.array(M)
x = cp.array(x)

In [37]:
A,B,C = ista(C,M,500,0.01,x,N=100)

16.90745074420497
15.70669601751328
15.124172563983539
14.733457811977477
14.44227397211905
14.211827206242436
14.022022120888188
13.861167191426254
13.721926622427768
13.59932769324839
13.489863704957989
13.390996900207574
13.300844933568534
13.217967305340798
13.141265031429032
13.069837856855354
13.002912038048922
12.939895704574965
12.88029629115608
12.823724255547958
12.769846879780893
12.718348290262542
12.668997384790512
12.621560935022732
12.575836220921541
12.531702575357356
12.489013888093078
12.447671868239889
12.407587030894172
12.368670890533927
12.330830936433923
12.29396635924214
12.258033332218348
12.223038186231825
12.188934491437468
12.155687139872667
12.123271286261204
12.091649975326714
12.060808242106505
12.030712741662864
12.001341669107145
11.972669746752494
11.944672456447424
11.917340525835492
11.89060678409448
11.864454919875076
11.83884967291305
11.813781264076969
11.789211618968908
11.76511861694022
11.741490342888794
11.71830200400235
11.695526697744581
11.

In [38]:
pl.ImagePlot(LPS_image)

<IPython.core.display.Javascript object>

<sigpy.plot.ImagePlot at 0x7f207d9560f0>

In [39]:
LPS_image = cp.array(LPS_image)

In [40]:
rec = A.dot(B).reshape(24,256,256)
org = LPS_image.transpose(2,0,1)

In [41]:
cod = cp.concatenate((rec,org),axis=2)

In [42]:
pl.ImagePlot(cod)

<IPython.core.display.Javascript object>

<sigpy.plot.ImagePlot at 0x7f207d9a2128>

In [55]:
(cp.conj(1+1j))

array(1.-1.j)