<a href="https://colab.research.google.com/github/oldaandozerskaya/Bulletin_of_opposition/blob/master/bart_scisum.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Import libs

In [None]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=False)

In [None]:
!pip install -q pytorch-lightning==1.2.9
!pip install -q transformers

In [None]:
import transformers
from torch.utils.data import DataLoader, TensorDataset, random_split, RandomSampler, Dataset
import pandas as pd
import numpy as np

import torch.nn.functional as F
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import ModelCheckpoint

import math
import random
import re
import argparse

# Model

In [None]:
import numpy as np

class LitModel(pl.LightningModule):
  def __init__(self, learning_rate, tokenizer, model, hparams):
    super().__init__()
    self.tokenizer = tokenizer
    self.model = model

    self.learning_rate = learning_rate
    self.hparams = hparams

    if self.hparams.freeze_encoder:
      freeze_params(self.model.get_encoder())

    if self.hparams.freeze_embeds:
      self.freeze_embeds()
  
  def freeze_embeds(self):
    freeze_params(self.model.model.shared)
    for d in [self.model.model.encoder, self.model.model.decoder]:
      freeze_params(d.embed_positions)
      freeze_params(d.embed_tokens)

  def forward(self, input_ids, **kwargs):
    return self.model(input_ids, **kwargs)
  
  def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters(), lr = self.learning_rate)
    return optimizer

  def training_step(self, batch, batch_idx):
    src_ids, src_mask = batch[0], batch[1]
    tgt_ids = batch[2]
    decoder_input_ids = shift_tokens_right(tgt_ids, tokenizer.pad_token_id)
    outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
    lm_logits = outputs[0]
    ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
    loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))
    return {'loss':loss}

  def validation_step(self, batch, batch_idx):

    src_ids, src_mask = batch[0], batch[1]
    tgt_ids = batch[2]

    decoder_input_ids = shift_tokens_right(tgt_ids, tokenizer.pad_token_id)
    
    outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
    lm_logits = outputs[0]

    ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
    val_loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))

    return {'loss': val_loss}
  
  def generate_text(self, text, eval_beams, early_stopping = True, max_len = 64):
    generated_ids = self.model.generate(
        text["input_ids"],
        attention_mask=text["attention_mask"],
        use_cache=True,
        decoder_start_token_id = self.tokenizer.pad_token_id,
        num_beams= eval_beams,
        max_length = max_len,
        early_stopping = early_stopping
    )
    return [self.tokenizer.decode(w, skip_special_tokens=True, clean_up_tokenization_spaces=True) for w in generated_ids]

def freeze_params(model):
  for layer in model.parameters():
    layer.requires_grade = False

In [None]:
class SummaryDataModule(pl.LightningDataModule):
  def __init__(self, tokenizer, data_file, dev_file, batch_size, num_examples = 110000):
    super().__init__()
    self.tokenizer = tokenizer
    self.data_file = data_file
    self.dev_file = dev_file
    self.batch_size = batch_size
    self.num_examples = num_examples
  
  def prepare_data(self):
    self.data = pd.read_csv(self.data_file, sep = '\t', index_col = 0)#[:20000]#, header = 1)#[:self.num_examples]
    #print(self.data.abstract.values[:5])
    self.train = self.data.sample(frac=1, random_state = 42)
    self.data = pd.read_csv(self.dev_file, sep = '\t', index_col = 0)#, header = 1)#[:self.num_examples]
    self.validate = self.data.sample(frac=1, random_state = 42)
    #self.train, self.validate, self.test = np.split(self.data.sample(frac=1, random_state = 42), [int(.8*len(self.data)), int(.9*len(self.data))])
    
    self.train = encode_sentences(self.tokenizer, self.train['abstract'], self.train['title'])
    self.validate = encode_sentences(self.tokenizer, self.validate['abstract'], self.validate['title'])
    #self.test = encode_sentences(self.tokenizer, self.test['abstract'], self.test['title'])

  def train_dataloader(self):
    dataset = TensorDataset(self.train['input_ids'], self.train['attention_mask'], self.train['labels'])                          
    train_data = DataLoader(dataset, sampler = RandomSampler(dataset), batch_size = self.batch_size)
    return train_data

  def val_dataloader(self):
    dataset = TensorDataset(self.validate['input_ids'], self.validate['attention_mask'], self.validate['labels']) 
    val_data = DataLoader(dataset, batch_size = self.batch_size)                       
    return val_data

  '''
  def test_dataloader(self):
    dataset = TensorDataset(self.test['input_ids'], self.test['attention_mask'], self.test['labels']) 
    test_data = DataLoader(dataset, batch_size = self.batch_size)                   
    return test_data
  '''

In [None]:
hparams = argparse.Namespace()

hparams.freeze_encoder = True
hparams.freeze_embeds = True
hparams.eval_beams = 4

