---
execute:
  eval: false
---








# Method 2: SetFit

When working with limited labelled data^[See the page on data labelling for our advice on knowing when you have limited data - rule of thumb here is < 100 posts per class], [traditional fine-tuning methods](./vanilla_finetuning.qmd) can be resource-intensive and less effective. SetFit is a framework that offers an efficient and prompt-free approach for few-shot text classification. It leverages sentence transformers and contrastive learning to achieve high accuracy with minimal data.

At a high level, the SetFit process involves:

1. **Model Selection**: Choosing a pre-trained sentence transformer (e.g., paraphrase-mpnet-base-v2).
2. **Data Preparation**: Formatting your labelled examples.
3. **Contrastive Fine-Tuning**: Generating sentence pairs and fine-tuning the model using contrastive learning.
4. **Classifier Training**: Training a classification head on the embeddings produced by the fine-tuned model.
5. **Evaluation**: Assessing model performance on a validation and finally test sets.

SetFit is a method for fine-tuning Sentence Transformers for classification tasks with limited labelled data. It involves:

1. Pre-trained Sentence Transformer: Starting with a model like SBERT (Sentence-BERT).
2. Few-Shot Learning: Using a small number of labelled examples per class.
3. Contrastive Learning: Fine-tuning the model using contrastive loss functions.
4. Classification Head: Adding a simple classifier on top of the embeddings.

**Benefits**

* Data Efficiency: Achieves good performance with as few as 8 examples per class.
* Computationally Light: Fine-tunes quickly and requires less computational power.
* No Prompts Needed: Eliminates the need for hand-crafted prompts.

**Limitations**

* Performance ceiling: May not match the performance of models fine-tuned on large datasets
* Dependence of pre-trained model quality: The quality of embeddings is tied to the pre-trained model used. 

**When to Use SetFit?**

* If you have a small amount of labelled data.
* For quick prototyping and iterative development (if time allows, give SetFit a go first and if it looks promising then it's worth labelling up more data to perform [vanilla fine-tuning](./vanilla_finetuning.qmd)).

## How to fine-tune a model with SetFit?

Let's dive into fine-tuning a model using SetFit. This section will get you started quickly. Feel free to run the code, experiment, and learn by doing. After this walkthrough, we'll provide a more detailed explanation of each step.

Start by installing the required packages/modules...


In [None]:
!pip install setfit datasets

... before loading in our dataset. For this example, we'll use the `sst2` (Stanford Sentiment Treebank) dataset, which is great for sentiment analysis as it is single sentences extracted from movie reviews that have been annotated as either positive or negative. 


In [None]:
from datasets import load_dataset

dataset = load_dataset("SetFit/sst2")

### Prepare the data

Now we have loaded in the data, let's prepare it for the SetFit framework. The benefit of SetFit is being able to perform model fine-tuning with very few labelled data. As such, we will load in data from the SetFit library, but will sample it so we only keep 8 (yes 8!) instances of each label for fine-tuning to simulate a few-shot learning scenario. Note the dataset provided is already split up into training, testing, and validation sets (and it is the training set we will be sampling). The testing set is left unaffected for better evaluation.


In [None]:
# Use 8 examples per class for training
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=8)

# Obtain the validation and test datasets
validation_dataset = dataset["validation"]
test_dataset = dataset["test"]

### Loading a Pre-trained SetFit Model

Then initialise a SetFit model using a Sentence Transformer model of our choice. For this example we will use `BAAI/bge-small-en-v1.5`:


In [None]:
from setfit import SetFitModel

model = SetFitModel.from_pretrained("BAAI/bge-small-en-v1.5",
labels=["negative", "positive"])

Now we prepare `TrainingArguments` for training- the most frequently used arguments (hyperparamters) are `num_epochs` and `max_steps` which affect the number of total training steps. We then initialise the `Trainer` and perform the training


In [None]:
from setfit import TrainingArguments

args = TrainingArguments(
    batch_size=32,
    num_epochs=10,
)

from setfit import Trainer

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
)

trainer.train()

## Evaluating the Model

After training, evaluate the model on the validation dataset.


In [None]:
metrics = trainer.evaluate(validation_dataset)
print(metrics)

Finally once we are happy with model performance based on the validation data, we can evaluate using the testing dataset.


In [None]:
trainer.evaluate(test_dataset)

Now we can save (or load) the model as needed


In [None]:
model.save_pretrained("setfit-bge-small-v1.5-sst-8-shot") # Save to a local directory

model = SetFitModel.from_pretrained("setfit-bge-small-v1.5-sst-8-shot") # Load from a local directory

Once a SetFit model has been trained, it can be used for inference straight away using `SetFitModel.predict()`


In [None]:
texts = [
    "I love this product! It's fantastic.",
    "Terrible customer service. Very disappointed.",
]

predictions = trainer.model.predict(texts)
print(predictions)

Congratulations! You've fine-tuned a SetFit model for sentiment analysis. Feel free to tweak the code, try different datasets, and explore further.

## Detailed overview

Now that you've got a taste of how SetFit works, let's delve deeper into each step.

### Setting Up the Environment

Despite SetFit being lightweight, we still recommend you running it in a cloud environment like Google Colab to access the GPUs

As such, make sure you are connected to a GPU, we recommend T4 as it's a good balance between speed and cost.

::: {.callout-tip collapse="true"}
# How do I do this?

To use a GPU in Colab, go to `Runtime` >` Change runtime type` and select a GPU under the hardware accelerator option
:::

### Install the required packages and modules


In [None]:
%%capture
# Install  necessary packages
!pip install setfit datasets evaluate

# Imports 
from datasets import load_dataset

In [None]:
dataset = load_dataset("SetFit/sst2")