In [1]:
!gdown 13d-tjvmDbKKrMfCvsJzaQ7pgHqQykJYB
!gdown 1rWTsDh0wEFIQFiSkWxBqVOa_V_J5EhBq

Downloading...
From: https://drive.google.com/uc?id=13d-tjvmDbKKrMfCvsJzaQ7pgHqQykJYB
To: /content/full_data.json
100% 2.50M/2.50M [00:00<00:00, 35.1MB/s]
Downloading...
From: https://drive.google.com/uc?id=1rWTsDh0wEFIQFiSkWxBqVOa_V_J5EhBq
To: /content/labels.torch
100% 31.4k/31.4k [00:00<00:00, 44.4MB/s]


In [23]:
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import json
from transformers import AutoTokenizer, BertModel
import tqdm
import torch.nn as nn
import math
import random

class OurDataset(Dataset):
    def __init__(self, data_file, labels_file):
        self.full_data = json.load(open(data_file))
        self.labels = torch.load(labels_file)

        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        self.model = BertModel.from_pretrained("bert-base-uncased")

    def __len__(self):
        return len(self.full_data)

    def __getitem__(self, idx):
        inputs = self.tokenizer(self.full_data[idx], return_tensors="pt")
        outputs = self.model(**inputs)
        last_hidden_states = outputs.last_hidden_state
        return last_hidden_states, self.labels[idx]

class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

class TransformerClassifier(nn.Module):

    def __init__(
        self,
        d_model=768,
        n_classes=3,
        nhead=4,
        dim_feedforward=512,
        num_layers=6,
        dropout=0.1,
        activation="relu",
        classifier_dropout=0.1,
    ):

        super().__init__()

        self.pos_encoder = PositionalEncoding(
            d_model=d_model,
            dropout=dropout,
            max_len=5000,
        )

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_layers,
        )
        self.head = nn.Sequential(
            nn.Linear(d_model, 256),
            nn.ReLU(),
            nn.Linear(256, n_classes),
            nn.Softmax(dim=1)
        )

        self.d_model = d_model
        self.dropout = nn.Dropout(p=classifier_dropout)

    def forward(self, x):
        x = self.pos_encoder(x)
        x = self.transformer_encoder(x)
        x = x.mean(dim=1)
        x = self.dropout(x)
        x = self.head(x)

        return x

In [24]:
ds = OurDataset("full_data.json", "labels.torch")
torch.manual_seed(29592)  # set the seed for reproducibility
# device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
device='cpu'

In [25]:
def train_model(model):
  correct = 0
  loss_fn = torch.nn.CrossEntropyLoss()
  optimizer = torch.optim.SGD(model.parameters(), lr=0.00001, momentum =0.9)
  for i in range(int(ds.__len__())):
    idx = random.randint(0, ds.__len__()-1)
    x, y = ds.__getitem__(idx)
    y = y.long()
    x = x.to(device)
    y = y.to(device)
    output = model(x)
    correct += output.argmax().eq(y).sum()
    loss = loss_fn(output, y)
    if i%100==0:
      print(i, correct, loss.item())
    loss.backward()
    optimizer.step()
  acc = correct/int((ds.__len__()))
  print(acc)

In [26]:
def eval(model):
  model.eval()
  correct = 0
  with torch.no_grad():
    for i in range(int(ds.__len__())):
      idx = random.randint(0, ds.__len__()-1)
      x, y = ds.__getitem__(idx)
      y = y.long()
      x = x.to(device)
      y = y.to(device)
      output = model(x)
      correct += output.argmax().eq(y).sum()
      if i%100==0:
        print(correct)
  acc = correct/int((ds.__len__()))
  return acc

In [27]:
import time
def time_model_evaluation(model):
    s = time.time()
    loss = eval(model)
    elapsed = time.time() - s
    print('''accuracy: {0:.3f}\nelapsed time (seconds): {1:.1f}'''.format(loss, elapsed))

import os
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

In [28]:
t = TransformerClassifier()
t.to(device)
print_size_of_model(t)

Size (MB): 91.849724


In [31]:
train_model(t)

