In [None]:
pip install prox-TV

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from skimage.io import imread
#from skimage.data import shepp_logan_phantom
from skimage.transform import radon, rescale
from skimage.transform import iradon
from skimage.draw import circle, line, rectangle
import prox_tv as ptv

In [None]:
# Soft-thresholding operator
# This is the proximal operator for the L1-norm:
# prox[w*||.||_1](x) = softT(x,w)
def softT(x,w):
    below = np.where(np.abs(x) <= w)
    above = np.where(np.abs(x) > w)
    x[below] = 0
    x[above] -= np.sign(x[above])*w
    return x

In [None]:
# Define iFBP parameters, such as...
# ... the number of iterations
Niter = 50
# ... the subsampling factor...
ssf = 2
# ... the reconstruction behavior...
# (options are: FBP, SBP)
case = 'FBP'
# ... and the regularization
# (options are: ST, TV)
filt = 'ST'
# Also, define the number of iterations between two fig refreshes
stepViz = 10
# And the window width for error images
eww = 100
# And standard displays
vmin, vmax = -100, 100
#vmin, vmax = 800, 1200

In [None]:
# Define parameters depending on the case, such as...
# ... the gradient step...
t = {'SBP': 0.008, 'FBP': 0.7}
# ... the reconstruction filter in radon function...
filter = {'SBP': None, 'FBP': 'ramp'}
# ... the min and max regularization values...
#lmax={'SBP': 800, 'FBP': 0}
#lmax={'SBP': 800, 'FBP': 500}
lmax={'SBP': 800, 'FBP': 50}
lmin={'SBP': 0. , 'FBP': 0  }
# We also define the regularizer either as the TV filter or the ST operator
regularizer={'TV':ptv.tv1_2d, 'ST':softT}

In [None]:
# Load example image
#image = np.fromfile("./brainCT_slc115.raw", dtype=np.float32).reshape((512,512))
# Or alternatively, load an empty image
image = np.zeros((512,512))

In [None]:
# Add synthetic elements to this image, such as...
# ... vessels...
rr, cc = circle(200,145,9,image.shape)
image[rr,cc] += 8000.0
rr, cc = circle(150,180,5,image.shape)
image[rr,cc] += 8000.0

# ... small hyperdensities (bleeding)...
rr, cc = circle(250,330,40,image.shape)
image[rr,cc] += 10.0

# ... or small hypodensities (infarcts)
rr, cc = circle(300,150,30,image.shape)
image[rr,cc] -= 30.0

In [None]:
# Rescale the image for computation speed
# (and to smooth out additional structures)
#image = rescale(image, scale=0.35, mode='reflect', multichannel=False)
image = rescale(image, scale=0.5, mode='reflect', multichannel=False)

# Add high-contrasted lines with different spacing
# to analyse the spatial resolution
rr, cc = line(108,108,112,118)
image[rr,cc] += 2000.0
rr += 2
image[rr,cc] += 2000.0
rr += 3
image[rr,cc] += 2000.0
rr += 4
image[rr,cc] += 2000.0
rr += 5
image[rr,cc] += 2000.0
rr += 6
image[rr,cc] += 2000.0

In [None]:
# Display the image
plt.figure(figsize=(10,10))
plt.imshow(image, cmap=plt.cm.Greys_r, vmin=vmin, vmax=vmax)

In [None]:
# Choose a number of projections
# Empirically, a full angular sampling corresponds to
# N ~ image height/width
N = max(image.shape) / ssf
print("Number of projections: "+str(N))
print("Angular spacing: "+str(round(180.0/N,3))+" deg")

In [None]:
# Define angular range and start/stop positions
# Acquisitions are assumed to be equally spaced
theta_min = 0.25*np.pi
theta_max = 180.0+theta_min
theta = np.linspace(theta_min, theta_max, N, endpoint=False)

# Compute sinogram with the Radon transform
sinogram = radon(image, theta=theta, circle=True)

plt.imshow(sinogram.T, cmap=plt.cm.Greys_r, extent=(theta_min, theta_max, 0, sinogram.shape[0]), aspect='auto')

