In [1]:
import os
import datasets
import numpy as np
import transformers
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM
from transformers import DataCollatorForSeq2Seq
from transformers import Seq2SeqTrainingArguments
from transformers import Seq2SeqTrainer
from knowledge_distillation import DistillationTrainingArguments, DistillationTrainer
from huggingface_hub import notebook_login
from dotenv import load_dotenv
load_dotenv(verbose=True)
import torch
import wandb
import sys
sys.path.append('/opt/ml/final-project-level3-nlp-02')
from rouge import compute

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
wandb.init(
    entity="final_project",
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mabbymark[0m (use `wandb login --relogin` to force relogin)


In [16]:
seed = 42

train_size = 5000
eval_size = 500
check_point = "gogamza/kobart-summarization"
max_input_length = 512
max_target_length = 30

batch_size = 8
num_train_epochs = 5
learning_rate=5.6e-5
weight_decay = 0.01
logging_steps = 100
model_name = check_point.split("/")[-1]

is_distillation = True
if is_distillation:
    student_check_point = "encoder_decoder_pruned_last_3"
    teacher_check_point = "kobart-summarization-finetuned-paper-sample-size-1000/checkpoint-1000"
    alpha=0.5
    temperature = 2.0


# Loading Dataset

In [3]:
api_token = os.getenv('HF_DATASET_API_TOKEN')
dataset = datasets.load_dataset('metamong1/summarization_paper', use_auth_token=api_token)

Reusing dataset paper_summarization (/opt/ml/.cache/huggingface/datasets/metamong1___paper_summarization/Paper Summarization/1.4.0/24bb09528ebb04fdc6aafb6e110202e52fbb818c0f204839bc833d8ce1e86a5f)


  0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
training_dataset = dataset['train'].shuffle(seed=seed).select(range(train_size))

Loading cached shuffled indices for dataset at /opt/ml/.cache/huggingface/datasets/metamong1___paper_summarization/Paper Summarization/1.4.0/24bb09528ebb04fdc6aafb6e110202e52fbb818c0f204839bc833d8ce1e86a5f/cache-93a7be41ce043bd9.arrow


# Prepare Training

In [5]:
if is_distillation:
    tokenizer = AutoTokenizer.from_pretrained(student_check_point)
else:
    tokenizer = AutoTokenizer.from_pretrained(check_point)

In [6]:
def preprocess_function(examples):
    model_inputs = tokenizer(
        examples['text'], max_length=max_input_length, truncation = True, #padding=True
    )

    # Set up the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            examples['title'], max_length=max_target_length, truncation=True, #padding=True
        )
    
    model_inputs['labels'] = labels['input_ids']

    return model_inputs

In [7]:
tokenized_train_dataset = training_dataset.map(preprocess_function, batched=True)

Loading cached processed dataset at /opt/ml/.cache/huggingface/datasets/metamong1___paper_summarization/Paper Summarization/1.4.0/24bb09528ebb04fdc6aafb6e110202e52fbb818c0f204839bc833d8ce1e86a5f/cache-0083b6d5c7c91f98.arrow


In [8]:
tokenized_eval_dataset = dataset['validation'].select(range(500)).map(preprocess_function, batched=True)

Loading cached processed dataset at /opt/ml/.cache/huggingface/datasets/metamong1___paper_summarization/Paper Summarization/1.4.0/24bb09528ebb04fdc6aafb6e110202e52fbb818c0f204839bc833d8ce1e86a5f/cache-d9529bd184d76b26.arrow


# Training

In [9]:
if is_distillation:
    student_model = AutoModelForSeq2SeqLM.from_pretrained(student_check_point).to(device)
    teacher_model = AutoModelForSeq2SeqLM.from_pretrained(teacher_check_point).to(device)
else:
    model = AutoModelForSeq2SeqLM.from_pretrained(check_point).to(device)



In [17]:
if is_distillation:
    args =  DistillationTrainingArguments(
        output_dir=f'{student_check_point}-distilled-{train_size}',
        evaluation_strategy = 'steps',
        learning_rate=learning_rate,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        weight_decay = weight_decay,
        save_total_limit=2,
        num_train_epochs=num_train_epochs,
        predict_with_generate=True,
        logging_steps=logging_steps,
        alpha=alpha,
        temperature=temperature,
        report_to='wandb'
    )
else:
    args = Seq2SeqTrainingArguments(
        output_dir=f'{model_name}-finetuned-{train_size}',
        evaluation_strategy='steps', #'epoch',
        learning_rate=learning_rate,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        weight_decay = weight_decay,
        save_total_limit=2,
        num_train_epochs=num_train_epochs,
        predict_with_generate=True,
        logging_steps=logging_steps,
        report_to='wandb'
        # push_to_hub=True,
    )

using `logging_steps` to initialize `eval_steps` to 100
PyTorch: setting up devices


In [18]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred

    # Decode generated summaries into text
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)

    # Replace -100 in the labels as we can't decode them
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)

    # Decode reference summaries into text
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Compute ROUGE scores
    result = compute(
        predictions=decoded_preds, references=decoded_labels, tokenizer=tokenizer
    )

    # Extract the median scores
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    return {k: round(v, 4) for k, v in result.items()}

In [19]:
if is_distillation:
    data_collator = DataCollatorForSeq2Seq(tokenizer, model=student_model)
else:
    data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [20]:
if is_distillation:
    trainer = DistillationTrainer(
        model=student_model,
        args=args,
        teacher_model = teacher_model,
        train_dataset=tokenized_train_dataset,
        eval_dataset=tokenized_eval_dataset,
        data_collator=data_collator,
        tokenizer=tokenizer,
        # compute_metrics=compute_metrics,
    )
else:
    trainer = Seq2SeqTrainer(
        model=model,
        args=args,
        train_dataset=tokenized_train_dataset,
        eval_dataset=tokenized_eval_dataset,
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics, 
    )

In [21]:
trainer.train()

The following columns in the training set  don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: text, file, doc_type, doc_id, title.
***** Running training *****
  Num examples = 5000
  Num Epochs = 5
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 3125
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"


Step,Training Loss,Validation Loss
100,127.4428,108.516457
200,119.4237,104.567047
300,117.9695,101.266464
400,115.666,98.521538
500,111.3656,95.728188
600,106.8622,93.204033
700,103.8432,90.837326
800,102.9255,88.669579
900,98.6589,86.641563
1000,95.592,84.876205


The following columns in the evaluation set  don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: text, file, doc_type, doc_id, title.
***** Running Evaluation *****
  Num examples = 500
  Batch size = 8
The following columns in the evaluation set  don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: text, file, doc_type, doc_id, title.
***** Running Evaluation *****
  Num examples = 500
  Batch size = 8
The following columns in the evaluation set  don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: text, file, doc_type, doc_id, title.
***** Running Evaluation *****
  Num examples = 500
  Batch size = 8
The following columns in the evaluation set  don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: text, file, doc_type, doc_id, title.
***** Running Evaluation *****
  Num examples = 500
  Batch s

In [15]:
trainer.evaluate()

The following columns in the evaluation set  don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: text, file, doc_type, doc_id, title.
***** Running Evaluation *****
  Num examples = 500
  Batch size = 8


{'eval_loss': 112.42849731445312,
 'eval_runtime': 4.9179,
 'eval_samples_per_second': 101.669,
 'eval_steps_per_second': 12.81,
 'epoch': 2.0}