<a href="https://colab.research.google.com/github/hsong-77/transformer-practice/blob/main/model-compression.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers
!pip install datasets

In [None]:
from datasets import load_dataset

clinc = load_dataset("clinc_oos", "plus")
clinc

In [None]:
sample = clinc["test"][42]
sample

In [None]:
intents = clinc["test"].features["intent"]
intents.int2str(sample["intent"])

In [None]:
from transformers import pipeline

ckpt = "transformersbook/bert-base-uncased-finetuned-clinc"
pipe = pipeline("text-classification", model=ckpt)

In [None]:
query = """Hey, I'd like to rent a vehicle from Nov 1st to Nov 15th in Paris and I need a 15 passenger van"""
pipe(query)

In [None]:
import torch
import numpy as np
from datasets import load_metric
from pathlib import Path
from time import perf_counter

accuracy_score = load_metric("accuracy")

class PerformanceBenchmark:
  def __init__(self, pipeline, dataset, optim_type="BERT baseline"):
    self.pipeline = pipeline
    self.dataset = dataset
    self.optim_type = optim_type

  def compute_accuracy(self):
    preds, labels = [], []
    for example in self.dataset:
      pred = self.pipeline(example["text"])[0]["label"]
      label = example["intent"]
      preds.append(intents.str2int(pred))
      labels.append(label)

    accuracy = accuracy_score.compute(predictions=preds, references=labels)
    print(f"Accuracy on test set - {accuracy['accuracy']:.3f}")
    return accuracy

  def compute_size(self):
    state_dict = self.pipeline.model.state_dict()
    tmp_path = Path("model.pt")
    torch.save(state_dict, tmp_path)
    size_mb = Path(tmp_path).stat().st_size / (1024 * 1024)
    # delete tmp path
    tmp_path.unlink()

    print(f"Model size (MB) - {size_mb:.2f}")
    return {"size_mb": size_mb}

  def time_pipeline(self, query):
    latencies = []
    # warm up
    for _ in range(10):
      _ = self.pipeline(query)
      # timed run
    for _ in range(100):
      start_time = perf_counter()
      _ = self.pipeline(query)
      latency = perf_counter() - start_time
      latencies.append(latency)
    
    time_avg_ms = 1000 * np.mean(latencies)
    time_std_ms = 1000 * np.std(latencies)
    print(f"Average latency (ms) - {time_avg_ms:.2f} +\- {time_std_ms:.2f}")
    return {"time_avg_ms": time_avg_ms, "time_std_ms": time_std_ms}

  def run_benchmark(self, query):
    metrics = {}
    metrics[self.optim_type] = self.compute_size()
    metrics[self.optim_type].update(self.time_pipeline(query))
    metrics[self.optim_type].update(self.compute_accuracy())

    return metrics

In [None]:
pb = PerformanceBenchmark(pipe, clinc["test"])
perf_metrics = pb.run_benchmark("What is the pin number for my account?")

In [None]:
from transformers import TrainingArguments

class DistillationTrainingArguments(TrainingArguments):
  def __init__(self, *args, alpha=0.5, temperature=2.0, **kwargs):
    super().__init__(*args, **kwargs)
    self.alpha = alpha
    self.temperature = temperature

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from transformers import Trainer

class DistillationTrainer(Trainer):
  def __init__(self, *args, teacher_model=None, **kwargs):
    super().__init__(*args, **kwargs)
    self.teacher_model = teacher_model

  def compute_loss(self, model, inputs, return_outputs=False):
    outputs_stu = model(**inputs)
    loss_ce = outputs_stu.loss
    logits_stu = outputs_stu.logits

    with torch.no_grad():
      outputs_tea = self.teacher_model(**inputs)
      logits_tea = outputs_tea.logits
    loss_fct = nn.KLDivLoss(reduction="batchmean")
    loss_kd = self.args.temperature ** 2 * loss_fct(F.log_softmax(logits_stu / self.args.temperature, dim=-1), F.softmax(logits_tea / self.args.temperature, dim=-1))

    loss = self.args.alpha * loss_ce + (1 - self.args.alpha) * loss_kd
    return (loss, outputs_stu) if return_outputs else loss

In [None]:
from transformers import AutoTokenizer

student_ckpt = "distilbert-base-uncased"
student_tokenizer = AutoTokenizer.from_pretrained(student_ckpt)

def tokenize_text(batch):
  return student_tokenizer(batch["text"], truncation=True)

clinc_enc = clinc.map(tokenize_text, batched=True, remove_columns=["text"])
clinc_enc = clinc_enc.rename_column("intent", "labels")
clinc_enc

In [None]:
def compute_metrics(pred):
  predictions, labels = pred
  predictions = np.argmax(predictions, axis=-1)
  return accuracy_score.compute(predictions=predictions, references=labels)