In [None]:
# -*- coding: utf-8 -*-
"""bart.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1J2GpbAAQtHXt3MmcmfVt8aiQ8aSYAMWH
"""

In [None]:
# This run uses Pytorch Lightening to finetune the model
!pip install -q pytorch-lightning
!pip install git+https://github.com/huggingface/transformers

In [None]:
# imports
import transformers
from torch.utils.data import DataLoader, TensorDataset, random_split, RandomSampler, Dataset
import pandas as pd
import numpy as np
import os

import torch.nn.functional as F
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import ModelCheckpoint

import math
import random
import re
import argparse

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
root_dir = "/content/drive/MyDrive/"
base_dir = root_dir + 'CS425-NLP project/datasets/BART_input.csv'
save_dir = root_dir + 'CS425-NLP project/datasets/BART-dialogue generation/'
checkpoint_dir = root_dir + 'CS425-NLP project/bart_saved_model'

Mounted at /content/drive


In [None]:
conversation_dataset = pd.read_csv(base_dir)


In [None]:
display(conversation_dataset)

Unnamed: 0,seeker,supporter
0,"I like acting, I hope to be an actor, what abo...",I do too. Wat do you like?
1,"No, but someday.",that is ok. have any kids?
2,After I am done with school I plan to have a f...,that is good. I have 2
3,"I hope so, how old are your kids?",that is great! you will be ready
4,I would imagine. I am sure they a great kids.,5 & 7. they take up a lot of my time
...,...,...
31073,"Yes, but this time I do not think that I was d...",Did you feel like you were not really asleep? ...
31074,"Exactly. I drank some warm milk, hoping that ...","Well, mine does that sometimes too when I drin..."
31075,I am sorry to hear that. Have you tried drink...,I haven't tried that but I might because I dri...
31076,The other thing that you might try to calm you...,Will it help my being so bloated? My stomach f...


In [None]:
class LitModel(pl.LightningModule):
  # Instantiate the model
  def __init__(self, learning_rate, tokenizer, model, hparams):
    super().__init__()
    self.tokenizer = tokenizer
    self.model = model
    self.learning_rate = learning_rate
    self.hp = hparams
  
  def freeze_embeds(self):
    ''' freeze the positional embedding parameters of the model; adapted from finetune.py '''
    freeze_params(self.model.model.shared)
    for d in [self.model.model.encoder, self.model.model.decoder]:
      freeze_params(d.embed_positions)
      freeze_params(d.embed_tokens)

  # Do a forward pass through the model
  def forward(self, input_ids, **kwargs):
    return self.model(input_ids, **kwargs)
  
  def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters(), lr = self.learning_rate)
    return optimizer

  def training_step(self, batch, batch_idx): 
    # Load the data into variables
    src_ids, src_mask = batch[0], batch[1]
    tgt_ids = batch[2]
    # Shift the decoder tokens right (but NOT the tgt_ids)
    decoder_input_ids = shift_tokens_right(tgt_ids, tokenizer.pad_token_id)

    # Run the model and get the logits
    outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
    lm_logits = outputs[0]
    # Create the loss function
    ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
    # Calculate the loss on the un-shifted tokens
    loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))

    return {'loss':loss}

  def validation_step(self, batch, batch_idx):

    src_ids, src_mask = batch[0], batch[1]
    tgt_ids = batch[2]

    decoder_input_ids = shift_tokens_right(tgt_ids, tokenizer.pad_token_id)
    
    # Run the model and get the logits
    outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
    lm_logits = outputs[0]

    ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
    val_loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))

    return {'loss': val_loss}
  
  # Method that generates text using the BartForConditionalGeneration's generate() method
  def generate_text(self, text, eval_beams, early_stopping = True, max_len = 40):
    """ Function to generate text """
    generated_ids = self.model.generate(
        text["input_ids"],
        attention_mask=text["attention_mask"],
        use_cache=True,
        decoder_start_token_id = self.tokenizer.pad_token_id,
        num_beams= eval_beams,
        max_length = max_len,
        early_stopping = early_stopping
    )
    return [self.tokenizer.decode(w, skip_special_tokens=True, clean_up_tokenization_spaces=True) for w in generated_ids]

