In [None]:
# import necessary packages

import os
import torch
import pickle
import tables
import numpy as np
import torch.nn as nn
from torch import cat
from torch.fft import *
from scipy.io import savemat
from einops import rearrange
import torch.nn.functional as F
import matplotlib.pyplot as plt
from collections import OrderedDict
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.tensorboard import SummaryWriter
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr

In [None]:
# load complementary undersampling cartesian trajectories

datapath = 'data/bruker/5_cu_96x96_cest'

nCU = tables.open_file(os.path.join(datapath,'MTCUNOffs.mat')).root.MTCUNOffs
nCU = int(np.array(nCU).squeeze())
imsize = [96,96]

PEs = tables.open_file(os.path.join(datapath,'MT_spatial_phase_1.mat')).root.MT_spatial_phase_1
PEs = np.array(PEs) * imsize[1] / 2 + imsize[1] / 2
PEs = np.round(PEs).astype(np.int32).squeeze()

PEs = PEs.tolist()
PEs = np.array([PEs[-1]]+PEs[0:-1])
PEs = np.split(PEs, nCU)

In [None]:
# define utility functions

def ifft2c(k):
    return ifftshift(ifft2(fftshift(k)))

def fft2c(x):
    return fftshift(fft2(ifftshift(x)))

def view_as_complex(x):
    x = torch.chunk(x, chunks=2, dim=1)
    x = x[0] + 1j * x[1]
    return x

def view_as_real(x):
    x = torch.cat([x.real, x.imag], dim=1).float()
    return x

In [None]:
# define dataset

class dataset(Dataset):
    def __init__(self, datapath, nCU, PEs):
        self.PEs = PEs
        self.nCU  = nCU
        names = set([os.path.splitext(name)[0] for name in os.listdir(datapath)])
        self.data_dict = {}
        for name in names:
            self.data_dict[name] = {}
            data = torch.from_numpy(np.load(os.path.join(datapath, f'{name}.npy')))
            with open(os.path.join(datapath, f'{name}.pkl'), 'rb') as f:
                offsets = pickle.load(f)
            # sort out M0 images and functional images
            m0imgs, funimgs = [], []
            for offset, img in zip(offsets, data):
                if int(offset[0]) >= 100:
                    img = ifft2c(img)
                    m0imgs.append(img)
                else:
                    funimgs.append(img)
            self.data_dict[name]['m0'] = torch.mean(torch.stack(m0imgs), 0, keepdim=True)
            self.data_dict[name]['fun'] = torch.stack(funimgs)
    def __len__(self):
        count = 0
        for key in self.data_dict.keys():
            count += len(self.data_dict[key]['fun']) - self.nCU + 1
        return count
    def __getitem__(self, index):
        # index corresponds which sample
        for key in self.data_dict.keys():
            nseq = len(self.data_dict[key]['fun']) - self.nCU + 1
            if index - nseq < 0:
                break
            index -= nseq
        # fetch data
        m0 = self.data_dict[key]['m0']
        kf = self.data_dict[key]['fun'][index:index+self.nCU]
        ku = []
        for i in range(nCU):
            mask = torch.zeros_like(kf[i],dtype=bool)
            mask[:,PEs[i]] = 1
            ku.append(kf[i] * mask)
        ku = torch.stack(ku)
        xf = ifft2c(kf)
        xu = ifft2c(ku)
        return xu, xf, ku, m0

In [None]:
# define variation network

class Attention(nn.Module):
    def __init__(self, dim, dim_head=64, heads=8):
        super().__init__()
        self.num_heads = heads
        self.dim_head = dim_head
        self.to_q = nn.Linear(dim, dim_head * heads, bias=False)
        self.to_k = nn.Linear(dim, dim_head * heads, bias=False)
        self.to_v = nn.Linear(dim, dim_head * heads, bias=False)
        self.rescale = nn.Parameter(torch.ones(heads, 1, 1))
        self.proj = nn.Linear(dim_head * heads, dim, bias=True)
        self.pos_emb = nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim)
        self.dim = dim
    def forward(self, x_in):
        b, c, h, w = x_in.shape
        x = x_in.permute(0, 2, 3, 1).reshape(b,h*w,c)
        # b, hw, hd
        q_inp = self.to_q(x)
        k_inp = self.to_k(x)
        v_inp = self.to_v(x)
        # b, h, hw, d
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.num_heads), (q_inp, k_inp, v_inp))
        # b, h, d, hw
        q = q.transpose(-2, -1)
        k = k.transpose(-2, -1)
        v = v.transpose(-2, -1)
        q = F.normalize(q, dim=-1, p=2)
        k = F.normalize(k, dim=-1, p=2)
        # attn: b, h, d, d
        attn = (k @ q.transpose(-2, -1))
        attn = attn * self.rescale
        attn = attn.softmax(dim=-1)
        # x: b, h, d, hw
        x = attn @ v
        x = x.permute(0, 3, 1, 2)
        x = x.reshape(b, h * w, self.num_heads * self.dim_head)
        # out: b, c, h, w
        out_c = self.proj(x).view(b, h, w, c).permute(0, 3, 1, 2)
        out_p = self.pos_emb(x_in)
        out = out_c + out_p
        return out

