In [1]:
%matplotlib notebook
import numpy as np
import sigpy as sp
import sigpy.plot as pl
import cupy as cp
import scipy.io
import math
from scipy.sparse import csc_matrix
import numpy.matlib
import matplotlib.pyplot as plt

In [2]:
cp.cuda.Device(2).use()

In [3]:
# load the cardiac data
cardiac_images = scipy.io.loadmat("images.mat")
imgs = cardiac_images["LplusS"]
imgs=imgs[87:183,65:161,:]
imgs = imgs.transpose(2,0,1)
imgs = cp.array(imgs)

In [4]:
C0 = imgs[0,:,:][None,:,:]/10

In [5]:
M0 = np.matlib.repmat(np.eye(32*32,dtype=np.complex),25*24,1).reshape(25,24,1024,1024)

In [6]:
def M_forward(M,c,gpu = False):
    if gpu:
        xp = cp
    else:
        xp = np
    return xp.matmul(M,c)

def C_forward(C,m,gpu = False):
    if gpu:
        xp = cp
    else:
        xp = np
    return xp.matmul(m,C)

def M_adjoint(M,x,gpu = False):
    if gpu:
        xp = cp
    else:
        xp = np
    return xp.matmul(M.T.conj(),x)

def C_adjoint(C,x,gpu = False):
    if gpu:
        xp = cp
    else:
        xp = np
    return xp.matmul(x,C.T.conj())



def soft_thresh_complex(x, l,gpu = False):
    if gpu:
        xp = cp
    else:
        xp = np
    return xp.sign(abs(x)) * xp.maximum(xp.abs(x) - l, 0.)*xp.exp(1j*xp.angle(x))


In [7]:
def R_forward(im,patch_no,patch_size,stride_length):
    [frames,n,m]=im.shape
    n_patch_per_side=math.floor((n-patch_size)/stride_length)+1
    row_patch_no=math.floor((patch_no-1)/n_patch_per_side)
    column_patch_no=(patch_no-1)%n_patch_per_side
    crop=im[:,int(row_patch_no*stride_length):int(row_patch_no*stride_length+patch_size),int(column_patch_no*stride_length):int(column_patch_no*stride_length+patch_size)]
    crop=crop.reshape((frames*patch_size*patch_size,1))
    return crop
def R_adjoint(crop,patch_no,im_size,im_frames,stride_length):
    [a,b]=crop.shape
    patch_size=int((a/im_frames)**(1/2))
#     print(patch_size)
    crop=crop.reshape((im_frames,patch_size,patch_size))
    n_patch_per_side=math.floor((im_size-patch_size)/stride_length)+1
    row_patch_no=math.floor((patch_no-1)/n_patch_per_side)
    column_patch_no=(patch_no-1)%n_patch_per_side
    padded=cp.zeros((im_frames,im_size,im_size),dtype=cp.complex)
    padded[:,int(row_patch_no*stride_length):int(row_patch_no*stride_length+patch_size),int(column_patch_no*stride_length):int(column_patch_no*stride_length+patch_size)]=crop
    return padded


def powermethods(matrix_A,iterations = 10):
    max_eig=0
    for i in range(iterations):
        c = cp.random.rand(32*32,1)
        frwrd=M_forward(matrix_A,c,gpu = True)
        mag_c=cp.dot(cp.transpose(c),c)
        eig=cp.dot(cp.transpose(frwrd),frwrd)/mag_c
        eig = eig[0][0]
#         print(eig)
        max_eig=max(eig,max_eig)
#         print(max_eig)
        
    return float(abs(max_eig))

def modelCforward(M,C):
    # C shape: 1,256,256
    # M shape: 225,24,1032,1032
    patchnum = 25
    result=cp.zeros((24,96,96),dtype=cp.complex)
    for i in range(patchnum):
        M_current=cp.array(M[i,:,:,:].reshape(24*1024,1024))
        patch=R_forward(C,patch_no=i+1,stride_length=16,patch_size=32)
        res = M_forward(M=M_current,c=patch,gpu=True)
#         print(R_adjoint(crop=res,im_frames=24,im_size=256,patch_no=i+1,stride_length=16) )
        result+=R_adjoint(crop=res,im_frames=24,im_size=96,patch_no=i+1,stride_length=16) 
    return result
def modelCadjoint(Im,M):
    # Im shape: (24,256,256)
    # M shape: 225,24,1032,1032
    patch_c = cp.zeros((225,96,96),dtype=cp.complex)
    for i in range(25):
        patch = R_forward(Im,patch_no=i+1,patch_size=32,stride_length=16)
        M_patch = cp.array(M[i,:,:,:].reshape(24*1024,1024))
        c_est = M_adjoint(M_patch,patch,gpu=True)
        radj = R_adjoint(c_est,im_frames=1,patch_no=i+1,stride_length=16,im_size=96)  
        patch_c[i,:,:] = radj.squeeze()
    return patch_c.sum(axis=0)[None,:,:]
