# Suggestions for Data Annotation with SeFit in Zero-Shot Text Classification

Suggestions make things easier and faster for data annotation. These preselected options will make the labeling process more efficient, as they will only need to correct the suggestions.

In this example, we will implement a zero-shot approach using SetFit to get some initial suggestions for a dataset in Argilla that combines two text classification tasks that include a `LabelQuestion` and a `MultiLabelQuestion`.

[**Argilla**](https://github.com/argilla-io/argilla) is a collaboration tool for AI engineers and domain experts who need to build high-quality datasets for their projects. Using Argilla, everyone can build robust language models through faster data curation using both human and machine feedback.

**Feedback** is a crucial part of the data curation process, and Argilla provides a way to manage and visualize it so that the curated data can be later used to improve a language model.

In this example, we will make our annotators' job easier by providing them with suggestions. We will learn how to train zero-shot sentiment and topic classificers using SetFit and then use them to suggest labels for the dataset.

## Setups

In [None]:
!pip install -qU argilla setfit transformers huggingface_hub sentence-transformers

In [None]:
import argilla as rg

from datasets import load_dataset
from setfit import SetFitModel, Trainer, get_templated_dataset

If we run Argilla using a Docker quickstart image or HuggingFace Spaces, we need to init the Argilla client with the API URL and API KEY:

In [None]:
# Replace api_url with your url if using Docker
# Replace api_key if you configured a custom API key
# Uncomment the last line and set your HF_TOKEN if your space is private
client = rg.Argilla(
    api_url="https://[your-owner-name]-[your_space_name].hf.space",
    api_key="[your-api-key]",
    # headers={"Authorization": f"Bearer {HF_TOKEN}"}
)

## Configure the dataset

We will load [`banking77`](https://huggingface.co/datasets/legacy-datasets/banking77) dataset, a popular open-source dataset that has customer requests in the banking domain.

In [None]:
data = load_dataset(
    'PolyAI/banking77',
    split='test'
)

Argilla works with the `Dataset` class, which easily enables us to create a dataset and manage the data and feedback.

First we need to configure the `Dataset`. In the `Settings`, we can specify the *guidelines*, *fields* where the data to be annotated will be added and the *questions* for the annotators.

In our use case, we need a text field and two different questions. We will use the original labels of this dataset to make a multi-label classification of the topics mentioned in the request, and we will also set up a label question to classify the sentiment of the request as either "positive", "neutral", or "negative".

In [None]:
settings = rg.Settings(
    fields=[rg.TextField(name='text')],
    questions=[
        rg.MultiLabelQuestion(
            name='topics',
            title='Select the topic(s) of the request',
            labels=data.info.features['label'].names,
            visible_labels=10
        ),
        rg.LabelQuestion(
            name='sentiment',
            title='What is the sentiment of the message?',
            labels=['positive', 'neutral', 'positive']
        )
    ]
)

dataset = rg.Dataset(
    name='setfit_tutorial_dataset',
    settings=settings
)
dataset.create()

## Train the models

Now we will use the data we loaded from HuggingFace and the labels and questions we configured for our dataset to train a zero-shot text classification model for each of the questions in our dataset.

We will use the [`SetFit`](https://github.com/huggingface/setfit) framework for few-shot fine-tuning of Sentence Transformers in both classifiers. In addition, the model we will use is [`all-MiniLM-L6-v2`](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2), a sentence embedding model fine-tuned on a 1B sentence pairs dataset using a contrastive objective.

In [None]:
def train_model(question_name, template, multi_label=False):
    train_dataset = get_templated_dataset(
        candidate_labels=dataset.questions[question_name].labels,
        sample_size=8,
        template=template,
        multi_label=multi_label
    )

    # Train a model using the training dataset we just built
    if multi_label:
        model = SetFitModel.from_pretrained(
            'sentence-transformers/all-MiniLM-L6-v2',
            multi_target_strategy='one-vs-rest'
        )
    else:
        model = SetFitModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')

    trainer = Trainer(
        model=model,
        train_dataset=train_dataset
    )
    trainer.train()

    return model

In [None]:
topic_model = train_model(
    question_name='topics',
    template='The customer request is about {}',
    multi_label=True
)

In [None]:
topic_model.save_pretrained('topic_model')

In [None]:
sentiment_model = train_model(
    question_name='sentiment',
    template='This message is {}',
    multi_label=False
)

In [None]:
sentiment_model.save_pretrained('sentiment_model')

## Make predictions

Once the training step is over, we can make predictions over our data.

In [None]:
def get_predictions(texts, model, question_name):
    probas = model.predict_proba(texts, as_numpy=True)
    labels = dataset.questions[question_name].labels

    for pred in probas:
        yield [
            {'label': label, 'score': score}
            for label, score in zip(labels, pred)
        ]

In [None]:
data = data.map(
    lambda batch: {
        'topics': list(get_predictions(batch['text'], topic_model, 'topics')),
        'sentiment': list(get_predictions(batch['text'], sentiment_model, 'sentiment'))
    },
    batched=True
)

In [None]:
data.to_pandas().head()

## Log the records to Argilla

With the data and the predictions we have produced, we can now build records (each of the data items that will be annotated by the annotator team) that include the suggestions from our models.

In the case of the `LabelQuestion` we will use the model that received the highest probability score and for the `MultiLabelQuestion` we will include all labels with a score above a certain threshold. In this case, we decided to go for `2 / len(labels)`, but we can experiment with our data and decide to go for a more restrictive or more lenient threshold.

In [None]:
def add_suggestions(record):
    suggestions = []

    # Get label with max score for sentiment question
    sentiment = max(record['sentiment'], key=lambda x: x['score'])['label']
    suggestions.append(rg.Suggestion(question_name='sentiment', value=sentiment))

    # Get all labels above a threshold for topics questions
    threshold = 2 / len(dataset.questions['topics'].labels)
    topics = [
        label['label']
        for label in record['topics'] if label['score'] >= threshold
    ]
    if topics:
        suggestions.append(rg.Suggestion(question_name='topics', value=topics))

    return suggestions

In [None]:
records = [
    rg.Record(
        fields={'text': record['text']},
        suggestions=add_suggestions(record)
    )
    for record in data
]

Once we are happy with the result, we can log the records to the dataset that we configured above. We can now access the dataset in Argilla and visualize the suggestions.

In [None]:
dataset.records.log(records)

We can also save and load our Argilla dataset into the HuggingFace Hub.

In [None]:
# Export
dataset.to_hub(repo_id='argilla/my_setfit_dataset')

In [None]:
# Import
dataset = rg.Dataset.from_hub(repo_id='argilla/my_setfit_dataset')