In [None]:
import nltk
import torch
import evaluate
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
from model_utils import load_raw_data
import warnings
warnings.simplefilter("ignore", FutureWarning)
warnings.simplefilter("ignore", UserWarning)

In [None]:
def preprocess_data(examples):
    # tokenize input
    inputs = [doc for doc in examples["extract"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)
    # tokenize label
    labels = tokenizer(text_target=examples["abstract"], max_length=max_target_length, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

def compute_rouge(pred):
    predictions, labels = pred
    # decode Predictions
    decode_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # decode labels
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decode_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)  
    # compute results
    res = metric.compute(predictions=decode_predictions, references=decode_labels, use_stemmer=True)
    res = {key: value * 100 for key, value in res.items()}
    # count generate token
    pred_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions] 
    res['gen_len'] = np.mean(pred_lens)
    return {k: round(v, 4) for k, v in res.items()}

#### Dataset, Pretrained Model, Tokenizer

In [None]:
# Setence pair path
train_sentence_pair = '../../dataset/to_abstractive/train_pair.parquet'
test_sentence_pair = '../../dataset/to_abstractive/test_pair.parquet'
# loading metric, sentence pair
metric = evaluate.load('rouge')
raw_data = load_raw_data(train_sentence_pair, test_sentence_pair)

In [None]:
# Loading pretrained model, tokenizer
model_checkpoint = "facebook/bart-base" # or use GanjinZero/biobart-v2-base
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
collator = DataCollatorForSeq2Seq(tokenizer, model)

#### Preprocessing the data

In [None]:
max_input_length = 1024
max_target_length = 512
tokenize_data = raw_data.map(preprocess_data, batched=True, remove_columns=['extract', 'abstract'])
tokenize_data

#### Seq2SeqArguments setting

In [None]:
# 生成超參數
generation_config = {
    'num_beams': 5,
    'max_length': 512,
    'min_length': 64,
    'length_penalty': 2.0,
    'early_stopping': True,
    'no_repeat_ngram_size': 3
    }

model.config.update(generation_config)

In [None]:
# 訓練超參數
args = Seq2SeqTrainingArguments(
    output_dir='model/checkpoint_bart',
    learning_rate=8e-5,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=32,
    per_device_eval_batch_size=1,
    eval_accumulation_steps=64,
    num_train_epochs=9,
    # weight_decay=0.01,
    lr_scheduler_type='linear',
    warmup_ratio=0.1,
    save_total_limit=9,
    save_strategy='epoch',
    evaluation_strategy='epoch',
    logging_steps=5625,
    predict_with_generate=True,
    fp16=True,
    seed=42,
    log_level='error'
    )

#### Seq2SeqTrainer

In [None]:
# 初始化
trainer = Seq2SeqTrainer(
    model, 
    args,
    train_dataset=tokenize_data['train'],
    eval_dataset=tokenize_data['validation'],
    data_collator=collator,
    tokenizer=tokenizer,
    compute_metrics=compute_rouge
    )

In [None]:
# 開始訓練
trainer.train()

#### Supplemental code

The maximum sequence limit of the current model is 1024, you can try the following methods to solve it  
+ Loading LSG Model with larger input sequences  
+ Using LSG Converter to extend the original model to longer sequences

In [None]:
## LSG Model
# model_checkpoint = "ccdv/lsg-bart-base-4096"
# tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
# model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint, trust_remote_code = True, pass_global_tokens_to_decoder = True)

## LSG Converter
# from lsg_converter import LSGConverter
# converter = LSGConverter(max_sequence_length=4096)
# model, tokenizer = converter.convert_from_pretrained("GanjinZero/biobart-v2-base", architecture="BartForConditionalGeneration",
#                                                      num_global_tokens=1,
#                                                      block_size=128, sparse_block_size=128,
#                                                      sparsity_factor=2, mask_first_token=True)

If the training is interrupted, you can train from the checkpoint

In [None]:
# resume_from_checkpoint=True