# Running Task 2 with a Pre-Trained Model

This notebook walks through an example of utilizing the existing trained models ([pre-trained weights](https://www.dropbox.com/sh/pzza5nuh93r9s18/AAAZsISLUl1H_u3U0TDSeOjNa?dl=0)) to perform pixel-level brain microstructure (blood-vessel, cell, axon) segmentations.

## Intro

In MTNeuro, we provied multiple tasks to evaluate the model across multiple scales and capabilities. This notebook deals with `Task 2`, .i.e, the **Pixel Level Brain Microstructure Segmentation Task**, and provides methods to download and load the pretrained weights and models configs to easily perform pixel-level segmentation and explore the outcomes.

#### Citation
For more details about `Task 2` and the other tasks, please refer:

```
Quesada, J., Sathidevi, L., Liu, R., Ahad, N., Jackson, J.M., Azabou, M., ... & Dyer, E. L. (2022). MTNeuro: A Benchmark for Evaluating Representations of Brain Structure Across Multiple Levels of Abstraction. Thirty-sixth Conference on Neural Information Processing Systems Datasets and Benchmarks Track.

Prasad, J. A., Balwani, A. H., Johnson, E. C., Miano, J. D., Sampathkumar, V., De Andrade, V., ... & Dyer, E. L. (2020). A three-dimensional thalamocortical dataset for characterizing brain heterogeneity. Scientific Data, 7(1), 1-7.
```


### Specifying the Model, Setting, and Mode for Image Segmentation

Specify your preferred model and setting in the code below (refer to table below for available options)

**Available Options**:
\begin{array} & \underline{\textbf{2D Models}} & \hspace{35pt}\underline{\textbf{3D Models}} & \hspace{35pt}\underline{\textbf{Settings}} & \hspace{35pt}\underline{\textbf{Modes}} \\  UNet & \hspace{35pt}UNet & \hspace{35pt}3class & \hspace{35pt}3D \\ smp\_UnetPlusPlus & \hspace{35pt}mzp\_HighResNet & \hspace{35pt}4class & \hspace{35pt}2D \\
smp\_PSPNet & \hspace{35pt}mzp\_VNetLight & & \\
smp\_PAN & & & \\
smp\_FPN & & & \\
smp\_MAnet\end{array}


Note: `smp` indicates models imported from [segmentation_models.pytorch](https://github.com/qubvel/segmentation_models.pytorch) library and `mzp` indicates models imported from [medicalzoopytorch](https://github.com/black0017/MedicalZooPytorch) library.


**3class Setting**: Segmentation using 3 labels - cell bodies, blood vessels and background (axons considered as background)

**4class Setting**: Segmentation using 4 labels - cell bodies, blood vessels, axons and background. In this setting the ZI region is excluded as clearly distinguishing the axons in the slices from ZI region would be difficult even for a knowledgeable human annotator.

_Refer [MTNeuro paper](https://openreview.net/forum?id=5xuowSQ17vy) for more details on the models, tasks and settings._

In [28]:
Model = 'UNet'         #Eg:- UNet, smp_PAN
Setting = '3class'     #Eg:- 3class, 4class
Mode = '2D'            #Eg:- 3D, 2D

fh = open('config.txt', 'w')
fh.write(Model+' '+Setting+' '+Mode)
fh.close()

### Cloning Required Repositories, Installing Libraries and Dependencies, and Importing Packages

Installing required packages:

In [None]:
!git clone https://github.com/MTNeuro/MTNeuro && cd MTNeuro && pip install -e .

# Setting up the segmentation_models.pytorch library
!pip install segmentation-models-pytorch
!pip install torchsummaryX

#setting up medicalzoopytorch
!git clone https://github.com/black0017/MedicalZooPytorch/ && cd MedicalZooPytorch && mv lib ../. 

Importing the Packages:

In [2]:
#import libraries
import sys
import os
import matplotlib.pyplot as plt
import numpy as np
import json as json
from tqdm import tqdm
import glob

#pytorch imports
import torch
from torchvision import transforms
from torchvision.transforms import ToTensor
from torch.utils.data import Dataset, DataLoader, Subset

#intern library that is used to pull 3D brain volume form BossDB
import segmentation_models_pytorch as smp      #SMP library: https://github.com/qubvel/segmentation_models.pytorch; For models and for calculation of metrics
import lib.medzoo as medzoo

The `BossDBDataset` class is used to download the required slices of brain imaging data from `BossDB`. Let's import it from the cloned MTNeuro package:

In [3]:
#BossDB MTNeuro dataset loader
from MTNeuro.bossdbdataset import BossDBDataset               

### Downloading the trained weights and corresponding configuration file

The trained weights and configuration files for the models can also be downloaded manually from: [Link](https://www.dropbox.com/sh/pzza5nuh93r9s18/AAAZsISLUl1H_u3U0TDSeOjNa?dl=0). 

Below script automatically does that for you (using the model and setting specified earlier):

In [None]:
!bash scripts/download_pretrained_weights_and_conf.sh

### Loading the appropriate Configuration File

**task config**: Settings corresponding to the task, like the x, y and z region of the slices. [Example](https://github.com/MTNeuro/MTNeuro/blob/main/MTNeuro/taskconfig/task2_2D_3class.json)

**network config**: Settings corresponding to the model and the training, like model layer sizes, training batch size, etc. [Example](https://github.com/MTNeuro/MTNeuro/blob/main/MTNeuro/networkconfig/UNet_2D_3class.json)

In [31]:
#load the task config for specified setting that 
#(specifies the x,y ranges to pick from the data for forming the slices)
task_config = json.load(open('../MTNeuro/taskconfig/task2_'+str(Mode)+'_'+str(Setting)+'.json'))

#Load the network config for the specified setting
#(Contains the batch size and model config information)
network_config = json.load(open(glob.glob("*.json")[0]))

### Loading the Model and the corresponding pre-trained weights

Importing the custom `load_model()` method from `loading_model.py` for loading the specified model:

#### Loading the Model

Specify the device onto which you want load the model and utilize `load_model()` method.

**Syntax**: `model_object = load_model(<network configuration file for the model>, <device onto which to load the model>)`
    

In [None]:
from scripts.loading_model import load_model

# Specify device
device = torch.device('cuda') if torch.cuda.is_available() else  torch.device('cpu')

model_object = load_model(network_config, device)

#### Loading the pre-trained weights to the model

The custom `load_weights()` method loads the downloaded pretrained weights into the specified `model_object`

**Syntax**: `model_object = load_weights(<name of the model>, <Mode>, <Setting>, <the object of the actual model>)`

In [None]:
from scripts.load_pretrained_weights import load_weights
model_object = load_weights(Model, Mode, Setting, model_object)

### Preparing the DataLoader using BossDBDataset

**BossDBDataset**: Helper class that utilizes the `intern` library to download 3D brain volume from BossDB and convert it into a suitable PyTorch dataloader of image slices.

Note: The `ToTensor` transform needs to be applied in order to enable the conversion from numpy to PyTorch tensor.

In [34]:
##Set-up the test dataloader
test_data = BossDBDataset(task_config, None, 'test')

#droping last batch due to unequal size
if Mode=="3D":
    test_data = Subset(test_data, list(range(len(test_data))[:-network_config['batch_size']]))

test_dataloader = DataLoader(dataset=test_data,
                                        batch_size=network_config['batch_size'],
                                        shuffle=False)

### Model Prediction

#### Prediction Function


In [35]:
#function to predict using trained model
def predict(img, model, device):
        model.eval()
        x = img.to(device)  # to torch, send to device
        with torch.no_grad():
            out = model(x)  # send through model/network

        out_argmax = torch.argmax(out, dim=1)  # perform softmax on outputs
        return out_argmax

#### Prediction Loop


In [None]:
batch_iter = tqdm(enumerate(test_dataloader), 'test', total=len(test_dataloader), leave=False)
# predict the segmentations of test set
tp_tot = torch.empty(0,network_config['classes'])
fp_tot = torch.empty(0,network_config['classes'])
fn_tot = torch.empty(0,network_config['classes'])
tn_tot = torch.empty(0,network_config['classes'])

#specify the device to load the input data to
devce = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

y_ = []
x_ = []
output_ = []
for i, (x, y) in batch_iter:
    target = y.to(device) #can do this on CPU
    y_.append(y)
    x_.append(x.squeeze())
    with torch.no_grad():
        # get the output image (make prediction)
        output = predict(x, model_object, device)
        output_.append(output)
        tp_, fp_, fn_, tn_ = smp.metrics.get_stats(output, target, mode='multiclass', num_classes = network_config['classes'])
        tp_tot = torch.vstack((tp_tot,tp_))
        fp_tot = torch.vstack((fp_tot,fp_))
        fn_tot = torch.vstack((fn_tot,fn_))
        tn_tot = torch.vstack((tn_tot,tn_))