# RoBERTa Text Classifier (PyTorch)
Source: https://pytorch.org/text/stable/tutorials/sst2_classification_non_distributed.html

In [22]:
ENVIRONMENT = 'gcp'  # gcp, colab, or local
if ENVIRONMENT == 'colab':
    from google.colab import drive
    drive.mount('/content/drive')
    %cd '/content/drive/MyDrive/deep-learning-project/roberta_text_classifier'
    data_path = '../data'
elif ENVIRONMENT == 'gcp':
    !pip install torchtext
    data_path = 'gs://hateful_memes/hateful_memes'
else:
    data_path = '../data'



In [16]:
import numpy as np
import torch
import torch.nn as nn
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

## Data Transformation

In [17]:
import torchtext.transforms as T
from torch.hub import load_state_dict_from_url

class PadTransform(nn.Module):
  """Pad tensor to fixed length"""
  def __init__(self, max_length: int, pad_value: int):
    super().__init__()
    self.max_length = max_length
    self.pad_value = pad_value

  def forward(self, x):
    """
    :param x: The tensor to pad
    :type x: Tensor
    :return: Tensor padded up to max_length with pad value
    :rtype: Tensor
    """
    if type(x) == list:
      while len(x) < self.max_length:
        x.append(self.pad_value)
    else:
      max_encoded_length = x.size(-1)
      if max_encoded_length < self.max_length:
        pad_amount = self.max_length - max_encoded_length
        x = torch.nn.functional.pad(x, (0, pad_amount), value=self.pad_value)
    return x

padding_idx = 1
bos_idx = 0
eos_idx = 2
max_seq_len = 50  # 256
xlmr_vocab_path = r"https://download.pytorch.org/models/text/xlmr.vocab.pt"
xlmr_spm_model_path = r"https://download.pytorch.org/models/text/xlmr.sentencepiece.bpe.model"

text_transform = T.Sequential(
    T.SentencePieceTokenizer(xlmr_spm_model_path),
    T.VocabTransform(load_state_dict_from_url(xlmr_vocab_path)),
    T.Truncate(max_seq_len - 2),
    T.AddToken(token=bos_idx, begin=True),
    T.AddToken(token=eos_idx, begin=False),
    PadTransform(max_length=max_seq_len, pad_value=padding_idx)
)

# Alternative method
#from torchtext.models import XLMR_BASE_ENCODER
#text_transform = XLMR_BASE_ENCODER.transform()

## Create dataset

In [20]:
import pandas as pd

# Train Data
train = pd.read_csv(f'{data_path}/train_captioned.csv')
train['context'] = train['text'] + '. ' + train['caption']
train.drop(columns=['Unnamed: 0', 'text', 'caption'], inplace=True)

train = train[['context', 'label']]



In [23]:
# Validation Data
valid = pd.read_csv(f'{data_path}/dev_unseen_captioned.csv')
#valid = pd.read_csv('../data/dev_seen_captioned.csv')
#valid = pd.read_csv('../data/test_unseen_captioned.csv')
#valid = pd.read_csv('../data/test_seen_captioned.csv')
valid['context'] = valid['text'] + '. ' + valid['caption']
valid.drop(columns=['Unnamed: 0', 'text', 'caption'], inplace=True)

dev = valid[['context', 'label']]

In [24]:
from torch.utils.data import Dataset, DataLoader

class TextDataset(Dataset):
  def __init__(self, text, labels):
    self.label = labels
    self.text = text
    self.token_ids = text.apply(lambda x: text_transform(x))

  def __len__(self):
    return len(self.label)

  def __getitem__(self, idx):
    label = self.label[idx]
    text = self.text[idx]
    token_ids = self.token_ids[idx]
    sample = {"text": text, "token_ids": token_ids, "label": label}
    return sample

In [25]:
train_dataset = TextDataset(text=train['context'], labels=train['label'])
dev_dataset = TextDataset(text=dev['context'], labels=dev['label'])

In [26]:
train_dataset[7001]

{'text': 'no one: steven hawking:. a black and white photo of a computer keyboard and mouse.',
 'token_ids': [0,
  110,
  1632,
  12,
  2288,
  1353,
  6,
  187404,
  214,
  12,
  5,
  10,
  22556,
  136,
  35011,
  16186,
  111,
  10,
  13909,
  149186,
  136,
  114669,
  5,
  2,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1],
 'label': 1}

In [31]:
BATCH_SIZE = 16  # original: 16,  if None: batch size appears to be 1
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
dev_dataloader = DataLoader(dev_dataset, batch_size=BATCH_SIZE, shuffle=True)

## Model Preparation

In [32]:
num_classes = 2
input_dim = 768  # 768

from torchtext.models import RobertaClassificationHead, XLMR_BASE_ENCODER
classifier_head = RobertaClassificationHead(num_classes=num_classes, input_dim=input_dim)
model = XLMR_BASE_ENCODER.get_model(head=classifier_head)
#model = ROBERTA_BASE_ENCODER.get_model(head=classifier_head)
model.to(DEVICE)

