In [1]:
import sys
sys.path.append('../../')

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams
from matplotlib.pyplot import imshow
from tqdm import tqdm
import torch
from torch.nn import Module
from torchsummary import summary
from collections import namedtuple, defaultdict
from pathlib import Path
from omegaconf import OmegaConf, DictConfig
import time

%matplotlib inline
rcParams['figure.figsize'] = (10, 15)

In [None]:
from src.constructor.config_structure import ConfigParams
from src.constructor import TASKS


device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Path to yaml
config_path = 'examples/configs/representation_arcface_sop.yaml'
# Load yaml
config = OmegaConf.load(config_path)
OmegaConf.resolve(config)
schema = OmegaConf.structured(ConfigParams)
config = OmegaConf.merge(schema, config)

# Path to checkpoint
checkpoint = torch.load('..logs/sop_arcface/sop_arcface/2022-06-28_19-33-57/16-15810.ckpt')

# Create task
model = TASKS.get(config.task.name)(config)
model = model.to(device)
model.load_state_dict(checkpoint['state_dict'])
model.eval();

In [None]:
# Create vectors on validation dataset
dataloader = model.val_dataloader()[0]

targets = []
vectors = []
for batch in tqdm(dataloader):
    targets.append(batch['target'])
    batch['image'] = batch['image'].to(device)
    vectors.append(model.forward_with_gt(batch)['embeddings'].detach().cpu())

In [None]:
# Save traced model
traced = torch.jit.trace(model.forward, torch.rand(4, 3, 224, 224, device=device))
traced.save('sop_resnet50_arcface.pt')