In [None]:
# generate multiple spiral trajectories
import torch
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal.windows import gaussian as gausswin

im_size = 64
spokelength = im_size * 2
ntrajs = 8

# equally sample high frequency and low frequency
theta = gausswin(spokelength*2, std=im_size)[spokelength:]
theta = np.cumsum(theta)
theta = theta / max(theta) * 4*np.pi    # range of theta determines number of circles: 2*pi forms one circle

# spiral curve with radius following gaussian density
radius = np.linspace(0, np.pi, len(theta))  # range of radius determins how far lines can go
k0 = radius * np.exp(1j * theta) # a single spiral

ktrajs = []
# generate multiple spirals with different angles
cirp = 2j*np.pi * np.arange(im_size)/im_size
for i in range(ntrajs):
    itlvp = cirp[range(i, im_size, im_size//ntrajs)]
    ktraj = np.matmul(k0.T[:,np.newaxis], np.exp(itlvp[np.newaxis,:]))
    # convert k-space trajectory to tensor
    kx, ky = ktraj.real, ktraj.imag
    ktraj = np.stack((ky.flatten(), kx.flatten()), axis=0)
    ktraj = torch.from_numpy(ktraj)
    ktrajs.append(ktraj)

# plot kspace trajectory
for i in range(ntrajs):
    plt.plot(kx[:, i], ky[:, i])
plt.axis('equal')
plt.title('kspace trajectory')
plt.show()

In [None]:
# from trajectory, we can form a binary spiral mask, aims for data consistency
masks = []
for ktraj in ktrajs:
    kx, ky = ktraj[0], ktraj[1]
    xs, ys = kx / np.pi, ky / np.pi
    xs, ys = (xs + 1) / 2 * (im_size-1), (ys + 1) / 2 * (im_size-1)

    xs, ys = xs.tolist(), ys.tolist()
    mask = torch.zeros([im_size, im_size])
    for x, y in zip(xs, ys):
        x, y = round(x), round(y)
        mask[x, y] = 1.
    masks.append(mask)
masks = torch.stack(masks)[None]

In [None]:
# define some utility function
from torch.fft import *

def ifft2c(k):
    return ifftshift(ifft2(fftshift(k)))

def fft2c(x):
    return fftshift(fft2(ifftshift(x)))

In [None]:
# setup seed
import torch
import random
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
setup_seed(12345)
# set num threads
torch.set_num_threads(8)    # too many threads may kill python kernel

In [None]:
# define dataset
import os
import torch
import pickle
import numpy as np
import torchkbnufft as tkbn
from torch.utils.data import Dataset

class dataset(Dataset):
    """
    :param ktraj - run cell tagged ktraj to obtain spiral trajectory
    :return complexType zero-filled image zf and original image xf
    """
    def __init__(self, data_path, ktrajs, im_size, nimgs):
        # load all data into memory
        self.nimgs = nimgs
        self.ktrajs = torch.stack(ktrajs).cuda()
        names = set([os.path.splitext(name)[0] for name in os.listdir(data_path)])
        self.data_dict = {}
        for name in names:
            self.data_dict[name] = {}
            data = torch.from_numpy(np.load(os.path.join(data_path, f'{name}.npy')))
            with open(os.path.join(data_path, f'{name}.pkl'), 'rb') as f:
                offsets = pickle.load(f)
            # sort out M0 images and functional images
            m0imgs, funimgs = [], []
            for offset, img in zip(offsets, data):
                if int(offset[0]) == 25527: # offset at far end is regarded as m0
                    img = abs(ifft2c(img))
                    m0imgs.append(img)
                else:
                    funimgs.append(img)
            self.data_dict[name]['m0'] = torch.mean(torch.stack(m0imgs), 0, keepdim=True)
            self.data_dict[name]['fun'] = torch.stack(funimgs)
        # contruct nufft object
        self.nufft = tkbn.KbNufft(im_size).cuda()
        self.adj = tkbn.KbNufftAdjoint(im_size).cuda()
        self.dcf = tkbn.calc_density_compensation_function(self.ktrajs, im_size).cuda()
    def __len__(self):
        count = 0
        for key in self.data_dict.keys():
            count += len(self.data_dict[key]['fun']) - self.nimgs + 1
        return count
    def __getitem__(self, index):
        # index corresponds which sample
        for key in self.data_dict.keys():
            nseq = len(self.data_dict[key]['fun']) - self.nimgs + 1
            if index - nseq < 0:
                break
            index -= nseq
        # fetch data
        m0 = self.data_dict[key]['m0']
        kf = self.data_dict[key]['fun'][index:index+self.nimgs].cuda()
        # in fft2/ifft2, last two dimensions are the default to be transformed
        xf = ifft2c(kf).unsqueeze(1)
        ku = self.nufft(xf, self.ktrajs)
        zf = self.adj(ku*self.dcf, self.ktrajs)
        # normalize
        zf, xf = abs(zf).squeeze(), abs(xf).squeeze()
        xf = xf / (m0.max() - m0.min())
        zf = zf / (zf.max() - zf.min())
        m0 = (m0 - m0.min()) / (m0.max() - m0.min())
        return zf.to(torch.float32), xf.to(torch.float32), m0.to(torch.float32)

In [None]:
# transformer with data consistency
import torch
import torch.nn as nn
from torch import cat
from einops import rearrange
import torch.nn.functional as F
from collections import OrderedDict

class Attention(nn.Module):
    def __init__(self, dim, dim_head=64, heads=8):
        super().__init__()
        self.num_heads = heads
        self.dim_head = dim_head
        self.to_q = nn.Linear(dim, dim_head * heads, bias=False)
        self.to_k = nn.Linear(dim, dim_head * heads, bias=False)
        self.to_v = nn.Linear(dim, dim_head * heads, bias=False)
        self.rescale = nn.Parameter(torch.ones(heads, 1, 1))
        self.proj = nn.Linear(dim_head * heads, dim, bias=True)
        self.pos_emb = nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim)
        self.dim = dim
    def forward(self, x_in):
        b, c, h, w = x_in.shape
        x = x_in.permute(0, 2, 3, 1).reshape(b,h*w,c)
        # b, hw, hd
        q_inp = self.to_q(x)
        k_inp = self.to_k(x)
        v_inp = self.to_v(x)
        # b, h, hw, d
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.num_heads), (q_inp, k_inp, v_inp))
        # b, h, d, hw
        q = q.transpose(-2, -1)
        k = k.transpose(-2, -1)
        v = v.transpose(-2, -1)
        q = F.normalize(q, dim=-1, p=2)
        k = F.normalize(k, dim=-1, p=2)
        # attn: b, h, d, d
        attn = (k @ q.transpose(-2, -1))
        attn = attn * self.rescale
        attn = attn.softmax(dim=-1)
        # x: b, h, d, hw
        x = attn @ v
        x = x.permute(0, 3, 1, 2)
        x = x.reshape(b, h * w, self.num_heads * self.dim_head)
        # out: b, c, h, w
        out_c = self.proj(x).view(b, h, w, c).permute(0, 3, 1, 2)
        out_p = self.pos_emb(x_in)
        out = out_c + out_p
        return out

