# ECE 697 - Final Project Demo Notebook
### By: Cameron Craig, Justin Hiemstra
---
Thank you for evaluating our work on the ECE 697 Capstone Project. The purpose of this notebook is to provide a high-level interface from which to run the main components of our code. This is intended to facilitate the reproduction of our results, and to allow interested parties to perform forward passes through our trained model to enhance their own medical image data.

This notebook covers the primary accomplishments of the project, but it does not cover everything in the codebase. There are other python notebooks which contain our work exploring and analyzing data, and the generation of synthetic data. Additionally, code and shell scripts related to our work using CycleGAN to correct noise and imhomogeneity are available as a submodule of this repository.

---

#### 1. Clone the repository and unzip sample data

In [None]:
!git clone https://github.com/ccraig3/ece697-mri-denoising.git
!mv ece697-mri-denoising/* ./
!rm -rf ece697-mri-denoising
!unzip sample_brains.zip
!unzip sample_knees.zip

Install needed packages

In [None]:
!pip install pytorch-lightning
!pip install piq
!pip install antspyx
!pip install nibabel
!pip install wandb
!wandb login

Imports

In [None]:
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import cm
import ants
import cv2
from lightning_unet import LitUNet
from mri_sup_dataset import MriSupDataset
import torch
import torch.nn.functional as F
from tqdm import tqdm
from torchvision import transforms, _utils
import os
import pytorch_lightning as pl
import pickle as pkl
from unet import UNet
import piq
import nibabel as nib
import math

---
#### 2. Train a sample version of the model using the provided sample resources

The weights will be saved under a new directory called 'model_ckpts'

In [None]:
# You need to have a gpu to run this line
!python3 train_unet.py --train_imgs 'sample_knees' --val_imgs 'sample_brains' --test_imgs 'sample_brains' --train_bias '25_sample_synth_bias_fields.pkl' --val_bias '25_sample_synth_bias_fields.pkl' --test_bias '25_sample_synth_bias_fields.pkl' --proj_name 'UNet-L1+L2-Demo' --run_name 'my first run' --max_epochs 10 --batch_size 25 --wf 6

---
#### 3. Evaluate your trained model on the sample test set

In [None]:
CKPT_PATH = r'model_ckpts/epoch=9.ckpt'

def my_criterion(prediction, y):
  return F.mse_loss(prediction, y) + F.l1_loss(prediction, y)

unet_model = LitUNet.load_from_checkpoint(CKPT_PATH)
unet_model.eval()

In [None]:
# Fixed Model Parameters
SIZE = (320, 320) # (Height, Width) of the generated bias field
num_points = 60 # resolution of the simulated coil array

# Uniform Random Variable Bounds
LOW_BOOST_BOUNDS = (0, 0.2)
COIL_VERT_POS_BOUNDS = (SIZE[0] - 1, SIZE[0] * 1.2)
PARAM_B_ADJUST_BOUNDS = (-0.02, 0.02)
COIL_WIDTH_BOUNDS = (0.1, 0.4)

# Helper functions

def poly_dec(x):
  return 1927.5 * (x + 37)**-2.093

# Normalize pixel values to the range [0, 1]
def normalize(image):
  new_image = image - np.min(image)
  return new_image / np.max(new_image)

# Obtain a sample of a uniform random variable on the specified bounds
def sampleRV(BOUNDS):
  return np.random.uniform(BOUNDS[0], BOUNDS[1])

def normalize(image):
  new_image = image - np.min(image)
  return new_image / np.max(new_image)

def clip_img(image):
  image = np.where(image > 1, 1, image)
  return np.where(image < 0, 0, image)

def add_rician_noise(image, intensity=1):
  n1 = np.random.normal(0, 1, image.shape)
  n1 = n1 / np.max(n1)
  n2 = np.random.normal(0, 1, image.shape)
  n2 = n2 / np.max(n2)
  return clip_img(np.abs(image + intensity*n1 + intensity*n2*1j))

#Create a Bias Field
def genBiasField(SIZE, coil_left, coil_right, coil_vert_pos, b_adj, low_boost):
  global a, b, c, d

  # Define Coil Shape
  cx = np.linspace(round(SIZE[1]*coil_left), round(SIZE[1]*coil_right), num=num_points) # Horizontal coordinates

  # Put coil array at or below bottom edge of bias field
  y_pos = round(coil_vert_pos)
  cy = np.linspace(y_pos, y_pos, num=num_points) # Vertical coordinates

  coils = np.stack([cy, cx], axis=0).T # Reshape to prepare for arithmetic operations

  B = np.zeros(SIZE)
  dists = np.zeros((coils.shape[0],)) # Distances between coil points and field points

  # Exponential curve random perturbations
  #local_b = b + b_adj

  # Loop over all pixels in B
  for i in range(B.shape[0]):
    for j in range(B.shape[1]):
      # Stack of copies of this point's coordinates
      p = np.array([i, j])
      p = np.tile(p, (num_points, 1))

      # Get the distance between this point and the closest coil point
      dist = np.min(np.linalg.norm(coils - p, axis=1))

      # Simulate exponential falloff
      #B[i, j] = exp_dec(dist, a, local_b, c, d)
      B[i, j] = poly_dec(dist)
  
  # Normalize B on range [0, 1]
  B_norm = normalize(B)

  # Scale up / boost the weak end intensity of the field
  B_boosted = B_norm * (1 - low_boost) + low_boost

  return B_boosted

def genCompositeField(SIZE, lb):
  num_coils = 1 #random.randint(1, 3)

  coil_width = 0.3 #sampleRV(COIL_WIDTH_BOUNDS)
  if num_coils == 1:
    coil_width += 0.1
  coil_left_bound, coil_right_bound = 0.5 - coil_width, 0.5 + coil_width # horizontal extent of coil array
  sub_fields = np.zeros((num_coils, SIZE[0], SIZE[1]))
  c_fraction = 1. / num_coils
  vert_pos = 350 #sampleRV(COIL_VERT_POS_BOUNDS)
  b_adj = 0 #sampleRV(PARAM_B_ADJUST_BOUNDS)
  low_boost = lb #sampleRV(LOW_BOOST_BOUNDS)

  for i in range(num_coils):
    sub_fields[i, :, :] = genBiasField(SIZE, coil_left_bound*c_fraction + i*c_fraction, coil_right_bound*c_fraction + i*c_fraction, vert_pos, b_adj, low_boost)

  return sub_fields.mean(axis=0)

def run_n4(img, mask):
  img = img.copy()
  mask = mask.copy()
  ants_img = ants.from_numpy((img * 255.).astype('uint8').T)
  ants_mask = ants.from_numpy(mask.T)
  ants_mask = ants_mask / ants_mask.max()
  ants_mask = ants_mask.threshold_image( 1, 2 )
  n4_corr = ants.n4_bias_field_correction(ants_img, mask=ants_mask, rescale_intensities=True)
  #n4_corr = ants.abp_n4(ants_img)
  return n4_corr.numpy().T / 255.

def predict(x):
  x_np = x.reshape((320, 320)) * 255.
  x_np = x_np.astype('uint8')
  x_tensor = transform(x_np).float()
  x_tensor -= x_tensor.min()
  x_tensor /= (x_tensor.max() + 1e-9)
  y_hat = unet_model(x_tensor.view(1, 1, 320, 320))
  y_hat_np = y_hat.cpu().detach().numpy()
  return y_hat_np.reshape((320, 320))

def save_img(img, filename):
  img_np = img.copy()
  img_cv = (img_np * 255.).astype('uint8')
  cv2.imwrite(filename, img_cv)

In [None]:
TEST_DIR = r'sample_brains'
test_file_names = [name for name in os.listdir(TEST_DIR) if os.path.isfile(os.path.join(TEST_DIR, name))]

N_GAIN = 0.06 # Percent intensity of the magnitude of the absolute value of the noise  0.07
B_LOW_END = 0.055 # Percent intensity of the darkest part of the B field  0.06

B = genCompositeField(SIZE, B_LOW_END)

transform = transforms.Compose([
          transforms.ToTensor(),
          transforms.Resize((320, 320))
        ])

n4_psnrs = []
n4_ssims = []
n4_ms_ssims = []
n4_losses = []

model_psnrs = []
model_ssims = []
model_ms_ssims = []
model_losses = []

!mkdir results

Run test on model

In [None]:
for i in tqdm(range(len(test_file_names))):

  # Get original test image
  #img = whole_img[:, :, i].T
  name = os.path.join(TEST_DIR, test_file_names[i])
  img = cv2.imread(name, cv2.IMREAD_GRAYSCALE)
  img = cv2.resize(img, dsize=(320, 320), interpolation=cv2.INTER_CUBIC)
  img -= np.min(img)
  img = img / np.max(img)

  # Make noisy version
  mask_np = np.where(img > 0, 1, 0).astype('float32')
  img_x = add_rician_noise(normalize(img * B * mask_np), intensity=N_GAIN)

  # Get n4 correction
  img_n4 = run_n4(img_x, mask_np)

  # Get model prediction
  with torch.no_grad():
    img_y_hat = predict(img_x)
  
  # Save images
  if i % 20 == 0:
    save_img(img, "results/y_" + str(i) + ".png")
    save_img(img_x, "results/x_" + str(i) + ".png")
    save_img(img_n4, "results/n4_" + str(i) + ".png")
    save_img(img_y_hat, "results/y_hat_" + str(i) + ".png")

  # Convert results to tensors
  img_tensor = transform(img).float().view(1, 1, 320, 320)
  img_tensor = torch.clamp(img_tensor, 0, 1)
  img_n4_tensor = transform(img_n4).float().view(1, 1, 320, 320)
  img_n4_tensor = torch.clamp(img_n4_tensor, 0, 1)
  img_y_hat_tensor = transform(img_y_hat).float().view(1, 1, 320, 320)
  img_y_hat_tensor = torch.clamp(img_y_hat_tensor, 0, 1)

  # Evaluate results for n4
  psnr = piq.psnr(img_n4_tensor, img_tensor, data_range=1.).item()
  ssim = piq.ssim(img_n4_tensor, img_tensor).item()
  ms_ssim = piq.multi_scale_ssim(img_n4_tensor, img_tensor, data_range=1., reduction='mean').item()
  loss = my_criterion(img_n4_tensor, img_tensor).item()

  n4_psnrs.append(psnr)
  n4_ssims.append(ssim)
  n4_ms_ssims.append(ms_ssim)
  n4_losses.append(loss)

  # Evaluate results for unet model
  psnr = piq.psnr(img_y_hat_tensor, img_tensor, data_range=1.).item()
  ssim = piq.ssim(img_y_hat_tensor, img_tensor).item()
  ms_ssim = piq.multi_scale_ssim(img_y_hat_tensor, img_tensor, data_range=1., reduction='mean').item()
  loss = my_criterion(img_y_hat_tensor, img_tensor).item()

  model_psnrs.append(psnr)
  model_ssims.append(ssim)
  model_ms_ssims.append(ms_ssim)
  model_losses.append(loss)

In [None]:
# Print results
print('PSNR for N4: ' + str(np.array(n4_psnrs).mean()))
print('SSIM for N4: ' + str(np.array(n4_ssims).mean()))
print('MS-SSIM for N4: ' + str(np.array(n4_ms_ssims).mean()))
print('L1+L2 loss for N4: ' + str(np.array(n4_losses).mean()))
print('\n')
print('PSNR for model: ' + str(np.array(model_psnrs).mean()))
print('SSIM for model: ' + str(np.array(model_ssims).mean()))
print('MS-SSIM for model: ' + str(np.array(model_ms_ssims).mean()))
print('L1+L2 loss for model: ' + str(np.array(model_losses).mean()))

---
#### 4. Inference the pre-trained model

You can download the pre-trained model weights [here](https://drive.google.com/file/d/11MCFPiCKgFMSn52G6Szqt0yGZSV0j7cE/view?usp=sharing).

Be sure to upload the .ckpt file to the root directory of this notebook.

In [None]:
def my_criterion(prediction, y):
  return F.mse_loss(prediction, y) + F.l1_loss(prediction, y)

# Path to the pre-trained model checkpoint (uploaded by you)
# Alternatively, you may use the unet_model trained above
ckpt_path = r'/content/epoch=32-step=39600.ckpt'

# You need to upload and provide this, we can't provide you with a sample for medical privacy reasons.
INPUT_NIFTI_PATH = r'WBD006_Alan_-_PET_MR_WB_Dynami_1_3504.nii'
if len(INPUT_NIFTI_PATH) == 0:
  print('ERROR: You need to provide the path to your own nifti file for inferencing.')

unet_model = LitUNet.load_from_checkpoint(ckpt_path)
unet_model.eval()

In [None]:
# Load nifti
nifti_img = nib.load(INPUT_NIFTI_PATH)
whole_img = nifti_img.get_fdata()
whole_img = whole_img / np.max(whole_img)

In [None]:
# Display middle slice before correction
plt.imshow(whole_img[:, :, whole_img.shape[-1] // 2].T, cmap='gray')

In [None]:
!mkdir predictions

# Correct each slice of the input individually
for i in tqdm(range(whole_img.shape[-1])):
  img = whole_img[:, :, i].T  
  img = cv2.resize(img, dsize=(320, 320), interpolation=cv2.INTER_CUBIC)
  img -= np.min(img)
  img = img / np.max(img)

  # Get model prediction
  with torch.no_grad():
    img_y_hat = predict(img)
  
  # Save model prediction
  save_img(img_y_hat, "predictions/y_hat_slice_" + str(i) + ".png")
