In [1]:
import os
import datasets
import numpy as np
import transformers
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM
from transformers import DataCollatorForSeq2Seq
from transformers import Seq2SeqTrainingArguments
from transformers import Seq2SeqTrainer
from knowledge_distillation import DistillationTrainingArguments, DistillationTrainer, TinyTrainer
from huggingface_hub import notebook_login
from dotenv import load_dotenv
load_dotenv(verbose=True)
import torch
import wandb
import sys
sys.path.append('/opt/ml/final-project-level3-nlp-02')
from rouge import compute

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
wandb.init(
    entity="final_project",
    project='optimization',
    name='edpl3_tiny_25ep'
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mabbymark[0m (use `wandb login --relogin` to force relogin)


In [8]:
seed = 42

train_size = 5000
eval_size = 500
check_point = 'encoder_decoder_pruned_last_5' #"gogamza/kobart-summarization"
max_input_length = 512
max_target_length = 30

batch_size = 8
num_train_epochs = 25
learning_rate=5.6e-5
weight_decay = 0.01
logging_steps = 500
model_name = check_point.split("/")[-1]

distillation_type = 'tiny'
if distillation_type:
    student_check_point = "encoder_decoder_pruned_last_3"
    teacher_check_point = "kobart-summarization-finetuned-5000/checkpoint-1000"
    alpha=0.5
    temperature = 30


# Loading Dataset

In [9]:
api_token = os.getenv('HF_DATASET_API_TOKEN')
dataset = datasets.load_dataset('metamong1/summarization_paper', use_auth_token=api_token)

Reusing dataset paper_summarization (/opt/ml/.cache/huggingface/datasets/metamong1___paper_summarization/Paper Summarization/1.4.0/24bb09528ebb04fdc6aafb6e110202e52fbb818c0f204839bc833d8ce1e86a5f)


  0%|          | 0/2 [00:00<?, ?it/s]

In [10]:
training_dataset = dataset['train'].shuffle(seed=seed).select(range(train_size))

Loading cached shuffled indices for dataset at /opt/ml/.cache/huggingface/datasets/metamong1___paper_summarization/Paper Summarization/1.4.0/24bb09528ebb04fdc6aafb6e110202e52fbb818c0f204839bc833d8ce1e86a5f/cache-93a7be41ce043bd9.arrow


# Prepare Training

In [11]:
if distillation_type:
    tokenizer = AutoTokenizer.from_pretrained(student_check_point)
else:
    tokenizer = AutoTokenizer.from_pretrained(check_point)

In [12]:
def preprocess_function(examples):
    model_inputs = tokenizer(
        examples['text'], max_length=max_input_length, truncation = True, #padding=True
    )

    # Set up the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            examples['title'], max_length=max_target_length, truncation=True, #padding=True
        )
    
    model_inputs['labels'] = labels['input_ids']

    return model_inputs

In [13]:
tokenized_train_dataset = training_dataset.map(preprocess_function, batched=True)

Loading cached processed dataset at /opt/ml/.cache/huggingface/datasets/metamong1___paper_summarization/Paper Summarization/1.4.0/24bb09528ebb04fdc6aafb6e110202e52fbb818c0f204839bc833d8ce1e86a5f/cache-0083b6d5c7c91f98.arrow


In [14]:
tokenized_eval_dataset = dataset['validation'].select(range(500)).map(preprocess_function, batched=True)

Loading cached processed dataset at /opt/ml/.cache/huggingface/datasets/metamong1___paper_summarization/Paper Summarization/1.4.0/24bb09528ebb04fdc6aafb6e110202e52fbb818c0f204839bc833d8ce1e86a5f/cache-d9529bd184d76b26.arrow


# Training

In [15]:
if distillation_type:
    student_model = AutoModelForSeq2SeqLM.from_pretrained(student_check_point, torch_dtype='auto').to(device)
    teacher_model = AutoModelForSeq2SeqLM.from_pretrained(teacher_check_point, torch_dtype='auto').to(device)
else:
    model = AutoModelForSeq2SeqLM.from_pretrained(check_point, torch_dtype='auto').to(device)



In [16]:
if distillation_type:
    args =  DistillationTrainingArguments(
        output_dir=f'{student_check_point}-distilled-{train_size}',
        evaluation_strategy = 'steps',
        learning_rate=learning_rate,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        weight_decay = weight_decay,
        save_total_limit=2,
        num_train_epochs=num_train_epochs,
        predict_with_generate=True,
        logging_steps=logging_steps,
        alpha=alpha,
        temperature=temperature,
        report_to='wandb'
    )
else:
    args = Seq2SeqTrainingArguments(
        output_dir=f'{model_name}-finetuned-{train_size}',
        evaluation_strategy='steps', #'epoch',
        learning_rate=learning_rate,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        weight_decay = weight_decay,
        save_total_limit=2,
        num_train_epochs=num_train_epochs,
        predict_with_generate=True,
        logging_steps=logging_steps,
        report_to='wandb'
        # push_to_hub=True,
    )

In [17]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred

    # Decode generated summaries into text
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)

    # Replace -100 in the labels as we can't decode them
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)

    # Decode reference summaries into text
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Compute ROUGE scores
    result = compute(
        predictions=decoded_preds, references=decoded_labels, tokenizer=tokenizer
    )

    # Extract the median scores
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    return {k: round(v, 4) for k, v in result.items()}