class FeedForward(nn.Module):
    def __init__(self, dim, mult=4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(dim, dim * mult, 1, 1, bias=False),
            nn.GELU(),
            nn.Conv2d(dim * mult, dim * mult, 3, 1, 1, bias=False, groups=dim * mult),
            nn.GELU(),
            nn.Conv2d(dim * mult, dim, 1, 1, bias=False),
        )
    def forward(self, x):
        return self.net(x)

class Transformer(nn.Module):
    def __init__(self, dim, dim_head=64, heads=8, num_blocks=2):
        super().__init__()
        self.blocks = nn.ModuleList([])
        for _ in range(num_blocks):
            self.blocks.append(
                nn.ModuleList([
                    Attention(dim=dim, dim_head=dim_head, heads=heads),
                    FeedForward(dim=dim)
                ])
            )
    def forward(self, x):
        for (attn, ff) in self.blocks:
            x = attn(x) + x
            x = ff(x) + x
        return x

class MoTR(nn.Module):
    def __init__(self, nimgs, dim=16, stages=2, num_blocks=[2,2]):
        super(MoTR, self).__init__()
        self.nimgs = nimgs
        dim_stage, dim_in = dim, 1+nimgs
        self.stages = nn.ModuleList()
        # embedding - transformer - mapping
        for i in range(stages):
            self.stages.append(
                nn.Sequential(OrderedDict([
                    ('embedding', nn.Sequential(nn.Conv2d(dim_in, dim_stage, 3, 1, 1, bias=False), nn.GELU())),
                    ('transformer', Transformer(dim_stage, dim, dim_stage//dim, num_blocks[i])),
                    ('mapping', nn.Sequential(nn.Conv2d(dim_stage, dim_stage*2, 3, 1, 1, bias=False), nn.GELU()))
                ]))
            )
            dim_stage *= 2
            dim_in = dim_stage
        # output
        self.convout = nn.Conv2d(dim_stage, nimgs, 1)
    def forward(self, zf, m0, k0):
        x = cat([m0, zf], 1)
        for layer in self.stages:
            x = layer(x)
        xr = self.convout(x)
        masks = k0 == 0
        kr = fft2c(xr) * masks + k0
        xr = abs(ifft2c(kr))
        return xr

In [None]:
# construct training dataloader and execute training loop
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

nimgs = 8
trainset = dataset('data/processed/trains', ktrajs=ktrajs, im_size=[64, 64], nimgs=nimgs)
trainloader = DataLoader(trainset, batch_size=8, shuffle=True, drop_last=True)  # shuffle permutes the indices of all samples through torch.randperm(n)

# take a snapshot of zf, xf, and m0
import matplotlib.pyplot as plt
zf, xf, m0 = trainset.__getitem__(nimgs)
plt.subplot(1, 3, 1)
plt.imshow(zf[0].cpu().numpy())
plt.subplot(1, 3, 2)
plt.imshow(xf[0].cpu().numpy())
plt.subplot(1, 3, 3)
plt.imshow(m0[0].cpu().numpy())
plt.show()

# hyper-parameters
lr, weight_decay, epochs = 1e-4, 1e-4, 120

net = MoTR(nimgs=nimgs, dim=16, stages=3, num_blocks=[3, 2, 1]).cuda()
optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.99)
criterion = nn.L1Loss()

iter = 0
writer = SummaryWriter()
for epoch in range(epochs):
    for zf, xf, m0 in trainloader:
        zf, xf, m0 = zf.cuda(), xf.cuda(), m0.cuda()
        # forward, you can always trust k0 since it's directly from scanner
        k0 = fft2c(xf) * masks.to(xf)
        xr = net(zf, m0, k0)
        loss = criterion(xr, xf)
        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # log
        writer.add_scalar('Loss/train', loss.item(), iter)
        iter = iter + 1
    scheduler.step()

torch.save(net.state_dict(), 'variables/motr.pth')

# take a snapshot of xr and xf
plt.subplot(1, 2, 1)
plt.imshow(xr[0, 4].detach().cpu().numpy())
plt.subplot(1, 2, 2)
plt.imshow(xf[0, 4].detach().cpu().numpy())

In [None]:
# construct validation dataloader and execute validation loop
import torch
from torch.utils.data import DataLoader
from skimage.metrics import structural_similarity as ssim

nimgs = 8
validset = dataset('data/processed/valids', ktrajs=ktrajs, im_size=[64, 64], nimgs=nimgs)
validloader = DataLoader(validset, batch_size=1, shuffle=False, drop_last=False)

net = MoTR(nimgs=nimgs, dim=16, stages=3, num_blocks=[3, 2, 1]).cuda()
net.load_state_dict(torch.load('variables/motr.pth'))

net.eval()
net.requires_grad_(False)

SSIMs, MAEs = [], []
for zf, xf, m0 in validloader:
    zf, xf, m0 = zf.cuda(), xf.cuda(), m0.cuda()
    k0 = fft2c(xf) * masks.to(xf)
    xr = net(zf, m0, k0)
    xr, xf = xr[0, 4].cpu().numpy(), xf[0, 4].cpu().numpy()
    SSIMs.append(ssim(xr, xf, data_range=1.))
    MAEs.append(abs(xr-xf).mean())
print(f'average ssim of all slices: {np.array(SSIMs).mean()}')
print(f'average mae of all slices: {np.array(MAEs).mean()}')

# take a snapshot of reconstruction results
import matplotlib.pyplot as plt
plt.imshow(xr, cmap='gray')
plt.show()
plt.imshow(xf, cmap='gray')
plt.show()
plt.imshow(abs(xr-xf))
plt.show()