In [None]:
!pip install gdown

In [None]:
import gdown

url = 'https://drive.google.com/uc?id=1gaEWwTnNEOebN2aTLCDXePVQwkvBwl8W'

output = 'MIDV-2020-Text.zip'
gdown.download(url, output)

In [None]:
!unzip MIDV-2020-Text.zip >/dev/null

In [None]:
# Import necessary packages
from tqdm import tqdm
from torch import nn
from torch.optim import Adam
from transformers import BertModel
from transformers import BertTokenizer
import os
import copy
import torch
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

%matplotlib inline

In [None]:
# Define the name of the dataset
DATASET_NAME = 'MIDV-2020-Text'

# Determine the device to be used for training and training
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Determine if we will be pinning memory during data loading
PIN_MEMORY = True if DEVICE == 'cuda' else False

# HYPERPARAMETERS
# Initialize learning rate, number of epochs, batch size, and momentum
LEARNING_RATE = 1e-6
NUM_EPOCHS = 5
BATCH_SIZE = 2

# Load the labels
with open(os.path.join(DATASET_NAME, 'labels.pkl'), 'rb') as f:
  LABELS = pickle.load(f)

In [None]:
# Define the dataset class
class Dataset(torch.utils.data.Dataset):
  def __init__(self, df, tokenizer):
    self.labels = [LABELS[label] for label in df['category']]
    self.texts = [tokenizer(text,
                            padding='max_length', max_length=512, truncation=True,
                            return_tensors='pt') for text in df['text']]

  def classes(self):
    return self.labels

  def __len__(self):
    return len(self.labels)
  
  def get_batch_labels(self, idx):
    return np.array(self.labels[idx])
  
  def get_batch_texts(self, idx):
    return self.texts[idx]
  
  def __getitem__(self, idx):
    batch_texts = self.get_batch_texts(idx)
    batch_y = self.get_batch_labels(idx)
    return batch_texts, batch_y

In [None]:
# Define the BERT classifier
class BertClassifier(nn.Module):
  def __init__(self):
    super(BertClassifier, self).__init__()
    self.bert = BertModel.from_pretrained('bert-base-cased')
    self.out = nn.Linear(768, len(LABELS))
  
  def forward(self, input_id, mask):
    _, o2 = self.bert(input_ids=input_id, attention_mask=mask, return_dict=False)
    out = self.out(o2)
    return out

In [None]:
# Load the train and test data
df_train = pd.read_csv(os.path.join(DATASET_NAME, 'train.csv'))
df_test = pd.read_csv(os.path.join(DATASET_NAME, 'test.csv'))

# Create the BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

# Create the train and test datasets
train_ds = Dataset(df_train, tokenizer)
test_ds = Dataset(df_test, tokenizer)

# Create the train and test data loaders
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                                           num_workers=os.cpu_count(), pin_memory=PIN_MEMORY)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False,
                                          num_workers=os.cpu_count(), pin_memory=PIN_MEMORY)

In [None]:
# Initialize the BERT classifer
model = BertClassifier().to(DEVICE)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=LEARNING_RATE)

# Initialize a dictionary to store loss history
loss_history = {'train_loss': [], 'test_loss': []}

# Initialize a dictionary to store accuracy history
accuracy_history = {'train_accuracy': [], 'test_accuracy': []}

# Variables to store the best trained model
best_model_weights = copy.deepcopy(model.state_dict())
best_epoch = 0
best_loss = float('inf')
best_accuracy = 0.0

# Train the model
for e in range(NUM_EPOCHS):
  total_acc_train = 0
  total_loss_train = 0

  for train_input, train_label in tqdm(train_loader):
    train_label = train_label.type(torch.LongTensor)
    train_label = train_label.to(DEVICE)
    mask = train_input['attention_mask'].to(DEVICE)
    input_id = train_input['input_ids'].squeeze(1).to(DEVICE)

    output = model(input_id, mask)

    batch_loss = criterion(output, train_label)
    total_loss_train += batch_loss.item()

    acc = (output.argmax(dim=1) == train_label).sum().item()
    total_acc_train += acc

    model.zero_grad()
    batch_loss.backward()
    optimizer.step()
  
  total_acc_test = 0
  total_loss_test = 0

  with torch.no_grad():
    for test_input, test_label in test_loader:
      test_label = test_label.type(torch.LongTensor)
      test_label = test_label.to(DEVICE)
      mask = test_input['attention_mask'].to(DEVICE)
      input_id = test_input['input_ids'].squeeze(1).to(DEVICE)

      output = model(input_id, mask)

      batch_loss = criterion(output, test_label)
      total_loss_test += batch_loss.item()

      acc = (output.argmax(dim=1) == test_label).sum().item()
      total_acc_test += acc
  
  train_loss = total_loss_train / len(df_train)
  train_acc = total_acc_train / len(df_train)
  test_loss = total_loss_test / len(df_test)
  test_acc = total_acc_test / len(df_test)

  loss_history['train_loss'].append(train_loss)
  accuracy_history['train_accuracy'].append(train_acc)
  loss_history['test_loss'].append(test_loss)
  accuracy_history['test_accuracy'].append(test_acc)

  print(f'Epoch: {e + 1} | Train Loss: {train_loss: .5f} \
            | Train Accuracy: {train_acc: .5f} \
            | Test Loss: {test_loss: .5f} \
            | Test Accuracy: {test_acc: .5f}')

# Serialize the model to disk
torch.save(model, DATASET_NAME + '_model.pt')

In [None]:
# Plot the loss history
plt.style.use('ggplot')
plt.figure()
plt.plot(range(1, NUM_EPOCHS + 1), loss_history['train_loss'], label='train_loss')
plt.plot(range(1, NUM_EPOCHS + 1), loss_history['test_loss'], label='test_loss')
plt.title('Training Loss vs Testing Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(loc='lower left')
plt.savefig('loss_history.png')

In [None]:
# Plot the accuracy history
plt.style.use('ggplot')
plt.figure()
plt.plot(range(1, NUM_EPOCHS + 1), accuracy_history['train_accuracy'], label='train_accuracy')
plt.plot(range(1, NUM_EPOCHS + 1), accuracy_history['test_accuracy'], label='test_accuracy')
plt.title('Training Accuracy vs Testing Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend(loc='lower left')
plt.savefig('accuracy_history.png')