# 3D Resnet for NIfTI images

Showcase a simple 3D resnet—built in pytorch and fastai—for MR image synthesis, which is the task of taking a specific MR image contrast and making it look like another MR image contrast (e.g., T1-weighted to FLAIR).

## Setup notebook

In [1]:
!whereis gcc 
!gcc -v

gcc: /usr/bin/gcc /usr/lib/gcc
Using built-in specs.
COLLECT_GCC=gcc
COLLECT_LTO_WRAPPER=/usr/lib/gcc/x86_64-linux-gnu/6/lto-wrapper
Target: x86_64-linux-gnu
Configured with: ../src/configure -v --with-pkgversion='Debian 6.3.0-18+deb9u1' --with-bugurl=file:///usr/share/doc/gcc-6/README.Bugs --enable-languages=c,ada,c++,java,go,d,fortran,objc,obj-c++ --prefix=/usr --program-suffix=-6 --program-prefix=x86_64-linux-gnu- --enable-shared --enable-linker-build-id --libexecdir=/usr/lib --without-included-gettext --enable-threads=posix --libdir=/usr/lib --enable-nls --with-sysroot=/ --enable-clocale=gnu --enable-libstdcxx-debug --enable-libstdcxx-time=yes --with-default-libstdcxx-abi=new --enable-gnu-unique-object --disable-vtable-verify --enable-libmpx --enable-plugin --enable-default-pie --with-system-zlib --disable-browser-plugin --enable-java-awt=gtk --enable-gtk-cairo --with-java-home=/usr/lib/jvm/java-1.5.0-gcj-6-amd64/jre --enable-java-home --with-jvm-root-dir=/usr/lib/jvm/java-1.5.0-gc

In [2]:
# !pip uninstall -y niftynet numpy
!pip install bottleneck
!pip install fastai torch torchvision numpy



In [3]:
from pathlib import PosixPath, Path
import os
import sys

import fastai.vision as faiv
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
import torchvision

 Support in-notebook plotting

In [4]:
%matplotlib inline

Report versions

In [5]:
print('numpy version: {}'.format(np.__version__))
from matplotlib import __version__ as mplver
print('matplotlib version: {}'.format(mplver))
print(f'fastai version: {faiv.__version__}')
print(f'pytorch version: {torch.__version__}')
print(f'torchvision version: {torchvision.__version__}')

numpy version: 1.16.3
matplotlib version: 3.0.3
fastai version: 1.0.52
pytorch version: 1.1.0
torchvision version: 0.3.0


In [6]:
pv = sys.version_info
print('python version: {}.{}.{}'.format(pv.major, pv.minor, pv.micro))

python version: 3.6.8


Reload packages where content for package development

In [7]:
%load_ext autoreload
%autoreload 2

Check GPU

In [8]:
!nvidia-smi

/bin/sh: 1: nvidia-smi: not found


## Define test images

In [9]:
data_dir = Path('data/Utrecht')
data_dirs = Path().glob('data/Utrecht/**/pre'); data_dirs
# data_dir = PosixPath('/iacl/pg19/jacobr/zs/blog/')

<generator object Path.glob at 0x7fa3386b2150>

In [13]:
!ls {data_dir/'**/**/*.nii*'}

