In [1]:
import os
import datasets
import numpy as np
import transformers
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM
from transformers import DataCollatorForSeq2Seq
from transformers import Seq2SeqTrainingArguments
from transformers import Seq2SeqTrainer
from knowledge_distillation import DistillationTrainingArguments, DistillationTrainer
from huggingface_hub import notebook_login
from dotenv import load_dotenv
load_dotenv(verbose=True)
import torch
import wandb
import sys
sys.path.append('/opt/ml/final-project-level3-nlp-02')
from rouge import compute

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
wandb.init(
    entity="final_project",
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mabbymark[0m (use `wandb login --relogin` to force relogin)


In [28]:
seed = 42

train_size = 5000
eval_size = 500
check_point = "gogamza/kobart-summarization"
max_input_length = 512
max_target_length = 30

batch_size = 8
num_train_epochs = 2
learning_rate=5.6e-5
weight_decay = 0.01
logging_steps = 100
model_name = check_point.split("/")[-1]

is_distillation = False
if is_distillation:
    student_check_point = "encoder_decoder_pruned_last_3"
    teacher_check_point = "kobart-summarization-finetuned-paper-sample-size-1000/checkpoint-1000"
    alpha=0.5
    temperature = 2.0


# Loading Dataset

In [29]:
api_token = os.getenv('HF_DATASET_API_TOKEN')
dataset = datasets.load_dataset('metamong1/summarization_paper', use_auth_token=api_token)

Reusing dataset paper_summarization (/opt/ml/.cache/huggingface/datasets/metamong1___paper_summarization/Paper Summarization/1.4.0/24bb09528ebb04fdc6aafb6e110202e52fbb818c0f204839bc833d8ce1e86a5f)


  0%|          | 0/2 [00:00<?, ?it/s]

In [30]:
training_dataset = dataset['train'].shuffle(seed=seed).select(range(train_size))

Loading cached shuffled indices for dataset at /opt/ml/.cache/huggingface/datasets/metamong1___paper_summarization/Paper Summarization/1.4.0/24bb09528ebb04fdc6aafb6e110202e52fbb818c0f204839bc833d8ce1e86a5f/cache-93a7be41ce043bd9.arrow


# Prepare Training

In [31]:
if is_distillation:
    tokenizer = AutoTokenizer.from_pretrained(student_check_point)
else:
    tokenizer = AutoTokenizer.from_pretrained(check_point)

Could not locate the tokenizer configuration file, will try to use the model config instead.
loading configuration file https://huggingface.co/gogamza/kobart-summarization/resolve/main/config.json from cache at /opt/ml/.cache/huggingface/transformers/1c32baaf6a1067a5e27a0dfbac0a3d23a86d958ab10b092d5ea4150bd451de17.4e52ef6c87e6938c92ba0d19888607d76e30e950e81060a8fa6cb1189c93614d
Model config BartConfig {
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartForConditionalGeneration"
  ],
  "attention_dropout": 0.0,
  "author": "Heewon Jeon(madjakarta@gmail.com)",
  "bos_token_id": 0,
  "classif_dropout": 0.1,
  "classifier_dropout": 0.1,
  "d_model": 768,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 3072,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 6,
  "decoder_start_token_id": 2,
  "do_blenderbot_90_layernorm": false,
  "dropout": 0.1,
  "encoder_attention_heads": 16,
  "enc

In [32]:
def preprocess_function(examples):
    model_inputs = tokenizer(
        examples['text'], max_length=max_input_length, truncation = True, #padding=True
    )

    # Set up the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            examples['title'], max_length=max_target_length, truncation=True, #padding=True
        )
    
    model_inputs['labels'] = labels['input_ids']

    return model_inputs

In [33]:
tokenized_train_dataset = training_dataset.map(preprocess_function, batched=True)

Loading cached processed dataset at /opt/ml/.cache/huggingface/datasets/metamong1___paper_summarization/Paper Summarization/1.4.0/24bb09528ebb04fdc6aafb6e110202e52fbb818c0f204839bc833d8ce1e86a5f/cache-53d2d619f903a834.arrow


In [34]:
tokenized_eval_dataset = dataset['validation'].select(range(500)).map(preprocess_function, batched=True)

Loading cached processed dataset at /opt/ml/.cache/huggingface/datasets/metamong1___paper_summarization/Paper Summarization/1.4.0/24bb09528ebb04fdc6aafb6e110202e52fbb818c0f204839bc833d8ce1e86a5f/cache-4917a92c5c9016f3.arrow


# Training

In [35]:
if is_distillation:
    student_model = AutoModelForSeq2SeqLM.from_pretrained(student_check_point).to(device)
    teacher_model = AutoModelForSeq2SeqLM.from_pretrained(teacher_check_point).to(device)
else:
    model = AutoModelForSeq2SeqLM.from_pretrained(check_point).to(device)

loading configuration file https://huggingface.co/gogamza/kobart-summarization/resolve/main/config.json from cache at /opt/ml/.cache/huggingface/transformers/1c32baaf6a1067a5e27a0dfbac0a3d23a86d958ab10b092d5ea4150bd451de17.4e52ef6c87e6938c92ba0d19888607d76e30e950e81060a8fa6cb1189c93614d
Model config BartConfig {
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartForConditionalGeneration"
  ],
  "attention_dropout": 0.0,
  "author": "Heewon Jeon(madjakarta@gmail.com)",
  "bos_token_id": 0,
  "classif_dropout": 0.1,
  "classifier_dropout": 0.1,
  "d_model": 768,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 3072,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 6,
  "decoder_start_token_id": 2,
  "do_blenderbot_90_layernorm": false,
  "dropout": 0.1,
  "encoder_attention_heads": 16,
  "encoder_ffn_dim": 3072,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 6,
  "eos_token_id": 1,


In [36]:
if is_distillation:
    args =  DistillationTrainingArguments(
        output_dir=f'{student_check_point}-distilled-{train_size}',
        evaluation_strategy = 'steps',
        learning_rate=learning_rate,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        weight_decay = weight_decay,
        save_total_limit=2,
        num_train_epochs=num_train_epochs,
        predict_with_generate=True,
        logging_steps=logging_steps,
        alpha=alpha,
        temperature=temperature,
        report_to='wandb'
    )
else:
    args = Seq2SeqTrainingArguments(
        output_dir=f'{model_name}-finetuned-{train_size}',
        evaluation_strategy='steps', #'epoch',
        learning_rate=learning_rate,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        weight_decay = weight_decay,
        save_total_limit=2,
        num_train_epochs=num_train_epochs,
        predict_with_generate=True,
        logging_steps=logging_steps,
        report_to='wandb'
        # push_to_hub=True,
    )

using `logging_steps` to initialize `eval_steps` to 100
PyTorch: setting up devices


In [37]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred

    # Decode generated summaries into text
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)

    # Replace -100 in the labels as we can't decode them
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)

    # Decode reference summaries into text
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Compute ROUGE scores
    result = compute(
        predictions=decoded_preds, references=decoded_labels, tokenizer=tokenizer
    )

    # Extract the median scores
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    return {k: round(v, 4) for k, v in result.items()}

