In [1]:
import torch
from transformers import BertTokenizer, BertForSequenceClassification

import random
import numpy as np

### Set Up CUDA Device

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
torch.cuda.get_device_name(0)

SEED = 19

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if device == torch.device("cuda"):
    torch.cuda.manual_seed_all(SEED)

### Set up Tokenizer and Max Input Length

In [3]:
MAX_LEN = 256
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased',do_lower_case=True)

In [5]:
def prep_input(input_text: str):
    input_ids = tokenizer.encode(input_text, add_special_tokens=True,max_length=MAX_LEN,padding="max_length",truncation=True)
    attention_mask = [float(i>0) for i in input_ids]
    return input_ids, attention_mask

In [5]:
# model = BertForSequenceClassification.from_pretrained("model/")
# model.eval();

### TorchScript Tracing

In [12]:
# Tokenizing input text and creating attention mask
text = "I am feeling awfully sad right now."
input_ids, attention_mask = prep_input(text)

# Creating a dummy input
input_tensor = torch.tensor([input_ids])
attention_tensor = torch.tensor([attention_mask])
dummy_input = [input_tensor, attention_tensor]

# Initializing the model with torchscript flag
model = BertForSequenceClassification.from_pretrained("model/", torchscript=True)
model.eval()

# Creating the trace
traced_model = torch.jit.trace(model, [input_tensor, attention_tensor])
torch.jit.save(traced_model, "cpp_implementation/traced_model.pt")

In [11]:
traced_model(input_tensor, attention_tensor)

(tensor([[-1.9769, -1.6921, -1.8883, -0.9651,  7.1175, -1.5798]],
        grad_fn=<AddmmBackward0>),)

In [13]:
text = "I am feeling awfully sad right now."
print(prep_input(text))

([101, 1045, 2572, 3110, 9643, 2135, 6517, 2157, 2085, 1012, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.