In [1]:
import torch
from transformers import BertTokenizer, BertForSequenceClassification

import random
import numpy as np

### Set Up CUDA Device

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
torch.cuda.get_device_name(0)

SEED = 19

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if device == torch.device("cuda"):
    torch.cuda.manual_seed_all(SEED)

### Set up Tokenizer and Max Input Length

In [3]:
MAX_LEN = 256
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased',do_lower_case=True)

In [4]:
def prep_input(input_text: str):
    input_ids = tokenizer.encode(input_text, add_special_tokens=True,max_length=MAX_LEN,padding="max_length",truncation=True)
    attention_mask = [float(i>0) for i in input_ids]
    return input_ids, attention_mask

### Test Model

In [5]:
class TestModule(torch.nn.Module):
    def __init__(self, N, M):
        super(TestModule, self).__init__()
        self.weight = torch.nn.Parameter(torch.rand(N,M))

    def forward(self, input):
        output = self.weight + input
        return output

test_model = TestModule(10,20)
traced_test_model = torch.jit.trace(test_model, torch.rand(10,20))
torch.jit.save(traced_test_model, "cpp_implementation/model/traced_test_model.pt")

In [6]:
traced_test_model(torch.rand(10,20))

tensor([[1.3234, 0.4146, 1.5300, 1.0233, 0.9296, 0.9264, 1.2571, 0.2846, 1.1305,
         1.8966, 1.0159, 0.8344, 0.3349, 0.1053, 1.4521, 0.8973, 1.0102, 0.0964,
         1.0318, 1.0100],
        [0.2821, 0.9347, 0.9755, 1.0972, 0.7162, 0.7428, 1.5503, 0.7036, 0.6201,
         1.0893, 1.8649, 1.2398, 1.3660, 1.8776, 1.0299, 0.2941, 0.5068, 1.2674,
         0.3258, 1.6474],
        [0.5563, 0.6325, 1.0845, 0.7499, 0.7171, 1.3896, 1.4508, 0.2293, 1.4863,
         0.8581, 0.8256, 0.5737, 1.3381, 0.3949, 0.1705, 0.8039, 0.7444, 1.1505,
         0.9634, 1.4007],
        [0.6120, 1.5510, 0.9598, 0.9128, 1.0110, 1.1201, 1.5673, 0.7439, 1.5159,
         1.6113, 0.3120, 0.8220, 1.0192, 1.3265, 0.4369, 1.4185, 1.0756, 0.9926,
         0.9987, 0.9466],
        [0.4969, 0.9322, 1.4344, 1.1974, 1.4854, 1.2386, 0.7164, 0.9035, 1.4181,
         1.4654, 0.1916, 0.7707, 0.6331, 1.4596, 0.1925, 1.8217, 0.2660, 1.6215,
         1.5253, 1.4100],
        [1.0160, 1.4793, 0.7380, 1.0065, 1.3689, 1.1495, 0.9

### TorchScript Tracing

In [7]:
# Tokenizing input text and creating attention mask
text = "I am feeling awfully sad right now."
input_ids, attention_mask = prep_input(text)

# Creating a dummy input
input_tensor = torch.tensor([input_ids])
attention_tensor = torch.tensor([attention_mask])
dummy_input = [input_tensor, attention_tensor]

# Initializing the model with torchscript flag
model = BertForSequenceClassification.from_pretrained("model/", torchscript=True)
model.eval()

# Creating the trace
traced_model = torch.jit.trace(model, [input_tensor, attention_tensor])
torch.jit.save(traced_model, "cpp_implementation/model/traced_model.pt")

In [11]:
print(input_tensor.size())
print(attention_tensor.size())

torch.Size([1, 256])
torch.Size([1, 256])


In [8]:
traced_model(input_tensor, attention_tensor)

(tensor([[-1.9769, -1.6921, -1.8883, -0.9651,  7.1175, -1.5798]],
        grad_fn=<AddmmBackward0>),)

In [9]:
text_1 = "i can't be sad about that. [ unaffable"
print(prep_input(text_1))

([101, 1045, 2064, 1005, 1056, 2022, 6517, 2055, 2008, 1012, 1031, 14477, 20961, 3468, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0