<a href="https://colab.research.google.com/github/hsong-77/transformer-practice/blob/main/model-compression.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers
!pip install datasets
!pip install optuna

In [None]:
from datasets import load_dataset

clinc = load_dataset("clinc_oos", "plus")
clinc

In [None]:
sample = clinc["test"][42]
sample

In [None]:
intents = clinc["test"].features["intent"]
intents.int2str(sample["intent"])

In [None]:
from transformers import pipeline

ckpt = "transformersbook/bert-base-uncased-finetuned-clinc"
pipe = pipeline("text-classification", model=ckpt)

In [None]:
query = """Hey, I'd like to rent a vehicle from Nov 1st to Nov 15th in Paris and I need a 15 passenger van"""
pipe(query)

In [None]:
import torch
import numpy as np
from datasets import load_metric
from pathlib import Path
from time import perf_counter

accuracy_score = load_metric("accuracy")

class PerformanceBenchmark:
  def __init__(self, pipeline, dataset, optim_type="BERT baseline"):
    self.pipeline = pipeline
    self.dataset = dataset
    self.optim_type = optim_type

  def compute_accuracy(self):
    preds, labels = [], []
    for example in self.dataset:
      pred = self.pipeline(example["text"])[0]["label"]
      label = example["intent"]
      preds.append(intents.str2int(pred))
      labels.append(label)

    accuracy = accuracy_score.compute(predictions=preds, references=labels)
    print(f"Accuracy on test set - {accuracy['accuracy']:.3f}")
    return accuracy

  def compute_size(self):
    state_dict = self.pipeline.model.state_dict()
    tmp_path = Path("model.pt")
    torch.save(state_dict, tmp_path)
    size_mb = Path(tmp_path).stat().st_size / (1024 * 1024)
    # delete tmp path
    tmp_path.unlink()

    print(f"Model size (MB) - {size_mb:.2f}")
    return {"size_mb": size_mb}

  def time_pipeline(self, query="What is the pin number for my account?"):
    latencies = []
    # warm up
    for _ in range(10):
      _ = self.pipeline(query)
      # timed run
    for _ in range(100):
      start_time = perf_counter()
      _ = self.pipeline(query)
      latency = perf_counter() - start_time
      latencies.append(latency)
    
    time_avg_ms = 1000 * np.mean(latencies)
    time_std_ms = 1000 * np.std(latencies)
    print(f"Average latency (ms) - {time_avg_ms:.2f} +\- {time_std_ms:.2f}")
    return {"time_avg_ms": time_avg_ms, "time_std_ms": time_std_ms}

  def run_benchmark(self):
    metrics = {}
    metrics[self.optim_type] = self.compute_size()
    metrics[self.optim_type].update(self.time_pipeline())
    metrics[self.optim_type].update(self.compute_accuracy())

    return metrics

In [None]:
pb = PerformanceBenchmark(pipe, clinc["test"])
perf_metrics = pb.run_benchmark()

In [None]:
from transformers import TrainingArguments

class DistillationTrainingArguments(TrainingArguments):
  def __init__(self, *args, alpha=0.5, temperature=2.0, **kwargs):
    super().__init__(*args, **kwargs)
    self.alpha = alpha
    self.temperature = temperature

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from transformers import Trainer

class DistillationTrainer(Trainer):
  def __init__(self, *args, teacher_model=None, **kwargs):
    super().__init__(*args, **kwargs)
    self.teacher_model = teacher_model

  def compute_loss(self, model, inputs, return_outputs=False):
    outputs_stu = model(**inputs)
    loss_ce = outputs_stu.loss
    logits_stu = outputs_stu.logits

    with torch.no_grad():
      outputs_tea = self.teacher_model(**inputs)
      logits_tea = outputs_tea.logits
    loss_fct = nn.KLDivLoss(reduction="batchmean")
    loss_kd = self.args.temperature ** 2 * loss_fct(F.log_softmax(logits_stu / self.args.temperature, dim=-1), F.softmax(logits_tea / self.args.temperature, dim=-1))

    loss = self.args.alpha * loss_ce + (1 - self.args.alpha) * loss_kd
    return (loss, outputs_stu) if return_outputs else loss

In [None]:
from transformers import AutoTokenizer

student_ckpt = "distilbert-base-uncased"
student_tokenizer = AutoTokenizer.from_pretrained(student_ckpt)

def tokenize_text(batch):
  return student_tokenizer(batch["text"], truncation=True)

clinc_enc = clinc.map(tokenize_text, batched=True, remove_columns=["text"])
clinc_enc = clinc_enc.rename_column("intent", "labels")
clinc_enc

In [None]:
def compute_metrics(pred):
  predictions, labels = pred
  predictions = np.argmax(predictions, axis=-1)
  return accuracy_score.compute(predictions=predictions, references=labels)

In [None]:
from transformers import AutoConfig

batch_size = 48

finetuned_ckpt = "distilbert-base-uncased-finetuned-clinc"
student_training_args = DistillationTrainingArguments(output_dir=finetuned_ckpt, evaluation_strategy="epoch",
                                                      num_train_epochs=5, learning_rate=2e-5,
                                                      per_device_train_batch_size=batch_size,
                                                      per_device_eval_batch_size=batch_size,
                                                      alpha=1, weight_decay=0.01, push_to_hub=False)

