<a href="https://colab.research.google.com/github/mmsamiei/MS-Thesis-Phase2/blob/master/Tiny-Bert-L2-H128-Train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
!cp /content/drive/My\ Drive/Thesis/phase-2/history_sentence_pairs_train.csv ./train.csv
!cp /content/drive/My\ Drive/Thesis/phase-2/history_sentence_pairs_valid.csv ./valid.csv

In [2]:
!pip -q install transformers

[K     |████████████████████████████████| 542kB 41.7MB/s 
[K     |████████████████████████████████| 1.0MB 54.4MB/s 
[K     |████████████████████████████████| 3.7MB 50.6MB/s 
[K     |████████████████████████████████| 870kB 44.0MB/s 
[?25h  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone


In [1]:
from torch.utils.data import Dataset, DataLoader
import os
import torch
import json
from torch.utils.data.sampler import SubsetRandomSampler

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('srush/bert_uncased_L-2_H-128_A-2')

In [0]:
import random
import pandas as pd
import logging

class MyDataset(Dataset):
    """My dataset."""

    def __init__(self, csv_file, frac=1):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
        """
        self.dialogues = pd.read_csv(csv_file).sample(frac=frac)
        

        s = (self.dialogues.true_sentence.str.len() + self.dialogues.history.str.len()).sort_values().index
        self.dialogues = self.dialogues.reindex(s)
        s = (self.dialogues.false_sentence.str.len() + self.dialogues.history.str.len()).sort_values().index
        self.dialogues = self.dialogues.reindex(s)


        self.dialogues.dropna(inplace=True)


    def __len__(self):
        return len(self.dialogues)

    def truncuate_join_pair_sentence(self, sentence1, sentence2, max_len=510):

        """
        truncuate sentence one from head and sentence two from tail
        Args:
            sentence1 (string): first sentence
            sentence2 (string): seconde sentence
        """
        temp1 = tokenizer.encode(sentence1,add_special_tokens=False)
        temp2 = tokenizer.encode(sentence2,add_special_tokens=False)
        ### two above line may cause warning but no problem because we've handle them below
        logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR)
        seq_1 = temp1
        seq_2 = temp2
        num_tokens_to_remove = len(temp1) + len(temp2) + 3 - max_len
        if num_tokens_to_remove > 0 :
            seq_1, seq_2, _ = tokenizer.truncate_sequences(temp1[::-1],temp2, num_tokens_to_remove=num_tokens_to_remove)
            seq_1.reverse()
        result_list = [tokenizer.cls_token_id]+seq_1+[tokenizer.sep_token_id]+seq_2+[tokenizer.sep_token_id]
        return result_list


    def __getitem__(self, idx):
      
        
        history = self.dialogues.iloc[idx].history
        true_sentence = self.dialogues.iloc[idx].true_sentence
        false_sentence = self.dialogues.iloc[idx].false_sentence


        true_pair = self.truncuate_join_pair_sentence(history, true_sentence)
        false_pair = self.truncuate_join_pair_sentence(history, false_sentence)
        
        

        true_pair = torch.LongTensor(true_pair)
        false_pair = torch.LongTensor(false_pair)

        sample = {'true_pair': true_pair, 'false_pair': false_pair}

        return sample

In [24]:
train_dataset = MyDataset('train.csv', frac=0.02)
valid_dataset = MyDataset('valid.csv', frac=0.1)
print(len(train_dataset))
print(len(valid_dataset))

55514
14743


In [25]:
def my_collate_fn(batch):

  len_batch = len(batch)

  
  max_len_true_pair = max([len(data['true_pair']) for data in batch])
  max_len_false_pair = max([len(data['false_pair']) for data in batch])
  
  padding_ind = 0 ## for bert is 0
  result_true_pair = torch.zeros(len_batch, max_len_true_pair)
  result_false_pair = torch.zeros(len_batch, max_len_false_pair)

  for i, data in enumerate(batch):
    p1 = len(data['true_pair'])
    result_true_pair[i, :p1] = data['true_pair']
    p2 = len(data['false_pair'])
    result_false_pair[i, :p2] = data['false_pair']


  return result_true_pair.long(), result_false_pair.long()

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128,
                                             shuffle=True, collate_fn=my_collate_fn,
                                           num_workers=32)

valid_sampler = torch.utils.data.SequentialSampler(valid_dataset)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=64, sampler=valid_sampler,
                                             shuffle=False, collate_fn=my_collate_fn)

for batch_idx, batch  in tqdm(enumerate(train_loader)):
  true_batch, false_batch = batch
  print(false_batch.shape)
  break

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

torch.Size([128, 288])


In [26]:
import os

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from transformers import AutoModel

class Model(nn.Module):
  def __init__(self):
    super(Model, self).__init__()

    self.bert = AutoModel.from_pretrained("srush/bert_uncased_L-2_H-128_A-2")
    self.fc = nn.Linear(128,1)
    self.activation = nn.Tanh()

    for p in self.bert.embeddings.parameters():
      p.requires_grad = False
    
    nn.init.xavier_normal_(self.fc.weight)
  
  def forward(self, x):
        temp = x
        temp = self.bert(temp)[0]
        ## temp = [batch, len, hid_size]
        temp = temp[:,0,:]
        temp = self.fc(temp)
        temp = self.activation(temp)
        temp = (temp - (-1))/2 
        return temp

dev = torch.device('cuda')

x = torch.LongTensor(200, 40).random_(1,1000).to(dev)
model = Model().to(dev)
print(model(x).shape)


def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(count_parameters(model))

torch.Size([200, 1])
413185


In [0]:
class NoamOpt:
    "Optim wrapper that implements rate."
    def __init__(self, model_size, factor, warmup, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.factor = factor
        self.model_size = model_size
        self._rate = 0
        
    def step(self):
        "Update parameters and rate"
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()
        
    def rate(self, step = None):
        "Implement `lrate` above"
        if step is None:
            step = self._step
        return self.factor * \
            (self.model_size ** (-0.5) *
            min(step ** (-0.5), step * self.warmup ** (-1.5)))
    
    def zero_grad(self):
        self.optimizer.zero_grad()

In [0]:
optimizer = NoamOpt(128, 1, 2000,
            torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))

#optimizer = torch.optim.Adam(model.parameters(), lr=0.000001, betas=(0.9, 0.98), eps=1e-9)

In [0]:
import torch.nn.functional as F

def mahdi_loss(true_sml, false_sml):
  loss = torch.mean( -torch.log(true_sml)-torch.log(1-false_sml))
  return loss

In [0]:
def train_step(batch_idx, batch):
  true_pairs, false_pairs = batch
  true_pairs = true_pairs.to(dev)
  false_pairs = false_pairs.to(dev)
  optimizer.zero_grad()
  true_sml = model(true_pairs)
  false_sml = model(false_pairs)
  loss = mahdi_loss(true_sml, false_sml)
  loss.backward()
  optimizer.step()
  return loss.item()

def valid_step(batch_idx, batch):

  true_pairs, false_pairs = batch
  true_pairs = true_pairs.to(dev)
  false_pairs = false_pairs.to(dev)
  true_sml = model(true_pairs)
  false_sml = model(false_pairs)
  z = true_sml - false_sml
  num_err = z[z<0].size()[0]
  return num_err

In [0]:
from tqdm.auto import tqdm

MAX_EPOCH = 100
for epoch_num in range(MAX_EPOCH):
  model.train()
  total_loss_epoch = 0
  for batch_idx, batch in tqdm(enumerate(iter(train_loader))):
    step_loss = train_step(batch_idx, batch)
    total_loss_epoch += step_loss
  print(total_loss_epoch / len(train_loader))
  ### Validation ###
  total_error = 0
  for batch_idx, batch in enumerate(valid_loader):
    total_error += valid_step(batch_idx, batch)
  print("Error rate: {}".format(total_error / len(valid_dataset)))

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


1.4102256572740968
Error rate: 0.4130095638608153


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))