# The validity of JIT model

In [1]:
from pathlib import Path
import os

In [2]:
HOME = Path(os.environ['HOME'])

In [3]:
MODELNAME = "fpn"
ENCODER = "efficientnet-b5"
DOWN = False # Downsample at the bottom
SRC = HOME/"ucsi"/"fastai"/"models"/"bestmodel_3.pth" # source model path
DST = HOME/"ucsi"/"jit"/"fpn_b5_e3.pth" # desitination model path

In [4]:
from torch import jit
import segmentation_models_pytorch as smp
import torch
from torch import nn

In [5]:
if MODELNAME =="fpn":
    model_class = smp.FPN
elif MODELNAME == "unet":
    model_class = smp.Unet

### Loading The Model

In [6]:
seg_conf = {
    "encoder_name":ENCODER,
    "encoder_weights":None,
    "classes":4,
    "activation":"sigmoid",
}

print("Constructing the model")
print(seg_conf)
if DOWN:
    class majorModel(nn.Module):
        def __init__(self, seg_model):
            super().__init__()
            self.seq = nn.Sequential(*[
                nn.Conv2d(3,12,kernel_size=(3,3), padding=1, stride=1, ),
                nn.ReLU(),
                nn.Conv2d(12,3,kernel_size=(3,3), padding=1, stride=2),
                nn.ReLU(),
                seg_model,])
            
        def forward(self,x):
            return self.seq(x)
    model = majorModel(model_class(**seg_conf))
    
else:
    model = model_class(**seg_conf)

Constructing the model
{'encoder_name': 'efficientnet-b5', 'encoder_weights': None, 'classes': 4, 'activation': 'sigmoid'}


In [7]:
CUDA = torch.cuda.is_available()
print("CUDA available:\t%s"%(CUDA))

CUDA available:	True


In [8]:
print("Loading from weights:\t%s"%(SRC))
state = torch.load(SRC)
if "model" in state:
    state = state["model"]
if CUDA:
    model = model.cuda()
model.load_state_dict(state)

Loading from weights:	/home/b2ray2c/ucsi/fastai/models/bestmodel_3.pth


<All keys matched successfully>

In [9]:
testimg = torch.rand(2, 3, 320, 640)
if CUDA:
    testimg = testimg.cuda()

In [10]:
model = model.eval()
with torch.no_grad():
    y1 = model(testimg)

### Save to JIT

In [11]:
print("Saving to jit traced model:\t%s"%(DST))
with torch.no_grad():
    traced = jit.trace(model, testimg)
    traced.save(str(DST))

Saving to jit traced model:	/home/b2ray2c/ucsi/jit/fpn_b5_e3.pth


