# Export FairMOT model in torchscript format

This notebook contains the necessary modifications to export the YoloV5s backbone based FairMOT model in torchscript format. This format can be used later on as an input to [Apache TVM](https://tvm.apache.org) or [SageMaker Neo](https://aws.amazon.com/sagemaker/neo/) compilation job. At the end of the notebook you also find an example how to launch a SageMaker Neo compilation job with the exported model.

Set up the FairMOT environment as described in README. If you also want to compile the model with Neo, you will need boto3 and sagemaker Python sdk:

```bash
$ pip install boto3 sagemaker
```

The model is modified in the following ways:

 - The library implementation of Sigmoid Linear Unit (SiLU) function, used as activation function in the YoloV5 backbone, and as a separate layer in the FairMOT heads, is not implemented in the TVM compiler. The packaged SiLU is replaced by a manual implementation. The replacer code was partly inspired by the original [YoloV5 repo](https://github.com/ultralytics/yolov5/blob/master/export.py).
 - The FairMOT network, implemented in `PoseYOLOv5s` class (`./src/lib/models.yolo.py`) returns a dictionary that is not supported by TVM. A new flag was added to the `./src/lib/models/networks/config/yolov5s.yaml` configuration file: setting `tuple_output` to `true` makes the network return a tuple instead of a `dict`. The elements in the tuple are ordered based on the lexicographic order of the dictionary keys.

In [None]:
%cd src

In [None]:
import datetime
import time
import os
import json

import boto3
import sagemaker

import torch
import torch.nn as nn
print('pytorch version:', torch.__version__)

import _init_paths
from opts import opts
from models.model import create_model, load_model

In [None]:
def wait_for_compilation_job(compilation_job_name, polling=10):
    sagemaker_client = boto3.client('sagemaker')
    resp = sagemaker_client.describe_compilation_job(CompilationJobName=compilation_job_name)
    status = resp['CompilationJobStatus']
    info = { 'compilation_job_result': status }
    while status not in ['COMPLETED', 'FAILED', 'STOPPED']:
        print('.', end='', flush=True)
        time.sleep(polling)
        resp = sagemaker_client.describe_compilation_job(CompilationJobName=compilation_job_name)
        status = resp['CompilationJobStatus']
        info['compilation_job_result'] = status
    print(status)
    if status == 'FAILED':
        reason = resp.get('FailureReason', '(No reason provided)').replace('\\n', '\n')
        info['compilation_job_failure_reason'] = reason
        print(reason)
    return info

In [None]:
args = 'mot --load_model ../models/fairmot_lite.pth --conf_thres 0.4 --arch yolo'.split(' ')
opt = opts().init(args)

In [None]:
model = create_model(opt.arch, opt.heads, opt.head_conv)
print(f'arch={opt.arch}, heads={opt.heads}, head_conv={opt.head_conv}')
model = load_model(model, opt.load_model)
model = model.eval()

In [None]:
# SiLU https://arxiv.org/pdf/1606.08415.pdf 
class SiLU(nn.Module):  # export-friendly version of nn.SiLU()
    @staticmethod
    def forward(x):
        return x * torch.sigmoid(x)
    
from models.common import Conv
from models.yolo import Detect

replaced_silu = 0
modified_detect = 0
num_silu = 0

# from https://github.com/ultralytics/yolov5/blob/master/export.py
for k, m in model.named_modules():
    if isinstance(m, Conv):  # assign export-friendly activations
        if isinstance(m.act, nn.SiLU):
            m.act = SiLU()
            replaced_silu += 1
    elif isinstance(m, Detect):
        m.inplace = inplace
        m.onnx_dynamic = dynamic
        modified_detect += 1
    
    # in FairMOT heads there's also silu as layers
    elif isinstance(m, nn.Sequential):
        for idx, submod in enumerate(m):
            if isinstance(submod, nn.SiLU):
                m[idx] = SiLU()
                num_silu += 1

print('No. of replaced SiLU activations:', replaced_silu)
print('No. of modified Detect layers:', modified_detect)
print('No. of replaced SiLU layers:', num_silu)

In [None]:
input_size = [1, 3, 608, 1088]
print('Exported model input size:', input_size)
dummy_input = torch.randn(*input_size)

In [None]:
print('Model output type:', type(model(dummy_input)))

In [None]:
!mkdir -p ../models

In [None]:
traced_model_filename = '../models/fairmot_lite_torchscript.pth'

traced_model = torch.jit.trace(model, dummy_input, strict=False)
traced_model.save(traced_model_filename)

print('Traced model was saved to:', traced_model_filename)

In [None]:
import tarfile

model_archive_filename = '../models/fairmot_lite_torchscript.tar.gz'
with tarfile.open(model_archive_filename, "w:gz") as f:
    f.add(traced_model_filename, arcname=os.path.basename(traced_model_filename))
print('Traced model was archived as:', model_archive_filename)

## Compile the traced model with SageMaker Neo

In [None]:
sm_client = boto3.client('sagemaker')
sm_session = sagemaker.Session()
bucket = sm_session.default_bucket()
compiler_uri_prefix = 'model_compiler/'

In [None]:
base_job_name = 'fairmot-pt-jetson-xavier'
job_name = f'{base_job_name}-{datetime.datetime.now():%Y-%m-%d-%H-%M-%S}'
job_prefix = os.path.join(compiler_uri_prefix, job_name)
model_input_prefix = os.path.join(job_prefix, 'input')
model_input_uri = sm_session.upload_data(path=model_archive_filename, key_prefix=model_input_prefix)
print(f'Compilation job input was uploaded to:\n{model_input_uri}')

In [None]:
target_device = 'jetson_xavier'
target_platform = None
compiler_options = None

role = sagemaker.get_execution_role()
framework = 'PYTORCH'
framework_version = '.'.join(torch.__version__.split('.')[:2])
output_prefix = os.path.join(job_prefix, 'output')
output_uri = f's3://{bucket}/{output_prefix}'

input_config = {
    'S3Uri': model_input_uri,
    'DataInputConfig': json.dumps({'input0': input_size}),
    'Framework': framework
}
if framework_version is not None:
    input_config['FrameworkVersion'] = framework_version

output_config = {
    'S3OutputLocation': output_uri,
}
if target_device:
    output_config['TargetDevice'] = target_device
if target_platform:
    output_config['TargetPlatform'] = target_platform
if compiler_options:
    output_config['CompilerOptions'] = json.dumps(compiler_options)
    
sm_client.create_compilation_job(
    CompilationJobName=job_name,
    RoleArn=role,
    InputConfig=input_config,
    OutputConfig=output_config,
    StoppingCondition={
        'MaxRuntimeInSeconds': 600,
        'MaxWaitTimeInSeconds': 900
    }
)

In [None]:
wait_for_compilation_job(job_name)