# Porting a NiftyNet model to PyTorch

Introduction, motivation, related works

## HighRes3DNet

HighRes3DNet is a residual convolutional neural network designed to have a large receptive field and preserve a high resolution using a relatively small number of parameters. It was presented in 2017 by Li et al. at IPMI: [*On the Compactness, Efficiency, and Representation of 3D Convolutional Networks: Brain Parcellation as a Pretext Task*](https://arxiv.org/abs/1707.01992).

<img src="images/network.png" alt="drawing" width="600"/>

The authors used [NiftyNet](https://niftynet.io/) to train a model based on this architecture to perform [brain parcellation](https://ieeexplore.ieee.org/document/7086081?arnumber=7086081) from $T_1$-weighted MR images using the [ADNI dataset](http://adni.loni.usc.edu/).

<img src="images/li-000.png" alt="drawing" width="150"/> <img src="images/li-001.png" alt="drawing" width="150"/>

The code of the architecture is on [NiftyNet's GitHub repository](https://github.com/NifTK/NiftyNet/blob/dev/niftynet/network/highres3dnet.py) and the weights and configuration file have been uploaded to the [Model Zoo](https://github.com/NifTK/NiftyNetModelZoo/tree/5-reorganising-with-lfs/highres3dnet_brain_parcellation).

After reading the paper and the code, one can easily [implement the architecture using PyTorch](https://github.com/fepegar/highresnet).

In this notebook we will:


1.   Extract the parameters from a TensorFlow checkpoint
2.   List item



## Setup

### Install and import libraries

In [1]:
%%capture
!pip install -r requirements.txt

In [2]:
from pathlib import Path

import torch
from torchsummary import summary

import numpy as np
import pandas as pd
import nibabel as nib
from tqdm import tqdm
import matplotlib.pyplot as plt
from skimage.color import label2rgb

import tf2pt
import utils

from highresnet import HighRes3DNet

### Download NiftyNet data

We can use NiftyNet's [`net_download`](https://niftynet.readthedocs.io/en/dev/model_zoo.html#net-download) to get all we need from the [Model Zoo](https://github.com/NifTK/NiftyNetModelZoo/tree/5-reorganising-with-lfs):

In [3]:
%%capture
#!net_download highres3dnet_brain_parcellation_model_zoo

In [4]:
niftynet_dir = Path('~/niftynet').expanduser()
utils.list_files(niftynet_dir)

niftynet/
    extensions/
        __init__.py
        network/
            __init__.py
        highres3dnet_brain_parcellation/
            __init__.py
            highres3dnet_config_eval.ini
    models/
        highres3dnet_brain_parcellation/
            databrain_std_hist_models_otsu.txt
            models/
                model.ckpt-33000.meta
                model.ckpt-33000.data-00000-of-00001
                model.ckpt-33000.index
    data/
        OASIS/
            OAS1_0145_MR2_mpr_n4_anon_sbj_111.nii.gz
            license


There are three directories under `~/niftynet`:
1. `extensions` is a Python package and contains the configuration file `highres3dnet_config_eval.ini`
2. `models` contains the landmarks for histogram standardisation and the parameters
3. `data` contains an MRI that can be used to test the model

## TensorFlow world 
<img src="https://static.nvidiagrid.net/ngc/containers/tensorflow.png" alt="drawing" width="50"/>

In [8]:
output_csv_path = 'state_dict_tf.csv'
models_dir = niftynet_dir / 'models'
checkpoint_name = 'model.ckpt-33000'
checkpoint_path = models_dir / 'highres3dnet_brain_parcellation'/ 'models' / checkpoint_name

state_dict_tf_path = 'state_dict_tf.pth'
state_dict_pt_path = 'state_dict_pt.pth'

pd.set_option('display.max_colwidth', -1)  # do not truncate strings when displaying data frames
pd.set_option('display.max_rows', None)  # show all rows

filter_variables = True

Let's see what variables are stored in the checkpoint.

Some of them are filtered out by `tf2pt.checkpoint_to_state_dict()` for clarity:
* Variables used by the Adam optimizer during training
* Variables with no shape. They won't help much
* Variables containing `biased` or `ExponentialMovingAverage`. Results using these variables have turned out to be different to the ones produced by NiftyNet

We'll store the TensorFlow variables names in a data frame to list them in this notebook and the values in a Python dictionary to retrieve them later.

In [6]:
data_frame_tf, state_dict_tf = tf2pt.checkpoint_to_state_dict(checkpoint_path)
data_frame_tf

Restoring session...
INFO:tensorflow:Restoring parameters from /home/fernando/niftynet/models/highres3dnet_brain_parcellation/models/model.ckpt-33000


100%|██████████| 380/380 [00:00<00:00, 422.51it/s]


Unnamed: 0,name,shape
0,conv_0_bn_relu/bn_/beta,16
1,conv_0_bn_relu/bn_/gamma,16
2,conv_0_bn_relu/bn_/moving_mean,16
3,conv_0_bn_relu/bn_/moving_variance,16
4,conv_0_bn_relu/conv_/w,"3, 3, 3, 1, 16"
5,conv_1_bn_relu/bn_/beta,80
6,conv_1_bn_relu/bn_/gamma,80
7,conv_1_bn_relu/bn_/moving_mean,80
8,conv_1_bn_relu/bn_/moving_variance,80
9,conv_1_bn_relu/conv_/w,"1, 1, 1, 64, 80"


The layers names and parameters shapes overall seem to be coherent with the figure in the paper, but there's an additional 1x1x1 convolution with 80 output channels. It's also in the [code](https://github.com/NifTK/NiftyNet/blob/1832a516c909b67d0d9618acbd04a7642c12efca/niftynet/network/highres3dnet.py#L93). It seems to be the model with dropout used in the study to compute the model's uncertainty, so our implementation of the architecture should include this layer as well.

## PyTorch world
<img src="https://s3-ap-south-1.amazonaws.com/av-blog-media/wp-content/uploads/2019/01/pytorch-logo.png" alt="drawing" width="100"/>


In [14]:
num_input_modalities = 1
num_classes = 160
model = HighRes3DNet(num_input_modalities, num_classes, add_dropout_layer=True)

Let's see what are the names created by PyTorch:

In [15]:
state_dict_pt = model.state_dict()
rows = []
for name, parameters in state_dict_pt.items():
    if 'num_batches_tracked' in name:  # for clarity
        continue
    shape = ', '.join(str(n) for n in parameters.shape)
    row = {'name': name, 'shape': shape}
    rows.append(row)
df_pt = pd.DataFrame.from_dict(rows)
df_pt.style.set_properties(**{'text-align': 'left'})
df_pt

Unnamed: 0,name,shape
0,block.0.convolutional_block.1.weight,"16, 1, 3, 3, 3"
1,block.0.convolutional_block.2.weight,16
2,block.0.convolutional_block.2.bias,16
3,block.0.convolutional_block.2.running_mean,16
4,block.0.convolutional_block.2.running_var,16
5,block.1.dilation_block.0.residual_block.0.convolutional_block.0.weight,16
6,block.1.dilation_block.0.residual_block.0.convolutional_block.0.bias,16
7,block.1.dilation_block.0.residual_block.0.convolutional_block.0.running_mean,16
8,block.1.dilation_block.0.residual_block.0.convolutional_block.0.running_var,16
9,block.1.dilation_block.0.residual_block.0.convolutional_block.3.weight,"16, 16, 3, 3, 3"


The names and shapes look good and there are 104 lines in both lists, so we should be able to create a mapping between the TensorFlow and the PyTorch variables. 

In [17]:
for name_tf, tensor_tf in tqdm(list(state_dict_tf.items())):
    shape_tf = tuple(tensor_tf.shape)
    print(f'{str(shape_tf):18}', name_tf) 
    
    name_pt, tensor_pt = tf2pt.tf2pt(name_tf, tensor_tf)
    shape_pt = tuple(state_dict_pt[name_pt].shape)
    print(f'{str(shape_pt):18}', name_pt)
    
    if sum(shape_tf) != sum(shape_pt):
        raise ValueError
        
    if name_pt not in state_dict_pt:
        raise KeyError
    
    state_dict_pt[name_pt] = tensor_pt
    print()
print('Saving state dictionary to', state_dict_pt_path)
torch.save(state_dict_pt, state_dict_pt_path)

100%|██████████| 105/105 [00:00<00:00, 3104.33it/s]

(16,)              conv_0_bn_relu/bn_/beta
(16,)              block.0.convolutional_block.2.bias

(16,)              conv_0_bn_relu/bn_/gamma
(16,)              block.0.convolutional_block.2.weight

(16,)              conv_0_bn_relu/bn_/moving_mean
(16,)              block.0.convolutional_block.2.running_mean

(16,)              conv_0_bn_relu/bn_/moving_variance
(16,)              block.0.convolutional_block.2.running_var

(3, 3, 3, 1, 16)   conv_0_bn_relu/conv_/w
(16, 1, 3, 3, 3)   block.0.convolutional_block.1.weight

(80,)              conv_1_bn_relu/bn_/beta
(80,)              block.4.convolutional_block.1.bias

(80,)              conv_1_bn_relu/bn_/gamma
(80,)              block.4.convolutional_block.1.weight

(80,)              conv_1_bn_relu/bn_/moving_mean
(80,)              block.4.convolutional_block.1.running_mean

(80,)              conv_1_bn_relu/bn_/moving_variance
(80,)              block.4.convolutional_block.1.running_var

(1, 1, 1, 64, 80)  conv_1_bn_relu/conv_/w
(80




If PyTorch is happy when loading our state dict into the model, we should be on the right track.

In [21]:
model.load_state_dict(state_dict_pt)

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

No incompatible keys. Phew!

Set params

Plot weights

Visualize activations

Run inference