<a href="https://colab.research.google.com/github/mmsamiei/just-practice-deep/blob/master/lightning_pytorch_pair_ranking.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
!cp /content/drive/My\ Drive/Thesis/phase-2/history_sentence_pairs_train.csv ./train.csv
!cp /content/drive/My\ Drive/Thesis/phase-2/history_sentence_pairs_valid.csv ./valid.csv

In [2]:
!pip -q install transformers
!pip install -q  pytorch-lightning
!pip install -U tqdm

Requirement already up-to-date: tqdm in /usr/local/lib/python3.6/dist-packages (4.43.0)


In [3]:
from torch.utils.data import Dataset, DataLoader
import os
import torch
import json
from torch.utils.data.sampler import SubsetRandomSampler

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')

In [0]:
import random
import pandas as pd
import logging

class MyDataset(Dataset):
    """My dataset."""

    def __init__(self, json_file):
        """
        Args:
            json_file (string): Path to the json file with annotations.
        """
        self.dialogues = pd.read_csv(json_file)
        

        s = (self.dialogues.true_sentence.str.len() + self.dialogues.history.str.len()).sort_values().index
        self.dialogues = self.dialogues.reindex(s)
        s = (self.dialogues.false_sentence.str.len() + self.dialogues.history.str.len()).sort_values().index
        self.dialogues = self.dialogues.reindex(s)


        self.dialogues.dropna(inplace=True)


    def __len__(self):
        return len(self.dialogues)

    def truncuate_join_pair_sentence(self, sentence1, sentence2, max_len=510):

        """
        truncuate sentence one from head and sentence two from tail
        Args:
            sentence1 (string): first sentence
            sentence2 (string): seconde sentence
        """
        temp1 = tokenizer.encode(sentence1,add_special_tokens=False)
        temp2 = tokenizer.encode(sentence2,add_special_tokens=False)
        ### two above line may cause warning but no problem because we've handle them below
        logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR)
        seq_1 = temp1
        seq_2 = temp2
        num_tokens_to_remove = len(temp1) + len(temp2) + 3 - max_len
        if num_tokens_to_remove > 0 :
            seq_1, seq_2, _ = tokenizer.truncate_sequences(temp1[::-1],temp2, num_tokens_to_remove=num_tokens_to_remove)
            seq_1.reverse()
        result_list = [tokenizer.cls_token_id]+seq_1+[tokenizer.sep_token_id]+seq_2+[tokenizer.sep_token_id]
        return result_list


    def __getitem__(self, idx):
      
        
        history = self.dialogues.iloc[idx].history
        true_sentence = self.dialogues.iloc[idx].true_sentence
        false_sentence = self.dialogues.iloc[idx].false_sentence


        true_pair = self.truncuate_join_pair_sentence(history, true_sentence)
        false_pair = self.truncuate_join_pair_sentence(history, false_sentence)
        
        

        true_pair = torch.LongTensor(true_pair)
        false_pair = torch.LongTensor(false_pair)

        sample = {'true_pair': true_pair, 'false_pair': false_pair}

        return sample

In [0]:
dataset = MyDataset('train.csv')
valid_dataset = MyDataset('valid.csv')

In [0]:
def my_collate_fn(batch):

  len_batch = len(batch)

  
  max_len_true_pair = max([len(data['true_pair']) for data in batch])
  max_len_false_pair = max([len(data['false_pair']) for data in batch])
  
  padding_ind = 0 ## for bert is 0
  result_true_pair = torch.zeros(len_batch, max_len_true_pair)
  result_false_pair = torch.zeros(len_batch, max_len_false_pair)

  for i, data in enumerate(batch):
    p1 = len(data['true_pair'])
    result_true_pair[i, :p1] = data['true_pair']
    p2 = len(data['false_pair'])
    result_false_pair[i, :p2] = data['false_pair']


  return result_true_pair.long(), result_false_pair.long()

sampler = torch.utils.data.SequentialSampler(dataset)
dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=128, sampler=sampler,
                                             shuffle=False, collate_fn=my_collate_fn)

valid_sampler = torch.utils.data.SequentialSampler(valid_dataset)
valid_dataset_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=32, sampler=valid_sampler,
                                             shuffle=False, collate_fn=my_collate_fn)