id2label = pipe.model.config.id2label
label2id = pipe.model.config.label2id

num_labels = intents.num_classes
student_config = AutoConfig.from_pretrained(student_ckpt, num_labels=num_labels,
                                            id2label=id2label, label2id=label2id)

In [None]:
from transformers.models import distilbert
from transformers import AutoModelForSequenceClassification

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def student_init():
  return AutoModelForSequenceClassification.from_pretrained(student_ckpt, config=student_config).to(device)

teacher_ckpt = "transformersbook/bert-base-uncased-finetuned-clinc"
teacher_model = AutoModelForSequenceClassification.from_pretrained(teacher_ckpt, num_labels=num_labels).to(device)

distilbert_trainer = DistillationTrainer(model_init=student_init,
                                         teacher_model=teacher_model, args=student_training_args,
                                         train_dataset=clinc_enc["train"], eval_dataset=clinc_enc["validation"],
                                         compute_metrics=compute_metrics, tokenizer=student_tokenizer)

distilbert_trainer.train()

In [None]:
pipe = pipeline("text-classification", model=distilbert_trainer.model, tokenizer=student_tokenizer)

optim_type = "DistilBERT"
pb = PerformanceBenchmark(pipe, clinc["test"], optim_type=optim_type)
perf_metrics.update(pb.run_benchmark())

In [None]:
import matplotlib.pyplot as plt
import pandas as pd

def plot_metrics(perf_metrics, current_optim_type):
  df = pd.DataFrame.from_dict(perf_metrics, orient='index')

  for idx in df.index:
    df_opt = df.loc[idx]
    # Add a dashed circle around the current optimization type
    if idx == current_optim_type:
      plt.scatter(df_opt["time_avg_ms"], df_opt["accuracy"] * 100,
                  alpha=0.5, s=df_opt["size_mb"], label=idx, marker='$\u25CC$')
    else:
      plt.scatter(df_opt["time_avg_ms"], df_opt["accuracy"] * 100,
                        s=df_opt["size_mb"], label=idx, alpha=0.5)

  legend = plt.legend(bbox_to_anchor=(1,1))
  for handle in legend.legendHandles:
    handle.set_sizes([20])

  plt.ylim(70,90)
  # Use the slowest model to define the x-axis range
  xlim = int(perf_metrics["BERT baseline"]["time_avg_ms"] + 7)
  plt.xlim(1, xlim)
  plt.ylabel("Accuracy (%)")
  plt.xlabel("Average latency (ms)")
  plt.show()

plot_metrics(perf_metrics, optim_type)

In [None]:
import optuna

def hp_space(trial):
  return {"num_train_epochs": trial.suggest_int("num_train_epochs", 5, 10),
          "alpha": trial.suggest_float("alpha", 0, 1),
          "temperature": trial.suggest_int("temperature", 2, 20)}

best_run = distilbert_trainer.hyperparameter_search(n_trials=20, direction="maximize", hp_space=hp_space)
best_run

In [None]:
for k, v in best_run.hyperparameters.items():
  setattr(student_training_args, k, v)

distilled_ckpt = "distilbert-base-uncased-distilled-clinc"
student_training_args.output_dir = distilled_ckpt

distil_trainer = DistillationTrainer(model_init=student_init,
                                     teacher_model=teacher_model, args=student_training_args,
                                     train_dataset=clinc_enc['train'], eval_dataset=clinc_enc['validation'],
                                     compute_metrics=compute_metrics, tokenizer=student_tokenizer)

distil_trainer.train()

In [None]:
pipe = pipeline("text-classification", model=distil_trainer.model, tokenizer=student_tokenizer)
optim_type = "Distillation"
pb = PerformanceBenchmark(pipe, clinc["test"], optim_type=optim_type)
perf_metrics.update(pb.run_benchmark())

plot_metrics(perf_metrics, optim_type)

In [None]:
state_dict = pipe.model.state_dict()
weights = state_dict["distilbert.transformer.layer.0.attention.out_lin.weight"]
plt.hist(weights.flatten().numpy(), bins=250, range=(-0.3,0.3), edgecolor="C0")
plt.show()

In [None]:
zero_point = 0
scale = (weights.max() - weights.min()) / (127 - (-128))
(weights / scale + zero_point).clamp(-128, 127).round().char()

In [None]:
from torch import quantize_per_tensor

dtype = torch.qint8
quantized_weights = quantize_per_tensor(weights, scale, zero_point, dtype)
quantized_weights.int_repr()

In [None]:
%%timeit
weights @ weights

In [None]:
from torch.nn import quantized
from torch.nn.quantized import QFunctional

q_fn = QFunctional()

In [None]:
%%timeit
q_fn.mul(quantized_weights, quantized_weights)

In [None]:
import sys

sys.getsizeof(weights.storage()) / sys.getsizeof(quantized_weights.storage())

In [None]:
from torch.quantization import quantize_dynamic

model_quantized = quantize_dynamic(distil_trainer.model, {nn.Linear}, dtype=torch.qint8)

In [None]:
pipe = pipeline("text-classification", model=model_quantized, tokenizer=student_tokenizer)
optim_type = "Distillation + quantization"
pb = PerformanceBenchmark(pipe, clinc["test"], optim_type=optim_type)
perf_metrics.update(pb.run_benchmark())

In [None]:
plot_metrics(perf_metrics, optim_type)