# Parallel non-Cartesian Spatial-Temporal Dictionary Learning Neural Networks (stDLNN) for Accelerating 4D-MRI

**Author:** Zhijun Wang, Huajun She  
**Affiliation:** Shanghai Jiao Tong University  
**Email:** wzj@mriee.com   
**Date:** 2022/10/2 


## Requirements

In [None]:
import os, time, copy
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch #1.8
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import torchkbnufft as tkbn #1.1.0 

from torch.nn.utils import clip_grad_norm_
from torch.utils.checkpoint import checkpoint

## Settings

In [None]:
M_I = 8
patch_1 = [2, 3, 3, 3] # p_a, p_x, p_y, p_t
spars_1 = [2, 3, 3, 3]
H_L_1   = [128, 64, 32]
N_I_1   = 3
patch_list  = [patch_1 for i in range(M_I)]
sparse_list = [spars_1 for i in range(M_I)]
H_L_list    = [  H_L_1 for i in range(M_I)]
N_I_list    = [  N_I_1 for i in range(M_I)]

In [None]:
R = 25 
alpha = 0.56
b = 1 #batch size
lr = 0.001
num_epoch = 200
device = torch.device("cuda")

## Model

### CEM

Coefficient Estimation Module  

x.shape = $ (b,l_a, l_b,p_1 \times p_2  \times p_3  \times p_4) $  
L_list : List of units in each layer  
lam.shape = $ (b,l_a, l_b,1) $    

In [None]:
class CEM(nn.Module):
    def __init__(self, L_list):
        super().__init__()
        self.actf = nn.ReLU(inplace = True)
        Ls = []
        for i in range(len(L_list)-1):
            Ls.append(nn.Linear(L_list[i], L_list[i+1], bias=True))
        self.Ls = nn.ModuleList(Ls)
            
    def forward(self, x):
        for linear in self.Ls[:-1]:
            x = self.actf(linear(x))
        lam = self.Ls[-1](x)  
        return lam

### PDM

Patch De-aliasing Module  
  
x.shape = $ (b, l_a, l_b,p_1 \times p_2  \times p_3  \times p_4) $  
lam.shape = $ (b, l_a, l_b, 1) $  
psi.shape = $ (b, l_a, l_b, p_1 \times p_2  \times p_3  \times p_4) $    
   

In [None]:
from utils.dict4 import Dict_4D

In [None]:
def soft_thresh(x, l):
    return torch.sign(x) * (torch.abs(x) - l).clamp(min=0)  

In [None]:
class PDM(nn.Module):
    def __init__(self, N_I, patch_size, sparse_size):
        super().__init__()
        self.N_I = N_I 
        
        Dict = Dict_4D(patch_size,sparse_size)
        Dict = torch.from_numpy(Dict).float()
        ps, nd = Dict.shape
        self.Dict = nn.Parameter(Dict)
        
        Diag = torch.eye(nd)
        self.Diag = nn.Parameter(Diag, requires_grad=False)
        
        zeta = 1 / np.linalg.norm(Dict, ord=2) ** 2
        zeta = torch.FloatTensor((zeta / 2,))
        self.zeta = nn.Parameter(zeta)       
       

    def forward(self, x, lam):
        S = self.Diag - 2 * self.zeta * self.Dict.T.mm(self.Dict)
        t = 2 * self.zeta * x.matmul(self.Dict)
        theta = lam * self.zeta
        g = soft_thresh(t, theta)
        for n in range(self.N_I):
            g = soft_thresh(g.matmul(S) + t, theta)
        psi = g.matmul(self.Dict.T)
        return psi

### DN

De-aliasing Network  
  
x.shape = $ (b, N_a, N_x, N_y, N_t) $  
z.shape = $ (b, N_a, N_x, N_y, N_t) $  

In [None]:
from utils.fold4 import unFold, Fold