def modelMadjoint(Im,C):
    # C shape: (1,256,256)
    # Im shape: (24,256,256)
#     patch_m = np.zeros((225,24,1024,1024),dtype=cp.complex)
    for i in range(25):
        patch = R_forward(Im,patch_no=i+1,patch_size=32,stride_length=16)
        C_patch = R_forward(C,patch_no=i+1,patch_size=32,stride_length=16)
        cadj = C_adjoint(C_patch,patch,gpu=True).reshape(24,1024,1024)
        patch_m[i,:,:,:] = cp.asnumpy(cadj)

In [8]:
patch_m = np.zeros((25,24,1024,1024),dtype=cp.complex)

In [10]:
def ista(C1,M1,L,lamda,x,gpu = False):
    if gpu:
        xp = cp
    else:
        xp = np
    C0=C1
    M0=M1
    converge = []
    print(xp.linalg.norm(modelCforward(M0,C0)-x))
#     patch_m = cp.zeros((225,256,256),dtype=cp.complex)
    for i in range(40):
        print('Iteration no:'+str(i))
        for t in range(10):
            print(t)
            modelMadjoint(modelCforward(M0,C0)-x,C0)
            M0 = M0 - (1/100)*patch_m
#             print(C_adjoint(C0,C_forward(C0,M0,gpu = gpu)-x,gpu = gpu))
            #M01 = soft_thresh_complex(M01,lamda/10000,gpu = gpu)
            converge.append(xp.linalg.norm(modelCforward(M0,C0)-x))
            print(xp.linalg.norm(modelCforward(M0,C0)-x),"ekin")
        for j in range(10):
            print(j)
            C01 = C0 - (1/L)*modelCadjoint(modelCforward(M0,C0)-x,M0)
#             C01 = soft_thresh_complex(w,lamda/L,gpu = gpu)
            converge.append(xp.linalg.norm(modelCforward(M0,C01)-x))
            C0 = C01
            print(xp.linalg.norm(modelCforward(M0,C0)-x),"ekin_1")
            

    return M0,C0,converge

In [11]:
M1,C1,converge = ista(C0,M0,1000,0.01,imgs,gpu=True)

62.31955061234166
Iteration no:0
0
61.16788609765075 ekin
1
60.047419104218115 ekin
2


KeyboardInterrupt: 

In [11]:
def modelCforward(M,C):
    # C shape: 1,256,256
    # M shape: 225,24,1032,1032
    patchnum = 225
    result=cp.zeros((24,256,256),dtype=cp.complex)
    for i in range(patchnum):
        M_current=cp.array(M[i,:,:,:].reshape(24*1024,1024))
        patch=R_forward(C,patch_no=i+1,stride_length=16,patch_size=32)
        res = M_forward(M=M_current,c=patch,gpu=True)
#         print(R_adjoint(crop=res,im_frames=24,im_size=256,patch_no=i+1,stride_length=16) )
        result+=R_adjoint(crop=res,im_frames=24,im_size=256,patch_no=i+1,stride_length=16) 
    return result
def modelCadjoint(Im,M):
    # Im shape: (24,256,256)
    # M shape: 225,24,1032,1032
    patch_c = cp.zeros((225,256,256),dtype=cp.complex)
    for i in range(225):
        patch = R_forward(Im,patch_no=i+1,patch_size=32,stride_length=16)
        M_patch = cp.array(M[i,:,:,:].reshape(24*1024,1024))
        c_est = M_adjoint(M_patch,patch,gpu=True)
        radj = R_adjoint(c_est,im_frames=1,patch_no=i+1,stride_length=16,im_size=256)  
        patch_c[i,:,:] = radj.squeeze()
    return patch_c.sum(axis=0)[None,:,:]
def modelMadjoint(Im,C,patch_m):
    # C shape: (1,256,256)
    # Im shape: (24,256,256)
#     patch_m = np.zeros((225,24,1024,1024),dtype=cp.complex)
    for i in range(225):
        patch = R_forward(Im,patch_no=i+1,patch_size=32,stride_length=16)
        C_patch = R_forward(C,patch_no=i+1,patch_size=32,stride_length=16)
        cadj = C_adjoint(C_patch,patch,gpu=True).reshape(24,1024,1024)
        patch_m[i,:,:,:] = cp.asnumpy(cadj)
    return patch_m

