#Fine-tuning BART on the Big Patent data set

In [None]:
# This run uses Pytorch Lightening to finetune the model
!pip install -q pytorch-lightning
!pip install -q transformers

In [None]:
# imports
import transformers
from torch.utils.data import DataLoader, TensorDataset, random_split, RandomSampler, Dataset
import pandas as pd
import numpy as np

import torch.nn.functional as F
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import ModelCheckpoint

import math
import random
import re
import argparse

# Firing up Google Drive
Load up your google drive for loading the lyrics for training and for saving model weights

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)
root_dir = "/content/gdrive/My Drive/"
base_dir = root_dir + 'BART test/'

Mounted at /content/gdrive


# Pytorch Lightning for running the training (Adapted from Medium)
The below code uses Pytorch Lightning for training the model, which is explained very well (and simply) at https://pytorch-lightning.readthedocs.io/en/latest/. Very briefly, most of th usual methods one would set up for a Pytorch class are setup in a pl.LightningModule class. This then goes on to automate a bunch of the training for example updating the optimizer, clearing gradients etc.

In [None]:
class LitModel(pl.LightningModule):
  # Instantiate the model
  def __init__(self, learning_rate, tokenizer, model, hparams):
    super().__init__()
    self.tokenizer = tokenizer
    self.model = model
    self.learning_rate = learning_rate
    self.hparams.update(vars(hparams))
    #self.freeze_encoder = freeze_encoder
    #self.freeze_embeds_ = freeze_embeds

    if self.hparams.freeze_encoder:
      freeze_params(self.model.get_encoder())

    if self.hparams.freeze_embeds:
      self.freeze_embeds()
    
  
  def freeze_embeds(self):
    ''' freeze the positional embedding parameters of the model; adapted from finetune.py '''
    freeze_params(self.model.model.shared)
    for d in [self.model.model.encoder, self.model.model.decoder]:
      freeze_params(d.embed_positions)
      freeze_params(d.embed_tokens)

  # Do a forward pass through the model
  def forward(self, input_ids, **kwargs):
    #input_ids = torch.Tensor(input_ids)
    return self.model(input_ids, **kwargs)
  
  def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters(), lr = self.learning_rate)
    return optimizer

  def training_step(self, batch, batch_idx):
    # Load the data into variables
    src_ids, src_mask = batch[0], batch[1]
    tgt_ids = batch[2]
    # Shift the decoder tokens right (but NOT the tgt_ids)
    decoder_input_ids = shift_tokens_right(tgt_ids, tokenizer.pad_token_id)

    # Run the model and get the logits
    outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
    lm_logits = outputs[0]
    # Create the loss function
    ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
    # Calculate the loss on the un-shifted tokens
    loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))

    return {'loss':loss}

  def validation_step(self, batch, batch_idx):

    src_ids, src_mask = batch[0], batch[1]
    tgt_ids = batch[2]

    decoder_input_ids = shift_tokens_right(tgt_ids, tokenizer.pad_token_id)
    
    # Run the model and get the logits
    outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
    lm_logits = outputs[0]

    ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
    val_loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))

    return {'loss': val_loss}

  def predict_step(self, batch, batch_idx):

    src_ids, src_mask = batch[0], batch[1]
    tgt_ids = batch[2]

    decoder_input_ids = shift_tokens_right(tgt_ids, tokenizer.pad_token_id)
    
    # Run the model and get the logits
    outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
    return outputs 
  
  # Method that generates text using the BartForConditionalGeneration's generate() method
  def generate_text(self, text, eval_beams, early_stopping = True, max_len = 32):
    ''' Function to generate text '''
    generated_ids = self.model.generate(
        text["input_ids"],
        attention_mask=text["attention_mask"],
        use_cache=True,
        decoder_start_token_id = self.tokenizer.pad_token_id,
        num_beams= eval_beams,
        max_length = max_len,
        early_stopping = early_stopping
    )
    return [self.tokenizer.decode(w, skip_special_tokens=True, clean_up_tokenization_spaces=True) for w in generated_ids]

