# Few shots classsification using SetFit

In [1]:
!pip install setfit
# !pip install huggingface_hub

Collecting setfit
  Obtaining dependency information for setfit from https://files.pythonhosted.org/packages/a4/cb/53ac2baccb2291613f06d5ab6653655c31e391b9691c0c976e9d995548ea/setfit-1.0.2-py3-none-any.whl.metadata
  Downloading setfit-1.0.2-py3-none-any.whl.metadata (11 kB)
Collecting datasets>=2.3.0 (from setfit)
  Obtaining dependency information for datasets>=2.3.0 from https://files.pythonhosted.org/packages/ec/93/454ada0d1b289a0f4a86ac88dbdeab54921becabac45da3da787d136628f/datasets-2.16.1-py3-none-any.whl.metadata
  Downloading datasets-2.16.1-py3-none-any.whl.metadata (20 kB)
Collecting sentence-transformers>=2.2.1 (from setfit)
  Downloading sentence-transformers-2.2.2.tar.gz (85 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.0/86.0 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l- \ done
[?25hCollecting evaluate>=0.3.0 (from setfit)
  Obtaining dependency information for evaluate>=0.3.0 from 

In [2]:
import warnings
warnings.filterwarnings('ignore')
warnings.filterwarnings("ignore", category=DeprecationWarning)

import torch
import pandas as pd
from sklearn.model_selection import train_test_split

from datasets import Dataset, DatasetDict, load_dataset
from setfit import SetFitModel, Trainer, TrainingArguments, sample_dataset
from sentence_transformers.losses import CosineSimilarityLoss


In [3]:
# wandb login enabled by default in SetFit, if installed
import wandb
from kaggle_secrets import UserSecretsClient

user_secrets = UserSecretsClient()
wandb_token = user_secrets.get_secret("wandb_key") 
wandb.login(key=wandb_token)

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

## Load data

In [4]:
# filepath = "data/lmd_ukraine_annotated.parquet"
filepath = "/kaggle/input/lmd-ukraine-annotated-v3/lmd_ukraine_annotated_V3.parquet"
data = pd.read_parquet(filepath)
print(data.dtypes)
display(data.head(3))

article_id             int64
url                   object
title                 object
desc                  object
date          datetime64[ns]
author                object
comment               object
comment_id            object
label_text            object
dtype: object


Unnamed: 0,article_id,url,title,desc,date,author,comment,comment_id,label_text
0,3259703,https://www.lemonde.fr/actualite-medias/articl...,"Le conflit russo-ukrainien, qui mobilise les m...",Au Festival de journalisme de Couthures : la g...,2022-07-16,Ricardo Uztarroz,La question qui vaille et qui n'est pas posée...,e7206b56918f694f,pro_russia
1,3259703,https://www.lemonde.fr/actualite-medias/articl...,"Le conflit russo-ukrainien, qui mobilise les m...",Au Festival de journalisme de Couthures : la g...,2022-07-16,Ricardo Uztarroz,Salandre : les documents dont vous faîtes ét...,d904e44906dfb957,pro_russia
2,3259703,https://www.lemonde.fr/actualite-medias/articl...,"Le conflit russo-ukrainien, qui mobilise les m...",Au Festival de journalisme de Couthures : la g...,2022-07-16,Correcteur,« C’est l’affaire des russes »? C’est donc vot...,1c03f54daeffd1ca,pro_ukraine


In [5]:
# Classes overview / % annotated labels
print(len(data))
print(data.label_text.value_counts())
print(sum(data.label_text.notnull()))
print(sum(data.label_text.isnull()))

175353
label_text
other          180
pro_ukraine    165
pro_russia     117
Name: count, dtype: int64
462
174891


## Prepare Dataset (labels, optional sample, split)

In [6]:
# Labeled data is split between train and eval
# Test set will be the unlabeled data ; will be used later for distillation
with_labels = data.query("label_text.notnull()")
test_df = data.query("label_text.isnull()")
print(len(with_labels), len(test_df))

train_df, eval_df = train_test_split(with_labels, test_size=0.3, stratify=with_labels['label_text'], random_state=40)

print(len(train_df))
print(train_df.label_text.value_counts())
print(len(eval_df))
print(eval_df.label_text.value_counts())

462 174891
323
label_text
other          126
pro_ukraine    115
pro_russia      82
Name: count, dtype: int64
139
label_text
other          54
pro_ukraine    50
pro_russia     35
Name: count, dtype: int64


  if is_sparse(pd_dtype):
  if is_sparse(pd_dtype) or not is_extension_array_dtype(pd_dtype):


In [7]:
# For labeled data, add a 'label' column where 'label_text' str -> int
# We do it now, because we SetFit wants integers and not floats for training
label_mapping = {'pro_ukraine': 0, 'pro_russia': 1, 'other': 2}
for df in [train_df, eval_df]:
    df['label'] = df['label_text'].map(label_mapping)

In [8]:
# convert to huggingface --commonly used, DatasetDict format
train_dataset = Dataset.from_pandas(train_df)
eval_dataset = Dataset.from_pandas(eval_df)
test_dataset = Dataset.from_pandas(test_df)

# to DatasetDict format
dataset = DatasetDict({
    'train': train_dataset,
    'validation': eval_dataset,
    'test': test_dataset
})

# save # classes, to be used later when loading model
num_classes = len(train_dataset.unique("label"))
num_classes

  if _pandas_api.is_sparse(col):


3

## Modeling, using *Sklearn LogisticRegression* head

Note : our own tests and also by the authors, LogisticRegression gives better results than a differentiable, torch head.  
Model, classification head type (rforest, GBM...) and params, hyperparameters were chosen after multiple experiments.  
See hyperparameters optimization notebook.

In [9]:
# Optional : sample dataset, X number of examples per class
# train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=80, seed=40)

In [10]:
model = SetFitModel.from_pretrained(
    "sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
    labels=["pro_ukraine", "pro_russia", "other"],
    head_params={
        "solver": "liblinear",
        "max_iter": 137
    }
)

  self.comm = Comm(**args)


config.json:   0%|          | 0.00/723 [00:00<?, ?B/s]

.gitattributes:   0%|          | 0.00/690 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/4.10k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/723 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/122 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.11G [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.08M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/402 [00:00<?, ?B/s]

modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.


In [11]:
args = TrainingArguments(
    batch_size=32,
    body_learning_rate=3e-7,
    num_epochs=2,
    max_steps=2350,
    sampling_strategy='oversampling',
    evaluation_strategy="steps",
    eval_steps=50,
    save_strategy="steps",
    save_steps = 50,
    save_total_limit = 1,
    report_to = 'wandb',
    run_name = 'setfit_optimized_v2',
    load_best_model_at_end=True,
)

In [12]:
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    metric="accuracy",
    column_mapping={"comment": "text", "label": "label"}  # Map dataset columns to text/label expected by trainer
)

Applying column mapping to the training dataset
Applying column mapping to the evaluation dataset


Map:   0%|          | 0/323 [00:00<?, ? examples/s]

In [13]:
trainer.train()

***** Running training *****
  Num examples = 2141
  Num epochs = 2
  Total optimization steps = 2350
  Total train batch size = 32
[34m[1mwandb[0m: Currently logged in as: [33mvionmatthieu[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: wandb version 0.16.2 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade
[34m[1mwandb[0m: Tracking run with wandb version 0.16.1
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20240116_164354-wbqd5isp[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33msetfit_optimized_v2[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/vionmatthieu/setfit[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/vionmatthieu/setfit/runs/wbqd5isp[0m


Step,Training Loss,Validation Loss,Embedding Loss,Rate
50,No log,No log,0.2636,0.0
100,No log,No log,0.2611,0.0
150,No log,No log,0.2572,0.0
200,No log,No log,0.2546,0.0
250,No log,No log,0.2505,0.0
300,No log,No log,0.2473,0.0
350,No log,No log,0.2453,0.0
400,No log,No log,0.2425,0.0
450,No log,No log,0.2411,0.0
500,No log,No log,0.2387,0.0


  self.comm = Comm(**args)


  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

  0%|          | 0/397 [00:00<?, ?it/s]

Loading best SentenceTransformer model from step 1950.


Batches:   0%|          | 0/11 [00:00<?, ?it/s]

## Evaluate

In [14]:
metrics = trainer.evaluate(eval_dataset)
print(metrics)

Applying column mapping to the evaluation dataset
***** Running evaluation *****
  self.comm = Comm(**args)


Batches:   0%|          | 0/5 [00:00<?, ?it/s]

  self.comm = Comm(**args)


Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

{'accuracy': 0.762589928057554}


## Save model to local dir

In [15]:
# save to local directory
save_directory = "/kaggle/working/" 
trainer.model._save_pretrained(save_directory=save_directory)

In [16]:
# load local file and predict
model = SetFitModel.from_pretrained(save_directory)
preds = model.predict(
    [
        "La Russie va gagner cette guerre, ils ont plus de ressources",
        "les journalistes sont corrompus, le traitement est partial",
        "les pauvres ukrainiens se font anéantir et subissent des crimes de guerre",
        "La France doit donner plus d'armes à l'ukraine"
    ]
)
print(preds)

  self.comm = Comm(**args)


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

['pro_russia', 'other', 'pro_ukraine', 'pro_ukraine']


## Export model to huggingface hub

In [17]:
filepath_model = "gentilrenard/setfit-paraphrase-multi-mpnet-base-v2-lemonde"

In [18]:
# optional push model to the hub

from kaggle_secrets import UserSecretsClient

user_secrets = UserSecretsClient()
hf_token = user_secrets.get_secret("hf_key") 

trainer.push_to_hub(filepath_model, use_auth_token=hf_token)

  self.comm = Comm(**args)


Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]

model_head.pkl:   0%|          | 0.00/19.3k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.11G [00:00<?, ?B/s]

'https://huggingface.co/gentilrenard/setfit-paraphrase-multi-mpnet-base-v2-lemonde/tree/main/'

## Load from hub / inference

In [19]:
# Download from Hub
model = SetFitModel.from_pretrained(filepath_model)

config.json:   0%|          | 0.00/742 [00:00<?, ?B/s]

.gitattributes:   0%|          | 0.00/1.57k [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/21.4k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/742 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/122 [00:00<?, ?B/s]

config_setfit.json:   0%|          | 0.00/103 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.11G [00:00<?, ?B/s]

model_head.pkl:   0%|          | 0.00/19.3k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/964 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.1M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.34k [00:00<?, ?B/s]

modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

config_setfit.json:   0%|          | 0.00/103 [00:00<?, ?B/s]

model_head.pkl:   0%|          | 0.00/19.3k [00:00<?, ?B/s]

In [20]:
# Run inference
preds = model.predict(
    [
        "La Russie va gagner cette guerre, ils ont plus de ressources",
        "les journalistes sont corrompus, le traitement est partial",
        "les pauvres ukrainiens se font anéantir et subissent des crimes de guerre",
        "La France doit donner plus d'armes à l'ukraine"
    ]
)
print(preds)

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

['pro_russia', 'other', 'pro_ukraine', 'pro_ukraine']


In [21]:
# best 72.66 but lower F1