In [None]:
from clearml import Task, OutputModel
from cub_tools.trainer import ClearML_Ignite_Trainer
import furl
import pathlib
import tempfile
from datetime import datetime

In [None]:
model_config = '/home/edmorris/projects/image_classification/caltech_birds/scripts/configs/timm/resnext101_64x4d_config.yaml'

In [None]:
task = Task.get_task(task_id="f485f5447d5e4c1b871729649a5a39b5")

In [None]:
assert len(task.get_models()['output']) == 1,'[ERROR] More than one model detected, unable to proceed.'

In [None]:
dirname = tempfile.mkdtemp(prefix=f"ignite_torchscripts_{datetime.now().strftime('%Y_%m_%d_%H_%M_%S_')}")
temp_file_path = os.path.join(dirname,'model.pt')

In [None]:
import torch

# Local modules
from cub_tools.trainer import ClearML_Ignite_Trainer
from cub_tools.args import get_parser
from cub_tools.config import get_cfg_defaults, get_key_value_dict

In [None]:
cmd_args = [
    'DIRS.CLEAN_UP', False,     # Don't do anything to the directory structure.
    'MODEL.PRETRAINED', False,  # Don't load default weights, as we want to load our own.
    ]  
trainer = ClearML_Ignite_Trainer(task=task,config=model_config, cmd_args=cmd_args)

In [None]:
trainer.create_model(load_to_device=False)

In [None]:
# Get the best model weights file for this experiment
for chkpnt_model in trainer.task.get_models()['output']:
    print(chkpnt_model.name)
    print(chkpnt_model.url)
    if "best_model" in chkpnt_model.name:
        break

# Get the model weights file locally and update the model
local_cache_path = chkpnt_model.get_local_copy()
trainer.update_model_from_checkpoint(checkpoint_file=local_cache_path)

In [None]:
# Get a sample dataset for running inference with
trainer.create_datatransforms()
trainer.create_dataloaders(shuffle={'train' : True, 'test' : True})

In [None]:
# Create an image batch
X, y = next(iter(trainer.val_loader))
# Push the input images to the device
X = X.to(trainer.device)
# Trace the model
traced_module = torch.jit.trace(trainer.model, (X))
# Write the trace module of the model to disk
traced_module.save(temp_file_path) ### TODO: Need to work out where this is saved, and how to push to an artefact.

In [None]:
# Build the remote location of the torchscript file, based on the best model weights
# Create furl object of existing model weights
model_furl = furl.furl(chkpnt_model.url)
# Strip off the model path
model_path = pathlib.Path(model_furl.pathstr)
# Get the existing model weights name, and split the name from the file extension.
file_split = os.path.splitext(model_path.name)
# Create the torchscript filename
#if fname is None:
fname = file_split[0]+"_torchscript"+file_split[1]
# Construct the new full uri with the new filename
new_model_furl = furl.furl(origin=model_furl.origin, path=str(model_path.parent))

In [None]:
# Upload the torchscript model file to the clearml-server
new_output_model = OutputModel(task=trainer.task)

In [None]:
new_output_model.update_weights(
            weights_filename=temp_file_path,
            target_filename=fname,
            upload_uri=new_model_furl.url
            )