In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, logging
import torch
import torch.nn as nn

logging.set_verbosity_info()

# Options

In [None]:
category_dict = {"유형": "type", "극성": "polarity", "시제": "tense", "확실성": "certainty"}
category = "확실성"     # 유형, 극성, 시제, 확실성
english_category = category_dict[category]   # type, polarity, tense, certainty
pretrained_model_name_or_path = "kykim/electra-kor-base"

# Prepare Dataset

In [None]:
from datasets import load_dataset

ds = load_dataset(
    "csv",
    data_files={"train": f"data/train_data_{english_category}.csv", \
                "test": f"data/validation_data_{english_category}.csv"}
)

In [None]:
from datasets import ClassLabel

names = list(set(ds["train"][category]))
num_labels = len(names)
cl = ClassLabel(num_classes=num_labels, names=names)
id2label = {k: v for k, v in enumerate(cl.names)}

ds = ds.cast_column(category, cl)

# Model

In [None]:
ckpt_list = {
    "극성": "results/극성/2022-12-21_034027/checkpoint-1239",
    "시제": "results/시제/2022-12-21_053512/checkpoint-826",
    "유형": "results/유형/2022-12-21_014535/checkpoint-1239",
    "확실성": "results/확실성/2022-12-21_073113/checkpoint-1239"
}

In [None]:
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
student_model = AutoModelForSequenceClassification.from_pretrained(
    pretrained_model_name_or_path,
    num_labels=num_labels,
    id2label=id2label
)

In [None]:
teacher_model = AutoModelForSequenceClassification.from_pretrained(
    ckpt_list[category]
)

# Preprocess

In [None]:
remove_columns = list(set(ds["train"].features) - {"input_ids", "token_type_ids", "attention_mask", category})
remove_columns

In [None]:
def tokenize_function(batch):
    tokens = tokenizer(batch["문장"], padding="max_length", truncation=True)
    return tokens

ds = ds.map(tokenize_function, batched=True, remove_columns=remove_columns)

In [None]:
ds = ds.with_format("torch")
ds = ds.rename_column(category, "labels")

# Metrics

In [None]:
class ConfiguredMetric:
    def __init__(self, metric, *metric_args, **metric_kwargs):
        self.metric = metric
        self.metric_args = metric_args
        self.metric_kwargs = metric_kwargs

    def add(self, *args, **kwargs):
        return self.metric.add(*args, **kwargs)

    def add_batch(self, *args, **kwargs):
        return self.metric.add_batch(*args, **kwargs)

    def compute(self, *args, **kwargs):
        return self.metric.compute(*args, *self.metric_args, **kwargs, **self.metric_kwargs)

    @property
    def name(self):
        return self.metric.name

    def _feature_names(self):
        return self.metric._feature_names()

In [None]:
import evaluate
import numpy as np

metrics = evaluate.combine([
    evaluate.load('accuracy'),
    ConfiguredMetric(evaluate.load('f1'), average='weighted'),
    ConfiguredMetric(evaluate.load('precision'), average='weighted'),
    ConfiguredMetric(evaluate.load('recall'), average='weighted'),
])

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metrics.compute(predictions=predictions, references=labels)

# Trainer

In [None]:
from datetime import datetime
now = datetime.now()
name = now.strftime("%Y-%m-%d_%H%M%S")
name

In [None]:
from transformers import TrainingArguments, Trainer
import torch
import torch.nn as nn
import torch.nn.functional as F

class DistillationTrainingArguments(TrainingArguments):
    def __init__(self, *args, alpha=0.5, temperature=2.0, **kwargs):
        super().__init__(*args, **kwargs)

        self.alpha = alpha
        self.temperature = temperature

In [None]:
import wandb

wandb.init(
    name = name,
    tags = [category, "teacher-student", pretrained_model_name_or_path],
    project = "huggingface",
)

In [None]:
training_args = DistillationTrainingArguments(
    output_dir=f'./results/{category}/{name}',          # output directory
    num_train_epochs=10,             # total # of training epochs
    per_device_train_batch_size=16,  # batch size per device during training
    per_device_eval_batch_size=64,   # batch size for evaluation
    gradient_accumulation_steps=2,   # Number of updates steps to accumulate the gradients for
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    evaluation_strategy = "epoch",
    save_strategy= "epoch",
    learning_rate=1e-4,
    do_eval=True,
    logging_steps=50,
    fp16=True,
    run_name=name,
)

In [None]:
class DistillationTrainer(Trainer):
    def __init__(self, *args, teacher_model=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher = teacher_model
        # place teacher on same device as student
        self._move_model_to_device(self.teacher, self.model.device)
        self.teacher.eval()

    def compute_loss(self, model, inputs, return_outputs=False):

        # compute student output
        outputs_student = model(**inputs)
        student_loss = outputs_student["loss"] if isinstance(outputs_student, dict) else outputs_student[0]
        # compute teacher output
        with torch.no_grad():
            outputs_teacher = self.teacher(**inputs)

        # assert size
        assert outputs_student.logits.size() == outputs_teacher.logits.size()

        # Soften probabilities and compute distillation loss
        loss_function = nn.KLDivLoss(reduction="batchmean")
        loss_logits = (loss_function(
            F.log_softmax(outputs_student["logits"] / self.args.temperature, dim=-1),
            F.softmax(outputs_teacher["logits"] / self.args.temperature, dim=-1)) * (self.args.temperature ** 2))
        # Return weighted student loss
        loss = self.args.alpha * student_loss + (1. - self.args.alpha) * loss_logits
        return (loss, outputs_student) if return_outputs else loss

In [None]:
trainer = DistillationTrainer(
    model = student_model,
    teacher_model = teacher_model,
    args = training_args,
    train_dataset=ds["train"],
    eval_dataset=ds["test"],
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()