# Compressing SetFit models with Knowledge Distillation

## Setup

If you're running this Notebook on Colab or some other cloud platform, you will need to install the `setfit` library. Uncomment the following cell and run it:

In [None]:
# %pip install setfit

To be able to share your model with the community, there are a few more steps to follow.

First, you have to store your authentication token from the Hugging Face Hub (sign up [here](https://huggingface.co/join) if you haven't already!). To do so, execute the following cell and input an [access token](https://huggingface.co/docs/hub/security-tokens) associated with your account:

In [None]:
from huggingface_hub import notebook_login

notebook_login()

Then you need to install Git-LFS, which you can do by uncommenting and running following command:

In [None]:
# !apt install git-lfs

Finally, you may need to configue Git on your system by providing details about who you are:

In [1]:
# !git config --global user.email "you@example.com"
# !git config --global user.name "Your Name"

This notebook is designed to work with any multiclass [text classification dataset](https://huggingface.co/models?pipeline_tag=text-classification&sort=downloads) and pretrained [Sentence Transformer](https://huggingface.co/models?library=sentence-transformers&sort=downloads) on the Hub. Change the values below to try a different dataset / model!

In [21]:
dataset_id = "ag_news"
teacher_model_id = "sentence-transformers/paraphrase-mpnet-base-v2"
student_model_id = "sentence-transformers/paraphrase-MiniLM-L3-v2"

## Loading and sampling the dataset

In [22]:
from datasets import load_dataset

dataset = load_dataset(dataset_id)

Downloading builder script:   0%|          | 0.00/1.83k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.28k [00:00<?, ?B/s]

Using custom data configuration default


Downloading and preparing dataset ag_news/default (download: 29.88 MiB, generated: 30.23 MiB, post-processed: Unknown size, total: 60.10 MiB) to /home/lewis_huggingface_co/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548...


Downloading data:   0%|          | 0.00/11.0M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/751k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/120000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/7600 [00:00<?, ? examples/s]

Dataset ag_news downloaded and prepared to /home/lewis_huggingface_co/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

## Creating a performance benchmark

In [None]:
class PerformanceBenchmark:
    def __init__(self, model, dataset, optim_type="BERT baseline"):
        self.model = model
        self.dataset = dataset
        self.optim_type = optim_type
        
    def compute_accuracy(self):
        preds, labels = [], []
        for example in self.dataset:
            pred = self.model(example["text"])[0]["label"]
            label = example["label"]
            preds.append(topics.str2int(pred))
            labels.append(label)
        accuracy = accuracy_score.compute(predictions=preds, references=labels)
        print(f"Accuracy on test set - {accuracy['accuracy']:.3f}")
        return accuracy    

    def compute_size(self):
        state_dict = self.pipeline.model.state_dict()
        tmp_path = Path("model.pt")
        torch.save(state_dict, tmp_path)
        # Calculate size in megabytes
        size_mb = Path(tmp_path).stat().st_size / (1024 * 1024)
        # Delete temporary file
        tmp_path.unlink()
        print(f"Model size (MB) - {size_mb:.2f}")
        return {"size_mb": size_mb}

    def time_pipeline(self, query="What is the pin number for my account?"):
        latencies = []
        # Warmup
        for _ in range(10):
            _ = self.model(query)
        # Timed run
        for _ in range(100):
            start_time = perf_counter()
            _ = self.model(query)
            latency = perf_counter() - start_time
            latencies.append(latency)
        # Compute run statistics
        time_avg_ms = 1000 * np.mean(latencies)
        time_std_ms = 1000 * np.std(latencies)
        print(f"Average latency (ms) - {time_avg_ms:.2f} +\- {time_std_ms:.2f}")
        return {"time_avg_ms": time_avg_ms, "time_std_ms": time_std_ms}
    
    def run_benchmark(self):
        metrics = {}
        metrics[self.optim_type] = self.compute_size()
        metrics[self.optim_type].update(self.time_pipeline())
        metrics[self.optim_type].update(self.compute_accuracy())
        return metrics

## Knowledge distillation

In [23]:
from setfit import SetFitModel, SetFitTrainer, DistillationSetFitTrainer
from setfit.data import sample_dataset

In [24]:
train_dataset = dataset["train"].train_test_split(seed=42)

In [29]:
train_dataset_teacher = sample_dataset(train_dataset["train"])
train_dataset_student = train_dataset["test"].select(range(1000))
test_dataset = dataset["test"]

Loading cached shuffled indices for dataset at /home/lewis_huggingface_co/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548/cache-326f3e05a0ac45e9.arrow
Loading cached processed dataset at /home/lewis_huggingface_co/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548/cache-4e356bd67c5864e3.arrow


  0%|          | 0/90 [00:00<?, ?ba/s]

  0%|          | 0/90 [00:00<?, ?ba/s]

  0%|          | 0/90 [00:00<?, ?ba/s]

  0%|          | 0/90 [00:00<?, ?ba/s]

### Train teacher

In [32]:
from setfit import SetFitModel, SetFitTrainer
from sentence_transformers.losses import CosineSimilarityLoss

teacher_model = SetFitModel.from_pretrained(teacher_model_id)

# Create trainer
teacher_trainer = SetFitTrainer(
    model=teacher_model,
    train_dataset=train_dataset_teacher,
    eval_dataset=test_dataset)

# Train and evaluate
teacher_trainer.train()
teacher_metrics = teacher_trainer.evaluate()
teacher_metrics

model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.
***** Running training *****
  Num examples = 1280
  Num epochs = 1
  Total optimization steps = 80
  Total train batch size = 16


Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/80 [00:00<?, ?it/s]

***** Running evaluation *****


{'accuracy': 0.829078947368421}

In [33]:
student_model = SetFitModel.from_pretrained(student_model_id)

student_trainer = DistillationSetFitTrainer(teacher_model=teacher_model, train_dataset=train_dataset_student, student_model=student_model, eval_dataset=test_dataset)

student_trainer.train()
student_metrics = student_trainer.evaluate()
student_metrics

model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.
***** Running training *****
  Num examples = 40000
  Num epochs = 1
  Total optimization steps = 2500
  Total train batch size = 16


Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/2500 [00:00<?, ?it/s]

***** Running evaluation *****


{'accuracy': 0.8281578947368421}