RuntimeError: 
Could not export Python function call 'SwishImplementation'. Remove calls to Python functions before export. Did you forget add @script or @script_method annotation? If this is a nn.ModuleList, add it to __constants__:
/home/b2ray2c/github/EfficientNet-PyTorch/efficientnet_pytorch/utils.py(57): forward
/home/b2ray2c/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py(531): _slow_forward
/home/b2ray2c/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py(545): __call__
/home/b2ray2c/github/EfficientNet-PyTorch/efficientnet_pytorch/utils.py(66): relu_fn
/home/b2ray2c/anaconda3/lib/python3.7/site-packages/segmentation_models_pytorch/encoders/efficientnet.py(19): forward
/home/b2ray2c/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py(531): _slow_forward
/home/b2ray2c/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py(545): __call__
/home/b2ray2c/anaconda3/lib/python3.7/site-packages/segmentation_models_pytorch/base/encoder_decoder.py(24): forward
/home/b2ray2c/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py(531): _slow_forward
/home/b2ray2c/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py(545): __call__
/home/b2ray2c/anaconda3/lib/python3.7/site-packages/torch/jit/__init__.py(904): trace_module
/home/b2ray2c/anaconda3/lib/python3.7/site-packages/torch/jit/__init__.py(772): trace
<ipython-input-11-b8b1769c2dd5>(3): <module>
/home/b2ray2c/anaconda3/lib/python3.7/site-packages/IPython/core/interactiveshell.py(3326): run_code
/home/b2ray2c/anaconda3/lib/python3.7/site-packages/IPython/core/interactiveshell.py(3249): run_ast_nodes
/home/b2ray2c/anaconda3/lib/python3.7/site-packages/IPython/core/interactiveshell.py(3058): run_cell_async
/home/b2ray2c/anaconda3/lib/python3.7/site-packages/IPython/core/async_helpers.py(68): _pseudo_sync_runner
/home/b2ray2c/anaconda3/lib/python3.7/site-packages/IPython/core/interactiveshell.py(2881): _run_cell
/home/b2ray2c/anaconda3/lib/python3.7/site-packages/IPython/core/interactiveshell.py(2855): run_cell
/home/b2ray2c/anaconda3/lib/python3.7/site-packages/ipykernel/zmqshell.py(536): run_cell
/home/b2ray2c/anaconda3/lib/python3.7/site-packages/ipykernel/ipkernel.py(294): do_execute
/home/b2ray2c/anaconda3/lib/python3.7/site-packages/tornado/gen.py(209): wrapper
/home/b2ray2c/anaconda3/lib/python3.7/site-packages/ipykernel/kernelbase.py(542): execute_request
/home/b2ray2c/anaconda3/lib/python3.7/site-packages/tornado/gen.py(209): wrapper
/home/b2ray2c/anaconda3/lib/python3.7/site-packages/ipykernel/kernelbase.py(272): dispatch_shell
/home/b2ray2c/anaconda3/lib/python3.7/site-packages/tornado/gen.py(209): wrapper
/home/b2ray2c/anaconda3/lib/python3.7/site-packages/ipykernel/kernelbase.py(365): process_one
/home/b2ray2c/anaconda3/lib/python3.7/site-packages/tornado/gen.py(748): run
/home/b2ray2c/anaconda3/lib/python3.7/site-packages/tornado/gen.py(714): __init__
/home/b2ray2c/anaconda3/lib/python3.7/site-packages/tornado/gen.py(225): wrapper
/home/b2ray2c/anaconda3/lib/python3.7/site-packages/ipykernel/kernelbase.py(378): dispatch_queue
/home/b2ray2c/anaconda3/lib/python3.7/site-packages/tornado/gen.py(748): run
/home/b2ray2c/anaconda3/lib/python3.7/site-packages/tornado/gen.py(787): inner
/home/b2ray2c/anaconda3/lib/python3.7/site-packages/tornado/ioloop.py(743): _run_callback
/home/b2ray2c/anaconda3/lib/python3.7/site-packages/tornado/ioloop.py(690): <lambda>
/home/b2ray2c/anaconda3/lib/python3.7/asyncio/events.py(88): _run
/home/b2ray2c/anaconda3/lib/python3.7/asyncio/base_events.py(1771): _run_once
/home/b2ray2c/anaconda3/lib/python3.7/asyncio/base_events.py(534): run_forever
/home/b2ray2c/anaconda3/lib/python3.7/site-packages/tornado/platform/asyncio.py(148): start
/home/b2ray2c/anaconda3/lib/python3.7/site-packages/ipykernel/kernelapp.py(563): start
/home/b2ray2c/anaconda3/lib/python3.7/site-packages/traitlets/config/application.py(664): launch_instance
/home/b2ray2c/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py(16): <module>
/home/b2ray2c/anaconda3/lib/python3.7/runpy.py(85): _run_code
/home/b2ray2c/anaconda3/lib/python3.7/runpy.py(193): _run_module_as_main


In [None]:
if CUDA:
    model = model.cpu()

### Recover from saved JIT

In [None]:
recovered = jit.load(str(DST))
if CUDA: 
    recovered = recovered.cuda()
with torch.no_grad():
    y2 = recovered(testimg)

In [None]:
print("Absolute Mean Error:%s"%(torch.abs(y1-y2).mean().item()))