# UniRes - Super-resolution Demo
Colab notebook for basic install and run of unires & nitorch

## Installations & setup (**OPEN - UNIRES NEEDS LOGIN**)

### NITorch (**N**euro**I**maging Py**Torch**)

First clone the repo...

In [None]:
!git clone https://github.com/balbasty/nitorch

Set-up of NITorch and dependencies

In [None]:
! pip install numpy
! pip install nibabel
! pip install matplotlib
! pip install scipy

In [None]:
! pip install ./nitorch/

### UniRes
Private repo - requires login to clone from git

In [None]:
import os
from getpass import getpass
import urllib

user = input('User name: ')
password = getpass('Password: ')
password = urllib.parse.quote(password) # your password is converted into url format

cmd_string = 'git clone https://{0}:{1}@github.com/brudfors/UniRes.git'.format(user, password)

os.system(cmd_string)
cmd_string, password = "", "" # removing the password from the variable

Cloned repo can now be installed as normal

In [None]:
pip install ./UniRes/

## Download data

### Load CT/MR public data (TCIA) using RESTful API

(https://www.cancerimagingarchive.net/)

Followed by conversion to nifti

In [None]:
! mkdir data
downloadPath = './data'

In [None]:
import requests

In [None]:
def download_url(url, save_path, chunk_size=128):
    r = requests.get(url, stream=True)
    if r.status_code == 200:
      print('Request successful, code', r.status_code)
      with open(save_path, 'wb') as fd:
          for chunk in r.iter_content(chunk_size=chunk_size):
              fd.write(chunk)
    else:
      print('Request unsuccessful, code', r.status_code)

In [None]:
ct_url = 'https://services.cancerimagingarchive.net/services/v4/TCIA/query/getImage?SeriesInstanceUID=1.3.6.1.4.1.14519.5.2.1.7009.2402.882136884134365981035682566340'
mri_url = 'https://services.cancerimagingarchive.net/services/v4/TCIA/query/getImage?SeriesInstanceUID=1.3.6.1.4.1.14519.5.2.1.7009.2402.327122726537459238654047774771'

In [None]:
ct_zip = os.path.join('./data/ct_dicom.zip')
mri_zip = os.path.join('./data/mri_dicom.zip')

In [None]:
download_url(url=ct_url,save_path=ct_zip)
download_url(url=mri_url,save_path=mri_zip)

In [None]:
! unzip ./data/ct_dicom.zip -d ./data/ct_dicom/
! unzip ./data/mri_dicom.zip -d ./data/mri_dicom/

In [None]:
! pip install dicom2nifti

In [None]:
import dicom2nifti

In [None]:
dicom2nifti.dicom_series_to_nifti('./data/ct_dicom/', './data/ct.nii', reorient_nifti=True)
dicom2nifti.dicom_series_to_nifti('./data/mri_dicom/', './data/mri.nii', reorient_nifti=True)

## Analysis (**OPEN TO VISUALISE**)

#### Import dependencies

In [None]:
%matplotlib inline
import torch
import numpy as np
from nitorch.plot import show_slices
from nitorch.spatial import voxel_size
from nitorch.io import map
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from unires.struct import settings
from unires.run import preproc

In [None]:
device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device_type)
if device_type == 'cuda':
    print('GPU: ' + torch.cuda.get_device_name(0) + ', CUDA: ' + str(torch.cuda.is_available()))
else:
    print('CPU')

#### Load volumes

Image volume and affine matrix loaded as pytorch tensors

In [None]:
# CT
nii_ct = map('./data/ct.nii') # load file object
dat_in_ct = nii_ct.fdata(dtype=torch.float32, device=device) # load image data
mat_in_ct = nii_ct.affine.to(device).type(torch.float64) # load affine matrix

In [None]:
# T1w MRI
nii_mri = map('./data/mri.nii') # load file object
dat_in_mri = nii_mri.fdata(dtype=torch.float32, device=device) # load image data
mat_in_mri = nii_mri.affine.to(device).type(torch.float64) # load affine matrix

#### Preview image volumes (**OPEN TO VIEW**)

In [None]:
fig_ct, ax_ct = plt.subplots(1, 3, figsize=(15,5)) 
show_slices(dat_in_ct, fig_ax=[fig_ct, ax_ct], title='Raw CT', fig_num=1)

