In [1]:
!pip install transformers
!pip install torchmetrics
!pip install pytorch_lightning

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.20.1-py3-none-any.whl (4.4 MB)
[K     |████████████████████████████████| 4.4 MB 9.3 MB/s 
Collecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)
[K     |████████████████████████████████| 6.6 MB 64.5 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 57.3 MB/s 
[?25hCollecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.8.1-py3-none-any.whl (101 kB)
[K     |████████████████████████████████| 101 kB 11.7 MB/s 
Installing collected packages: pyyaml, tokenizers, huggingface-hub, transformers
  Attempting uninstall: pyyaml
    Found existing installation: PyYAML 3.13
    Uninstal

In [2]:
from google.colab import drive
from torch import nn
from transformers import BertTokenizer, BertModel
import torch
import csv
from torch.utils.data import Dataset
from transformers import BertTokenizer, BertModel, BertForMaskedLM, AdamW
from torchmetrics import Accuracy
import os
from torch.utils.data.dataloader import DataLoader
from torch.nn.utils.rnn import pad_sequence
import tqdm
import time
import math
import numpy as np
from transformers import DataCollatorForLanguageModeling
import pytorch_lightning as pl

In [3]:
from google.colab import drive
drive.mount('/content/drive')

data_path = '/content/drive/MyDrive/data'

train_data_path = f'{data_path}/train'
true_train_path = f'{train_data_path}/true.csv'
false_train_path = f'{train_data_path}/false.csv'

test_data_path = f'{data_path}/test'
true_test_path = f'{test_data_path}/true.csv'
false_test_path = f'{test_data_path}/false.csv'

Mounted at /content/drive


In [4]:


class TextClassificationModel(nn.Module):

    def __init__(self, embed_dim=768, num_class=1):
        super(TextClassificationModel, self).__init__()
        self.bert = BertModel.from_pretrained("bert-base-uncased").cuda()
        self.fc = nn.Linear(embed_dim, num_class).cuda()
        self.sigmoid = nn.Sigmoid().cuda()

    def forward(self, input_ids, attention_masks):
        inputs = {
            "input_ids": input_ids.cuda(), 
            "attention_mask" : attention_masks.cuda()
        }
        outputs = self.bert(**inputs)
        output = outputs.last_hidden_state
        output = output[:,0,:]
        output = self.fc(output)
        return self.sigmoid(output)

# class BertLM(pl.LightningModule):

#     def __init__(self, class_name):
#         super().__init__()
#         self.bert = BertForMaskedLM.from_pretrained('bert-base-uncased')
#         self.epoch_number = 0
#         self.class_name = class_name

#     def forward(self, input_ids, labels):
#         return self.bert(input_ids=input_ids,labels=labels)

#     def training_step(self, batch, batch_idx):
#         input_ids = batch["input_ids"]
#         labels = batch["labels"]
#         outputs = self(input_ids=input_ids, labels=labels)
#         loss = outputs[0]
#         return {"loss": loss}

#     def training_epoch_end(self, outputs):
#         super().training_epoch_end(outputs)
#         mean_loss = 0
#         n_batch  = len(outputs)
#         for i in range(n_batch):
#             mean_loss += outputs[i]['loss'].cpu().numpy() / n_batch
#         print(f"End of epoch {self.epoch_number} with mean loss '{mean_loss}' on label {self.class_name}.", "fine_tuning")
#         self.epoch_number += 1
#     def configure_optimizers(self):
#         return AdamW(self.parameters(), lr=1e-5)

# class BertLMPred(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.bert = BertForMaskedLM.from_pretrained('bert-base-uncased')

#     def forward(self, input_ids, labels=None):
#         return self.bert(input_ids=input_ids,labels=labels)


In [5]:
class SpellCheckingDataset(Dataset):

    def __init__(self, tokenizer, data_paths, labels, batch_size=32):
        self.dataset = []
        for i in range(len(data_paths)):
          data_path = data_paths[i]
          with open(data_path, 'r', encoding='utf-8') as file:
              data = csv.reader(file)
              for item in data:
                self.dataset.append((item[0], labels[i]))
        self.tokenizer = tokenizer
        self.batch_size = batch_size

    def __getitem__(self, idx):
        if (self.tokenizer == None):
          raise Exception('Tokenizer cannot be null')

        tweet, label = self.dataset[idx]
        tokenized_tweet = self.tokenizer(tweet)
        input_ids = tokenized_tweet['input_ids']
        attention_mask = tokenized_tweet['attention_mask']

        return input_ids, attention_mask, label

    def __len__(self):
      return len(self.dataset)

class LMSpellCheckingDataset(Dataset):

    def __init__(self, tokenizer, data_paths, labels, batch_size=32):
        self.dataset = []
        for i in range(len(data_paths)):
          data_path = data_paths[i]
          with open(data_path, 'r', encoding='utf-8') as file:
              data = csv.reader(file)
              for item in data:
                self.dataset.append((item[0], labels[i]))
        self.tokenizer = tokenizer
        self.batch_size = batch_size

    def __getitem__(self, idx):
        if (self.tokenizer == None):
          raise Exception('Tokenizer cannot be null')

        tweet, label = self.dataset[idx]
        tokenized_tweet = self.tokenizer(tweet)
        input_ids = tokenized_tweet['input_ids']
        attention_mask = tokenized_tweet['attention_mask']

        return torch.tensor(input_ids)

    def __len__(self):
      return len(self.dataset)

In [6]:
# CHECKPOINTS_DIR = '/content/models'

def pad_batched_sequence(batch):
    
    input_ids = [torch.tensor(item[0]) for item in batch]
    # print('input ids')
    # print(input_ids)

    attention_masks =  [torch.tensor(item[1]) for item in batch]
    # print('attention masks')
    # print(attention_masks)

    input_ids = pad_sequence(input_ids, padding_value=0, batch_first=True)
    # print('pad input ids')
    # print(input_ids)

    attention_masks = pad_sequence(attention_masks, 
                                   padding_value=0, 
                                   batch_first=True)
    # print('pad attention masks')
    # print(attention_masks)

    labels = None
    
    if batch[0][2] is not None:
        labels = torch.tensor([[item[2]] for item in batch]).double().cuda()
    
    return input_ids.cuda(), attention_masks.cuda(), labels

class SpellCheckingTrainer():

  def __init__(self,
               model,
               train_dataset,
               save_data_path,
               batch_size=32,
               epochs=20,
               lr=0.001):
        
    self.model = model
    self.epochs = epochs
    self.batch_size = batch_size
    self.train_loader = DataLoader(train_dataset, 
                                   batch_size=batch_size, 
                                   shuffle=True, 
                                   drop_last=True, 
                                   collate_fn=pad_batched_sequence)
    self.save_path = save_data_path
    self.loss_function = nn.BCELoss()
    self.optimizer = torch.optim.Adam(list(self.model.parameters()), lr=lr)        
    self.accuracy = Accuracy(num_classes=1)

  def train_one_epoch(self, epoch_index):
    running_loss = 0.
    running_accuracy = 0.
    last_loss = 0.
    threshold = torch.tensor([0.5]).cuda()

    for i, data in tqdm.tqdm(enumerate(self.train_loader), 
                             total=len(self.train_loader)):
        # Every data instance is an input + label pair
        input_ids, attention_masks, labels = data

        # Zero your gradients for every batch!
        self.optimizer.zero_grad()

        # Make predictions for this batch
        outputs = self.model(input_ids, attention_masks)

        # Compute the loss and its gradients
        loss = self.loss_function(outputs.double(), labels.double())
        loss.backward()

        # Adjust learning weights
        self.optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        
        result = (outputs > threshold).float() * 1
        running_accuracy += torch.sum(result == labels) / self.batch_size
        
        if i % 10 == 9:
            last_loss = running_loss / 10 # loss per batch
            last_accuracy = running_accuracy / 10 # accuracy per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            print('  batch {} accuracy: {}'.format(i + 1, last_accuracy))
            running_loss = 0.
            running_accuracy = 0.

    return last_loss


  def train(self):
    epoch_number = 0
    for epoch in range(self.epochs):
        print('EPOCH {}:'.format(epoch_number + 1))

        # Make sure gradient tracking is on, and do a pass over the data
        self.model.train(True)
        avg_loss = self.train_one_epoch(epoch_number)

        # We don't need gradients on to do reporting
        self.model.train(False)

        epoch_number += 1
    
    torch.save(self.model.state_dict(), self.save_path)
    
          
          

In [7]:


tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained("bert-base-uncased")

true_label = 1
false_label = 0

true_train_dataset = SpellCheckingDataset(tokenizer, [true_train_path], [true_label])
false_train_dataset = SpellCheckingDataset(tokenizer, [false_train_path], [false_label])

test_dataset = SpellCheckingDataset(tokenizer, 
                                    [true_test_path, false_test_path],
                                    [true_label, false_label])

true_saved_model_path = '/content/drive/MyDrive/models/true_bert.berm_lm'
false_saved_model_path = '/content/drive/MyDrive/models/false_bert.berm_lm'

true_trainer = SpellCheckingTrainer(TextClassificationModel(), 
                                    true_train_dataset,
                                    true_saved_model_path)

false_trainer = SpellCheckingTrainer(TextClassificationModel(), 
                                     false_train_dataset,
                                     false_saved_model_path)
 

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/420M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.trans

In [None]:
true_trainer.train()
false_trainer.train()




EPOCH 1:


  2%|▏         | 10/403 [00:03<02:13,  2.95it/s]

  batch 10 loss: 0.0809702505817
  batch 10 accuracy: 0.90625


  5%|▍         | 20/403 [00:06<02:03,  3.09it/s]

  batch 20 loss: 9.74941587150786e-05
  batch 20 accuracy: 1.0


  7%|▋         | 30/403 [00:10<02:06,  2.95it/s]

  batch 30 loss: 7.665900075088309e-05
  batch 30 accuracy: 1.0


 10%|▉         | 40/403 [00:13<02:01,  2.98it/s]

  batch 40 loss: 6.321273556593354e-05
  batch 40 accuracy: 1.0


 12%|█▏        | 50/403 [00:16<02:01,  2.91it/s]

  batch 50 loss: 5.304097863122118e-05
  batch 50 accuracy: 1.0


 15%|█▍        | 60/403 [00:20<01:56,  2.96it/s]

  batch 60 loss: 4.5547186205881256e-05
  batch 60 accuracy: 1.0


 17%|█▋        | 70/403 [00:23<01:49,  3.05it/s]

  batch 70 loss: 3.9787640177359925e-05
  batch 70 accuracy: 1.0


 20%|█▉        | 80/403 [00:26<01:51,  2.90it/s]

  batch 80 loss: 3.525267347145879e-05
  batch 80 accuracy: 1.0


 22%|██▏       | 90/403 [00:30<01:52,  2.79it/s]

  batch 90 loss: 3.170458740125791e-05
  batch 90 accuracy: 1.0


 25%|██▍       | 100/403 [00:33<01:43,  2.92it/s]

  batch 100 loss: 2.8935122533390755e-05
  batch 100 accuracy: 1.0


 27%|██▋       | 110/403 [00:37<01:43,  2.82it/s]

  batch 110 loss: 2.6636554094839917e-05
  batch 110 accuracy: 1.0


 30%|██▉       | 120/403 [00:40<01:41,  2.80it/s]

  batch 120 loss: 2.491654307511618e-05
  batch 120 accuracy: 1.0


 32%|███▏      | 130/403 [00:44<01:36,  2.83it/s]

  batch 130 loss: 2.305795099310669e-05
  batch 130 accuracy: 1.0


 35%|███▍      | 140/403 [00:47<01:31,  2.87it/s]

  batch 140 loss: 2.1738423486908814e-05
  batch 140 accuracy: 1.0


 37%|███▋      | 150/403 [00:51<01:29,  2.81it/s]

  batch 150 loss: 2.0418897750883825e-05
  batch 150 accuracy: 1.0


 40%|███▉      | 160/403 [00:55<01:27,  2.77it/s]

  batch 160 loss: 1.9377658321589114e-05
  batch 160 accuracy: 1.0


 42%|████▏     | 170/403 [00:58<01:20,  2.90it/s]

  batch 170 loss: 1.8309597432919868e-05
  batch 170 accuracy: 1.0


 45%|████▍     | 180/403 [01:02<01:15,  2.94it/s]

  batch 180 loss: 1.7340259636269857e-05
  batch 180 accuracy: 1.0


 47%|████▋     | 190/403 [01:05<01:19,  2.69it/s]

  batch 190 loss: 1.6409666386565475e-05
  batch 190 accuracy: 1.0


 50%|████▉     | 200/403 [01:09<01:14,  2.72it/s]

  batch 200 loss: 1.5627715316709767e-05
  batch 200 accuracy: 1.0


 52%|█████▏    | 210/403 [01:12<01:08,  2.81it/s]

  batch 210 loss: 1.4994406221393324e-05
  batch 210 accuracy: 1.0


 55%|█████▍    | 220/403 [01:16<01:06,  2.76it/s]

  batch 220 loss: 1.4252690068359558e-05
  batch 220 accuracy: 1.0


 57%|█████▋    | 230/403 [01:20<01:06,  2.59it/s]

  batch 230 loss: 1.3673771893515073e-05
  batch 230 accuracy: 1.0


 60%|█████▉    | 240/403 [01:23<00:57,  2.82it/s]

  batch 240 loss: 1.3108637743068882e-05
  batch 240 accuracy: 1.0


 62%|██████▏   | 250/403 [01:27<00:52,  2.90it/s]

  batch 250 loss: 1.2597148740918153e-05
  batch 250 accuracy: 1.0


 65%|██████▍   | 260/403 [01:30<00:50,  2.84it/s]

  batch 260 loss: 1.2024564569842986e-05
  batch 260 accuracy: 1.0


 67%|██████▋   | 270/403 [01:34<00:49,  2.67it/s]

  batch 270 loss: 1.1501155152340908e-05
  batch 270 accuracy: 1.0


 69%|██████▉   | 280/403 [01:38<00:45,  2.68it/s]

  batch 280 loss: 1.1081682658301286e-05
  batch 280 accuracy: 1.0


 72%|███████▏  | 290/403 [01:41<00:37,  2.99it/s]

  batch 290 loss: 1.0680464593949693e-05
  batch 290 accuracy: 1.0


 74%|███████▍  | 300/403 [01:45<00:34,  2.96it/s]

  batch 300 loss: 1.025801221713314e-05
  batch 300 accuracy: 1.0


 77%|███████▋  | 310/403 [01:48<00:32,  2.90it/s]

  batch 310 loss: 9.864990090937238e-06
  batch 310 accuracy: 1.0


 79%|███████▉  | 320/403 [01:52<00:28,  2.87it/s]

  batch 320 loss: 9.444773294338556e-06
  batch 320 accuracy: 1.0


 82%|████████▏ | 330/403 [01:55<00:24,  3.00it/s]

  batch 330 loss: 9.10316099819139e-06
  batch 330 accuracy: 1.0


 84%|████████▍ | 340/403 [01:58<00:21,  2.98it/s]

  batch 340 loss: 8.711256937489243e-06
  batch 340 accuracy: 1.0


 87%|████████▋ | 350/403 [02:02<00:18,  2.90it/s]

  batch 350 loss: 8.334254327664292e-06
  batch 350 accuracy: 1.0


 89%|████████▉ | 360/403 [02:05<00:14,  2.94it/s]

  batch 360 loss: 8.08316768123915e-06
  batch 360 accuracy: 1.0


 92%|█████████▏| 370/403 [02:09<00:11,  2.75it/s]

  batch 370 loss: 7.80898418846819e-06
  batch 370 accuracy: 1.0


 94%|█████████▍| 380/403 [02:12<00:07,  2.99it/s]

  batch 380 loss: 7.509841063534641e-06
  batch 380 accuracy: 1.0


 97%|█████████▋| 390/403 [02:16<00:04,  2.79it/s]

  batch 390 loss: 7.2412456212096496e-06
  batch 390 accuracy: 1.0


 99%|█████████▉| 400/403 [02:19<00:01,  2.72it/s]

  batch 400 loss: 7.040451081027228e-06
  batch 400 accuracy: 1.0


100%|██████████| 403/403 [02:21<00:00,  2.86it/s]


EPOCH 2:


  2%|▏         | 10/403 [00:03<02:24,  2.73it/s]

  batch 10 loss: 6.685428441123563e-06
  batch 10 accuracy: 1.0


  5%|▍         | 20/403 [00:07<02:22,  2.69it/s]

  batch 20 loss: 6.439557668171569e-06
  batch 20 accuracy: 1.0


  7%|▋         | 30/403 [00:10<02:15,  2.76it/s]

  batch 30 loss: 6.238018198450572e-06
  batch 30 accuracy: 1.0


 10%|▉         | 40/403 [00:14<02:10,  2.79it/s]

  batch 40 loss: 5.998108032219447e-06
  batch 40 accuracy: 1.0


 12%|█▏        | 50/403 [00:18<02:12,  2.66it/s]

  batch 50 loss: 5.739943863123175e-06
  batch 50 accuracy: 1.0


 15%|█▍        | 60/403 [00:21<02:01,  2.81it/s]

  batch 60 loss: 5.600244717112281e-06
  batch 60 accuracy: 1.0


 17%|█▋        | 70/403 [00:25<01:55,  2.89it/s]

  batch 70 loss: 5.362197343447454e-06
  batch 70 accuracy: 1.0


 20%|█▉        | 80/403 [00:28<01:50,  2.93it/s]

  batch 80 loss: 5.183754988794325e-06
  batch 80 accuracy: 1.0


 22%|██▏       | 90/403 [00:32<01:51,  2.81it/s]

  batch 90 loss: 5.026174404764259e-06
  batch 90 accuracy: 1.0


 25%|██▍       | 100/403 [00:36<01:58,  2.55it/s]

  batch 100 loss: 4.839163919218084e-06
  batch 100 accuracy: 1.0


 27%|██▋       | 110/403 [00:39<01:49,  2.67it/s]

  batch 110 loss: 4.668172264800533e-06
  batch 110 accuracy: 1.0


 30%|██▉       | 120/403 [00:43<01:39,  2.83it/s]

  batch 120 loss: 4.541884359620236e-06
  batch 120 accuracy: 1.0


 32%|███▏      | 130/403 [00:47<01:36,  2.84it/s]

  batch 130 loss: 4.364187216549926e-06
  batch 130 accuracy: 1.0


 35%|███▍      | 140/403 [00:50<01:28,  2.96it/s]

  batch 140 loss: 4.177176813471847e-06
  batch 140 accuracy: 1.0


 37%|███▋      | 150/403 [00:53<01:24,  2.98it/s]

  batch 150 loss: 4.037105333499139e-06
  batch 150 accuracy: 1.0


 40%|███▉      | 160/403 [00:57<01:24,  2.88it/s]

  batch 160 loss: 3.976382892578533e-06
  batch 160 accuracy: 1.0


 42%|████▏     | 170/403 [01:00<01:18,  2.97it/s]

  batch 170 loss: 3.7826670046668317e-06
  batch 170 accuracy: 1.0


 45%|████▍     | 180/403 [01:04<01:15,  2.95it/s]

  batch 180 loss: 3.6500461870483907e-06
  batch 180 accuracy: 1.0


 47%|████▋     | 190/403 [01:07<01:17,  2.75it/s]

  batch 190 loss: 3.541639860797967e-06
  batch 190 accuracy: 1.0


 50%|████▉     | 200/403 [01:11<01:14,  2.74it/s]

  batch 200 loss: 3.449624872567994e-06
  batch 200 accuracy: 1.0


 52%|█████▏    | 210/403 [01:14<01:06,  2.92it/s]

  batch 210 loss: 3.3333954101563583e-06
  batch 210 accuracy: 1.0


 55%|█████▍    | 220/403 [01:18<01:03,  2.90it/s]

  batch 220 loss: 3.255909160827826e-06
  batch 220 accuracy: 1.0


 56%|█████▌    | 224/403 [01:19<01:03,  2.81it/s]