class FeedForward(nn.Module):
    def __init__(self, dim, mult=4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(dim, dim * mult, 1, 1, bias=False),
            nn.GELU(),
            nn.Conv2d(dim * mult, dim * mult, 3, 1, 1, bias=False, groups=dim * mult),
            nn.GELU(),
            nn.Conv2d(dim * mult, dim, 1, 1, bias=False),
        )
    def forward(self, x):
        return self.net(x)

class Transformer(nn.Module):
    def __init__(self, idim, hdim, odim, dim_head=64, heads=8, num_blocks=2):
        super().__init__()
        self.convin = nn.Sequential(nn.Conv2d(idim, hdim, 3, 1, 1, bias=False), nn.GELU())
        self.blocks = nn.ModuleList([])
        for _ in range(num_blocks):
            self.blocks.append(
                nn.ModuleList([
                    Attention(dim=hdim, dim_head=dim_head, heads=heads),
                    FeedForward(dim=hdim)
                ])
            )
        self.convout = nn.Conv2d(hdim, odim, 1, 1, 0)
    def forward(self, x, m0):
        m0 = view_as_real(m0)
        x  = view_as_real(x)
        x  = torch.cat([x,m0],1)
        x  = self.convin(x)
        for (attn, ff) in self.blocks:
            x = attn(x) + x
            x = ff(x) + x
        out = self.convout(x)
        out = view_as_complex(out)
        return out

class vn(nn.Module):
    def __init__(self, idim, hdim, odim, niter):
        super(vn, self).__init__()
        self.iters = nn.ModuleList()
        for i in range(niter):
            self.iters.append(Transformer(idim, hdim, odim))
    def forward(self, x, ku, m0):
        for layer in self.iters:
            x = x + layer(x,m0)
            # data consistency
            kx = fft2c(x) * (ku == 0) + ku
            x = ifft2c(kx)
        return x

In [None]:
# define learning rate scheduler

class PolyScheduler(LambdaLR):
    def __init__(self, optimizer, t_total, exponent=0.9, last_epoch=-1):
        self.t_total = t_total
        self.exponent = exponent
        super(PolyScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)

    def lr_lambda(self, step):
        return (1 - step / self.t_total)**self.exponent

In [None]:
# run training

trainset = dataset('data/processed/trains', nCU, PEs)
trainloader = DataLoader(trainset, batch_size=8, shuffle=True, drop_last=True)

# take a snapshot of xu, xf, and m0
xu, xf, ku, m0 = trainset.__getitem__(0)
print(xu.shape, xf.shape, ku.shape, m0.shape)
for i in range(nCU):
    plt.subplot(2, nCU, i+1)
    plt.imshow(abs(xu)[i].cpu().numpy(), cmap='gray')
    plt.subplot(2, nCU, nCU+i+1)
    plt.imshow(abs(xf)[i].cpu().numpy(), cmap='gray')
plt.show()

# hyper-parameters
lr, weight_decay, epochs = 1e-4, 1e-4, 60

net = vn(idim=2*(1+nCU), hdim=16, odim=2*nCU, niter=6).cuda()
optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = PolyScheduler(optimizer, t_total=epochs)
criterion = nn.L1Loss()

iter = 0
writer = SummaryWriter()
for epoch in range(epochs):
    for xu, xf, ku, m0 in trainloader:
        xu, xf, ku, m0 = xu.cuda(), xf.cuda(), ku.cuda(), m0.cuda()
        xr = net(xu, ku, m0)
        loss = criterion(xr, xf)
        # backward
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0)
        optimizer.step()
        # log
        writer.add_scalar('Loss/train', loss.item(), iter)
        iter = iter + 1
    scheduler.step()

torch.save(net.state_dict(), 'variables/motr_cartesian/motr.pth')

# take a snapshot of xr and xf
for i in range(nCU):
    plt.subplot(2, nCU, i+1)
    plt.imshow(abs(xr)[0][i].detach().cpu().numpy())
    plt.subplot(2, nCU, nCU+i+1)
    plt.imshow(abs(xf)[0][i].detach().cpu().numpy())
plt.show()