ix_ = (dat_in_ct.cpu().shape[0])//2
divider = make_axes_locatable(ax_ct[2])
cax = divider.append_axes('right', size='5%', pad=0.05)
fig_ct.colorbar(plt.imshow(dat_in_ct.cpu()[ix_,:,:].numpy(), cmap='gray', aspect='auto'), cax=cax, orientation='vertical')

In [None]:
fig_ct_clip, ax_ct_clip = plt.subplots(1, 3, figsize=(15,5)) 
show_slices(torch.clamp(dat_in_ct, min=0, max=150), fig_ax=[fig_ct_clip, ax_ct_clip], title='Intensity-clipped CT', fig_num=2)

im_ = torch.clamp(dat_in_ct, min=0, max=150).cpu()
ix_ = (dat_in_ct.cpu().shape[0])//2
divider = make_axes_locatable(ax_ct_clip[2])
cax = divider.append_axes('right', size='5%', pad=0.05)
fig_ct.colorbar(plt.imshow(im_[ix_,:,:].numpy(), cmap='gray', aspect='auto'), cax=cax, orientation='vertical')

In [None]:
fig_mri, ax_mri = plt.subplots(1, 3, figsize=(15,5)) 
show_slices(dat_in_mri, fig_ax=[fig_mri, ax_mri], title='Raw (T1w) MRI', fig_num=3)
ax_mri[0].set_aspect('auto')
ax_mri[1].set_aspect('auto')
ax_mri[2].set_aspect('auto')

ix_ = (dat_in_mri.cpu().shape[0])//2
divider = make_axes_locatable(ax_mri[2])
cax = divider.append_axes('right', size='5%', pad=0.05)
fig_ct.colorbar(plt.imshow(dat_in_mri.cpu()[ix_,:,:].numpy(), cmap='gray', aspect='auto'), cax=cax, orientation='vertical')

#### Perform basic upscaling with traditional method (trilinear interpolation)

*Use UniRes but set to perform 0 iterations*

In [None]:
s = settings()

s.vx = 1 # reconstruction voxel size (1mm isotropic)
s.plot_conv = True # produce plot of convergence
s.max_iter = 0 # force to use trilinear interp

In [None]:
dat_r_ct, mat_r_ct, _ = preproc('./data/ct.nii', sett=s)

In [None]:
dat_r_mri, mat_r_mri, _ = preproc('./data/mri.nii', sett=s)

#### Perform super-resolution with UniRes 'preproc' function

*Restore to max iterations = 512*

In [None]:
s = settings()

s.vx = 1 # reconstruction voxel size (1mm isotropic)
s.plot_conv = True # produce plot of convergence

In [None]:
dat_sr_ct, mat_sr_ct, pth_sr_ct = preproc('./data/ct.nii', sett=s)

In [None]:
dat_sr_mri, mat_sr_mri, pth_sr_mri = preproc('./data/mri.nii', sett=s)

### Visualise output volumes (**OPEN TO VIEW**)

In [None]:
fig_all_ct, ax_all_ct = plt.subplots(3, 3, figsize=(15,15)) 
show_slices(dat_in_ct, fig_ax=[fig_all_ct, ax_all_ct[0,:]])
show_slices(dat_r_ct, fig_ax=[fig_all_ct, ax_all_ct[1,:]])
show_slices(dat_sr_ct, fig_ax=[fig_all_ct, ax_all_ct[2,:]])

fig_all_ct.suptitle('CT', fontsize=24)

dim_in = np.round(voxel_size(mat_in_ct).cpu().numpy() * 2) / 2
dim_r = np.round(voxel_size(mat_r_ct).cpu().numpy() * 2) / 2
dim_sr = np.round(voxel_size(mat_sr_ct).cpu().numpy() * 2) / 2
fig_all_ct.text(-0.05, 0.75, 'Raw\n({}x{}x{}mm)'.format(dim_in[0],dim_in[1],dim_in[2]), fontsize=20)
fig_all_ct.text(-0.05, 0.5, 'Linear resample\n({}x{}x{}mm)'.format(dim_r[0],dim_r[1],dim_r[2]), fontsize=20)
fig_all_ct.text(-0.05, 0.25, 'UniRes\n({}x{}x{}mm)'.format(dim_sr[0],dim_sr[1],dim_sr[2]), fontsize=20)