0 tensor(1) 1.0495038032531738
100 tensor(69) 0.5517643690109253
200 tensor(126) 1.5514447689056396
300 tensor(200) 1.5514447689056396
400 tensor(266) 0.5514447093009949
500 tensor(335) 0.5514447093009949
600 tensor(402) 1.5514447689056396
700 tensor(469) 0.5514447093009949
800 tensor(523) 0.5514447093009949
900 tensor(589) 0.5514447093009949
1000 tensor(647) 1.5514447689056396
1100 tensor(711) 0.5514447093009949
1200 tensor(772) 0.5514447093009949
1300 tensor(836) 0.5514447093009949
1400 tensor(903) 0.5514447093009949
1500 tensor(962) 0.5514447093009949
1600 tensor(1028) 0.5514447093009949
1700 tensor(1096) 0.5514447093009949
1800 tensor(1164) 0.5514447093009949
1900 tensor(1231) 1.5514447689056396
2000 tensor(1296) 1.5514447689056396
2100 tensor(1356) 0.5514447093009949
2200 tensor(1424) 0.5514447093009949
2300 tensor(1497) 0.5514447093009949
2400 tensor(1557) 0.5514447093009949
2500 tensor(1620) 1.5514447689056396
2600 tensor(1682) 0.5514447093009949
2700 tensor(1754) 0.551444709300

In [32]:
import torch.quantization

quantized_model = torch.quantization.quantize_dynamic(
    t, {nn.Linear}, dtype=torch.qint8
)

In [33]:
print_size_of_model(quantized_model)

Size (MB): 77.112473


In [34]:
torch.set_num_threads(1)
time_model_evaluation(t)

tensor(1)
tensor(71)
tensor(136)
tensor(197)
tensor(271)
tensor(332)
tensor(400)
tensor(462)
tensor(530)
tensor(590)
tensor(654)
tensor(713)
tensor(774)
tensor(840)
tensor(897)
tensor(953)
tensor(1014)
tensor(1090)
tensor(1153)
tensor(1216)
tensor(1276)
tensor(1341)
tensor(1410)
tensor(1472)
tensor(1537)
tensor(1599)
tensor(1655)
tensor(1725)
tensor(1787)
tensor(1855)
tensor(1916)
tensor(1983)
tensor(2047)
tensor(2117)
tensor(2182)
tensor(2249)
tensor(2314)
tensor(2382)
tensor(2448)
tensor(2516)
tensor(2577)
tensor(2645)
tensor(2707)
tensor(2769)
tensor(2829)
tensor(2896)
tensor(2946)
tensor(3013)
tensor(3083)
tensor(3147)
tensor(3211)
tensor(3271)
tensor(3335)
tensor(3398)
tensor(3461)
tensor(3522)
tensor(3579)
tensor(3642)
tensor(3707)
tensor(3770)
tensor(3835)
tensor(3896)
tensor(3965)
tensor(4035)
tensor(4104)
tensor(4171)
tensor(4239)
tensor(4303)
tensor(4366)
tensor(4429)
tensor(4484)
tensor(4545)
tensor(4608)
tensor(4674)
tensor(4738)
tensor(4805)
tensor(4863)
accuracy: 0.640
el

In [35]:
time_model_evaluation(quantized_model)

tensor(1)
tensor(67)
tensor(131)
tensor(194)
tensor(260)
tensor(317)
tensor(376)
tensor(451)
tensor(522)
tensor(583)
tensor(654)
tensor(712)
tensor(775)
tensor(839)
tensor(904)
tensor(962)
tensor(1022)
tensor(1085)
tensor(1152)
tensor(1218)
tensor(1280)
tensor(1341)
tensor(1411)
tensor(1479)
tensor(1544)
tensor(1606)
tensor(1661)
tensor(1721)
tensor(1777)
tensor(1847)
tensor(1917)
tensor(1982)
tensor(2052)
tensor(2113)
tensor(2178)
tensor(2232)
tensor(2296)
tensor(2366)
tensor(2427)
tensor(2495)
tensor(2551)
tensor(2611)
tensor(2672)
tensor(2743)
tensor(2804)
tensor(2868)
tensor(2932)
tensor(2996)
tensor(3055)
tensor(3123)
tensor(3183)
tensor(3253)
tensor(3313)
tensor(3381)
tensor(3440)
tensor(3500)
tensor(3566)
tensor(3632)
tensor(3694)
tensor(3752)
tensor(3820)
tensor(3882)
tensor(3943)
tensor(4001)
tensor(4072)
tensor(4137)
tensor(4206)
tensor(4275)
tensor(4340)
tensor(4395)
tensor(4467)
tensor(4525)
tensor(4587)
tensor(4648)
tensor(4718)
tensor(4783)
tensor(4850)
accuracy: 0.638
el