def freeze_params(model):
  """ Function that takes a model as input (or part of a model) and freezes the layers for faster training """
  for layer in model.parameters():
    layer.requires_grade = False

In [None]:
# Create a dataloading module as per the PyTorch Lightning Docs
class SummaryDataModule(pl.LightningDataModule):
  def __init__(self, tokenizer, data_file, batch_size, num_examples = 20000):
    super().__init__()
    self.tokenizer = tokenizer
    self.data_file = data_file
    self.batch_size = batch_size
    self.num_examples = num_examples
  
  # Loads and splits the data into training, validation and test sets with a 60/20/20 split
  def prepare_data(self):
    self.data = pd.read_csv(self.data_file, lineterminator='\n')[:self.num_examples]
    self.train, self.validate, self.test = np.split(self.data.sample(frac=1), [int(.6*len(self.data)), int(.8*len(self.data))])

  # encode the sentences using the tokenizer  
  def setup(self, stage):
    print(self.train)
    self.train = encode_sentences(self.tokenizer, self.train['seeker'], self.train['supporter'])
    self.validate = encode_sentences(self.tokenizer, self.validate['seeker'], self.validate['supporter'])
    self.test = encode_sentences(self.tokenizer, self.test['seeker'], self.test['supporter'])

  # Load the training, validation and test sets in Pytorch Dataset objects
  def train_dataloader(self):
    dataset = TensorDataset(self.train['input_ids'], self.train['attention_mask'], self.train['labels'])                          
    train_data = DataLoader(dataset, sampler = RandomSampler(dataset), batch_size = self.batch_size)
    return train_data

  def val_dataloader(self):
    dataset = TensorDataset(self.validate['input_ids'], self.validate['attention_mask'], self.validate['labels']) 
    val_data = DataLoader(dataset, batch_size = self.batch_size)                       
    return val_data

  def test_dataloader(self):
    dataset = TensorDataset(self.test['input_ids'], self.test['attention_mask'], self.test['labels']) 
    test_data = DataLoader(dataset, batch_size = self.batch_size)                   
    return test_data

In [None]:
# Create the hparams dictionary to pass in the model
hparams = argparse.Namespace()
hparams.freeze_encoder = True
hparams.freeze_embeds = True
hparams.eval_beams = 4

In [None]:
def shift_tokens_right(input_ids, pad_token_id):
  """ Shift input ids one token to the right, and wrap the last non pad token (usually <eos>). """
  prev_output_tokens = input_ids.clone()
  index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
  prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze()
  prev_output_tokens[:, 1:] = input_ids[:, :-1]
  return prev_output_tokens

def encode_sentences(tokenizer, source_sentences, target_sentences, max_length=32, pad_to_max_length=True, return_tensors="pt"):
  """ Function that tokenizes a sentence 
      Args: tokenizer - the BART tokenizer; source and target sentences are the source and target sentences
      Returns: Dictionary with keys: input_ids, attention_mask, target_ids
  """

  input_ids = []
  attention_masks = []
  target_ids = []
  tokenized_sentences = {}

  for sentence in source_sentences: 
    print(sentence)

  for sentence in source_sentences:
    encoded_dict = tokenizer(
          sentence,
          max_length=max_length,
          padding="max_length" if pad_to_max_length else None,
          truncation=True,
          return_tensors=return_tensors,
          add_prefix_space = True
      )

    input_ids.append(encoded_dict['input_ids'])
    attention_masks.append(encoded_dict['attention_mask'])

  input_ids = torch.cat(input_ids, dim = 0)
  attention_masks = torch.cat(attention_masks, dim = 0)

  for sentence in target_sentences:
    encoded_dict = tokenizer(
          sentence,
          max_length=max_length,
          padding="max_length" if pad_to_max_length else None,
          truncation=True,
          return_tensors=return_tensors,
          add_prefix_space = True
      )
    target_ids.append(encoded_dict['input_ids'])

  target_ids = torch.cat(target_ids, dim = 0)
  

  batch = {
      "input_ids": input_ids,
      "attention_mask": attention_masks,
      "labels": target_ids,
  }

  return batch