data/Utrecht/0/orig/3DT1.nii.gz        data/Utrecht/31/orig/3DT1.nii.gz
data/Utrecht/0/orig/3DT1_mask.nii.gz   data/Utrecht/31/orig/3DT1_mask.nii.gz
data/Utrecht/0/orig/FLAIR.nii.gz       data/Utrecht/31/orig/FLAIR.nii.gz
data/Utrecht/0/orig/T1.nii.gz	       data/Utrecht/31/orig/T1.nii.gz
data/Utrecht/0/pre/3DT1.nii.gz	       data/Utrecht/31/pre/3DT1.nii.gz
data/Utrecht/0/pre/FLAIR.nii.gz        data/Utrecht/31/pre/FLAIR.nii.gz
data/Utrecht/0/pre/T1.nii.gz	       data/Utrecht/31/pre/T1.nii.gz
data/Utrecht/11/orig/3DT1.nii.gz       data/Utrecht/33/orig/3DT1.nii.gz
data/Utrecht/11/orig/3DT1_mask.nii.gz  data/Utrecht/33/orig/3DT1_mask.nii.gz
data/Utrecht/11/orig/FLAIR.nii.gz      data/Utrecht/33/orig/FLAIR.nii.gz
data/Utrecht/11/orig/T1.nii.gz	       data/Utrecht/33/orig/T1.nii.gz
data/Utrecht/11/pre/3DT1.nii.gz        data/Utrecht/33/pre/3DT1.nii.gz
data/Utrecht/11/pre/FLAIR.nii.gz       data/Utrecht/33/pre/FLAIR.nii.gz
data/Utrecht/11/pre/T1.nii.gz	       data/Utrecht/33/pre/T1.nii.gz
d

In [14]:
# help(faiv)

## Test 3d fastai transforms

In [15]:
dev = "cuda" if torch.cuda.is_available() else "cpu"
print(dev)
device = torch.device(dev)

cpu


In [16]:
import nibabel as nib


def open_nii(fn:str) -> faiv.Image:
    """ Return fastai `Image` object created from NIfTI image in file `fn`."""
    x = nib.load(str(fn)).get_data().astype(np.int32)
    y = torch.Tensor(x)
    return faiv.Image(y)

    
class NiftiItemList(faiv.ImageList):
    """ custom item list for nifti files """
    def open(self, fn:faiv.PathOrStr)->faiv.Image: return open_nii(fn)

    
class NiftiNiftiList(NiftiItemList):
    """ item list suitable for synthesis tasks """
    _label_cls = NiftiItemList

In [17]:
# help(NiftiItemList)

In [18]:
from functools import singledispatch

@faiv.TfmPixel
@singledispatch
def crop(x, pct, axis:int) -> torch.Tensor:
    """" crop a 3d image along an axis """
    s = x.shape
    i0, i1 = int(s[axis]*pct[0]), int(s[axis]*pct[1])
    return x[np.newaxis,i0:i1,:,:].contiguous() if axis == 0 else \
           x[np.newaxis,:,i0:i1,:].contiguous() if axis == 1 else \
           x[np.newaxis,:,:,i0:i1].contiguous()

tfms = [crop(pct=(0.20,0.80),axis=2)]

In [26]:
def get_y_fn(x):
    par = os.path.dirname(os.path.dirname(x))
#     print(par)
#     a = '{}/../..'.format(os.path.dirname(x))
#     print(os.path.dirname(a))
#     parent = 'train' if 'train' in str(x) else 'valid'
    fn = '{}/wmh.nii.gz'.format(par)
#     print(fn)
    return fn

In [27]:
def filterMe(x):
    shouldFilter = 'pre' in str(x) and '/T1.nii' in str(x)
#     print('{} - {}'.format(x, shouldFilter))
    return shouldFilter


In [28]:
# x =NiftiNiftiList(glob.glob('data/Utrecht/**/pre'));
# print(x)
a= NiftiNiftiList.from_folder(data_dir, extensions=('.gz')); a

NiftiNiftiList (160 items)
Image (240, 240, 48),Image (256, 256, 192),Image (256, 256, 192),Image (240, 240, 48),Image (240, 240, 48)
Path: data/Utrecht

In [29]:


b = a.filter_by_func(filterMe); b

NiftiNiftiList (20 items)
Image (240, 240, 48),Image (240, 240, 48),Image (240, 240, 48),Image (240, 240, 48),Image (240, 240, 48)
Path: data/Utrecht

In [30]:
b = a.split_by_rand_pct(.1, 42);b 

ItemLists;

