In [None]:
import datasets
import transformers
import rouge

In [None]:
train_data = datasets.load_dataset("cnn_dailymail", "3.0.0", split="train")
validation_data = datasets.load_dataset("cnn_dailymail", "3.0.0", split="validation[:10%]")

### Sample of the data

In [None]:
def show_examples(dataset, num_samples=3, seed=42):
    samples = dataset.shuffle(seed=seed).select(range(num_samples))
        
    for idx, sample in enumerate(samples):
        display(f'sample {idx}: {sample["article"]} \n')
        display(f'highlight {idx}: {sample["highlights"]} \n')
        display(f'id: {sample["id"]}')
        display('-------')
        
def get_samples(dataset, num_samples=10):
    return dataset.shuffle(seed=1).select(range(num_samples))

def get_random_sample(dataset):
    sample = dataset.shuffle(seed=1).select(range(1)) 
    return [sample["article"][0], sample["highlights"][0]]

In [None]:
# get_random_sample(train_data)

### Tokenizer

In [None]:
batch_size=4  # change to 16 for full training
encoder_max_length=512
decoder_max_length=128


tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-uncased")

def tokenize_data_to_model_input(batch):
    inputs  = tokenizer(batch["article"], padding="max_length", 
                       truncation=True, max_length=encoder_max_length)
    outputs = tokenizer(batch["highlights"], padding="max_length", 
                        truncation=True, max_length=decoder_max_length)
    
    batch["input_ids"] = inputs.input_ids
    batch["attention_mask"] = inputs.attention_mask
    batch["decoder_input_ids"] = outputs.input_ids
    batch["decoder_attention_mask"] = outputs.attention_mask
    batch["labels"] = outputs.input_ids.copy()
    
    # because BERT automatically shifts the labels, the labels correspond exactly to `decoder_input_ids`. 
    # We have to make sure that the PAD token is ignored
    batch["labels"] = [[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]]
    
    return batch

# For now subsample is being used
train_data = train_data.select(range(32))

train_data = train_data.map(
    tokenize_data_to_model_input, 
    batched=True, 
    batch_size=batch_size, 
    remove_columns=["article", "highlights", "id"]
)
print(train_data)

train_data.set_format(
    type="torch", columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)

print(train_data)


# only use 16 training examples for notebook - DELETE LINE FOR FULL TRAINING
validation_data = validation_data.select(range(16))

validation_data = validation_data.map(
    tokenize_data_to_model_input, 
    batched=True, 
    batch_size=batch_size, 
    remove_columns=["article", "highlights", "id"]
)
validation_data.set_format(
    type="torch", columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)
    

### Encoder - Decoder

In [None]:
bert2bert = transformers.EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")


In [None]:
# set special tokens
bert2bert.config.decoder_start_token_id = tokenizer.bos_token_id
bert2bert.config.eos_token_id = tokenizer.eos_token_id
bert2bert.config.pad_token_id = tokenizer.pad_token_id

# sensible parameters for beam search
bert2bert.config.vocab_size = bert2bert.config.decoder.vocab_size
bert2bert.config.max_length = 142
bert2bert.config.min_length = 56
bert2bert.config.no_repeat_ngram_size = 3
bert2bert.config.early_stopping = True
bert2bert.config.length_penalty = 2.0
bert2bert.config.num_beams = 4

### Evaluation Metric

In [None]:
from rouge import Rouge 

rouge_scorer = Rouge()

def compute_evaluation_metric(prediction):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    # all unnecessary tokens are removed
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = tokenizer.pad_token_id
    label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
    
    score = rouge_scorer.get_scores(label_str, pred_str)
    score = round(score, 4)
    
    return {
        "rouge2_fmeasure": score[0]["rouge-2"]["f"]
    }

### Training

In [None]:
from typing_extensions import Protocol, runtime_checkable
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer

training_args = Seq2SeqTrainingArguments(
    output_dir="./",
    evaluation_strategy="steps",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    predict_with_generate=True,
    logging_steps=2,  # set to 1000 for full training
    save_steps=16,  # set to 500 for full training
    eval_steps=4,  # set to 8000 for full training
    warmup_steps=1,  # set to 2000 for full training
    max_steps=16, # delete for full training
    overwrite_output_dir=True,
    save_total_limit=3
)

# args = Seq2SeqTrainingArguments(
#     f"finetuned-xsum",
#     evaluation_strategy = "epoch",
#     learning_rate=2e-5,
#     per_device_train_batch_size=batch_size,
#     per_device_eval_batch_size=batch_size,
#     weight_decay=0.01,
#     save_total_limit=3,
#     num_train_epochs=1,
#     predict_with_generate=True,
# #     fp16=True,
# )

trainer = Seq2SeqTrainer(
    model=bert2bert,
    tokenizer=tokenizer,
    args=training_args,
    compute_metrics=compute_evaluation_metric,
    train_dataset=train_data,
    eval_dataset=validation_data,
)
    
# create_trainer()
trainer.train()

In [None]:
# from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer

# learning_rate=2e-5
# weight_decay = 0.01
# training_args = Seq2SeqTrainingArguments(
#     output_dir="./",
#     evaluation_strategy="steps",
#     learning_rate=learning_rate,
# #     per_device_train_batch_size=batch_size,
# #     per_device_eval_batch_size=batch_size,
#     predict_with_generate=True,
#     weight_decay=weight_decay,
# #     logging_steps=2,  # set to 1000 for full training
# #     save_steps=16,  # set to 500 for full training
# #     eval_steps=4,  # set to 8000 for full training
# #     warmup_steps=1,  # set to 2000 for full training
#     max_steps=16, # delete for full training
#     overwrite_output_dir=True,
#     save_total_limit=3,
# #     fp16=True, 
# )

# data_collator = transformers.DataCollatorForSeq2Seq(tokenizer, model=bert2bert)

# # instantiate trainer
# trainer = Seq2SeqTrainer(
#     model=bert2bert,
#     tokenizer=tokenizer,
#     args=training_args,
#     compute_metrics=compute_evaluation_metric,
#     train_dataset=train_data,
# )
# trainer.train()