In [109]:
FM = modelCforward(M1,C1)

In [119]:
plt.figure()
plt.plot(converge)

<IPython.core.display.Javascript object>

[<matplotlib.lines.Line2D at 0x7f0ce5b6f160>]

In [116]:
pl.ImagePlot(cp.asnumpy(FM))

<IPython.core.display.Javascript object>

<sigpy.plot.ImagePlot at 0x7f0cdcecf550>

In [118]:
cp.save("M_10b10.npy",M1)
cp.save("C_10b10.npy",C1)
cp.save("converge_10b10.npy",converge)

In [28]:
def calc_vector_patch(weights,index1,index2):
    [n,m]=weights.shape
    vector=np.zeros((2,1))
    for i in range(n):
        for j in range(n):
            vector[0]+=(index1-i)*abs(weights[i,j])
            vector[1]+=(index2-j)*abs(weights[i,j])
      
    return vector

In [98]:
vector_maps=np.zeros((96*96*24,2),dtype=cp.complex)
M_new=M1.reshape((25,24,1024,32,32))
b=0
for p in range(2):
    print('p='+str(p))
    patch_vmap=np.zeros((2,24,32*32))
    for i in range(24):
        print('i='+str(i))
        for j in range(32*32):
            a=j%32
            vect=calc_vector_patch(M_new[p,i,j,:,:],a,b)
            if np.linalg.norm(vect)>=0:
                patch_vmap[:,i,j]=vect.reshape(2,)
            if a==31 & b==31:
                b=0
            elif a==31:
                b+=1
    patch_vmap=cp.array(patch_vmap.reshape(24*2*32*32,1))
    patch_vmap=R_adjoint(patch_vmap,p+1,96,24*2,16)
    patch_vmap=patch_vmap.reshape((2,24*96*96)).transpose((1,0))
    vector_maps+=cp.asnumpy(patch_vmap)
        

p=0
i=0
i=1
i=2
i=3
i=4
i=5
i=6
i=7
i=8
i=9
i=10
i=11
i=12
i=13
i=14
i=15
i=16
i=17
i=18
i=19
i=20
i=21
i=22
i=23
p=1
i=0
i=1
i=2
i=3
i=4
i=5
i=6
i=7
i=8
i=9
i=10
i=11
i=12
i=13
i=14
i=15
i=16
i=17
i=18
i=19
i=20
i=21
i=22
i=23


In [99]:
maps=vector_maps.reshape((24,96,96,2))

In [104]:
maps[15,:,:,:]

array([[[-1.98768506+0.j, -1.65356502+0.j],
        [ 0.03427376+0.j, -1.84781644+0.j],
        [-0.67925277+0.j, -4.48785361+0.j],
        ...,
        [ 0.        +0.j,  0.        +0.j],
        [ 0.        +0.j,  0.        +0.j],
        [ 0.        +0.j,  0.        +0.j]],

       [[-2.68573301+0.j, -0.31477675+0.j],
        [-3.84614004+0.j, -3.16466275+0.j],
        [-1.74132194+0.j, -3.38653685+0.j],
        ...,
        [ 0.        +0.j,  0.        +0.j],
        [ 0.        +0.j,  0.        +0.j],
        [ 0.        +0.j,  0.        +0.j]],

       [[-2.43548343+0.j,  1.68291075+0.j],
        [-1.83469535+0.j,  0.3590956 +0.j],
        [-1.22222634+0.j, -0.99324064+0.j],
        ...,
        [ 0.        +0.j,  0.        +0.j],
        [ 0.        +0.j,  0.        +0.j],
        [ 0.        +0.j,  0.        +0.j]],

       ...,

       [[ 0.        +0.j,  0.        +0.j],
        [ 0.        +0.j,  0.        +0.j],
        [ 0.        +0.j,  0.        +0.j],
        ...,
     

In [107]:
test=maps[1,:,:,:]
plt.figure()
plt.quiver(-abs(test[:,:,1]),abs(test[:,:,0]),scale=1000, scale_units='inches')

<IPython.core.display.Javascript object>

<matplotlib.quiver.Quiver at 0x7f0ce5e1dcc0>

In [77]:
fig=plt.figure()
columns = 4
rows = 6
maps=vector_maps.reshape((24,96,96,2))
for i in range(1,25):
    img = FM[:,:,i-1]
    fig.add_subplot(rows, columns, i)
    #plt.imshow(img.T,aspect='auto')
    test=maps[i-1,:,:,:]
    plt.quiver(-test[:,:,1],test[:,:,0],scale=700, scale_units='inches')
plt.show()

<IPython.core.display.Javascript object>