In [None]:
# Load the model
from transformers import BartTokenizer, BartForConditionalGeneration, AdamW, BartConfig

tokenizer = BartTokenizer.from_pretrained('facebook/bart-base', add_prefix_space=True)

bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")

Downloading:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.72k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/558M [00:00<?, ?B/s]

In [None]:
# Load the data into the model for training
summary_data = SummaryDataModule(tokenizer, base_dir, batch_size = 16, num_examples = 140000)

summary_data.data = pd.read_csv(summary_data.data_file, lineterminator='\n')[:summary_data.num_examples]
summary_data.train, summary_data.validate, summary_data.test = np.split(summary_data.data.sample(frac=1), [int(.6*len(summary_data.data)), int(.8*len(summary_data.data))])
display(summary_data.train)

summary_data.train = encode_sentences(summary_data.tokenizer, summary_data.train['seeker'], summary_data.train['supporter'])
summary_data.validate = encode_sentences(summary_data.tokenizer, summary_data.validate['seeker'], summary_data.validate['supporter'])
summary_data.test = encode_sentences(summary_data.tokenizer, summary_data.test['seeker'], summary_data.test['supporter'])


# Load the model from a pre-saved checkpoint; alternatively use the code below to start training from scratch
# model = LitModel.load_from_checkpoint(checkpoint_dir + "/epoch=2-step=11418.ckpt",
#                                       learning_rate = 2e-5, tokenizer = tokenizer, model = bart_model, hparams = hparams)

model = LitModel(2e-5, tokenizer, bart_model, hparams)

In [None]:
checkpoint = ModelCheckpoint(dirpath=save_dir + 'checkpoint_files/')
trainer = pl.Trainer(gpus = 1,
                     max_epochs = 3,
                     min_epochs = 1,
                     auto_lr_find = False,
                     callbacks=[checkpoint])

  f"Setting `Trainer(gpus={gpus!r})` is deprecated in v1.7 and will be removed"
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model, summary_data)



[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Me as well, that's awesome! Do you love art?
i wonder if there even is life on other planets
Well I like to pick up heavy stuff and throw it as far as I can. This time it landed on my bed. What do you do for fun?
Sounds like a good idea!
Nice! Do you have a large audience?
What's the best thing you have cooked?
i have spare cats , i'm a veterinarian , maybe that will do ?
no , it is just me and the cats .
I'm so excitied about going to my friend's big party tomorrow night! I really can't wait!
That's a good way of reacting! I think it teaches you to not respond to situations with anger.
anyone i have never been there
sorry . i speak english as well . how are you ?
Yeah me either, those aren't as interesting to watch.
do you think i could live up to your standards
what "inflatable things" would that be?
Yeah, it's always more fun when you have a home crowd. What is your favorite sport?
I clean aquariums and manage a hotel,

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type                         | Params
-------------------------------------------------------
0 | model | BartForConditionalGeneration | 139 M 
-------------------------------------------------------
139 M     Trainable params
0         Non-trainable params
139 M     Total params
557.682   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=3` reached.


In [None]:
def generate_reply(model_):
  # Put the model on eval mode
  model_.to(torch.device('cpu'))
  model_.eval()

  conversing = True
  while conversing:
    txt = input("What's on your mind?: ")
    if txt == "Bye!":
      break
    line = [txt]
    prompt_line_tokens = tokenizer(text = txt, max_length = 32, return_tensors = "pt", truncation = True)
    line = model.generate_text(prompt_line_tokens, eval_beams = 4)
    print("FRANKLY: ", line[0].strip())
  print("FRANKLY: Have a great day ahead! \n")

In [None]:
generate_reply(model_ = model)
# Say "Bye!" to exit (Case-sensitive)

What's on your mind?: my name is faith
FRANKLY:  That's cool.  What is your name?
What's on your mind?: i am faith
FRANKLY:  That's great.  What do you do for fun?
What's on your mind?: i play with my dogs
FRANKLY:  i love to play with my dogs
What's on your mind?: Bye!
FRANKLY: Hope you are feeling better. Have a great day ahead!