Train: NiftiNiftiList (18 items)
Image (240, 240, 48),Image (240, 240, 48),Image (240, 240, 48),Image (240, 240, 48),Image (240, 240, 48)
Path: data/Utrecht;

Valid: NiftiNiftiList (2 items)
Image (240, 240, 48),Image (240, 240, 48)
Path: data/Utrecht;

Test: None

In [31]:
 c    = b.label_from_func(get_y_fn); c


LabelLists;

Train: LabelList (18 items)
x: NiftiNiftiList
Image (240, 240, 48),Image (240, 240, 48),Image (240, 240, 48),Image (240, 240, 48),Image (240, 240, 48)
y: NiftiItemList
Image (240, 240, 48),Image (240, 240, 48),Image (240, 240, 48),Image (240, 240, 48),Image (240, 240, 48)
Path: data/Utrecht;

Valid: LabelList (2 items)
x: NiftiNiftiList
Image (240, 240, 48),Image (240, 240, 48)
y: NiftiItemList
Image (240, 240, 48),Image (240, 240, 48)
Path: data/Utrecht;

Test: None

In [32]:
idb = (NiftiNiftiList.from_folder(data_dir, extensions=('.gz'))
       .filter_by_func(filterMe)
       .split_by_rand_pct(0.4, 42)
       .label_from_func(get_y_fn)
       .transform((tfms,tfms), tfm_y=True)
       .databunch(bs=2))

idb

ImageDataBunch;

Train: LabelList (12 items)
x: NiftiNiftiList
Image (1, 240, 240, 29),Image (1, 240, 240, 29),Image (1, 240, 240, 29),Image (1, 240, 240, 29),Image (1, 240, 240, 29)
y: NiftiItemList
Image (1, 240, 240, 29),Image (1, 240, 240, 29),Image (1, 240, 240, 29),Image (1, 240, 240, 29),Image (1, 240, 240, 29)
Path: data/Utrecht;

Valid: LabelList (8 items)
x: NiftiNiftiList
Image (1, 240, 240, 29),Image (1, 240, 240, 29),Image (1, 240, 240, 29),Image (1, 240, 240, 29),Image (1, 240, 240, 29)
y: NiftiItemList
Image (1, 240, 240, 29),Image (1, 240, 240, 29),Image (1, 240, 240, 29),Image (1, 240, 240, 29),Image (1, 240, 240, 29)
Path: data/Utrecht;

Test: None

In [33]:
spectral_norm = nn.utils.spectral_norm
weight_norm = nn.utils.weight_norm

In [34]:
def conv3d(ni:int, nf:int, ks:int=3, stride:int=1, pad:int=1, norm='batch'):
    bias = not norm == 'batch'
    conv = faiv.init_default(nn.Conv3d(ni,nf,ks,stride,pad,bias=bias))
    conv = spectral_norm(conv) if norm == 'spectral' else \
           weight_norm(conv) if norm == 'weight' else conv
    layers = [conv]
    layers += [nn.ReLU(inplace=True)]  # use inplace due to memory constraints
    layers += [nn.BatchNorm3d(nf)] if norm == 'batch' else []
    return nn.Sequential(*layers)