In [18]:
if distillation_type:
    data_collator = DataCollatorForSeq2Seq(tokenizer, model=student_model)
else:
    data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [19]:
if distillation_type == 'distil':
    trainer = DistillationTrainer(
        model=student_model,
        args=args,
        teacher_model = teacher_model,
        train_dataset=tokenized_train_dataset,
        eval_dataset=tokenized_eval_dataset,
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
    )
elif distillation_type == 'tiny':
    trainer = TinyTrainer(
        model=student_model,
        args=args,
        teacher_model = teacher_model,
        train_dataset=tokenized_train_dataset,
        eval_dataset=tokenized_eval_dataset,
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
    )
else:
    trainer = Seq2SeqTrainer(
        model=model,
        args=args,
        train_dataset=tokenized_train_dataset,
        eval_dataset=tokenized_eval_dataset,
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics, 
    )

In [20]:
trainer.train()

The following columns in the training set  don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: doc_id, text, title, file, doc_type.
***** Running training *****
  Num examples = 5000
  Num Epochs = 25
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 15625
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"


Step,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
500,18.2379,6.121234,5.6757,0.5561,5.2032,5.1845
1000,12.9761,5.468546,7.2584,0.7969,6.4861,6.4869
1500,11.2839,4.995747,8.3674,1.1262,7.5826,7.5403
2000,10.3155,4.619568,10.0879,1.8653,8.9028,8.8778
2500,9.5694,4.279552,12.9361,3.4349,11.5578,11.5425
3000,8.8365,3.957933,15.4092,5.0143,13.8356,13.8007
3500,8.2072,3.671987,19.2798,7.8962,17.4057,17.4129
4000,7.8477,3.42772,21.8236,9.8593,19.892,19.9132
4500,7.3809,3.23194,24.6379,11.9653,22.3859,22.3394
5000,7.0077,3.111772,26.233,14.0747,24.1015,24.1009


The following columns in the evaluation set  don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: doc_id, text, title, file, doc_type.
***** Running Evaluation *****
  Num examples = 500
  Batch size = 8
Saving model checkpoint to encoder_decoder_pruned_last_3-distilled-5000/checkpoint-500
Configuration saved in encoder_decoder_pruned_last_3-distilled-5000/checkpoint-500/config.json
Model weights saved in encoder_decoder_pruned_last_3-distilled-5000/checkpoint-500/pytorch_model.bin
tokenizer config file saved in encoder_decoder_pruned_last_3-distilled-5000/checkpoint-500/tokenizer_config.json
Special tokens file saved in encoder_decoder_pruned_last_3-distilled-5000/checkpoint-500/special_tokens_map.json
Deleting older checkpoint [encoder_decoder_pruned_last_3-distilled-5000/checkpoint-30500] due to args.save_total_limit
The following columns in the evaluation set  don't have a corresponding argument in `BartForConditionalGeneration.forwar

KeyboardInterrupt: 

In [57]:
trainer.evaluate()

The following columns in the evaluation set  don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: title, file, doc_type, text, doc_id.
***** Running Evaluation *****
  Num examples = 500
  Batch size = 8


{'eval_loss': 3.3818862438201904,
 'eval_rouge1': 31.6557,
 'eval_rouge2': 18.708,
 'eval_rougeL': 28.8796,
 'eval_rougeLsum': 28.8763,
 'eval_runtime': 8.5031,
 'eval_samples_per_second': 58.802,
 'eval_steps_per_second': 7.409,
 'epoch': 10.0}