In [None]:
# Compute the circular FOV mask
w = sinogram.shape[0]
rr, cc = circle(0.5*(image.shape[0]-1.0),0.5*(image.shape[1]-1.0),0.5*(w-1.0),image.shape)
mask = np.zeros(image.shape)
mask[rr,cc] = 1.0

In [None]:
simple_backprojection = iradon(sinogram, theta=theta, circle=True, filter=filter['SBP']) / N
# Display the image
plt.figure(figsize=(10,10))
plt.imshow(simple_backprojection, cmap=plt.cm.Greys_r, vmin=0, vmax=1500*ssf)

In [None]:
fbp_init = iradon(sinogram, theta=theta, circle=True, filter=filter[case])
# Display the image
plt.figure(figsize=(10,10))
plt.imshow(fbp_init, cmap=plt.cm.Greys_r, vmin=vmin, vmax=vmax)

In [None]:
# Initialize image for iFBP
reconstruction_fbp = regularizer[filt](fbp_init,lmax[case])
# Record error image
err_init = fbp_init-image
# Display
plt.figure(figsize=(25,10))
plt.subplot(121)
plt.imshow(reconstruction_fbp, cmap=plt.cm.Greys_r, vmin=vmin, vmax=vmax)
plt.subplot(122)
plt.imshow(err_init, cmap=plt.cm.Greys_r, vmin=-0.5*eww, vmax=0.5*eww)

In [None]:
# Compute one-step of gradient descent
def one_step_gd(f,p,step):
    return f - step * iradon(radon(f, theta=theta, circle=True) - p, theta=theta, circle=True, filter=filter[case])

    ### end of code

In [None]:
# Run gradient descent
for k in range(Niter):
    

    ### end of code
    print("Iteration "+str(k+1)+"/"+str(Niter))

In [None]:
err = reconstruction_fbp - image
# Display
plt.figure(figsize=(35,10))
plt.subplot(131)
plt.imshow(fbp_init, cmap=plt.cm.Greys_r, vmin=vmin, vmax=vmax)
plt.subplot(132)
plt.imshow(reconstruction_fbp, cmap=plt.cm.Greys_r, vmin=vmin, vmax=vmax)
plt.subplot(133)
plt.imshow(err, cmap=plt.cm.Greys_r, vmin=-0.5*eww, vmax=0.5*eww)

In [None]:
lambd = np.linspace(lmax[case],lmin[case],Niter)
print(lambd)
MSE_vec = np.array([np.nan for k in range(Niter)])

In [None]:
for k in range(Niter):
    ### your code here
    reconstruction_fbp = one_step_gd(reconstruction_fbp, sinogram, t[case])
    reconstruction_fbp = regularizer[filt](reconstruction_fbp, lambd[k] * t[case])
    reconstruction_fbp *= mask
    
    ### end of code
    err_cur = reconstruction_fbp-image
    MSE_vec[k] = round(((err_cur)**2).mean(),5)
    print(str(k+1)+"\t"+str(lambd[k])+"\t"+str(MSE_vec[k]))
    if (k+1)%stepViz==0:
            plt.figure(figsize=(25,10))
            plt.subplot(121)
            plt.title("MSE")
            plt.plot([1+n for n in range(Niter)],MSE_vec)
            plt.subplot(122)
            plt.title("Iteration "+str(k+1)+"\nMSE = "+str(MSE_vec[k]))
            plt.imshow(reconstruction_fbp, cmap=plt.cm.Greys_r, vmin=vmin, vmax=vmax)
            plt.pause(0.05)

In [None]:
# Display
plt.figure(figsize=(35,10))
plt.subplot(131)
plt.imshow(fbp_init, cmap=plt.cm.Greys_r, vmin=vmin, vmax=vmax)
plt.subplot(132)
plt.imshow(reconstruction_fbp, cmap=plt.cm.Greys_r, vmin=vmin, vmax=vmax)
plt.subplot(133)
plt.imshow(reconstruction_fbp-image, cmap=plt.cm.Greys_r, vmin=-0.5*eww, vmax=0.5*eww)