In [1]:
import sys
sys.path.append('..')

In [2]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams
from matplotlib.pyplot import imshow
from tqdm import tqdm
import torch
from torch.nn import Module
from torchsummary import summary

from collections import namedtuple, defaultdict
from pathlib import Path
import time

%matplotlib inline
rcParams['figure.figsize'] = (10, 15)

In [None]:
from src.constructor.config_structure import TrainConfigParams
from src.registry import TASKS
from src.constructor.data import create_dataset
from src.constructor.trainer import download_s3_artifact
from train import load_config

In [4]:
def load_hparams(config_path):
    hparams = load_config(config_path)
    hparams = TrainConfigParams(**hparams)
    
    return hparams

In [None]:
config_path = '../configs/siloiz_pairwise_xbm_resnet50_adam_512d.yml'
config = load_hparams(config_path)

In [None]:
# Prevent XBM from loading
config.task.params['xbm_params'] = None

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
checkpoint = torch.load(download_s3_artifact('s3://ml-fips/mlruns/eedfa4c178bd44f8bc6c44cdcc0f4814/artifacts/last.ckpt', '/tmp'), map_location=device)
model = TASKS.get(config.task.name)(config)
model = model.to(device)
model.load_state_dict(checkpoint['state_dict'])
model.eval();

In [None]:
traced = torch.jit.trace(model.forward, torch.rand(4, 3, 224, 224, device=device))

In [None]:
traced.save('/mnt/ext1/fips-benchmark/models/siloiz_contrastive_xbm_resnet50_adam_512d.pt')