def freeze_params(model):
  ''' Function that takes a model as input (or part of a model) and freezes the layers for faster training
      adapted from finetune.py '''
  for layer in model.parameters():
    layer.requires_grade = False


In [None]:
# Create a dataloading module as per the PyTorch Lightning Docs
class SummaryDataModule(pl.LightningDataModule):
  def __init__(self, tokenizer, data_file, batch_size, num_examples = 20000):
    super().__init__()
    self.tokenizer = tokenizer
    self.data_file = data_file
    self.batch_size = batch_size
    self.num_examples = num_examples
  
  # Loads and splits the data into training, validation and test sets with a 60/20/20 split
  def prepare_data(self):
    self.data = pd.read_csv(self.data_file)[:self.num_examples]
    self.test = self.data 
    #self.train, self.validate, self.test = np.split(self.data.sample(frac=1), [int(.6*len(self.data)), int(.8*len(self.data))])
    #self.train, self.validate, self.test = np.split(self.data.sample(frac=1), [int(.99*len(self.data)),int(.995*len(self.data))])

  # encode the sentences using the tokenizer  
  def setup(self, stage):
    #self.train = encode_sentences(self.tokenizer, self.train['source'], self.train['target'])
    #self.validate = encode_sentences(self.tokenizer, self.validate['source'], self.validate['target'])
    self.test = encode_sentences(self.tokenizer, self.test['source'], self.test['target'], max_length = 32)

  # Load the training, validation and test sets in Pytorch Dataset objects
  def train_dataloader(self):
    dataset = TensorDataset(self.train['input_ids'], self.train['attention_mask'], self.train['labels'])                          
    train_data = DataLoader(dataset, sampler = RandomSampler(dataset), batch_size = self.batch_size)
    return train_data

  def val_dataloader(self):
    dataset = TensorDataset(self.validate['input_ids'], self.validate['attention_mask'], self.validate['labels']) 
    val_data = DataLoader(dataset, batch_size = self.batch_size)                       
    return val_data

  def test_dataloader(self):
    dataset = TensorDataset(self.test['input_ids'], self.test['attention_mask'], self.test['labels']) 
    test_data = DataLoader(dataset, batch_size = self.batch_size)                   
    return test_data
  
  def predict_dataloader(self):
    dataset = TensorDataset(self.test['input_ids'], self.test['attention_mask'], self.test['labels']) 
    test_data = DataLoader(dataset, batch_size = self.batch_size)               
    return test_data



In [None]:
# Create the hparams dictionary to pass in the model
# I realise that this isn't really how this is meant to be used, but having this here reminds me that I can edit it when I need
hparams = argparse.Namespace()

hparams.freeze_encoder = True
hparams.freeze_embeds = True
hparams.eval_beams = 4

In [None]:
def shift_tokens_right(input_ids, pad_token_id):
  """ Shift input ids one token to the right, and wrap the last non pad token (usually <eos>).
      This is taken directly from modeling_bart.py
  """
  prev_output_tokens = input_ids.clone()
  index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
  prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze()
  prev_output_tokens[:, 1:] = input_ids[:, :-1]
  return prev_output_tokens

def encode_sentences(tokenizer, source_sentences, target_sentences, max_length=32, pad_to_max_length=True, return_tensors="pt"):
  ''' Function that tokenizes a sentence 
      Args: tokenizer - the BART tokenizer; source and target sentences are the source and target sentences
      Returns: Dictionary with keys: input_ids, attention_mask, target_ids
  '''

  input_ids = []
  attention_masks = []
  target_ids = []
  tokenized_sentences = {}

  for sentence in source_sentences:
    encoded_dict = tokenizer(
          sentence,
          max_length=max_length,
          padding="max_length" if pad_to_max_length else None,
          truncation=True,
          return_tensors=return_tensors,
          add_prefix_space = True
      )

    input_ids.append(encoded_dict['input_ids'])
    attention_masks.append(encoded_dict['attention_mask'])

  input_ids = torch.cat(input_ids, dim = 0)
  attention_masks = torch.cat(attention_masks, dim = 0)

  for sentence in target_sentences:
    encoded_dict = tokenizer(
          sentence,
          max_length=max_length,
          padding="max_length" if pad_to_max_length else None,
          truncation=True,
          return_tensors=return_tensors,
          add_prefix_space = True
      )
    # Shift the target ids to the right
    # shifted_target_ids = shift_tokens_right(encoded_dict['input_ids'], tokenizer.pad_token_id)
    target_ids.append(encoded_dict['input_ids'])

  target_ids = torch.cat(target_ids, dim = 0)
  

  batch = {
      "input_ids": input_ids,
      "attention_mask": attention_masks,
      "labels": target_ids,
  }

  return batch


