<a href="https://colab.research.google.com/github/hsong-77/transformer-practice/blob/main/model-compression.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers
!pip install datasets

In [None]:
from datasets import load_dataset

clinc = load_dataset("clinc_oos", "plus")
clinc

In [None]:
sample = clinc["test"][42]
sample

In [None]:
intents = clinc["test"].features["intent"]
intents.int2str(sample["intent"])

In [None]:
from transformers import pipeline

ckpt = "transformersbook/bert-base-uncased-finetuned-clinc"
pipe = pipeline("text-classification", model=ckpt)

In [None]:
query = """Hey, I'd like to rent a vehicle from Nov 1st to Nov 15th in Paris and I need a 15 passenger van"""
pipe(query)

In [None]:
import torch
from datasets import load_metric
from pathlib import Path

accuracy_score = load_metric("accuracy")

class PerformanceBenchmark:
  def __init__(self, pipeline, dataset, optim_type="BERT baseline"):
    self.pipeline = pipeline
    self.dataset = dataset
    self.optim_type = optim_type

  def compute_accuracy(self):
    preds, labels = [], []
    for example in self.dataset:
      pred = self.pipeline(example["text"])[0]["label"]
      label = example["intent"]
      preds.append(intents.str2int(pred))
      labels.append(label)

    accuracy = accuracy_score.compute(predictions=preds, references=labels)
    print(f"Accuracy on test set - {accuracy['accuracy']:.3f}")
    return accuracy

  def compute_size(self):
    state_dict = self.pipeline.model.state_dict()
    tmp_path = Path("model.pt")
    torch.save(state_dict, tmp_path)
    size_mb = Path(tmp_path).stat().st_size / (1024 * 1024)
    # delete tmp path
    tmp_path.unlink()

    print(f"Model size (MB) - {size_mb:.2f}")
    return {"size_mb": size_mb}

  def time_pipeline(self):
    pass

  def run_benchmark(self):
    metrics = {}
    metrics[self.optim_type] = self.compute_size()
    metrics[self.optim_type].update(self.time_pipeline())
    metrics[self.optim_type].update(self.compute_accuracy())

    return metrics