def res3d_block(ni, nf, ks=3, norm='batch', dense=False):
    """ 3d Resnet block of `nf` features """
    return faiv.SequentialEx(conv3d(ni, nf, ks, pad=ks//2, norm=norm),
                             conv3d(nf, nf, ks, pad=ks//2, norm=norm),
                             faiv.MergeLayer(dense))

In [35]:
norm = 'batch'
layers = ([res3d_block(1,15,7,norm=norm,dense=True)] +
          [res3d_block(16,16,norm=norm) for _ in range(4)] +
          [conv3d(16,1,ks=1,pad=0,norm=None)])
model = nn.Sequential(*layers)

In [36]:
loss = nn.MSELoss()
learner = faiv.Learner(idb, model, loss_func=loss)

In [40]:
learner.lr_find(num_it=2)
learner.recorder.plot()

LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.


Traceback (most recent call last):
  File "/opt/conda/lib/python3.6/multiprocessing/queues.py", line 230, in _feed
    close()
  File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 177, in close
    self._close()
  File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 361, in _close
    _close(self._handle)
OSError: [Errno 9] Bad file descriptor


RuntimeError: DataLoader worker (pid(s) 2104) exited unexpectedly

In [38]:
cbs = [faiv.callbacks.CSVLogger(learner, 'history')]

learner.fit_one_cycle(100, 1e-2, callbacks=cbs)

epoch,train_loss,valid_loss,time


ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/opt/conda/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 3296, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-38-7a37df9b5ade>", line 3, in <module>
    learner.fit_one_cycle(100, 1e-2, callbacks=cbs)
  File "/opt/conda/lib/python3.6/site-packages/fastai/train.py", line 22, in fit_one_cycle
    learn.fit(cyc_len, max_lr, wd=wd, callbacks=callbacks)
  File "/opt/conda/lib/python3.6/site-packages/fastai/basic_train.py", line 199, in fit
    fit(epochs, self, metrics=self.metrics, callbacks=self.callbacks+callbacks)
  File "/opt/conda/lib/python3.6/site-packages/fastai/basic_train.py", line 99, in fit
    for xb,yb in progress_bar(learn.data.train_dl, parent=pbar):
  File "/opt/conda/lib/python3.6/site-packages/fastprogress/fastprogress.py", line 72, in __iter__
    for i,o in enumerate(self._gen):
  File "/opt/conda/lib/python3.6/site-packages/fastai/basic_data.py", line 75, in __it

RuntimeError: Traceback (most recent call last):
  File "/opt/conda/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 99, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/opt/conda/lib/python3.6/site-packages/fastai/torch_core.py", line 127, in data_collate
    return torch.utils.data.dataloader.default_collate(to_data(batch))
  File "/opt/conda/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 68, in default_collate
    return [default_collate(samples) for samples in transposed]
  File "/opt/conda/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 68, in <listcomp>
    return [default_collate(samples) for samples in transposed]
  File "/opt/conda/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 41, in default_collate
    storage = batch[0].storage()._new_shared(numel)
  File "/opt/conda/lib/python3.6/site-packages/torch/storage.py", line 126, in _new_shared
    return cls._new_using_fd(size)
RuntimeError: unable to write to file </torch_2050_396257799>


In [None]:
learner.save('test')

In [None]:
import nibabel as nib

In [None]:
obj = nib.load(str(data_dir/'t1/test/KKI2009-11-MPRAGE_zscore.nii.gz'))
test = torch.Tensor(obj.get_data()).to(device)

In [None]:
res = learner.model.forward(test[None,None,...]).cpu().detach().numpy()

In [None]:
plt.figure(figsize=(8,8));
plt.imshow(np.rot90(np.squeeze(res)[:,:,150],3),cmap='gray');
plt.axis('off');

In [None]:
nib.Nifti1Image(res,obj.affine,obj.header).to_filename('test.nii.gz')

In [None]:
!ls {data_dir/'flair/test'}

In [None]:
flair = nib.load(str(data_dir/'flair/test/KKI2009-11-FLAIR_reg_zscore.nii.gz'))

In [None]:
i = 150
def imp(data,ax,i,v,t=''): ax.imshow(np.rot90(data[:,:,i],3),cmap='gray',vmin=v[0],vmax=v[1]); ax.axis('off'); ax.set_title(t);
fig, (ax1,ax2,ax3) = plt.subplots(1,3,figsize=(16,12))
imp(obj.get_data(),ax1,i,(None,None),'T1');imp(flair.get_data(),ax2,i,(0,3.5),'FLAIR');imp(res.squeeze(),ax3,i,(0,3.5),'Syn');
plt.savefig('~/Downloads/blg.png',dpi=200)