<a href="https://colab.research.google.com/github/neuroneural/brainchop/blob/master/py2tfjs/Convert_Trained_Model_To_TFJS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Pytorch Model Conversion to tfjs

In this example, you will find a simple steps on how to convert the  trained  **MeshNet** model to tfjs model that can be used with [**Brainchop**](https://neuroneural.github.io/brainchop/). Given a Gray Matter White Matter (GWM) segmentation model that segmenting the brain into three different regions,uccessful conversion to tfjs will result in two main files, the model.json file, and the weights bin file:

    -The model.json file consists of model topology and weights manifest.
    -The binary weights file (i.e. *.bin) consists of the concatenated weight values.



This conversoin pipeline example is part of the [**Brainchop**](https://neuroneural.github.io/brainchop/)  project, where the basic MeshNet model is trained using **PyTorch**, and the resulting model converted to the **Tensorflow.js** (tfjs) model to be used with Brainchop.

---

**Authors:** [Mohamed Masoud](https://github.com/Mmasoud1), and [Sergey Plis](https://github.com/sergeyplis)



.

### Fetch required python files from brainchop repository [**folder**](https://github.com/neuroneural/brainchop/tree/master/py2tfjs/conversion_example)

We'll be calling three python scripts as libraries that need download from brainchop.


In [1]:
!wget --no-cache --backups=1 {f"https://raw.githubusercontent.com/neuroneural/brainchop/master/py2tfjs/conversion_example/blendbatchnorm.py"}
!wget --no-cache --backups=1 {f"https://raw.githubusercontent.com/neuroneural/brainchop/master/py2tfjs/conversion_example/meshnet.py"}
!wget --no-cache --backups=1 {f"https://raw.githubusercontent.com/neuroneural/brainchop/master/py2tfjs/conversion_example/meshnet2tfjs.py"}


--2024-03-28 09:08:58--  https://raw.githubusercontent.com/neuroneural/brainchop/master/py2tfjs/conversion_example/blendbatchnorm.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2254 (2.2K) [text/plain]
Saving to: ‘blendbatchnorm.py’


2024-03-28 09:08:58 (39.2 MB/s) - ‘blendbatchnorm.py’ saved [2254/2254]

--2024-03-28 09:08:58--  https://raw.githubusercontent.com/neuroneural/brainchop/master/py2tfjs/conversion_example/meshnet.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.110.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4711 (4.6K) [text/plain]
Saving to: ‘meshnet.py’



We'll be calling the saved trained Pytorch model to be converted to TFJS.

In [2]:
!wget --no-cache --backups=1 {f"https://raw.githubusercontent.com/neuroneural/brainchop/master/py2tfjs/conversion_example/modelAE.json"}
!wget --no-cache --backups=1 {f"https://raw.githubusercontent.com/neuroneural/brainchop/master/py2tfjs/conversion_example/model.pth"}


--2024-03-28 09:09:02--  https://raw.githubusercontent.com/neuroneural/brainchop/master/py2tfjs/conversion_example/modelAE.json
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3044 (3.0K) [text/plain]
Saving to: ‘modelAE.json’


2024-03-28 09:09:02 (45.2 MB/s) - ‘modelAE.json’ saved [3044/3044]

--2024-03-28 09:09:02--  https://raw.githubusercontent.com/neuroneural/brainchop/master/py2tfjs/conversion_example/model.pth
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 484293 (473K) [application/octet-stream]
Saving to: ‘model.pth’




In [3]:
import torch
from blendbatchnorm import fuse_bn_recursively
from meshnet2tfjs import meshnet2tfjs
from meshnet import (
    MeshNet,
    enMesh_checkpoint,
)

device_name = "cuda:0" if torch.cuda.is_available() else "cpu"
device = torch.device(device_name)


# Normalization
def preprocess_image(img, qmin=0.01, qmax=0.99):
    """Unit interval preprocessing"""
    img = (img - img.quantile(qmin)) / (img.quantile(qmax) - img.quantile(qmin))
    return img

In [4]:
# specify how many classes does the model predict
n_classes = 3
# specify the architecture
config_file = "modelAE.json"
# how many channels does the saved model have
model_channels = 15
# path to the saved model
model_path = "model.pth"
# tfjs model output directory with colab
tfjs_model_dir = "model_tfjs"

meshnet_model = enMesh_checkpoint(
    in_channels=1,
    n_classes=n_classes,
    channels=model_channels,
    config_file=config_file,
)

checkpoint = torch.load(model_path)
meshnet_model.load_state_dict(checkpoint)

<All keys matched successfully>

In [5]:
meshnet_model.eval()

enMesh_checkpoint(
  (model): Sequential(
    (0): Sequential(
      (0): Conv3d(1, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ELU(alpha=1.0, inplace=True)
    )
    (1): Sequential(
      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(2, 2, 2), dilation=(2, 2, 2))
      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ELU(alpha=1.0, inplace=True)
    )
    (2): Sequential(
      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(3, 3, 3), dilation=(3, 3, 3))
      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ELU(alpha=1.0, inplace=True)
    )
    (3): Sequential(
      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(4, 4, 4), dilation=(4, 4, 4))
      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=T

In [6]:
meshnet_model.to(device)

enMesh_checkpoint(
  (model): Sequential(
    (0): Sequential(
      (0): Conv3d(1, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ELU(alpha=1.0, inplace=True)
    )
    (1): Sequential(
      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(2, 2, 2), dilation=(2, 2, 2))
      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ELU(alpha=1.0, inplace=True)
    )
    (2): Sequential(
      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(3, 3, 3), dilation=(3, 3, 3))
      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ELU(alpha=1.0, inplace=True)
    )
    (3): Sequential(
      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(4, 4, 4), dilation=(4, 4, 4))
      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=T

This function takes a sequential block and fuses the batch normalization with convolution

In [7]:
mnm = fuse_bn_recursively(meshnet_model)

del meshnet_model
mnm.model.eval()

Sequential(
  (0): Sequential(
    (0): Conv3d(1, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): ELU(alpha=1.0, inplace=True)
  )
  (1): Sequential(
    (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(2, 2, 2), dilation=(2, 2, 2))
    (1): ELU(alpha=1.0, inplace=True)
  )
  (2): Sequential(
    (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(3, 3, 3), dilation=(3, 3, 3))
    (1): ELU(alpha=1.0, inplace=True)
  )
  (3): Sequential(
    (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(4, 4, 4), dilation=(4, 4, 4))
    (1): ELU(alpha=1.0, inplace=True)
  )
  (4): Sequential(
    (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(5, 5, 5), dilation=(5, 5, 5))
    (1): ELU(alpha=1.0, inplace=True)
  )
  (5): Sequential(
    (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(6, 6, 6), dilation=(6, 6, 6))
    (1): ELU(alpha=1.0, inplace=True)
  )
  (6): Sequential(


Convert MeshNet model to TensorFlow.js

In [8]:
meshnet2tfjs(mnm, tfjs_model_dir)

In [9]:
!ls

bcmodel.zip	     meshnet2tfjs.py	meshnet.py.1	model.pth    __pycache__
blendbatchnorm.py    meshnet2tfjs.py.1	modelAE.json	model.pth.1  sample_data
blendbatchnorm.py.1  meshnet.py		modelAE.json.1	model_tfjs


Save converted files to zip file

In [10]:
!zip -r "/content/bcmodel.zip" "model_tfjs"

updating: model_tfjs/ (stored 0%)
updating: model_tfjs/model.json (deflated 97%)
updating: model_tfjs/model.bin (deflated 6%)


Download TensorFlow model

In [15]:
from google.colab import files
files.download("/content/bcmodel.zip")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

# **Final notes**

This tutorial aims to provide a simple example of how to convert segmentation model from python pipeline to TensorFlow.js files (i.e. model.json and model.bin).  