<a href="https://colab.research.google.com/github/magdalena-b/Bairon/blob/master/BART_load_from_checkpoint.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

https://towardsdatascience.com/teaching-bart-to-rap-fine-tuning-hugging-faces-bart-model-41749d38f3ef

In [1]:
!pip install -q pytorch-lightning
!pip install -q transformers

[K     |████████████████████████████████| 915 kB 21.2 MB/s 
[K     |████████████████████████████████| 829 kB 36.8 MB/s 
[K     |████████████████████████████████| 118 kB 67.2 MB/s 
[K     |████████████████████████████████| 5.6 MB 61.4 MB/s 
[K     |████████████████████████████████| 234 kB 61.1 MB/s 
[K     |████████████████████████████████| 636 kB 50.5 MB/s 
[K     |████████████████████████████████| 1.3 MB 52.7 MB/s 
[K     |████████████████████████████████| 142 kB 51.6 MB/s 
[K     |████████████████████████████████| 294 kB 59.3 MB/s 
[?25h  Building wheel for future (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 2.6 MB 29.3 MB/s 
[K     |████████████████████████████████| 895 kB 42.6 MB/s 
[K     |████████████████████████████████| 3.3 MB 57.7 MB/s 
[?25h

In [2]:
import transformers
from torch.utils.data import DataLoader, TensorDataset, random_split, RandomSampler, Dataset
import pandas as pd
import numpy as np

import torch.nn.functional as F
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import ModelCheckpoint

import math
import random
import re
import argparse

In [3]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=False)

Mounted at /content/gdrive


In [15]:
def shift_tokens_right(input_ids, pad_token_id):
  """ Shift input ids one token to the right, and wrap the last non pad token (usually <eos>).
      This is taken directly from modeling_bart.py
  """
  prev_output_tokens = input_ids.clone()
  index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
  prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze()
  prev_output_tokens[:, 1:] = input_ids[:, :-1]
  return prev_output_tokens

def encode_sentences(tokenizer, source_sentences, target_sentences, max_length=32, pad_to_max_length=True, return_tensors="pt"):
  ''' Function that tokenizes a sentence 
      Args: tokenizer - the BART tokenizer; source and target sentences are the source and target sentences
      Returns: Dictionary with keys: input_ids, attention_mask, target_ids
  '''

  input_ids = []
  attention_masks = []
  target_ids = []
  tokenized_sentences = {}

  for sentence in source_sentences:
    encoded_dict = tokenizer(
          sentence,
          max_length=max_length,
          padding="max_length" if pad_to_max_length else None,
          truncation=True,
          return_tensors=return_tensors,
          add_prefix_space = True
      )

    input_ids.append(encoded_dict['input_ids'])
    attention_masks.append(encoded_dict['attention_mask'])

  input_ids = torch.cat(input_ids, dim = 0)
  attention_masks = torch.cat(attention_masks, dim = 0)

  for sentence in target_sentences:
    encoded_dict = tokenizer(
          sentence,
          max_length=max_length,
          padding="max_length" if pad_to_max_length else None,
          truncation=True,
          return_tensors=return_tensors,
          add_prefix_space = True
      )
    # Shift the target ids to the right
    # shifted_target_ids = shift_tokens_right(encoded_dict['input_ids'], tokenizer.pad_token_id)
    target_ids.append(encoded_dict['input_ids'])

  target_ids = torch.cat(target_ids, dim = 0)
  

  batch = {
      "input_ids": input_ids,
      "attention_mask": attention_masks,
      "labels": target_ids,
  }

  return batch


def noise_sentence(sentence_, percent_words, replacement_token = "<mask>"):
  '''
  Function that noises a sentence by adding <mask> tokens
  Args: sentence - the sentence to noise
        percent_words - the percent of words to replace with <mask> tokens; the number is rounded up using math.ceil
  Returns a noised sentence
  '''
  # Create a list item and copy
  sentence_ = sentence_.split(' ')
  sentence = sentence_.copy()
  
  num_words = math.ceil(len(sentence) * percent_words)
  
  # Create an array of tokens to sample from; don't include the last word as an option because in the case of lyrics
  # that word is often a rhyming word and plays an important role in song construction
  sample_tokens = set(np.arange(0, np.maximum(1, len(sentence)-1)))
  
  words_to_noise = random.sample(sample_tokens, num_words)
  
  # Swap out words, but not full stops
  for pos in words_to_noise:
      if sentence[pos] != '.':
          sentence[pos] = replacement_token
  
  # Remove redundant spaces
  sentence = re.sub(r' {2,5}', ' ', ' '.join(sentence))
  
  # Combine concurrent <mask> tokens into a single token; this just does two rounds of this; more could be done
  sentence = re.sub(r'<mask> <mask>', "<mask>", sentence)
  sentence = re.sub(r'<mask> <mask>', "<mask>", sentence)
  return sentence

In [16]:
class LitModel(pl.LightningModule):

  def __init__(self, learning_rate, tokenizer, model):
    super().__init__()
    self.learning_rate = learning_rate
    self.tokenizer = tokenizer
    self.model = model
    self.eval_beams = 4

    # if self.hparams.freeze_encoder:
    freeze_params(self.model.get_encoder())

    # if self.hparams.freeze_embeds:
    self.freeze_embeds()


  
  def freeze_embeds(self):
    ''' freeze the positional embedding parameters of the model; adapted from finetune.py '''
    freeze_params(self.model.model.shared)
    for d in [self.model.model.encoder, self.model.model.decoder]:
      freeze_params(d.embed_positions)
      freeze_params(d.embed_tokens)

  # Do a forward pass through the model
  def forward(self, input_ids, **kwargs):
    return self.model(input_ids, **kwargs)
  
  def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters(), lr = self.learning_rate)
    return optimizer

  def training_step(self, batch, batch_idx):
    # Load the data into variables
    src_ids, src_mask = batch[0], batch[1]
    tgt_ids = batch[2]
    # Shift the decoder tokens right (but NOT the tgt_ids)
    decoder_input_ids = shift_tokens_right(tgt_ids, tokenizer.pad_token_id)

    # Run the model and get the logits
    outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
    lm_logits = outputs[0]
    # Create the loss function
    ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
    # Calculate the loss on the un-shifted tokens
    loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))

    return {'loss':loss}

  def validation_step(self, batch, batch_idx):

    src_ids, src_mask = batch[0], batch[1]
    tgt_ids = batch[2]

    decoder_input_ids = shift_tokens_right(tgt_ids, tokenizer.pad_token_id)
    
    # Run the model and get the logits
    outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
    lm_logits = outputs[0]

    ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
    val_loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))

    return {'loss': val_loss}
  
  # Method that generates text using the BartForConditionalGeneration's generate() method
  def generate_text(self, text, eval_beams, early_stopping = True, max_len = 40):
    ''' Function to generate text '''
    generated_ids = self.model.generate(
        text["input_ids"],
        attention_mask=text["attention_mask"],
        use_cache=True,
        decoder_start_token_id = self.tokenizer.pad_token_id,
        num_beams= eval_beams,
        max_length = max_len,
        early_stopping = early_stopping
    )
    return [self.tokenizer.decode(w, skip_special_tokens=True, clean_up_tokenization_spaces=True) for w in generated_ids]

def freeze_params(model):
  ''' Function that takes a model as input (or part of a model) and freezes the layers for faster training
      adapted from finetune.py '''
  for layer in model.parameters():
    layer.requires_grade = False

In [8]:
# Load the model
from transformers import BartTokenizer, BartForConditionalGeneration, AdamW, BartConfig

tokenizer = BartTokenizer.from_pretrained('facebook/bart-base', add_prefix_space=True)

bart_model = BartForConditionalGeneration.from_pretrained(
    "facebook/bart-base")

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=898823.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1355863.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1627.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=557941479.0, style=ProgressStyle(descri…




In [5]:
hparams = argparse.Namespace()

hparams.freeze_encoder = True
hparams.freeze_embeds = True
hparams.eval_beams = 4

In [23]:
root_dir = "/content/gdrive/My Drive/"
base_dir = root_dir + 'BART/'

In [11]:
model_shakespeare = LitModel.load_from_checkpoint(base_dir + "shakespeare_w_last_word_10_epoch.ckpt",
                                      learning_rate = 2e-5, tokenizer = tokenizer, model = bart_model, hparams = hparams)

In [21]:
model_cummings = LitModel.load_from_checkpoint(base_dir + "cummings_w_last_word_10_epoch.ckpt",
                                      learning_rate = 2e-5, tokenizer = tokenizer, model = bart_model, hparams = hparams)

In [None]:
model_ginsberg = LitModel.load_from_checkpoint(base_dir + "ginsberg_w_last_word_10_epoch.ckpt",
                                      learning_rate = 2e-5, tokenizer = tokenizer, model = bart_model, hparams = hparams)

In [17]:
def generate_lyrics_shakespeare(seed_line, num_lines, model_, noise_percent = 0.25, multiple_lines = False, max_line_history = 3):
  ''' Function that generates lyrics based on previously generated lyrics 
      Args: seed_line - a line to start off the machine
            num_lines - the number of lines to generate
            model_ - the model used to generate the text
            multiple_lines - whether the model generates based on multiple previous lines or just the past line
            max_line_history - the maximum number of previous lines used in the current input
      Returns a list with num_lines of rap lines
  '''
  # Put the model on eval mode
  model_.to(torch.device('cpu'))
  model_.eval()
  lyrics = []
  lyrics.append(seed_line)
  prompt_line_tokens = tokenizer(noise_sentence(seed_line, 0.5), max_length = 60, return_tensors = "pt", truncation = True)
  # Loop through the number of lines generating a new line based on the old

  line = [seed_line]
  for i in range(num_lines):
    # Print out the new line
    print(line[0].strip())
    lyrics.append(line[0])
    line = model.generate_text(prompt_line_tokens, eval_beams = 4)
    # This deals with an artefact in the training data that I had an issue cleaning
    if line[0].find(":") != -1:
      line[0] = re.sub(r'[A-Z]+: ', '', line[0])
    # This allows the model to generate a new line conditioned on more than one line
    if multiple_lines:
      start_line = np.maximum(0, i - max_line_history)
      end_line = i
      prompt_line = ' '.join(lyrics[start_line:end_line]) # Going to end_line is fine because it is non-inclusive
    else:
      prompt_line = lyrics[i]
    prompt_line_tokens = tokenizer(noise_sentence(prompt_line, noise_percent), max_length = 32, return_tensors = "pt", truncation = True)

  return lyrics

In [19]:
def generate_lyrics_cummings(seed_line, num_lines, model_, noise_percent = 0.25, multiple_lines = False, max_line_history = 3):
  ''' Function that generates lyrics based on previously generated lyrics 
      Args: seed_line - a line to start off the machine
            num_lines - the number of lines to generate
            model_ - the model used to generate the text
            multiple_lines - whether the model generates based on multiple previous lines or just the past line
            max_line_history - the maximum number of previous lines used in the current input
      Returns a list with num_lines of rap lines
  '''
  # Put the model on eval mode
  model_.to(torch.device('cpu'))
  model_.eval()
  lyrics = []
  lyrics.append(seed_line)
  prompt_line_tokens = tokenizer(noise_sentence(seed_line, 0.7), max_length = 60, return_tensors = "pt", truncation = True)
  # Loop through the number of lines generating a new line based on the old

  line = [seed_line]
  for i in range(num_lines):
    # Print out the new line
    print(line[0].strip())
    lyrics.append(line[0])
    line = model.generate_text(prompt_line_tokens, eval_beams = 4)
    # This deals with an artefact in the training data that I had an issue cleaning
    if line[0].find(":") != -1:
      line[0] = re.sub(r'[A-Z]+: ', '', line[0])
    # This allows the model to generate a new line conditioned on more than one line
    if multiple_lines:
      start_line = np.maximum(0, i - max_line_history)
      end_line = i
      prompt_line = ' '.join(lyrics[start_line:end_line]) # Going to end_line is fine because it is non-inclusive
    else:
      prompt_line = lyrics[i]
    prompt_line_tokens = tokenizer(noise_sentence(prompt_line, noise_percent), max_length = 32, return_tensors = "pt", truncation = True)

  return lyrics

In [20]:
def generate_lyrics_ginsberg(seed_line, num_lines, model_, noise_percent = 0.25, multiple_lines = False, max_line_history = 3):
  ''' Function that generates lyrics based on previously generated lyrics 
      Args: seed_line - a line to start off the machine
            num_lines - the number of lines to generate
            model_ - the model used to generate the text
            multiple_lines - whether the model generates based on multiple previous lines or just the past line
            max_line_history - the maximum number of previous lines used in the current input
      Returns a list with num_lines of rap lines
  '''
  # Put the model on eval mode
  model_.to(torch.device('cpu'))
  model_.eval()
  lyrics = []
  lyrics.append(seed_line)
  prompt_line_tokens = tokenizer(noise_sentence(seed_line, 0.7), max_length = 60, return_tensors = "pt", truncation = True)
  # Loop through the number of lines generating a new line based on the old

  line = [seed_line]
  for i in range(num_lines):
    # Print out the new line
    print(line[0].strip())
    lyrics.append(line[0])
    line = model.generate_text(prompt_line_tokens, eval_beams = 4)
    # This deals with an artefact in the training data that I had an issue cleaning
    if line[0].find(":") != -1:
      line[0] = re.sub(r'[A-Z]+: ', '', line[0])
    # This allows the model to generate a new line conditioned on more than one line
    if multiple_lines:
      start_line = np.maximum(0, i - max_line_history)
      end_line = i
      prompt_line = ' '.join(lyrics[start_line:end_line]) # Going to end_line is fine because it is non-inclusive
    else:
      prompt_line = lyrics[i]
    prompt_line_tokens = tokenizer(noise_sentence(prompt_line, noise_percent), max_length = 32, return_tensors = "pt", truncation = True)

  return lyrics

In [18]:
new_song = generate_lyrics_shakespeare(seed_line = "Look in thy glass and tell the face thou viewest,", num_lines = 10, model_ = model,
                           noise_percent = 0.75, multiple_lines = True, max_line_history = 5)

Look in thy glass and tell the face thou viewest,


To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  /pytorch/aten/src/ATen/native/BinaryOps.cpp:467.)
  return torch.floor_divide(self, other)


Look in the rearward, and see what thou seeest,
And that is my love, my love's love,
And you, my dear friend, are my best friend,
When thou see'st the world, and sees not what thou seeest,
When I see your face, my eyes see,
When I see thee, and hear thee speak,
When I see you, I see thee, and I know you.
When I see her in the rearward view,
When I see your beauty, and your beauty in my sight;


In [None]:
new_song = generate_lyrics_cummings(seed_line = "somewhere i have never travelled, gladly beyond", num_lines = 10, model_ = model,
                           noise_percent = 0.3, multiple_lines = True, max_line_history = 10)

In [None]:
new_song = generate_lyrics_ginsberg(seed_line = "Strange now to think of you", num_lines = 10, model_ = model,
                           noise_percent = 0.3, multiple_lines = True, max_line_history = 10)