def noise_sentence(sentence_, percent_words, replacement_token = "<mask>"):
  '''
  Function that noises a sentence by adding <mask> tokens
  Args: sentence - the sentence to noise
        percent_words - the percent of words to replace with <mask> tokens; the number is rounded up using math.ceil
  Returns a noised sentence
  '''
  # Create a list item and copy
  sentence_ = sentence_.split(' ')
  sentence = sentence_.copy()
  
  num_words = math.ceil(len(sentence) * percent_words)
  
  # Create an array of tokens to sample from; don't include the last word as an option because in the case of lyrics
  # that word is often a rhyming word and plays an important role in song construction
  sample_tokens = set(np.arange(0, np.maximum(1, len(sentence)-1)))
  
  words_to_noise = random.sample(sample_tokens, num_words)
  
  # Swap out words, but not full stops
  for pos in words_to_noise:
      if sentence[pos] != '.':
          sentence[pos] = replacement_token
  
  # Remove redundant spaces
  sentence = re.sub(r' {2,5}', ' ', ' '.join(sentence))
  
  # Combine concurrent <mask> tokens into a single token; this just does two rounds of this; more could be done
  sentence = re.sub(r'<mask> <mask>', "<mask>", sentence)
  sentence = re.sub(r'<mask> <mask>', "<mask>", sentence)
  return sentence
  

# Load BART
Here we load the model. I used "bart-base" because I had memory issues using "bart-large". "bart-base" appears to load without the use_cache argument, which by necessity must be turned to "False" for "bart-large".

In [None]:
# Load the model
from transformers import BartTokenizer, BartForConditionalGeneration, AdamW, BartConfig

tokenizer = BartTokenizer.from_pretrained('facebook/bart-base', add_prefix_space=True)

bart_model = BartForConditionalGeneration.from_pretrained(
    "facebook/bart-base")


In [None]:
! pip install datasets

Collecting datasets
  Downloading datasets-2.1.0-py3-none-any.whl (325 kB)