ix_ = (dat_in_ct_cpu.shape[0])//2
fig_all_ct.subplots_adjust(right=0.85)
cax = fig_all_ct.add_axes([0.88, 0.125, 0.03, 0.755])
fig_ct.colorbar(plt.imshow(dat_in_ct_cpu[ix_,:,:].numpy(), cmap='gray', aspect='auto'), cax=cax, orientation='vertical')

In [None]:
fig_all_clip_ct, ax_all_clip_ct = plt.subplots(3, 3, figsize=(15,15)) 
show_slices(torch.clamp(dat_in_ct, min=0, max=150), fig_ax=[fig_all_clip_ct, ax_all_clip_ct[0,:]])
show_slices(torch.clamp(dat_r_ct, min=0, max=150), fig_ax=[fig_all_clip_ct, ax_all_clip_ct[1,:]])
show_slices(torch.clamp(dat_sr_ct, min=0, max=150), fig_ax=[fig_all_clip_ct, ax_all_clip_ct[2,:]])

fig_all_clip_ct.suptitle('CT (clipped)', fontsize=24)
dim_in = np.round(voxel_size(mat_in_ct).cpu().numpy() * 2) / 2
dim_r = np.round(voxel_size(mat_r_ct).cpu().numpy() * 2) / 2
dim_sr = np.round(voxel_size(mat_sr_ct).cpu().numpy() * 2) / 2
fig_all_clip_ct.text(-0.05, 0.75, 'Raw\n({}x{}x{}mm)'.format(dim_in[0],dim_in[1],dim_in[2]), fontsize=20)
fig_all_clip_ct.text(-0.05, 0.5, 'Linear resample\n({}x{}x{}mm)'.format(dim_r[0],dim_r[1],dim_r[2]), fontsize=20)
fig_all_clip_ct.text(-0.05, 0.25, 'UniRes\n({}x{}x{}mm)'.format(dim_sr[0],dim_sr[1],dim_sr[2]), fontsize=20)

im_ = torch.clamp(dat_in_ct, min=0, max=150).cpu()
ix_ = (dat_in_ct_cpu.shape[0])//2
fig_all_clip_ct.subplots_adjust(right=0.85)
cax = fig_all_clip_ct.add_axes([0.88, 0.125, 0.03, 0.755])
fig_all_clip_ct.colorbar(plt.imshow(im_[ix_,:,:].numpy(), cmap='gray', aspect='auto'), cax=cax, orientation='vertical')

In [None]:
fig_all_mri, ax_all_mri = plt.subplots(3, 3, figsize=(15,15)) 
show_slices(dat_in_mri, fig_ax=[fig_all_mri, ax_all_mri[0,:]])
show_slices(dat_r_mri, fig_ax=[fig_all_mri, ax_all_mri[1,:]])
show_slices(dat_sr_mri, fig_ax=[fig_all_mri, ax_all_mri[2,:]])

fig_all_mri.suptitle('MRI (T1w)', fontsize=24)

dim_in = np.round(voxel_size(mat_in_mri).cpu().numpy() * 2) / 2
dim_r = np.round(voxel_size(mat_r_mri).cpu().numpy() * 2) / 2
dim_sr = np.round(voxel_size(mat_sr_mri).cpu().numpy() * 2) / 2
fig_all_mri.text(-0.05, 0.75, 'Raw\n({}x{}x{}mm)'.format(dim_in[0],dim_in[1],dim_in[2]), fontsize=20)
fig_all_mri.text(-0.05, 0.5, 'Linear resample\n({}x{}x{}mm)'.format(dim_r[0],dim_r[1],dim_r[2]), fontsize=20)
fig_all_mri.text(-0.05, 0.25, 'UniRes\n({}x{}x{}mm)'.format(dim_sr[0],dim_sr[1],dim_sr[2]), fontsize=20)

ix_ = (dat_in_mri_cpu.shape[0])//2
fig_all_mri.subplots_adjust(right=0.85)
cax = fig_all_mri.add_axes([0.88, 0.125, 0.03, 0.755])
fig_all_mri.colorbar(plt.imshow(dat_in_mri_cpu[ix_,:,:].numpy(), cmap='gray', aspect='auto'), cax=cax, orientation='vertical')