# The validity of JIT model

In [1]:
from pathlib import Path
import os

In [2]:
HOME = Path(os.environ['HOME'])

In [3]:
MODELNAME = "fpn"
ENCODER = "efficientnet-b5"
DOWN = True # Downsample at the bottom
SRC = HOME/"ucsi"/"fastai"/"models"/"bestmodel_4.pth" # source model path
DST = HOME/"ucsi"/"jit"/"fpn_b5_e4.pth" # desitination model path

In [4]:
from torch import jit
import segmentation_models_pytorch as smp
import torch
from torch import nn

In [5]:
if MODELNAME =="fpn":
    model_class = smp.FPN
elif MODELNAME == "unet":
    model_class = smp.Unet

### Loading The Model

In [6]:
seg_conf = {
    "encoder_name":ENCODER,
    "encoder_weights":None,
    "classes":4,
    "activation":"sigmoid",
}

print("Constructing the model")
print(seg_conf)
if DOWN:
    class majorModel(nn.Module):
        def __init__(self, seg_model):
            super().__init__()
            self.seq = nn.Sequential(*[
                nn.Conv2d(3,12,kernel_size=(3,3), padding=1, stride=1, ),
                nn.ReLU(),
                nn.Conv2d(12,3,kernel_size=(3,3), padding=1, stride=2),
                nn.ReLU(),
                seg_model,])
        
        def forward(self,x):
            return self.seq(x)
    model = majorModel(model_class(**seg_conf))
    
else:
    model = model_class(**seg_conf)

Constructing the model
{'encoder_name': 'efficientnet-b5', 'encoder_weights': None, 'classes': 4, 'activation': 'sigmoid'}


In [7]:
CUDA = torch.cuda.is_available()
print("CUDA available:\t%s"%(CUDA))

CUDA available:	True


In [8]:
print("Loading from weights:\t%s"%(SRC))
state = torch.load(SRC)
if "model" in state:
    state = state["model"]
if CUDA:
    model = model.cuda()
model.load_state_dict(state)

Loading from weights:	/home/b2ray2c/ucsi/fastai/models/bestmodel_4.pth


<All keys matched successfully>

In [9]:
testimg = torch.rand(2, 3, 320, 640)
if CUDA:
    testimg = testimg.cuda()

In [10]:
model = model.eval()
with torch.no_grad():
    y1 = model(testimg)

### Save to JIT

In [11]:
print("Saving to jit traced model:\t%s"%(DST))
with torch.no_grad():
    traced = jit.trace(model, testimg)
    traced.save(str(DST))

Saving to jit traced model:	/home/b2ray2c/ucsi/jit/fpn_b5_e4.pth


In [12]:
if CUDA:
    model = model.cpu()

### Recover from saved JIT

In [13]:
recovered = jit.load(str(DST))
if CUDA: 
    recovered = recovered.cuda()
with torch.no_grad():
    y2 = recovered(testimg)

In [14]:
print("Absolute Mean Error:%s"%(torch.abs(y1-y2).mean().item()))

Absolute Mean Error:0.0