In [None]:
def shift_tokens_right(input_ids, pad_token_id):
  prev_output_tokens = input_ids.clone()
  index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
  prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze()
  prev_output_tokens[:, 1:] = input_ids[:, :-1]
  return prev_output_tokens

def encode_sentences(tokenizer, source_sentences, target_sentences, max_length=512, pad_to_max_length=True, return_tensors="pt"):
  input_ids = []
  attention_masks = []
  target_ids = []
  tokenized_sentences = {}

  for sentence in source_sentences:
    encoded_dict = tokenizer(
          sentence,
          max_length=max_length,
          padding="max_length" if pad_to_max_length else None,
          truncation=True,
          return_tensors=return_tensors,
          add_prefix_space = True
      )

    input_ids.append(encoded_dict['input_ids'])
    attention_masks.append(encoded_dict['attention_mask'])

  input_ids = torch.cat(input_ids, dim = 0)
  attention_masks = torch.cat(attention_masks, dim = 0)

  for sentence in target_sentences:
    encoded_dict = tokenizer(
          sentence,
          max_length=max_length,
          padding="max_length" if pad_to_max_length else None,
          truncation=True,
          return_tensors=return_tensors,
          add_prefix_space = True
      )
    # Shift the target ids to the right
    # shifted_target_ids = shift_tokens_right(encoded_dict['input_ids'], tokenizer.pad_token_id)
    target_ids.append(encoded_dict['input_ids'])

  target_ids = torch.cat(target_ids, dim = 0)
  
  batch = {
      "input_ids": input_ids,
      "attention_mask": attention_masks,
      "labels": target_ids,
  }

  return batch

# Load BART


In [None]:
# Load the model
from transformers import BartTokenizer, BartForConditionalGeneration, AdamW, BartConfig

#tokenizer = BartTokenizer.from_pretrained('sshleifer/distilbart-cnn-12-6', add_prefix_space=True)
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base', add_prefix_space=True)
#tokenizer = BartTokenizer.from_pretrained("VictorSanh/bart-base-finetuned-xsum", add_prefix_space=True)

bart_model = BartForConditionalGeneration.from_pretrained(
    #"VictorSanh/bart-base-finetuned-xsum")
    "facebook/bart-base")
    #"sshleifer/distilbart-cnn-12-6")

In [None]:
# Load the data into the model for training
summary_data = SummaryDataModule(tokenizer, '/content/gdrive/MyDrive/DAAD/Datasets_DAAD/SciTLDR/arXiv_train_q1.csv', \
                                  '/content/gdrive/MyDrive/DAAD/Datasets_DAAD/arXiv/arXiv_dev.csv', batch_size = 4, num_examples = 20000)

# Load the model from checkpoint
#model = LitModel.load_from_checkpoint(path,
                                       #learning_rate = 2e-5, tokenizer = tokenizer, model = bart_model, hparams = hparams)

model = LitModel(learning_rate = 2e-5, tokenizer = tokenizer, model = bart_model, hparams = hparams)

In [None]:
#trainer.save_checkpoint("/content/gdrive/MyDrive/BART_results/1/model.pt")

# Training


In [None]:
checkpoint = ModelCheckpoint('/content/gdrive/MyDrive/Untitled Folder/checkpoint_files_2')
trainer = pl.Trainer(gpus = 1,
                     max_epochs = 2,
                     min_epochs = 1,
                     auto_lr_find = False,
                     checkpoint_callback = checkpoint,
                     progress_bar_refresh_rate = 500)

In [None]:
trainer.fit(model, summary_data)

In [None]:
del(summary_data)

# Test

In [None]:
model.to(torch.device('cpu'))
model.eval()

In [None]:
!pip install rouge-score
from rouge_score import rouge_scorer
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)


scores = scorer.score('The quick brown fox jumps over the lazy dog',
                      'The quick brown dog jumps on the log.')

In [None]:
generated_tldr = []

path = ''
path_save = ''

test = pd.read_csv(path, sep = '\t', index_col = 0)
originals = test.title.values

for i, line in enumerate(test.values):#test.values
  if i%200 == 0:
    print(i)
    #print(originals[i])
    print(line)
  prompt_line_tokens = tokenizer(line[1], max_length = 512, return_tensors = "pt", truncation = True)
  line_ = model.generate_text(prompt_line_tokens, eval_beams = 4, max_len = 512)
  generated_tldr.append(line_)
  

with open(path_save, 'wb') as fp:
    pickle.dump(generated_tldr, fp)

In [None]:
rouge1 = 0
rouge2 = 0
rougel = 0
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)

for i in range(len(generated)):
  
  scores = scorer.score(originals_[i].replace('\n','').lower(), generated[i][0].replace('\n','').lower())
  #scores = scorer.score(originals[i].replace('\n','').lower(), first_sentences[i].replace('\n','').lower())
  rouge1+=scores['rouge1'][2]
  rouge2+=scores['rouge2'][2]
  rougel+=scores['rougeL'][2]

print(rouge1/len(generated))
print(rouge2/len(generated))
print(rougel/len(generated))