In [None]:
class DN(nn.Module):
    def __init__(self, patch_size, sparse_size, H_L, N_I):

        super().__init__()
        self.patch_size = patch_size
        self.cem = CEM([np.prod(patch_size)]+H_L+[1])
        self.pdm = PDM(N_I, patch_size, sparse_size)
        q = torch.normal(mean=1.0, std=1.0 / 10 * torch.ones(np.prod(patch_size)))
        self.q = nn.Parameter(q)
        
    
    def forward(self, x):
        output_size = x.shape
        
        Rx = unFold(x, self.patch_size) 
        lam = self.cem(Rx)
        psi = self.pdm(Rx, lam)
        psi *= self.q

        one = torch.ones_like(psi)
        one *= self.q
        
        z  =  Fold(psi, output_size, kernel_size=self.patch_size)
        z /=  Fold(one, output_size, kernel_size=self.patch_size)

        return z

### DC

parallel non-Cartesian Data Consistency  
  
x.shape = $ (b, N_a, N_x, N_y, N_t) $  
x0.shape = $ (b, N_a, N_x, N_y, N_t) $  
smap.shape =  $ (1, N_c, N_x, N_y) $   
kern:  the filter responses taking into account Toeplitz embedding   
xn.shape = $ (b, N_a, N_x, N_y, N_t) $  

In [None]:
class DC(nn.Module):
    def __init__(self):
        super().__init__()
        self.toep_ob = tkbn.ToepNufft()
        
    def multi_teop(self,z,smap,kerns):
        outputs = []
        for i in range(len(kerns)):
            o = self.toep_ob(z[:,i:i+1,...], kerns[i], smaps=smap)
            outputs.append(o)
        return torch.cat(outputs,axis=1)    

    def forward(self, z, x0, smap, kern, alpha):
        z = z.permute(( 0, 4, 1, 2, 3)).contiguous()
        x0 = x0.permute((0, 4, 1, 2, 3)).contiguous()
        z = z[:,:,0,...] + z[:,:,1,...]*1j
        x0 = x0[:,:,0,...] + x0[:,:,1,...]*1j
        toep = self.multi_teop(z,smap,kern) 
        xn = z - alpha * (toep - x0)
        xn = torch.stack([torch.real(xn),torch.imag(xn)],axis=2)
        xn = xn.permute((0, 2, 3, 4, 1))
        return xn

### stDLNN

Parallel non-Cartesian Spatial-Temporal Dictionary Learning Neural Networks
  
x.shape = $ (b, N_a, N_x, N_y, N_t) $  
smap.shape =  $ (1, N_c, N_x, N_y) $   
kern: the filter responses taking into account Toeplitz embedding   

In [None]:
class stDLNN(nn.Module):

    def __init__(self, device, M_I, patch_list, sparse_list, H_L_list, N_I_list, init_alpha):
        super().__init__()
        self.M_I = M_I
        
        DNs = []  
        for i in range(self.M_I):
            DNs.append( DN(patch_size = patch_list[i], sparse_size = sparse_list[i],
                           H_L = H_L_list[i], N_I = N_I_list[i]) )
        self.DNs = nn.ModuleList(DNs)
        self.dc  = DC()
        self.alpha = nn.Parameter(torch.FloatTensor([init_alpha]))
    
    
    def forward(self, x_und, smap, kern):
        x = x_und
        x.requires_grad_()
        
        for i in range(self.M_I):
            x = checkpoint(self.DNs[i],x) #gradient checkpointing
            x = self.dc(x, x_und, smap, kern, self.alpha)

        return x    


## Training

In [None]:
stdlnn = stDLNN(device, M_I, patch_list, sparse_list, H_L_list, N_I_list, alpha)
stdlnn = stdlnn.cuda()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(stdlnn.parameters(), lr=lr, betas=(0.5, 0.999))

im_u_set: training set (undersampled)  
gnd_set:  training set (ground truth)  
smap_set: sensitivity maps of training set   
kern: the filter responses calculated by [tkbn.calc_toeplitz_kernel](https://torchkbnufft.readthedocs.io/en/stable/generated/torchkbnufft.calc_toeplitz_kernel.html) (weights: density compensation for radial trajectory and acceleration rate)       

In [None]:
def train_epoch():
    train_err = 0
    train_batches = 0

    for im_u, gnd, smap in zip(im_u_set, gnd_set, smap_set): 
        optimizer.zero_grad()
        rec = stdlnn(im_u, smap, kern)
        loss = criterion(rec, gnd)
        loss.backward()
        clip_grad_norm_(stdlnn.parameters(), 1e-4)
        optimizer.step()

        train_err += loss.item()
        train_batches += 1
        
    return train_err / train_batches