In [0]:
import os

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from transformers import AutoModel

import pytorch_lightning as pl

class CoolSystem(pl.LightningModule):
 
    def __init__(self):
        super(CoolSystem, self).__init__()
        # not the best model...
        self.bert = AutoModel.from_pretrained("distilbert-base-uncased")
        self.fc = nn.Linear(768,1)

        for p in self.bert.transformer.layer[:-1].parameters():
          p.requires_grad = False

        for p in self.bert.embeddings.parameters():
          p.requires_grad = False

        nn.init.normal_(self.fc.weight)

    def forward(self, x):
        temp = x
        temp = self.bert(temp)[0]
        temp = temp[:,0,:]
        temp = self.fc(temp)
        return temp

    def training_step(self, batch, batch_idx):
        # REQUIRED

        true_pair, false_pair = batch
        batch_size = true_pair.shape[0]
        true_sml = self.forward(true_pair)
        false_sml = self.forward(false_pair)

        criterion = torch.nn.MarginRankingLoss(margin=1)
        y_batch_tensor = torch.ones(batch_size)
        if self.on_gpu:
                y_batch_tensor = y_batch_tensor.cuda(true_pair.device.index)
        loss = criterion(true_sml, false_sml, y_batch_tensor)

        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    
    def validation_step(self, batch, batch_idx):
        # OPTIONAL
        true_pair, false_pair = batch
        batch_size = true_pair.shape[0]
        true_sml = self.forward(true_pair)
        false_sml = self.forward(false_pair)

        total_num = false_sml.size()[0]


        z = true_sml - false_sml
        num_err = z[z<0].size()[0]

        return {'num_err': num_err, 'total_num':total_num}

    def validation_end(self, outputs):
        # OPTIONAL
        total_err = sum([x['num_err'] for x in outputs])
        total_num = sum([x['total_num'] for x in outputs])
        tensorboard_logs = {'total_err': total_err}
        print("Epoch complete, total error is: {} from {}".format(total_err, total_num))
        return {'total_err': total_err, 'log': tensorboard_logs}


    def configure_optimizers(self):
        # REQUIRED
        # can return multiple optimizers and learning_rate schedulers
        # (LBFGS it is automatically supported, no need for closure function)
        return torch.optim.Adam(self.parameters(), lr=0.02)

    @pl.data_loader
    def train_dataloader(self):
        # REQUIRED
        return dataset_loader
    
    @pl.data_loader
    def val_dataloader(self):
        # OPTIONAL
        return valid_dataset_loader
    


In [0]:
from pytorch_lightning import Trainer
from tqdm.auto import tqdm

model = CoolSystem()

# most basic trainer, uses good defaults
trainer = Trainer(gpus=[0], accumulate_grad_batches=16)    

trainer.fit(model)   

INFO:transformers.configuration_utils:loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json from cache at /root/.cache/torch/transformers/a41e817d5c0743e29e86ff85edc8c257e61bc8d88e4271bb1b243b6e7614c633.587f67ec28c540f4294c9c2ac7dcf7841ff371aeb12cdeb6a17f69da39ad9452
INFO:transformers.configuration_utils:Model config DistilBertConfig {
  "activation": "gelu",
  "architectures": [
    "DistilBertForMaskedLM"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": null,
  "dim": 768,
  "do_sample": false,
  "dropout": 0.1,
  "eos_token_ids": null,
  "finetuning_task": null,
  "hidden_dim": 3072,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1"
  },
  "initializer_range": 0.02,
  "is_decoder": false,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1
  },
  "length_penalty": 1.0,
  "max_length": 20,
  "max_position_embeddings": 512,
  "model_type": "distilbert",
  "n_heads": 12,
  "n_layers": 6,
  "num_beams": 1,
  "num_labels"

Epoch complete, total error is: 37 from 160
Epoch 1:   0%|          | 0/26293 [00:00<?, ?batch/s]



Epoch 1:   1%|          | 295/26293 [03:17<5:06:34,  1.41batch/s, batch_idx=294, gpu=0, loss=1.086, v_num=9]