In [1]:
import os
import datasets
import numpy as np
import transformers
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM
from transformers import DataCollatorForSeq2Seq
from transformers import Seq2SeqTrainingArguments
from transformers import Seq2SeqTrainer
from knowledge_distillation import DistillationTrainingArguments, DistillationTrainer
from huggingface_hub import notebook_login
from dotenv import load_dotenv
load_dotenv(verbose=True)
import torch
import wandb
import sys
sys.path.append('/opt/ml/final-project-level3-nlp-02')
from rouge import compute

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
wandb.init(
    entity="final_project",
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mabbymark[0m (use `wandb login --relogin` to force relogin)


In [41]:
seed = 42

train_size = 5000
eval_size = 500
check_point = 'encoder_decoder_pruned_last_5' #"gogamza/kobart-summarization"
max_input_length = 512
max_target_length = 30

batch_size = 8
num_train_epochs = 20
learning_rate=5.6e-5
weight_decay = 0.01
logging_steps = 100
model_name = check_point.split("/")[-1]

is_distillation = True
if is_distillation:
    student_check_point = "encoder_decoder_pruned_last_5"
    teacher_check_point = "kobart-summarization-finetuned-5000/checkpoint-1000"
    alpha=0.5
    temperature = 30


# Loading Dataset

In [42]:
api_token = os.getenv('HF_DATASET_API_TOKEN')
dataset = datasets.load_dataset('metamong1/summarization_paper', use_auth_token=api_token)

Reusing dataset paper_summarization (/opt/ml/.cache/huggingface/datasets/metamong1___paper_summarization/Paper Summarization/1.4.0/24bb09528ebb04fdc6aafb6e110202e52fbb818c0f204839bc833d8ce1e86a5f)


  0%|          | 0/2 [00:00<?, ?it/s]

In [43]:
training_dataset = dataset['train'].shuffle(seed=seed).select(range(train_size))

Loading cached shuffled indices for dataset at /opt/ml/.cache/huggingface/datasets/metamong1___paper_summarization/Paper Summarization/1.4.0/24bb09528ebb04fdc6aafb6e110202e52fbb818c0f204839bc833d8ce1e86a5f/cache-93a7be41ce043bd9.arrow


# Prepare Training

In [44]:
if is_distillation:
    tokenizer = AutoTokenizer.from_pretrained(student_check_point)
else:
    tokenizer = AutoTokenizer.from_pretrained(check_point)

Didn't find file encoder_decoder_pruned_last_5/added_tokens.json. We won't load it.
loading file encoder_decoder_pruned_last_5/vocab.json
loading file encoder_decoder_pruned_last_5/merges.txt
loading file encoder_decoder_pruned_last_5/tokenizer.json
loading file None
loading file encoder_decoder_pruned_last_5/special_tokens_map.json
loading file encoder_decoder_pruned_last_5/tokenizer_config.json


In [45]:
def preprocess_function(examples):
    model_inputs = tokenizer(
        examples['text'], max_length=max_input_length, truncation = True, #padding=True
    )

    # Set up the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            examples['title'], max_length=max_target_length, truncation=True, #padding=True
        )
    
    model_inputs['labels'] = labels['input_ids']

    return model_inputs

In [46]:
tokenized_train_dataset = training_dataset.map(preprocess_function, batched=True)

Loading cached processed dataset at /opt/ml/.cache/huggingface/datasets/metamong1___paper_summarization/Paper Summarization/1.4.0/24bb09528ebb04fdc6aafb6e110202e52fbb818c0f204839bc833d8ce1e86a5f/cache-7e2fedc45ccdd79e.arrow


In [47]:
tokenized_eval_dataset = dataset['validation'].select(range(500)).map(preprocess_function, batched=True)

Loading cached processed dataset at /opt/ml/.cache/huggingface/datasets/metamong1___paper_summarization/Paper Summarization/1.4.0/24bb09528ebb04fdc6aafb6e110202e52fbb818c0f204839bc833d8ce1e86a5f/cache-6ade4e4681b6569a.arrow


# Training

In [48]:
if is_distillation:
    student_model = AutoModelForSeq2SeqLM.from_pretrained(student_check_point, torch_dtype='auto').to(device)
    teacher_model = AutoModelForSeq2SeqLM.from_pretrained(teacher_check_point, torch_dtype='auto').to(device)
else:
    model = AutoModelForSeq2SeqLM.from_pretrained(check_point, torch_dtype='auto').to(device)

loading configuration file encoder_decoder_pruned_last_5/config.json
Model config BartConfig {
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartForConditionalGeneration"
  ],
  "attention_dropout": 0.0,
  "author": "Heewon Jeon(madjakarta@gmail.com)",
  "bos_token_id": 0,
  "classif_dropout": 0.1,
  "classifier_dropout": 0.1,
  "d_model": 768,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 3072,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 1,
  "decoder_start_token_id": 2,
  "do_blenderbot_90_layernorm": false,
  "dropout": 0.1,
  "encoder_attention_heads": 16,
  "encoder_ffn_dim": 3072,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 1,
  "eos_token_id": 1,
  "extra_pos_embeddings": 2,
  "force_bos_token_to_be_generated": false,
  "forced_eos_token_id": 2,
  "id2label": {
    "0": "NEGATIVE",
    "1": "POSITIVE"
  },
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "lab

In [49]:
if is_distillation:
    args =  DistillationTrainingArguments(
        output_dir=f'{student_check_point}-distilled-{train_size}',
        evaluation_strategy = 'steps',
        learning_rate=learning_rate,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        weight_decay = weight_decay,
        save_total_limit=2,
        num_train_epochs=num_train_epochs,
        predict_with_generate=True,
        logging_steps=logging_steps,
        alpha=alpha,
        temperature=temperature,
        report_to='wandb'
    )
else:
    args = Seq2SeqTrainingArguments(
        output_dir=f'{model_name}-finetuned-{train_size}',
        evaluation_strategy='steps', #'epoch',
        learning_rate=learning_rate,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        weight_decay = weight_decay,
        save_total_limit=2,
        num_train_epochs=num_train_epochs,
        predict_with_generate=True,
        logging_steps=logging_steps,
        report_to='wandb'
        # push_to_hub=True,
    )

using `logging_steps` to initialize `eval_steps` to 100
PyTorch: setting up devices


In [50]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred

    # Decode generated summaries into text
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)

    # Replace -100 in the labels as we can't decode them
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)

    # Decode reference summaries into text
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Compute ROUGE scores
    result = compute(
        predictions=decoded_preds, references=decoded_labels, tokenizer=tokenizer
    )

    # Extract the median scores
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    return {k: round(v, 4) for k, v in result.items()}

