# TorchScript

In [1]:
from pathlib import Path

import torch

from icecube.model.lstm import LSTM

In [2]:
BS = 1
NUM_LENGTHS = 96
NUM_FEATURES = 6

model_args = {
    "input_size":    6,
    "hidden_size":   192,
    "num_bins":      24,
    "num_layers":    3,
    "bias":          False,
    "batch_first":   True,
    "dropout":       0,
    "bidirectional": True,
    "task":          "clf",
    "net_name":      "lstm",
}

model_dir = Path('../../models/lstm/')
models = [
    'gru_mae_1.026_epoch_052.ckpt',
    'lstm_mae_1.026_epoch_058.ckpt',
]


def save_model(model_path, inputs=torch.randn(BS, NUM_LENGTHS, NUM_FEATURES), model_args=model_args, device=torch.device("cuda"), test=True):
    model_type = model_path.stem.split("_")[0]
    model_args['net_name'] = model_type
    print(f"Loading {model_type} from {model_path}")

    model = LSTM(**model_args).eval().to(device)
    inputs = inputs.to(device)

    state_dict = torch.load(model_path, map_location=device)['state_dict']
    model.load_state_dict(state_dict)

    script_path = model_path.parent / (model_path.stem + '.script')
    script = model.to_torchscript(file_path=script_path, method='trace', example_inputs=inputs)

    print(f"Save model to {script_path}")

    if test:
        torch.testing.assert_close(
            actual=script(inputs),
            expected=model(inputs),
            msg="Test failed!"       
        )
        print('Test passed!')
    return script_path


In [3]:
script_paths = [save_model(model_dir / m) for m in models]

Loading gru from ../../models/lstm/gru_mae_1.026_epoch_052.ckpt
Save model to ../../models/lstm/gru_mae_1.026_epoch_052.script
Test passed!
Loading lstm from ../../models/lstm/lstm_mae_1.026_epoch_058.ckpt
Save model to ../../models/lstm/lstm_mae_1.026_epoch_058.script
Test passed!
