## Tracing Resnet Model

In [None]:
import torch
import torchvision

#### Load Model

In [None]:
model = torchvision.models.resnet18(pretrained=True).cuda()
model.eval()

#### Generate Example Tensor

In [None]:
# You may either generate random tensor (with the appropriate dimensions), or load an example data from your dataset

example = torch.rand(1, 3, 224, 224).cuda() # Example of an image of pixel 224x224, 3 channels, batch size of 1.

#### Trace Model Using Example Tensor

In [None]:
traced_script_module = torch.jit.trace(model, example)

##### Verify Traced Model

In [None]:
# Verify that output of traced model is identical to original model
orig_output = model(example)
traced_output = traced_script_module(example)
print(torch.all(orig_output.eq(traced_output)))

In [None]:
# Save traced model
traced_script_module.save("model.pt")

## Tracing BERT Model

In [None]:
import torch
from transformers import BertModel, BertTokenizer, BertConfig

In [None]:
enc = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained("bert-base-uncased", torchscript=True)
model.eval()

In [None]:
# Tokenizing input text
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = enc.tokenize(text)

# Masking one of the input tokens
masked_index = 8
tokenized_text[masked_index] = "[MASK]"
indexed_tokens = enc.convert_tokens_to_ids(tokenized_text)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]

# Creating a dummy input
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])
dummy_input = [tokens_tensor, segments_tensors]



In [None]:
# Creating the trace
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])

In [None]:
# Verify that output of traced model is identical to original model
orig_output = model(tokens_tensor, segments_tensors)
traced_output = traced_model(tokens_tensor, segments_tensors)
print(torch.all(orig_output[0].eq(traced_output[0])))
print(torch.all(orig_output[1].eq(traced_output[1])))

In [None]:
# Save traced model
torch.jit.save(traced_model, "model.pt")