[?25l[K     |█                               | 10 kB 37.5 MB/s eta 0:00:01[K     |██                              | 20 kB 45.3 MB/s eta 0:00:01[K     |███                             | 30 kB 25.2 MB/s eta 0:00:01[K     |████                            | 40 kB 13.8 MB/s eta 0:00:01[K     |█████                           | 51 kB 12.8 MB/s eta 0:00:01[K     |██████                          | 61 kB 14.9 MB/s eta 0:00:01[K     |███████                         | 71 kB 14.5 MB/s eta 0:00:01[K     |████████                        | 81 kB 14.1 MB/s eta 0:00:01[K     |█████████                       | 92 kB 15.6 MB/s eta 0:00:01[K     |██████████                      | 102 kB 14.1 MB/s eta 0:00:01[K     |███████████                     | 112 kB 14.1 MB/s eta 0:00:01[K     |████████████                    | 122 kB 14.1 MB/s eta 0:00:01[K     |█████████████                   | 133 kB 14.1 MB/s eta

In [None]:
pip install folium==0.2.1 

Collecting folium==0.2.1
  Downloading folium-0.2.1.tar.gz (69 kB)
[?25l[K     |████▊                           | 10 kB 14.2 MB/s eta 0:00:01[K     |█████████▍                      | 20 kB 10.3 MB/s eta 0:00:01[K     |██████████████                  | 30 kB 8.7 MB/s eta 0:00:01[K     |██████████████████▊             | 40 kB 8.2 MB/s eta 0:00:01[K     |███████████████████████▍        | 51 kB 4.6 MB/s eta 0:00:01[K     |████████████████████████████    | 61 kB 5.4 MB/s eta 0:00:01[K     |████████████████████████████████| 69 kB 3.7 MB/s 
Building wheels for collected packages: folium
  Building wheel for folium (setup.py) ... [?25l[?25hdone
  Created wheel for folium: filename=folium-0.2.1-py3-none-any.whl size=79808 sha256=900463614432dcf07776fa90a02442ba3f2192dfa3a2b4129da99767e7275f12
  Stored in directory: /root/.cache/pip/wheels/9a/f0/3a/3f79a6914ff5affaf50cabad60c9f4d565283283c97f0bdccf
Successfully built folium
Installing collected packages: folium
  Attempting unin

In [None]:
from datasets import load_dataset

dataset = load_dataset('big_patent','a')

Downloading builder script:   0%|          | 0.00/2.17k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.86k [00:00<?, ?B/s]

Downloading and preparing dataset big_patent/a (download: 6.01 GiB, generated: 3.45 GiB, post-processed: Unknown size, total: 9.45 GiB) to /root/.cache/huggingface/datasets/big_patent/a/1.0.0/bdefa7c0b39fba8bba1c6331b70b738e30d63c8ad4567f983ce315a5fef6131c...


Downloading data:   0%|          | 0.00/6.45G [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/174134 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/9674 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/9675 [00:00<?, ? examples/s]

Dataset big_patent downloaded and prepared to /root/.cache/huggingface/datasets/big_patent/a/1.0.0/bdefa7c0b39fba8bba1c6331b70b738e30d63c8ad4567f983ce315a5fef6131c. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
train_data = dataset['train']

In [None]:

train_data = train_data.rename_columns({'description':'source', 'abstract':'target'})

In [None]:
train_data

Dataset({
    features: ['source', 'target'],
    num_rows: 174134
})

In [None]:
training = train_data.to_csv('train.csv')

Creating CSV from Arrow format:   0%|          | 0/18 [00:00<?, ?ba/s]

In [None]:
f = open('train.csv','r')

In [None]:
f.readlines()[0]

',source,target\n'

In [None]:
# Load the data into the model for training
summary_data = SummaryDataModule(tokenizer, '/content/train.csv',
                                 batch_size = 8, num_examples = 180000)

# Load the model from a pre-saved checkpoint; alternatively use the code below to start training from scratch
model = LitModel.load_from_checkpoint("/content/lightning_logs/version_0/checkpoints/epoch=0-step=21549.ckpt",
                                      learning_rate = 2e-5, tokenizer = tokenizer, model = bart_model, hparams = hparams)

#model = LitModel(learning_rate = 2.2e-5, tokenizer = tokenizer, model = bart_model, hparams = hparams)

# Training the model with Pytorch Lightning
The below code utilises Pytorch Lightning's fantastic Trainer module that helps to control the training process. After creating a ModelCheckpoint object, the other options are fed into the Trainer module. 


In [None]:
checkpoint = ModelCheckpoint(base_dir + 'checkpoint_files_2/')
trainer = pl.Trainer(gpus = 1,
                     max_epochs = 1,
                     min_epochs = 1,
                     auto_lr_find = False,
                     checkpoint_callback = checkpoint,
                     progress_bar_refresh_rate = 500 )

NameError: ignored

In [None]:
# Fit the instantiated model to the data
trainer.fit(model, summary_data)


Missing logger folder: /content/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                         | Params
-------------------------------------------------------
0 | model | BartForConditionalGeneration | 139 M 
-------------------------------------------------------
139 M     Trainable params
0         Non-trainable params
139 M     Total params
557.682   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [None]:
torch.save()

In [None]:
model.freeze()

In [None]:
test_data = dataset['test']

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
test_data = test_data.rename_columns({'description':'source','abstract':'target'})

In [None]:
#test_data = test_data.remove_columns('target')

In [None]:
test_data

Dataset({
    features: ['source', 'target'],
    num_rows: 9675
})

In [None]:
testing = test_data.to_csv('test.csv')

Creating CSV from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

In [None]:
test = DataLoader(test_set, batch_size= 8, num_workers= 8)


In [None]:
summarytest =  SummaryDataModule(tokenizer, '/content/test.csv' ,
                                 batch_size = 8, num_examples =10000)
summarytest.prepare_data()
summarytest.setup(stage = 'test')

In [None]:
output = trainer.predict(model, summarytest, return_predictions= True)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  f"The dataloader, {name}, does not have many workers which may be a bottleneck."


Predicting: 7it [00:00, ?it/s]

In [None]:
output[0][0].shape

torch.Size([16, 32, 50265])

In [None]:
outputs = output[0][0]

In [None]:
#https://github.com/huggingface/transformers/issues/3853
def text_predictions(self, input_ids):
        generated_ids = self.model.generate(
            input_ids=input_ids,
            num_beams=1,
            max_length=32,
            repetition_penalty=2.5,
            length_penalty=1.0,
            early_stopping=True,
        )
        preds = [
            self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True)
            for g in generated_ids
        ]
        return preds


In [None]:
preds = text_predictions(model, input_ids = summarytest.test['input_ids'])

In [None]:
len(preds)

9675

In [None]:
! pip install nltk



In [None]:
from datasets import load_metric

In [None]:
! pip install rouge-score

Collecting rouge-score
  Downloading rouge_score-0.0.4-py2.py3-none-any.whl (22 kB)
Installing collected packages: rouge-score
Successfully installed rouge-score-0.0.4


In [None]:
metric = load_metric('rouge')

Downloading builder script:   0%|          | 0.00/2.16k [00:00<?, ?B/s]

In [None]:
test_set = pd.read_csv('test.csv')

In [None]:
test_dict = test_set.to_dict()

In [None]:
labels = test_dict['target']
labels = labels.values()
labels = list(labels)

In [None]:
preds[1291]

' an orthodontic bracket includes a base member having first and second ends, the upper end being configured to be mounted on one of two opposing'

In [None]:
labels[1291]

'an orthodontic bracket , a method of manufacture and method of installing the bracket . the bracket is provided with visually enhanced reference edges for assisting in alignment of the bracket with respect to the tooth .'

In [None]:
 import nltk
 nltk.download('punkt')
 decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in preds]
 decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in labels]

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


In [None]:
result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
result_2 = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    


In [None]:
result 

{'rouge1': AggregateScore(low=Score(precision=0.4525933518807978, recall=0.13519238321862234, fmeasure=0.19872305076270172), mid=Score(precision=0.45518128438725647, recall=0.1366086882068861, fmeasure=0.20025631899288138), high=Score(precision=0.45780566580866194, recall=0.13818391702279048, fmeasure=0.2019331900193935)),
 'rouge2': AggregateScore(low=Score(precision=0.1126905043096126, recall=0.03296328130596104, fmeasure=0.04854426525227362), mid=Score(precision=0.115117908651504, recall=0.03382375842060322, fmeasure=0.049703316885100365), high=Score(precision=0.11748862276323936, recall=0.03467744689845681, fmeasure=0.050845509456202635)),
 'rougeL': AggregateScore(low=Score(precision=0.33181748066240546, recall=0.09898531682610702, fmeasure=0.14552193382210546), mid=Score(precision=0.3344289951268984, recall=0.10003933452718704, fmeasure=0.14671730283217305), high=Score(precision=0.33658424278511, recall=0.10118255299407498, fmeasure=0.14803422988985532)),
 'rougeLsum': AggregateS

In [None]:
result_2

{'rouge1': 20.025631899288136,
 'rouge2': 4.970331688510036,
 'rougeL': 14.671730283217304,
 'rougeLsum': 17.057013890096787}

# **Old Version of decode**

In [None]:
arr = np.array(outputs, dtype = float) # Sentence, Token, Logits 

In [None]:

arr = np.nan_to_num(arr, copy=False)
arr.shape
ten = torch.tensor(arr, dtype=torch.float)
ten[0,4,:2]

tensor([-1.0525, -3.8278])

In [None]:
out_test = [x for x in outputs[0][2] if x is not None]

In [None]:
x = torch.tensor([0, 34, 45, 23.2, 54, 65.2, 765, 2])
tokenizer.decode(x.tolist())

'<s> has not at who one short</s>'

In [None]:
out = tokenizer.decode(ten[0][14], clean_up_tokenization_spaces = True )

TypeError: ignored

In [None]:
tokenizer.decode()

In [None]:
out

'<s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s>

In [None]:
out = tokenizer.decode(outputs[0,9,:][:10], skip_special_tokens=True)

TypeError: ignored

In [None]:
out

' I the- the- of. to the'

In [None]:
decoded_preds = tokenizer.decode(outputs[0][0])

TypeError: ignored

In [None]:
import nltk 
decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in outputs]

AttributeError: ignored

In [None]:
decoded_preds = tokenizer.batch_decode(output, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)

[Seq2SeqLMOutput([('logits',
                   tensor([[[38.4177,  5.3495, 12.7170,  ...,  5.2870,  5.3244,  2.0695],
                            [ 2.2303, -5.8807,  5.8119,  ..., -5.4862, -5.5116, -2.8966],
                            [ 0.3222, -1.3144,  6.2740,  ..., -1.8403, -1.6395, -4.1810],
                            ...,
                            [-4.0877, -4.7807,  8.3982,  ..., -5.3914, -4.9821, -3.7013],
                            [-2.6754, -4.8575,  8.0365,  ..., -5.1284, -4.9273, -4.9749],
                            [-1.5540, -4.0637, 18.7742,  ..., -5.0277, -5.0703, -3.6696]],
                   
                           [[38.7793,  5.0033, 12.8514,  ...,  5.2229,  5.1317,  1.7553],
                            [ 2.4501, -6.0714,  6.6301,  ..., -5.6788, -6.1070, -3.3783],
                            [ 0.5230, -5.2756,  6.3479,  ..., -5.4470, -5.6043, -3.3518],
                            ...,
                            [-4.4695, -5.0789,  6.8240,  ..., -6.1327, -6.

In [None]:
test = tokenizer('test.csv', max_length = 3000, return_tensors = 'pt', truncation = True)

In [None]:
test

{'input_ids': tensor([[    0,  1296,     4, 49079,     2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]])}

In [None]:
predictions = model(test['input_ids'])

In [None]:
predictions

Seq2SeqLMOutput([('logits',
                  tensor([[[33.9910,  1.5302, 11.5015,  ...,  2.0414,  1.6683,  0.8358],
                           [-1.0331, -8.0223,  4.2332,  ..., -7.5121, -7.6632, -6.1119],
                           [ 0.9695, -5.7495,  4.3166,  ..., -5.4018, -5.5078, -5.2968],
                           [-3.6087, -4.7621,  4.4822,  ..., -4.6860, -4.6129, -4.5500],
                           [-2.2516, -7.1955,  2.8918,  ..., -7.1643, -7.1276, -7.6788]]])),
                 ('past_key_values',
                  ((tensor([[[[-1.5290e-01, -5.4000e-01, -6.5076e-01,  ...,  3.4145e-01,
                                9.1939e-02, -6.2074e-02],
                              [ 3.0408e-02,  3.7936e-01, -7.1570e-01,  ...,  6.6755e-02,
                                3.0837e-02,  3.1995e-01],
                              [ 7.1285e-01, -1.3107e+00,  5.7219e-01,  ...,  4.2248e-02,
                               -1.1942e+00,  3.2631e-01],
                              [-6.0390e-01, -

In [None]:
# If you want to manually save a checkpoint, this works, although the model should automatically save (progressively better)
# checkpoints as it moves through the epochs
# trainer.save_checkpoint(base_dir + "checkpoint_files_2/8_ep_140k_simple_0210.ckpt")