# T2 Shuffling implemented in SigPy

This is a demo of T2 Shuffling, implemented using SigPy.
Currently, it uses the "slow" version of the normal equations, because SigPy does not support directly
supplying the normal operator to the interative algorithms.

Run this with the same data found in https://github.com/jtamir/t2shuffling-support

In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np


import sigpy as sp
import sigpy.plot as pl

import scipy.io


import time
import sys
import os.path
import joblib


import cfl


%matplotlib notebook
import matplotlib.pyplot as plt


In [2]:
def nrmse(x, y):
    return np.linalg.norm(x - y) / np.linalg.norm(x)

In [3]:
data_dir = 'data'

teksp = cfl.readcfl(os.path.join(data_dir, 'teksp')).squeeze()
cfksp = cfl.readcfl(os.path.join(data_dir, 'cfksp')).squeeze()
sens = np.expand_dims(cfl.readcfl(os.path.join(data_dir, 'sens')).squeeze(), axis=3)
pat = np.expand_dims(cfl.readcfl(os.path.join(data_dir, 'mask')).squeeze(), axis=2)
bas = cfl.readcfl(os.path.join(data_dir, 'bas')).squeeze()
cfimg = np.expand_dims(cfl.readcfl(os.path.join(data_dir, 'cfimg_recon')).squeeze(), axis=2)

print ('teksp:', teksp.shape)
print ('cfksp:', cfksp.shape)
print ('sens:', sens.shape)
print ('pat:', pat.shape)
print ('bas:', bas.shape)
print ('cfimg:', cfimg.shape)

ny, nz, nc, T = teksp.shape
K_full = cfimg.shape[3]

Phi = bas[:, :K_full].real

teksp: (260, 240, 7, 80)
cfksp: (260, 240, 7, 4)
sens: (260, 240, 7, 1)
pat: (260, 240, 1, 80)
bas: (80, 80)
cfimg: (260, 240, 1, 4)


## Scan and recon params

In [4]:
def stkern_mat(bas, mask):
    T, K = bas.shape
    ny, nz, _, nt = pat.shape
    assert(nt == T)
    z = pat.reshape((ny, nz, T, 1, 1)) * bas.reshape((1, 1, nt, 1, K))
    z2 = np.expand_dims(np.sum(z * bas.reshape((1, 1, T, K, 1)), axis=2), axis=2)
    return z2


print('Computing space time kernel')
tic = time.time()
stk = stkern_mat(Phi, pat)
toc = time.time()
print('Done: {}'.format(toc - tic))
print()

Computing space time kernel
Done: 0.7876341342926025



In [5]:
# X, Y, C, T, K
F_op = sp.linop.FFT([*cfksp.shape, 1], axes = (0, 1))
S_op = sp.linop.Multiply([*cfimg.shape, 1], sens[...,None])
Psi_op = sp.linop.MatMul([*cfksp.shape, 1], stk)
Phi_op = sp.linop.MatMul([*cfksp.shape, 1], Phi[None, None, None, ...])
P_op = sp.linop.Multiply([*teksp.shape, 1], pat[...,None])
T2sh_nrml_op = S_op.H * F_op.H * Psi_op * F_op * S_op
T2sh_op = P_op * Phi_op * F_op * S_op
print(T2sh_op.ishape, T2sh_op.oshape)

[260, 240, 1, 4, 1] [260, 240, 7, 80, 1]


In [6]:
print('Fast normal eq')
tic = time.time()
z1 = T2sh_nrml_op * cfimg[..., None]
toc = time.time()
print('Done: {}'.format(toc - tic))
print()

print('Slow normal eq')
tic = time.time()
z2 = T2sh_op.H * T2sh_op * cfimg[..., None]
toc = time.time()
print('Done: {}'.format(toc - tic))
print()

print('Error: {}'.format(nrmse(z1, z2)))

Fast normal eq
Done: 0.42125701904296875

Slow normal eq
Done: 1.3517980575561523

Error: 9.636945463853408e-08


In [24]:
W = 10
S = 10
lamda = .04

def llr_soft_thresh(lam, block_op, W, S, input):

    input_reshape = input.reshape((input.shape[0], input.shape[1], -1))
    data = block_op * input_reshape
    data_reshape = data.reshape((data.shape[0], data.shape[1], -1, W*W))
    u, s, vh = np.linalg.svd(data_reshape, full_matrices=False)
    s_st = sp.soft_thresh(lam, s)
    data_reshape_st = u * s_st[..., None, :] @ vh
    output = block_op.H * data_reshape_st.reshape(data.shape)
    return output.reshape((input.shape))

class LLRProx(sp.prox.Prox):
    def __init__(self, shape, lamda, W, S, rand_shift=True):
        self.W = W
        self.S = S
        self.lamda = lamda
        self.rand_shift = rand_shift
        self.block_op = sp.linop.ArrayToBlocks([shape[0], shape[1], shape[3]], [W, W, 1], [S, S, 1])
        super().__init__(shape)
        
    def _prox(self, alpha, input):
        if self.rand_shift:
            shift_x, shift_y = np.random.randint(0, self.W, size=2)
            shift_shape = [shift_x, shift_y]
            C_op = sp.linop.Circshift(self.shape, shift_shape, axes=[0, 1])
            input = C_op * input

        output = llr_soft_thresh(self.lamda * alpha, self.block_op, self.W, self.S, input)

        if self.rand_shift:
            output = C_op.H * output
            
        return output
    
proxg = LLRProx(T2sh_op.ishape, lamda, W, S)
# proxg = LLRProx(T2sh_op.ishape, lamda, W, S, rand_shift=False)

In [25]:
teksp_center = sp.resize(teksp, [32, 32, *teksp.shape[2:]])
cfimg_center = np.linalg.norm(np.linalg.norm(sp.ifft(teksp_center, axes=[0, 1]), axis=-1), axis=-1)
scale = np.percentile(abs(cfimg_center.ravel()), 95)


In [None]:
max_iter = 250
alpha = .95

teksp_scale = teksp / scale
T2sh_app = sp.app.LinearLeastSquares(T2sh_op,
                                     teksp_scale[...,None],
                                     proxg=proxg,
                                     alg_name='GradientMethod',
                                     alpha=alpha,
                                     max_iter=max_iter,
                                     accelerate=True)

cfimg_recon = T2sh_app.run() * scale


HBox(children=(IntProgress(value=0, description='LinearLeastSquares', max=250, style=ProgressStyle(description…

In [None]:
pl.ImagePlot(cfimg_recon.transpose((3, 2, 4, 0, 1)))


In [None]:
teimg_recon = sp.linop.MatMul(x.shape, Phi[None, None, None, ...])(x)

In [None]:
pl.ImagePlot(teimg_recon.transpose((3, 2, 4, 0, 1)))