In [1]:
import sys
[sys.path.append(i) for i in ['.', '..']]
import torch
import numpy as np
from mirtorch.alg.cg import CG
from mirtorch.linear import LinearMap, FFTCn, NuSense, Sense, FFTCn, Identity, Diff2dframe, Gmri
import matplotlib.pyplot as plt
import copy
import h5py
import torchkbnufft as tkbn

## FFT example
First example shows the basic usage of linear operators.
You can use +, -, * like matrices, as long as the size matches.
.H calls the adjoint operatos.

In [None]:
# Assign device
device0 = torch.device('cuda:0')
# Load image
from skimage.data import shepp_logan_phantom
I_shepp = torch.tensor(shepp_logan_phantom()).to(device0)
# Define operator
Fop = FFTCn((400, 400), (400, 400), (0,1), norm = 'ortho')
k_shepp = Fop*I_shepp

show the k-space

In [None]:
plt.imshow(torch.log(torch.abs(k_shepp)).cpu().data.numpy())
plt.colorbar()

## MRI parallel imaging example (SENSE)
Here we trys a 8-times equidistant 1d undersampling mask.
Both conjugate phase reconstruction (A'y)
and quadratic roughness least-squares reconstruction (argmin(x) \|Ax-y\|_2^2 + \lambda \|Rx\|_2^2) are shown.

In [None]:
ex_multi = np.load('AXT2_210_6001736_layer11.npz')
device0 = torch.device('cpu')
# k-space
k_c = torch.tensor(ex_multi['k_r'] + 1j*ex_multi['k_i']).to(device0)/32767.0
# Sense map
s_c = torch.tensor(ex_multi['s_r'] + 1j*ex_multi['s_i']).to(device0)/32767.0
# Define Sense operator
(nc, nh, nw) = s_c.shape
Fop = FFTCn((nc, nh, nw), (nc, nh, nw), (1,2), norm = 'ortho')
# Conjugate phase reconstruction
I1 = Fop.H*k_c
I1 = torch.sqrt(I1.abs().pow(2).sum(dim=0))
# Define undersampling mask
mask = torch.zeros(nh,nw)
mask[:,0:nw:8]=1
# Define sense operator
Sop = Sense(s_c, mask, batchmode = False)
# Zero-filled reconstruction
I0 = Sop.H*k_c
plt.figure(figsize=(20,10))
plt.imshow(torch.abs(I0).data.numpy())
plt.colorbar()
plt.title('zero-filled')

Define the quadratic roughness penalty, and corresponding CG reconstruction.

In [None]:
T = Diff2dframe(Sop.size_in)
CG_tik = CG(Sop.H*Sop+0.01*T, max_iter = 40, tol=1e-2, alert = False)

In [None]:
# I0 is both the initialization, and the b in Ax = b.
I_tik = CG_tik.run(I0, I0)
plt.figure(figsize=(20,10))
plt.imshow(torch.abs(I_tik).cpu().data.numpy())
plt.colorbar()
plt.title('Recovered')

## Non-Cartesian reconstruction
Here we test a non-Cartesian reconstruction cases.

In [None]:
# Download the files
import wget
url = "https://www.dropbox.com/s/q1cr3u1yyvzjtoj/b0.h5?dl=1"
wget.download(url, './b0.h5')

In [None]:
hf = h5py.File('./b0.h5', 'r')
nx = 320;
ny = 320;
# Load Non-cartesian k-space trajectory
ktraj = hf['ktraj'][()]
ktraj = np.remainder(ktraj + np.pi, 2*np.pi)-np.pi
print('traj shape', ktraj.shape)
# Load k-space
k = hf['k_r'][()] + 1j*hf['k_i'][()]
[ncoil, nslice, nshot, ns] = k.shape
print('k shape', k.shape)
# Load density compensation function
dcf = hf['dcf'][()]
print('dcf shape', dcf.shape)
# Load sensetivity maps
smap = np.transpose(hf['s_r'][()] + 1j*hf['s_i'][()], (3,0,1,2))
smap = np.transpose(smap, (0,1,3,2))
[_, _, nx, ny] = smap.shape
print('smap shape', smap.shape)
# Load fmaps
fmap = hf['b0'][()]
fmap = np.transpose(fmap, (0,2,1))
print('fmap shape', fmap.shape)
# Define 5x retrospective undersampling mask
ktrajunder = ktraj.reshape(2,320,1280)
ktrajunder = ktrajunder[:,0:-1:5,:].reshape(2,81920)
kunder = k[:,:,0:-1:5,:]

In [None]:
# Load the numpy file to the pytorch, and define the operator
im_size = (nx,ny)
iz = 6
device0 = torch.device('cuda:0')
k0 = torch.tensor(kunder[:,iz,:,:]).to(device = device0).reshape(1,ncoil,nshot*ns//5)
s0 = torch.tensor(smap[iz,:,:,:]).to(device = device0).unsqueeze(0)
traj0 = torch.tensor(ktrajunder).to(device = device0)
Nop = NuSense(s0, traj0)

In [None]:
# PWLS reconstruction
I0 = Nop.H*k0
T = Diff2dframe(Nop.size_in)
CG_FD = CG(Nop.H*Nop+0.0001*T, max_iter = 40)
I_FD = CG_FD.run(I0, I0)

In [None]:
plt.figure(figsize=(20,10))
plt.imshow(torch.abs(I_FD.squeeze(0).squeeze(0)).cpu().data.numpy())
plt.colorbar()
plt.title('Recovered')

In [None]:
# Define field-corrected NuSENSE operator
b0 = torch.tensor(fmap[iz,:,:]).to(device0).unsqueeze(0)
Gop = Gmri(smaps=s0, zmap = -b0, traj = traj0.reshape(2,nshot//5, ns).unsqueeze(0))

In [None]:
Ib0 = Gop.H*k0.reshape(1,ncoil,nshot//5,ns)
T = Diff2d(Gop.size_in, dims = (1,2))
CG_FD_b0 = CG(Gop.H*Gop+0.0001*T.H*T, max_iter = 40)
I_FD_b0 = CG_FD_b0.run(Ib0, Ib0)

In [None]:
plt.figure(figsize=(20,10))
plt.imshow(torch.abs(I_FD_b0[0]).cpu().data.numpy())
plt.colorbar()
plt.title('B0-informed reconstruction')