In [38]:
if is_distillation:
    data_collator = DataCollatorForSeq2Seq(tokenizer, model=student_model)
else:
    data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [39]:
if is_distillation:
    trainer = DistillationTrainer(
        model=student_model,
        args=args,
        teacher_model = teacher_model,
        train_dataset=tokenized_train_dataset,
        eval_dataset=tokenized_eval_dataset,
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
    )
else:
    trainer = Seq2SeqTrainer(
        model=model,
        args=args,
        train_dataset=tokenized_train_dataset,
        eval_dataset=tokenized_eval_dataset,
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics, 
    )

In [40]:
trainer.train()

The following columns in the training set  don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: title, file, text, doc_id, doc_type.
***** Running training *****
  Num examples = 5000
  Num Epochs = 2
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 1250
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"


Step,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
100,2.427,2.219417,38.737,25.6732,34.9461,34.9241


The following columns in the evaluation set  don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: title, file, text, doc_id, doc_type.
***** Running Evaluation *****
  Num examples = 500
  Batch size = 8


KeyboardInterrupt: 

Exception in thread Thread-7:
Traceback (most recent call last):
  File "/opt/conda/envs/final/lib/python3.8/threading.py", line 932, in _bootstrap_inner
    self.run()
  File "/opt/conda/envs/final/lib/python3.8/threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/conda/envs/final/lib/python3.8/site-packages/wandb/sdk/wandb_run.py", line 149, in check_network_status
    status_response = self._interface.communicate_network_status()
  File "/opt/conda/envs/final/lib/python3.8/site-packages/wandb/sdk/interface/interface.py", line 120, in communicate_network_status
    resp = self._communicate_network_status(status)
  File "/opt/conda/envs/final/lib/python3.8/site-packages/wandb/sdk/interface/interface_queue.py", line 411, in _communicate_network_status
    resp = self._communicate(req, local=True)
  File "/opt/conda/envs/final/lib/python3.8/site-packages/wandb/sdk/interface/interface_queue.py", line 232, in _communicate
    return self._communicate_