In [51]:
if is_distillation:
    data_collator = DataCollatorForSeq2Seq(tokenizer, model=student_model)
else:
    data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [52]:
if is_distillation:
    trainer = DistillationTrainer(
        model=student_model,
        args=args,
        teacher_model = teacher_model,
        train_dataset=tokenized_train_dataset,
        eval_dataset=tokenized_eval_dataset,
        data_collator=data_collator,
        tokenizer=tokenizer,
        # compute_metrics=compute_metrics,
    )
else:
    trainer = Seq2SeqTrainer(
        model=model,
        args=args,
        train_dataset=tokenized_train_dataset,
        eval_dataset=tokenized_eval_dataset,
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics, 
    )

In [53]:
trainer.train()

The following columns in the training set  don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: doc_id, text, doc_type, file, title.
***** Running training *****
  Num examples = 5000
  Num Epochs = 20
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 12500
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"


Step,Training Loss,Validation Loss
100,27.4586,21.147156
200,20.4465,18.38345
300,18.5931,17.028332
400,17.1957,16.156927
500,16.3961,15.512665
600,15.7487,15.080248
700,15.2661,14.60751
800,15.0126,14.232364
900,14.7253,13.945302
1000,14.1601,13.730156


The following columns in the evaluation set  don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: doc_id, text, doc_type, file, title.
***** Running Evaluation *****
  Num examples = 500
  Batch size = 8
The following columns in the evaluation set  don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: doc_id, text, doc_type, file, title.
***** Running Evaluation *****
  Num examples = 500
  Batch size = 8
The following columns in the evaluation set  don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: doc_id, text, doc_type, file, title.
***** Running Evaluation *****
  Num examples = 500
  Batch size = 8
The following columns in the evaluation set  don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: doc_id, text, doc_type, file, title.
***** Running Evaluation *****
  Num examples = 500
  Batch s

TrainOutput(global_step=12500, training_loss=10.382124545898437, metrics={'train_runtime': 1722.2156, 'train_samples_per_second': 58.065, 'train_steps_per_second': 7.258, 'total_flos': 4752394983751680.0, 'train_loss': 10.382124545898437, 'epoch': 20.0})

In [57]:
trainer.evaluate()

The following columns in the evaluation set  don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: title, file, doc_type, text, doc_id.
***** Running Evaluation *****
  Num examples = 500
  Batch size = 8


{'eval_loss': 3.3818862438201904,
 'eval_rouge1': 31.6557,
 'eval_rouge2': 18.708,
 'eval_rougeL': 28.8796,
 'eval_rougeLsum': 28.8763,
 'eval_runtime': 8.5031,
 'eval_samples_per_second': 58.802,
 'eval_steps_per_second': 7.409,
 'epoch': 10.0}