RobertaModel(
  (encoder): RobertaEncoder(
    (transformer): TransformerEncoder(
      (token_embedding): Embedding(250002, 768, padding_idx=1)
      (layers): ModuleList(
        (0): TransformerEncoderLayer(
          (dropout): Dropout(p=0.1, inplace=False)
          (attention): MultiheadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (input_projection): Linear(in_features=768, out_features=2304, bias=True)
            (output_projection): Linear(in_features=768, out_features=768, bias=True)
          )
          (residual_mlp): ResidualMLP(
            (mlp): Sequential(
              (0): Linear(in_features=768, out_features=3072, bias=True)
              (1): GELU()
              (2): Dropout(p=0.1, inplace=False)
              (3): Linear(in_features=3072, out_features=768, bias=True)
              (4): Dropout(p=0.1, inplace=False)
            )
          )
          (attention_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)


## Training Methods

In [60]:
import torchtext.functional as F
from torch.optim import AdamW

learning_rate = 1e-3  # 1e-5

optim = AdamW(model.parameters(), lr=learning_rate)
#criteria = nn.CrossEntropyLoss()
criteria = nn.BCEWithLogitsLoss()
#sigmoid = nn.Sigmoid()
#criteria = nn.BCELoss()


# Load model checkpoint if one exists
LOAD_MODEL = True
if LOAD_MODEL:
  from os import listdir
  try:
    checkpoint = torch.load('model.pt')
    model.load_state_dict(checkpoint['model_state_dict'])
    optim.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    accuracy = checkpoint['accuracy']
  except Exception as e:
    print("Model checkpoint could not be loaded:", e)


def train_step(input, target):
    output = model(input)  # original
    #output = torch.argmax(model(input), dim=1).type(torch.float)  # use with BCELoss
    #output = model(input).requires_grad_(True)
    #preds = torch.argmax(output, dim=1).type(torch.float)#.requires_grad_(True)
    #loss = torch.sum(torch.abs(torch.sub(preds, target))).requires_grad_(True)
    #loss = criteria(preds, target).requires_grad_(True)
    loss = criteria(output, target)
    optim.zero_grad()
    loss.backward()
    optim.step()


def eval_step(input, target):
    output = model(input)  # original
    #output = torch.argmax(model(input), dim=1)
    #target = target.squeeze(1)
    #loss = torch.sum(torch.abs(torch.sub(preds, target)))
    loss = criteria(output, target).item()
    #return float(loss), (output.argmax(1) == target).type(torch.float).sum().item()
    return float(loss), (output.argmax(1) == target.argmax(1)).type(torch.float).sum().item()


def evaluate():
    model.eval()
    total_loss = 0
    correct_predictions = 0
    total_predictions = 0
    counter = 0
    with torch.no_grad():
        for batch in dev_dataloader:
            
            batch_token_ids = []
            for b in range(BATCH_SIZE):
                try:
                    token_ids_b = []
                    for i in range(max_seq_len):
                        token_ids_b.append(batch['token_ids'][i][b].item())
                    batch_token_ids.append(token_ids_b)
                except:
                    continue
            
            input = F.to_tensor(batch_token_ids, padding_value=padding_idx).to(DEVICE)
            #input = F.to_tensor(batch['token_ids'], padding_value=padding_idx).view(1,-1).to(DEVICE)
            target = torch.tensor(batch['label']).to(DEVICE)
            target = torch.nn.functional.one_hot(
                target.clone().detach(), 
                num_classes=num_classes
            ).type(torch.float).to(DEVICE)
            loss, predictions = eval_step(input, target)
            total_loss += loss
            correct_predictions += predictions
            total_predictions += len(target)
            counter += 1

    return total_loss / counter, correct_predictions / total_predictions

Model checkpoint could not be loaded: [Errno 2] No such file or directory: 'model.pt'


## Train

In [61]:
num_epochs = 5

for e in range(num_epochs):
  for batch in train_dataloader:
    
    # convert batch['token_ids'] into list of lists where each sub-list contains the tokens for sequence. 
    # This is done because tokens are coming in as list of tensors where tensor i has ith token for all sequences...
    batch_token_ids = []
    for b in range(BATCH_SIZE):
        try:
            token_ids_b = []
            for i in range(max_seq_len):
                token_ids_b.append(batch['token_ids'][i][b].item())
            batch_token_ids.append(token_ids_b)
        except:
            continue
        
    input = F.to_tensor(batch_token_ids, padding_value=padding_idx).to(DEVICE)
    #input = F.to_tensor(batch['token_ids'], padding_value=padding_idx).view(1,-1).to(DEVICE)
    target = torch.tensor(batch['label']).to(DEVICE)
    target = torch.nn.functional.one_hot(
        target.clone().detach(), 
        num_classes=num_classes
    #).view(-1, num_classes).type(torch.float).requires_grad_(True).to(DEVICE)
    ).type(torch.float).requires_grad_(True).to(DEVICE)
    train_step(input, target)

  loss, accuracy = evaluate()
  print("Epoch = [{}], loss = [{}], accuracy = [{}]".format(e, loss, accuracy))



Epoch = [0], loss = [0.6635420269825879], accuracy = [0.6296296296296297]
Epoch = [1], loss = [0.6626101241392248], accuracy = [0.6296296296296297]
Epoch = [2], loss = [0.6593696369844324], accuracy = [0.6296296296296297]
Epoch = [3], loss = [0.6587388760903302], accuracy = [0.6296296296296297]
Epoch = [4], loss = [0.6629548923057669], accuracy = [0.6296296296296297]


# Save Model Checkpoint

In [62]:
SAVE_MODEL = True

if SAVE_MODEL:

  PATH = 'model.pt'

  try:
    EPOCH = e
  except:
    EPOCH = num_epochs

  torch.save(
      {
        'epoch': EPOCH,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optim.state_dict(),
        'loss': loss,
        'accuracy': accuracy
      }, 
      PATH
  )