In [1]:
import transformers
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq 
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer

import datasets
from datasets import load_from_disk, load_metric
import random
import pandas as pd
import nltk
import numpy as np
import torch
from IPython.display import display, HTML

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

Using device: cuda


In [2]:
from datasets import load_dataset

dataset = load_dataset('big_patent', 'h')

Reusing dataset big_patent (/home/ccmilne/.cache/huggingface/datasets/big_patent/h/1.0.0/bdefa7c0b39fba8bba1c6331b70b738e30d63c8ad4567f983ce315a5fef6131c)


  0%|          | 0/3 [00:00<?, ?it/s]

In [2]:
model_checkpoint = 'sshleifer/distilbart-xsum-12-1' #'sshleifer/distill-pegasus-xsum-16-4' #"t5-small"

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).to(device)

## Processing the Data

In [4]:
if model_checkpoint in ["t5-small", "t5-base", "t5-larg", "t5-3b", "t5-11b"]:
    prefix = "summarize: "
else:
    prefix = ""

In [5]:
max_input_length = 1024
max_target_length = 128

def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["description"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["abstract"], max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [6]:
tokenized_dataset = dataset.map(preprocess_function, batched=True)

  0%|          | 0/258 [00:00<?, ?ba/s]

  0%|          | 0/15 [00:00<?, ?ba/s]

  0%|          | 0/15 [00:00<?, ?ba/s]

In [7]:
tokenized_dataset.save_to_disk("processed/big_patent")

## Fine-Tuning

In [3]:
tokenized_dataset = load_from_disk("processed/big_patent")

In [4]:
batch_size = 3
model_name = model_checkpoint.split("/")[-1]
args = Seq2SeqTrainingArguments(
    f"{model_name}-finetuned",
    
    #Training Loss
    save_strategy = "steps",
    logging_strategy = 'steps',
    logging_steps = 200,

    #Validation Loss
    evaluation_strategy = 'no', #"epoch",
#     eval_steps = 10,
    
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=1,
    predict_with_generate=True,)

In [5]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [6]:
metric = load_metric("rouge")

In [7]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    # Extract a few results
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    
    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    
    return {k: round(v, 4) for k, v in result.items()}

In [8]:
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [9]:
trainer.train()

The following columns in the training set  don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: abstract, description. If abstract, description are not expected by `BartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 257019
  Num Epochs = 1
  Instantaneous batch size per device = 3
  Total train batch size (w. parallel, distributed & accumulation) = 6
  Gradient Accumulation steps = 1
  Total optimization steps = 42837


Step,Training Loss
200,4.9774
400,4.4846
600,4.3332
800,4.2673
1000,4.2286
1200,4.1177
1400,4.0736
1600,4.0458
1800,4.0017
2000,3.9351


Saving model checkpoint to distilbart-xsum-12-1-finetuned/checkpoint-500
Configuration saved in distilbart-xsum-12-1-finetuned/checkpoint-500/config.json
Model weights saved in distilbart-xsum-12-1-finetuned/checkpoint-500/pytorch_model.bin
tokenizer config file saved in distilbart-xsum-12-1-finetuned/checkpoint-500/tokenizer_config.json
Special tokens file saved in distilbart-xsum-12-1-finetuned/checkpoint-500/special_tokens_map.json
Deleting older checkpoint [distilbart-xsum-12-1-finetuned/checkpoint-41500] due to args.save_total_limit
Saving model checkpoint to distilbart-xsum-12-1-finetuned/checkpoint-1000
Configuration saved in distilbart-xsum-12-1-finetuned/checkpoint-1000/config.json
Model weights saved in distilbart-xsum-12-1-finetuned/checkpoint-1000/pytorch_model.bin
tokenizer config file saved in distilbart-xsum-12-1-finetuned/checkpoint-1000/tokenizer_config.json
Special tokens file saved in distilbart-xsum-12-1-finetuned/checkpoint-1000/special_tokens_map.json
Deleting old

Saving model checkpoint to distilbart-xsum-12-1-finetuned/checkpoint-5500
Configuration saved in distilbart-xsum-12-1-finetuned/checkpoint-5500/config.json
Model weights saved in distilbart-xsum-12-1-finetuned/checkpoint-5500/pytorch_model.bin
tokenizer config file saved in distilbart-xsum-12-1-finetuned/checkpoint-5500/tokenizer_config.json
Special tokens file saved in distilbart-xsum-12-1-finetuned/checkpoint-5500/special_tokens_map.json
Deleting older checkpoint [distilbart-xsum-12-1-finetuned/checkpoint-4000] due to args.save_total_limit
Saving model checkpoint to distilbart-xsum-12-1-finetuned/checkpoint-6000
Configuration saved in distilbart-xsum-12-1-finetuned/checkpoint-6000/config.json
Model weights saved in distilbart-xsum-12-1-finetuned/checkpoint-6000/pytorch_model.bin
tokenizer config file saved in distilbart-xsum-12-1-finetuned/checkpoint-6000/tokenizer_config.json
Special tokens file saved in distilbart-xsum-12-1-finetuned/checkpoint-6000/special_tokens_map.json
Deleting

Saving model checkpoint to distilbart-xsum-12-1-finetuned/checkpoint-10500
Configuration saved in distilbart-xsum-12-1-finetuned/checkpoint-10500/config.json
Model weights saved in distilbart-xsum-12-1-finetuned/checkpoint-10500/pytorch_model.bin
tokenizer config file saved in distilbart-xsum-12-1-finetuned/checkpoint-10500/tokenizer_config.json
Special tokens file saved in distilbart-xsum-12-1-finetuned/checkpoint-10500/special_tokens_map.json
Deleting older checkpoint [distilbart-xsum-12-1-finetuned/checkpoint-9000] due to args.save_total_limit
Saving model checkpoint to distilbart-xsum-12-1-finetuned/checkpoint-11000
Configuration saved in distilbart-xsum-12-1-finetuned/checkpoint-11000/config.json
Model weights saved in distilbart-xsum-12-1-finetuned/checkpoint-11000/pytorch_model.bin
tokenizer config file saved in distilbart-xsum-12-1-finetuned/checkpoint-11000/tokenizer_config.json
Special tokens file saved in distilbart-xsum-12-1-finetuned/checkpoint-11000/special_tokens_map.jso

Saving model checkpoint to distilbart-xsum-12-1-finetuned/checkpoint-15500
Configuration saved in distilbart-xsum-12-1-finetuned/checkpoint-15500/config.json
Model weights saved in distilbart-xsum-12-1-finetuned/checkpoint-15500/pytorch_model.bin
tokenizer config file saved in distilbart-xsum-12-1-finetuned/checkpoint-15500/tokenizer_config.json
Special tokens file saved in distilbart-xsum-12-1-finetuned/checkpoint-15500/special_tokens_map.json
Deleting older checkpoint [distilbart-xsum-12-1-finetuned/checkpoint-14000] due to args.save_total_limit
Saving model checkpoint to distilbart-xsum-12-1-finetuned/checkpoint-16000
Configuration saved in distilbart-xsum-12-1-finetuned/checkpoint-16000/config.json
Model weights saved in distilbart-xsum-12-1-finetuned/checkpoint-16000/pytorch_model.bin
tokenizer config file saved in distilbart-xsum-12-1-finetuned/checkpoint-16000/tokenizer_config.json
Special tokens file saved in distilbart-xsum-12-1-finetuned/checkpoint-16000/special_tokens_map.js

Saving model checkpoint to distilbart-xsum-12-1-finetuned/checkpoint-20500
Configuration saved in distilbart-xsum-12-1-finetuned/checkpoint-20500/config.json
Model weights saved in distilbart-xsum-12-1-finetuned/checkpoint-20500/pytorch_model.bin
tokenizer config file saved in distilbart-xsum-12-1-finetuned/checkpoint-20500/tokenizer_config.json
Special tokens file saved in distilbart-xsum-12-1-finetuned/checkpoint-20500/special_tokens_map.json
Deleting older checkpoint [distilbart-xsum-12-1-finetuned/checkpoint-19000] due to args.save_total_limit
Saving model checkpoint to distilbart-xsum-12-1-finetuned/checkpoint-21000
Configuration saved in distilbart-xsum-12-1-finetuned/checkpoint-21000/config.json
Model weights saved in distilbart-xsum-12-1-finetuned/checkpoint-21000/pytorch_model.bin
tokenizer config file saved in distilbart-xsum-12-1-finetuned/checkpoint-21000/tokenizer_config.json
Special tokens file saved in distilbart-xsum-12-1-finetuned/checkpoint-21000/special_tokens_map.js

Saving model checkpoint to distilbart-xsum-12-1-finetuned/checkpoint-25500
Configuration saved in distilbart-xsum-12-1-finetuned/checkpoint-25500/config.json
Model weights saved in distilbart-xsum-12-1-finetuned/checkpoint-25500/pytorch_model.bin
tokenizer config file saved in distilbart-xsum-12-1-finetuned/checkpoint-25500/tokenizer_config.json
Special tokens file saved in distilbart-xsum-12-1-finetuned/checkpoint-25500/special_tokens_map.json
Deleting older checkpoint [distilbart-xsum-12-1-finetuned/checkpoint-24000] due to args.save_total_limit
Saving model checkpoint to distilbart-xsum-12-1-finetuned/checkpoint-26000
Configuration saved in distilbart-xsum-12-1-finetuned/checkpoint-26000/config.json
Model weights saved in distilbart-xsum-12-1-finetuned/checkpoint-26000/pytorch_model.bin
tokenizer config file saved in distilbart-xsum-12-1-finetuned/checkpoint-26000/tokenizer_config.json
Special tokens file saved in distilbart-xsum-12-1-finetuned/checkpoint-26000/special_tokens_map.js

Saving model checkpoint to distilbart-xsum-12-1-finetuned/checkpoint-30500
Configuration saved in distilbart-xsum-12-1-finetuned/checkpoint-30500/config.json
Model weights saved in distilbart-xsum-12-1-finetuned/checkpoint-30500/pytorch_model.bin
tokenizer config file saved in distilbart-xsum-12-1-finetuned/checkpoint-30500/tokenizer_config.json
Special tokens file saved in distilbart-xsum-12-1-finetuned/checkpoint-30500/special_tokens_map.json
Deleting older checkpoint [distilbart-xsum-12-1-finetuned/checkpoint-29000] due to args.save_total_limit
Saving model checkpoint to distilbart-xsum-12-1-finetuned/checkpoint-31000
Configuration saved in distilbart-xsum-12-1-finetuned/checkpoint-31000/config.json
Model weights saved in distilbart-xsum-12-1-finetuned/checkpoint-31000/pytorch_model.bin
tokenizer config file saved in distilbart-xsum-12-1-finetuned/checkpoint-31000/tokenizer_config.json
Special tokens file saved in distilbart-xsum-12-1-finetuned/checkpoint-31000/special_tokens_map.js

Saving model checkpoint to distilbart-xsum-12-1-finetuned/checkpoint-35500
Configuration saved in distilbart-xsum-12-1-finetuned/checkpoint-35500/config.json
Model weights saved in distilbart-xsum-12-1-finetuned/checkpoint-35500/pytorch_model.bin
tokenizer config file saved in distilbart-xsum-12-1-finetuned/checkpoint-35500/tokenizer_config.json
Special tokens file saved in distilbart-xsum-12-1-finetuned/checkpoint-35500/special_tokens_map.json
Deleting older checkpoint [distilbart-xsum-12-1-finetuned/checkpoint-34000] due to args.save_total_limit
Saving model checkpoint to distilbart-xsum-12-1-finetuned/checkpoint-36000
Configuration saved in distilbart-xsum-12-1-finetuned/checkpoint-36000/config.json
Model weights saved in distilbart-xsum-12-1-finetuned/checkpoint-36000/pytorch_model.bin
tokenizer config file saved in distilbart-xsum-12-1-finetuned/checkpoint-36000/tokenizer_config.json
Special tokens file saved in distilbart-xsum-12-1-finetuned/checkpoint-36000/special_tokens_map.js

Saving model checkpoint to distilbart-xsum-12-1-finetuned/checkpoint-40500
Configuration saved in distilbart-xsum-12-1-finetuned/checkpoint-40500/config.json
Model weights saved in distilbart-xsum-12-1-finetuned/checkpoint-40500/pytorch_model.bin
tokenizer config file saved in distilbart-xsum-12-1-finetuned/checkpoint-40500/tokenizer_config.json
Special tokens file saved in distilbart-xsum-12-1-finetuned/checkpoint-40500/special_tokens_map.json
Deleting older checkpoint [distilbart-xsum-12-1-finetuned/checkpoint-39000] due to args.save_total_limit
Saving model checkpoint to distilbart-xsum-12-1-finetuned/checkpoint-41000
Configuration saved in distilbart-xsum-12-1-finetuned/checkpoint-41000/config.json
Model weights saved in distilbart-xsum-12-1-finetuned/checkpoint-41000/pytorch_model.bin
tokenizer config file saved in distilbart-xsum-12-1-finetuned/checkpoint-41000/tokenizer_config.json
Special tokens file saved in distilbart-xsum-12-1-finetuned/checkpoint-41000/special_tokens_map.js

TrainOutput(global_step=42837, training_loss=3.3366516485160136, metrics={'train_runtime': 25318.5563, 'train_samples_per_second': 10.151, 'train_steps_per_second': 1.692, 'total_flos': 2.652226151157596e+17, 'train_loss': 3.3366516485160136, 'epoch': 1.0})

In [10]:
# predictions = trainer.predict(tokenized_dataset["test"])

In [17]:
# predictions

PredictionOutput(predictions=array([[    0,     3,     9, ...,     3,  8499,     3],
       [    0,     3,     9, ..., 19972,     3,     5],
       [    0,     3,     9, ...,  1904,     3,     6],
       ...,
       [    0,     3,     9, ...,   579,  1899, 15786],
       [    0,     3,     9, ...,    44,  5590,     7],
       [    0,     3,     9, ...,     9, 16188, 14286]]), label_ids=array([[    3,     9,  3317, ...,    84, 10446,     1],
       [    3,     9,  1573, ...,     3,   117,     1],
       [    3,     9, 23795, ...,  -100,  -100,  -100],
       ...,
       [    3,     9,  1573, ...,  7415,     8,     1],
       [    3,     9,  3240, ...,   689,    21,     1],
       [    3,     9,     3, ...,  -100,  -100,  -100]]), metrics={'test_loss': 2.4342410564422607, 'test_rouge1': 17.7916, 'test_rouge2': 6.4898, 'test_rougeL': 14.9019, 'test_rougeLsum': 16.0102, 'test_gen_len': 19.0, 'test_runtime': 640.4784, 'test_samples_per_second': 22.294, 'test_steps_per_second': 2.787})

In [15]:
# preds = np.argmax(predictions.predictions, axis=-1)

In [11]:
trainer.save_model('trained_models/bart_trained')

Saving model checkpoint to trained_models/bart_trained
Configuration saved in trained_models/bart_trained/config.json
Model weights saved in trained_models/bart_trained/pytorch_model.bin
tokenizer config file saved in trained_models/bart_trained/tokenizer_config.json
Special tokens file saved in trained_models/bart_trained/special_tokens_map.json
