<a href="https://colab.research.google.com/github/mmsamiei/MS-Thesis-Phase2/blob/master/Models/Hashemi.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
!cp /content/drive/My\ Drive/Thesis/phase-2/history_sentence_pairs_train.csv ./train.csv
!cp /content/drive/My\ Drive/Thesis/phase-2/history_sentence_pairs_valid.csv ./valid.csv

In [0]:
!pip -q install transformers

In [0]:
from torch.utils.data import Dataset, DataLoader
import os
import torch
import json
from torch.utils.data.sampler import SubsetRandomSampler

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('srush/bert_uncased_L-2_H-128_A-2')

In [0]:
import random
import pandas as pd
import logging

class MyDataset(Dataset):
    """My dataset."""

    def __init__(self, csv_file, frac=1):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
        """
        self.dialogues = pd.read_csv(csv_file).sample(frac=frac)
        

        # s = (self.dialogues.true_sentence.str.len() + self.dialogues.history.str.len()).sort_values().index
        # self.dialogues = self.dialogues.reindex(s)
        # s = (self.dialogues.false_sentence.str.len() + self.dialogues.history.str.len()).sort_values().index
        # self.dialogues = self.dialogues.reindex(s)


        self.dialogues.dropna(inplace=True)


    def __len__(self):
        return len(self.dialogues)

    @staticmethod
    def truncuate_join_pair_sentence(sentence1, sentence2, max_len=510):

        """
        truncuate sentence one from head and sentence two from tail
        Args:
            sentence1 (string): first sentence
            sentence2 (string): seconde sentence
        """
        temp1 = tokenizer.encode(sentence1,add_special_tokens=False)
        temp2 = tokenizer.encode(sentence2,add_special_tokens=False)
        ### two above line may cause warning but no problem because we've handle them below
        logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR)
        seq_1 = temp1
        seq_2 = temp2
        num_tokens_to_remove = len(temp1) + len(temp2) + 3 - max_len
        if num_tokens_to_remove > 0 :
            seq_1, seq_2, _ = tokenizer.truncate_sequences(temp1[::-1],temp2, num_tokens_to_remove=num_tokens_to_remove)
            seq_1.reverse()
        result_list = [tokenizer.cls_token_id]+seq_1+[tokenizer.sep_token_id]+seq_2+[tokenizer.sep_token_id]
        return result_list


    def __getitem__(self, idx):
      
        
        history = self.dialogues.iloc[idx].history
        true_sentence = self.dialogues.iloc[idx].true_sentence
        false_sentence = self.dialogues.iloc[idx].false_sentence


        true_pair = MyDataset.truncuate_join_pair_sentence(history, true_sentence)
        false_pair = MyDataset.truncuate_join_pair_sentence(history, false_sentence)
        
        

        true_pair = torch.LongTensor(true_pair)
        false_pair = torch.LongTensor(false_pair)

        sample = {'true_pair': true_pair, 'false_pair': false_pair}

        return sample

In [0]:
train_dataset = MyDataset('train.csv', frac=1)
valid_dataset = MyDataset('valid.csv', frac=1)
print(len(train_dataset))
print(len(valid_dataset))

2775678
147428


In [0]:
from tqdm.auto import tqdm

def my_collate_fn(batch):

  len_batch = len(batch)

  
  max_len_true_pair = max([len(data['true_pair']) for data in batch])
  max_len_false_pair = max([len(data['false_pair']) for data in batch])
  
  padding_ind = 0 ## for bert is 0
  result_true_pair = torch.zeros(len_batch, max_len_true_pair)
  result_false_pair = torch.zeros(len_batch, max_len_false_pair)

  for i, data in enumerate(batch):
    p1 = len(data['true_pair'])
    result_true_pair[i, :p1] = data['true_pair']
    p2 = len(data['false_pair'])
    result_false_pair[i, :p2] = data['false_pair']


  return result_true_pair.long(), result_false_pair.long()

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128,
                                             shuffle=True, collate_fn=my_collate_fn,
                                           num_workers=32)

valid_sampler = torch.utils.data.SequentialSampler(valid_dataset)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=128, sampler=valid_sampler,
                                             shuffle=False, collate_fn=my_collate_fn, num_workers=2)

for batch_idx, batch  in tqdm(enumerate(train_loader)):
  true_batch, false_batch = batch
  print(false_batch.shape)
  break

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

torch.Size([128, 310])


In [0]:
import os

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from transformers import AutoModel

class Model(nn.Module):
  def __init__(self):
    super(Model, self).__init__()

    self.bert = AutoModel.from_pretrained("srush/bert_uncased_L-2_H-128_A-2")
    self.fc = nn.Linear(128,1)
    self.activation = nn.Tanh()

    for p in self.bert.embeddings.parameters():
      p.requires_grad = False
    
    for p in self.bert.encoder.layer[0].parameters():
      p.requires_grad = False
    
    nn.init.xavier_normal_(self.fc.weight)
  
  def forward(self, x):
        temp = x
        temp = self.bert(temp)[0]
        ## temp = [batch, len, hid_size]
        temp = temp[:,0,:]
        temp = self.fc(temp)
        temp = self.activation(temp)
        temp = (temp - (-1))/2 
        return temp

dev = torch.device('cuda')
model = Model().to(dev)


# x = torch.LongTensor(200, 40).random_(1,1000).to(dev)
# print(model(x).shape)


def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(count_parameters(model))

214913


In [0]:
class NoamOpt:
    "Optim wrapper that implements rate."
    def __init__(self, model_size, factor, warmup, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.factor = factor
        self.model_size = model_size
        self._rate = 0
        
    def step(self):
        "Update parameters and rate"
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()
        
    def rate(self, step = None):
        "Implement `lrate` above"
        if step is None:
            step = self._step
        return self.factor * \
            (self.model_size ** (-0.5) *
            min(step ** (-0.5), step * self.warmup ** (-1.5)))
    
    def zero_grad(self):
        self.optimizer.zero_grad()

In [0]:
optimizer = NoamOpt(128, 1, 2000,
            torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))

#optimizer = torch.optim.Adam(model.parameters(), lr=0.000001, betas=(0.9, 0.98), eps=1e-9)

In [0]:
import torch.nn.functional as F

def mahdi_loss(true_sml, false_sml):
  eps = 1e-6
  loss = torch.mean( -torch.log(true_sml+eps)-torch.log(1-false_sml+eps))
  return loss

In [0]:
from tqdm.auto import tqdm

def train_step(batch_idx, batch):
  true_pairs, false_pairs = batch
  true_pairs = true_pairs.to(dev)
  false_pairs = false_pairs.to(dev)
  optimizer.zero_grad()
  true_sml = model(true_pairs)
  false_sml = model(false_pairs)
  loss = mahdi_loss(true_sml, false_sml)
  loss.backward()
  optimizer.step()
  return loss.item()

def valid_step(batch_idx, batch):
  true_pairs, false_pairs = batch
  true_pairs = true_pairs.to(dev)
  false_pairs = false_pairs.to(dev)
  true_sml = model(true_pairs)
  false_sml = model(false_pairs)
  z = true_sml - false_sml
  num_err = z[z<0].size()[0]
  return num_err

def valid_loop(valid_loader):
  total_error = 0
  model.eval()
  for batch_idx, batch in tqdm(enumerate(valid_loader),  total=len(valid_loader)):
    total_error += valid_step(batch_idx, batch)
  return total_error / len(valid_dataset)

In [0]:
from tqdm.auto import tqdm

MAX_STEP = 20000
STEP_CHECK = 200
step_num = 1
log_list = []

while step_num <= MAX_STEP:
  model.train()
  for batch_idx, batch in tqdm(enumerate(iter(train_loader)), total=len(train_loader)):
    step_loss = train_step(batch_idx, batch)
    log = {'step':step_num, 'train_loss':step_loss}
    if(step_num % STEP_CHECK == 0):
      valid_error = valid_loop(valid_loader)
      print("Error rate: {} at step {}".format(valid_error, step_num))  
      log['valid_error'] = valid_error
      log_list.append(log)
      torch.save({
            'model_state_dict': model.state_dict(),
            'log_list': log_list
            }, 'hashemi_{}steps.model'.format(step_num))
      model.train()
      step_num += 1
      continue
    log_list.append(log)
    model.train()
    step_num += 1


HBox(children=(IntProgress(value=0, max=21685), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1152), HTML(value='')))


Error rate: 0.3436117969449494 at step 200


HBox(children=(IntProgress(value=0, max=1152), HTML(value='')))


Error rate: 0.2756328512901213 at step 400


HBox(children=(IntProgress(value=0, max=1152), HTML(value='')))


Error rate: 0.23506389559649457 at step 600


HBox(children=(IntProgress(value=0, max=1152), HTML(value='')))


Error rate: 0.22665979325501262 at step 800


HBox(children=(IntProgress(value=0, max=1152), HTML(value='')))


Error rate: 0.21429443524974903 at step 1000


HBox(children=(IntProgress(value=0, max=1152), HTML(value='')))


Error rate: 0.21289036004015519 at step 1200


HBox(children=(IntProgress(value=0, max=1152), HTML(value='')))


Error rate: 0.2090308489567789 at step 1400


HBox(children=(IntProgress(value=0, max=1152), HTML(value='')))


Error rate: 0.20116938437745882 at step 1600


HBox(children=(IntProgress(value=0, max=1152), HTML(value='')))


Error rate: 0.20220039612556637 at step 1800


HBox(children=(IntProgress(value=0, max=1152), HTML(value='')))


Error rate: 0.20981767371191362 at step 2000


HBox(children=(IntProgress(value=0, max=1152), HTML(value='')))


Error rate: 0.20288547630029574 at step 2200


HBox(children=(IntProgress(value=0, max=1152), HTML(value='')))


Error rate: 0.19832731909813603 at step 2400


HBox(children=(IntProgress(value=0, max=1152), HTML(value='')))


Error rate: 0.19933798193016253 at step 2600


HBox(children=(IntProgress(value=0, max=1152), HTML(value='')))


Error rate: 0.1964620017907046 at step 2800


HBox(children=(IntProgress(value=0, max=1152), HTML(value='')))


Error rate: 0.18942805979868138 at step 3000


HBox(children=(IntProgress(value=0, max=1152), HTML(value='')))


Error rate: 0.19629921046205606 at step 3200


HBox(children=(IntProgress(value=0, max=1152), HTML(value='')))


Error rate: 0.19305016686111187 at step 3400


HBox(children=(IntProgress(value=0, max=1152), HTML(value='')))


Error rate: 0.19324009007786852 at step 3600


HBox(children=(IntProgress(value=0, max=1152), HTML(value='')))


Error rate: 0.18949588951895163 at step 3800


HBox(children=(IntProgress(value=0, max=1152), HTML(value='')))


Error rate: 0.18664704126760182 at step 4000


HBox(children=(IntProgress(value=0, max=1152), HTML(value='')))


Error rate: 0.18700653878503404 at step 4200


KeyboardInterrupt: ignored

In [0]:
checkpoint = torch.load('bagheri_1750steps.model')
step = checkpoint['log_list'][-1]['step']
model.load_state_dict(checkpoint['model_state_dict'])
model = model.eval()

In [0]:
checkpoint['log_list'][-100:]

[{'step': 1651, 'train_loss': 1.0663644075393677},
 {'step': 1652, 'train_loss': 1.0181926488876343},
 {'step': 1653, 'train_loss': 1.0271081924438477},
 {'step': 1654, 'train_loss': 1.0073920488357544},
 {'step': 1655, 'train_loss': 0.8714293241500854},
 {'step': 1656, 'train_loss': 1.1811153888702393},
 {'step': 1657, 'train_loss': 1.07233726978302},
 {'step': 1658, 'train_loss': 1.043323040008545},
 {'step': 1659, 'train_loss': 1.0266834497451782},
 {'step': 1660, 'train_loss': 1.1659026145935059},
 {'step': 1661, 'train_loss': 1.0925660133361816},
 {'step': 1662, 'train_loss': 1.1036810874938965},
 {'step': 1663, 'train_loss': 1.099855661392212},
 {'step': 1664, 'train_loss': 1.175348162651062},
 {'step': 1665, 'train_loss': 1.1531702280044556},
 {'step': 1666, 'train_loss': 1.0803329944610596},
 {'step': 1667, 'train_loss': 1.1233198642730713},
 {'step': 1668, 'train_loss': 1.096791386604309},
 {'step': 1669, 'train_loss': 0.9994050860404968},
 {'step': 1670, 'train_loss': 0.99732

In [0]:
test_df = pd.read_csv('./valid.csv')
test_df.fillna(' ', inplace=True)
mrr_df = pd.DataFrame(data=test_df.history.unique(), columns=['history'])
mrr_df['correct_rank'] = 0

In [0]:
for i in tqdm(range(len(test_df))):
  history = test_df.loc[i]['history']
  true_sentence = test_df.loc[i]['true_sentence']
  false_sentence = test_df.loc[i]['false_sentence']
  true_pair = torch.LongTensor(MyDataset.truncuate_join_pair_sentence(history, true_sentence)).reshape(1, -1).to(dev)
  false_pair = torch.LongTensor(MyDataset.truncuate_join_pair_sentence(history, false_sentence)).reshape(1, -1).to(dev)
  ### [1, sent_len]
  true_sml = model(true_pair)
  false_sml = model(false_pair)
  if(true_sml.item() <= false_sml.item()):
    mrr_df.loc[mrr_df['history']==history, 'correct_rank'] += 1  

HBox(children=(IntProgress(value=0, max=147429), HTML(value='')))




In [0]:
len(mrr_df[ mrr_df['correct_rank'] < 1 ]) / len(mrr_df)

0.25371024734982334