In [1]:
import scipy.io
import numpy as np
import tensorflow as tf
import scipy.linalg as la
from scipy.sparse import csr_matrix, kron, csc_matrix, block_diag

from torch import tensor
import time
from tqdm import tqdm
import matplotlib.pyplot as plt
from PIL import Image
import os

In [2]:
mat = scipy.io.loadmat('C:\\IOT\\ADMM\\mask.mat')
mask_3d = mat['mask_3D'] #(256, 256, 172)
mat2 = scipy.io.loadmat('C:\\IOT\\ADMM\\Ottawa.mat')
x3d_ref = mat2['X3D_ref']  #(256, 256, 172)
mat3 = scipy.io.loadmat('C:\\IOT\\ADMM\\X3DL.mat')
x3d_dl = mat3['X3D_DL'] #(256, 256, 172)
x3d = x3d_ref * mask_3d
t1 = time.time()
# Define parameters 
N = 10
lambda_ = 0.01
mu = 0.001

In [3]:
# Compute s_dl
def compute_basis(x3d, n):
    x = np.reshape(x3d,(x3d.shape[0]*x3d.shape[1],x3d.shape[2]), order='F')
    xt = np.transpose(x)
    m = xt.shape[0]
    u = xt
    c = u @ np.transpose(u)
    D, V = np.linalg.eig(c)
    ind = np.argsort(D)
    D = D[ind]
    D = np.diag(D)
    V = V[:,ind]
    E = V[:,m-n:]
    return E

def compute_s_dl(x3d):
    row, colum, channel = x3d.shape
    rsize = row * colum
    x2d_dl = np.reshape(x3d, (172, rsize), order='F')
    e_dl = compute_basis(x3d_dl, N)
    s_dl = np.transpose(e_dl) @ x2d_dl
    return s_dl, e_dl

In [4]:
s_dl, e_dl = compute_s_dl(x3d)
# ADMM
row, colum, channel = x3d.shape
rsize = row * colum
empty = np.zeros((channel, channel, rsize))
truerow, truecol, truechannel = np.where(mask_3d == 1)
empty[truechannel, truechannel, (truerow + row * truecol)] = 1
rp_tensor = np.einsum('kij, lk -> lij', tensor(empty), np.transpose(e_dl))  # (10, 172, 65536)
rrt_tensor = np.einsum('kij, li -> klj', tensor(rp_tensor), np.transpose(e_dl)) #(10, 10, 65536)
x2d = np.transpose(np.reshape(x3d,(65536, channel), order='F')) # (172, 65536)
rpy = np.zeros((10, rsize)) # (10, 65536)
for i in range(rsize):
    rpy[:,i] = rp_tensor[:,:,i] @ x2d[:,i]
rpy = np.reshape(rpy,(655360,1), order='F') 


In [5]:
RRtrps_per = np.transpose(rrt_tensor,(2,0,1)) # 65536 10 10
I = (mu/2) * np.eye(N, order='F')
S_left = np.zeros(RRtrps_per.shape)
for i in range(RRtrps_per.shape[0]):
    S_left[i,:,:] = np.linalg.inv(RRtrps_per[i,:,:] + I)
print(S_left.shape)
S_left = [csc_matrix(S_left[n,:,:]) for n in range(S_left.shape[0])]
S_left = block_diag(S_left)
print(S_left.shape)
S2D = np.zeros((N, rsize))
D = np.zeros((N,rsize))


(65536, 10, 10)
(655360, 655360)


In [6]:
for i in tqdm(range(0, 50)):
    Z = (1/(mu+lambda_))*(lambda_*s_dl+mu*(S2D-D))
    DELTA = (Z+D)
    delta = np.reshape(DELTA,(-1,1), order="F")
    s_right = rpy +  (mu/2)*delta
    s = S_left @ s_right
    S2D = np.reshape(s,(N,row*colum),order="F")
    D = D - S2D + Z 
X2D_rec = e_dl @ S2D  
X3D_rec = np.reshape(np.transpose(X2D_rec),(row,colum,channel),order="F")
scipy.io.savemat('C:\\IOT\\ADMM\\output2.mat', {'output2':X3D_rec})
t2 = time.time()
print(t2 - t1)

100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:02<00:00, 23.41it/s]


3198.5085022449493
