<a href="https://colab.research.google.com/github/mrm8488/shared_colab_notebooks/blob/master/HF_SetFit_spam.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# SetFit for Text Classification | Spam or Ham

SetFit is an efficient and prompt-free framework for few-shot fine-tuning of [Sentence Transformers](https://sbert.net/). It achieves high accuracy with little labeled data - for instance, with only 8 labeled examples per class on the Customer Reviews sentiment dataset, SetFit is competitive with fine-tuning RoBERTa Large on the full training set of 3k examples 🤯!

- [Post](https://huggingface.co/blog/setfit)

- [Paper](https://arxiv.org/abs/2209.11055)
> Created by [Manu Romero](https://twitter.com/mrm8488) [Narrativa](https://narrativa.com/) NLP/G Senior Engineer and HuggingFace 🤗 Fellow and Ambassador,

In [1]:
! nvidia-smi

Mon Oct 10 17:43:21 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   68C    P8    10W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
! pip install -q transformers datasets setfit

In [3]:
! huggingface-cli login


        _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
        _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
        _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
        _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
        _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

        To login, `huggingface_hub` now requires a token generated from https://huggingface.co/settings/tokens .
        
Token: 
Login successful
Your token has been saved to /root/.huggingface/token
[1m[31mAuthenticated through git-credential store but this isn't the helper defined on your machine.
You might have to re-authenticate when pushing to the Hugging Face Hub. Run the following command in yo

In [3]:
! git config --global credential.helper store

In [4]:
from datasets import load_dataset, concatenate_datasets
from sentence_transformers.losses import CosineSimilarityLoss

from setfit import SetFitModel, SetFitTrainer

In [5]:
dataset = load_dataset("sms_spam")



  0%|          | 0/1 [00:00<?, ?it/s]

In [6]:
dataset

DatasetDict({
    train: Dataset({
        features: ['sms', 'label'],
        num_rows: 5574
    })
})

In [7]:
dataset = dataset['train'].train_test_split(test_size=0.15, shuffle=True, seed=49)



In [8]:
dataset

DatasetDict({
    train: Dataset({
        features: ['sms', 'label'],
        num_rows: 4737
    })
    test: Dataset({
        features: ['sms', 'label'],
        num_rows: 837
    })
})

In [9]:
import datasets
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=5):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, datasets.ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))

In [10]:
show_random_elements(dataset["train"])

Unnamed: 0,sms,label
0,"Did I forget to tell you ? I want you , I need you, I crave you ... But most of all ... I love you my sweet Arabian steed ... Mmmmmm ... Yummy\n",ham
1,"Just wondering, the others just took off\n",ham
2,Okay same with me. Well thanks for the clarification\n,ham
3,"My fri ah... Okie lor,goin 4 my drivin den go shoppin after tt...\n",ham
4,"Cool, want me to go to kappa or should I meet you outside mu\n",ham


In [11]:
# Select N examples per class (8 in this case)
train_ds_pos = dataset["train"].shuffle(seed=49).filter(lambda example: example["label"] == 0).select(range(8))
train_ds_neg = dataset["train"].shuffle(seed=49).filter(lambda example: example["label"] == 1).select(range(8))
train_ds = concatenate_datasets([train_ds_pos, train_ds_neg])



In [12]:
train_ds = train_ds.shuffle(seed=21)



In [13]:
show_random_elements(train_ds)

Unnamed: 0,sms,label
0,FreeMsg: Txt: CALL to No: 86888 & claim your reward of 3 hours talk time to use from your phone now! Subscribe6GBP/mnth inc 3hrs 16 stop?txtStop\n,spam
1,EASTENDERS TV Quiz. What FLOWER does DOT compare herself to? D= VIOLET E= TULIP F= LILY txt D E or F to 84025 NOW 4 chance 2 WIN £100 Cash WKENT/150P16+\n,spam
2,Valentines Day Special! Win over £1000 in our quiz and take your partner on the trip of a lifetime! Send GO to 83600 now. 150p/msg rcvd. CustCare:08718720201\n,spam
3,Dear good morning how you feeling dear\n,ham
4,Thanx 4 sending me home...\n,ham


In [14]:
train_ds = train_ds.rename_column('sms', 'text')
test_ds = dataset['test']
test_ds = test_ds.rename_column('sms', 'text')

In [15]:
# Load SetFit model from Hub
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")

# Create trainer
trainer = SetFitTrainer(
    model=model,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    loss_class=CosineSimilarityLoss,
    batch_size=16,
    num_epochs=20,
)

model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.


In [16]:
# Train and evaluate!
trainer.train()

***** Running training *****
  Num examples = 640
  Num epochs = 20
  Total optimization steps = 40
  Total train batch size = 16


Epoch:   0%|          | 0/20 [00:00<?, ?it/s]

Iteration:   0%|          | 0/40 [00:00<?, ?it/s]

Iteration:   0%|          | 0/40 [00:00<?, ?it/s]

Iteration:   0%|          | 0/40 [00:00<?, ?it/s]

Iteration:   0%|          | 0/40 [00:00<?, ?it/s]

Iteration:   0%|          | 0/40 [00:00<?, ?it/s]

Iteration:   0%|          | 0/40 [00:00<?, ?it/s]

Iteration:   0%|          | 0/40 [00:00<?, ?it/s]

Iteration:   0%|          | 0/40 [00:00<?, ?it/s]

Iteration:   0%|          | 0/40 [00:00<?, ?it/s]

Iteration:   0%|          | 0/40 [00:00<?, ?it/s]

Iteration:   0%|          | 0/40 [00:00<?, ?it/s]

Iteration:   0%|          | 0/40 [00:00<?, ?it/s]

Iteration:   0%|          | 0/40 [00:00<?, ?it/s]

Iteration:   0%|          | 0/40 [00:00<?, ?it/s]

Iteration:   0%|          | 0/40 [00:00<?, ?it/s]

Iteration:   0%|          | 0/40 [00:00<?, ?it/s]

Iteration:   0%|          | 0/40 [00:00<?, ?it/s]

Iteration:   0%|          | 0/40 [00:00<?, ?it/s]

Iteration:   0%|          | 0/40 [00:00<?, ?it/s]

Iteration:   0%|          | 0/40 [00:00<?, ?it/s]

In [17]:
trainer.evaluate()

Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

***** Running evaluation *****


{'accuracy': 0.9438470728793309}

In [18]:
! git config --global user.email "user@email.com"

In [19]:
trainer.push_to_hub("setfit-mpnet-base-v2-finetuned-spam-detection")

Cloning https://huggingface.co/mrm8488/setfit-mpnet-base-v2-finetuned-spam-detection into local empty directory.


Upload file pytorch_model.bin:   0%|          | 3.34k/418M [00:00<?, ?B/s]

Upload file model_head.pkl:  48%|####8     | 3.34k/6.95k [00:00<?, ?B/s]

remote: Scanning LFS files for validity, may be slow...        
remote: LFS file scan complete.        
To https://huggingface.co/mrm8488/setfit-mpnet-base-v2-finetuned-spam-detection
   8a485d3..80bae7a  main -> main

remote: LFS file scan complete.        
To https://huggingface.co/mrm8488/setfit-mpnet-base-v2-finetuned-spam-detection
   8a485d3..80bae7a  main -> main



'https://huggingface.co/mrm8488/setfit-mpnet-base-v2-finetuned-spam-detection/commit/80bae7a965d19b1001c14d8a3be72bccf91b995e'