# Few-Shot Sentiment Analysis with SetFit + ONNX Export

This notebook walks through:
1. **Training** a sentiment classifier with only ~10 examples per class using [SetFit](https://github.com/huggingface/setfit)
2. **Exporting** the trained model to [ONNX](https://onnx.ai/) for fast, portable inference
3. **Running inference** with ONNX Runtime — no PyTorch needed in production

**Requirements:** Python 3.9+, ~5 minutes on CPU

In [None]:
!pip install setfit datasets torch transformers onnxruntime

---
## What is SetFit?

**SetFit** (Sentence Transformer Fine-tuning) is a framework for **few-shot text classification** by Hugging Face.

### Why SetFit over traditional fine-tuning?

| Feature | SetFit | Traditional Fine-tuning |
|---------|--------|------------------------|
| Training examples needed | **8–64 per class** | Hundreds to thousands |
| Training time | **1–5 min on CPU** | Hours on GPU |
| Prompts required | No | Sometimes |

### How it works (two phases)

1. **Contrastive fine-tuning** — Generates pairs of texts. Same-class pairs are pushed closer in embedding space; different-class pairs are pushed apart.
2. **Head training** — A logistic regression is trained on the resulting embeddings.

```
Text → [ Sentence Transformer ] → Embedding (384-dim) → [ Logistic Regression ] → Label
            (fine-tuned)                                    (trained on embeddings)
```

In [None]:
import os
import torch
import pickle
import numpy as np
from datasets import Dataset
from setfit import SetFitModel, SetFitTrainer
from sentence_transformers.losses import CosineSimilarityLoss

---
## 1. Prepare the Training Data

We only need **~10 examples per class** — that's the power of few-shot learning.

Our task: classify text as **positive** (0), **negative** (1), or **neutral** (2).

In [None]:
# -------------------------------------------------------------------
# Few-shot training data: 10 examples per class (30 total).
# SetFit generates contrastive pairs from these to learn the
# decision boundaries between classes.
# -------------------------------------------------------------------

training_data = {
    "text": [
        # --- Positive (label 0) ---
        "I absolutely love this product, it exceeded my expectations!",
        "The customer service was outstanding and very helpful.",
        "Best purchase I've made all year, highly recommend it.",
        "The quality is amazing for the price, very satisfied.",
        "This app is fantastic, it makes everything so much easier.",
        "I'm impressed by how well this works, great job!",
        "The food was delicious and the atmosphere was wonderful.",
        "Excellent experience from start to finish, will come back.",
        "This is exactly what I needed, works perfectly.",
        "I can't stop recommending this to all my friends.",

        # --- Negative (label 1) ---
        "Terrible product, broke after just two days of use.",
        "The worst customer service experience I've ever had.",
        "Complete waste of money, I want a full refund.",
        "The quality is awful, nothing like what was advertised.",
        "This app crashes constantly, it's unusable.",
        "I'm extremely disappointed with this purchase.",
        "The food was cold and the service was incredibly slow.",
        "Horrible experience, I will never shop here again.",
        "This doesn't work at all, total scam.",
        "I regret buying this, it's cheaply made junk.",

        # --- Neutral (label 2) ---
        "The package arrived on the expected delivery date.",
        "It works as described, nothing more nothing less.",
        "The product is okay, it serves its basic purpose.",
        "Standard quality for this price range.",
        "I received the item and it matches the description.",
        "It's an average product, does what it's supposed to do.",
        "The service was normal, no complaints or praise.",
        "Delivery was on time and the item was as expected.",
        "It's fine for everyday use, nothing special though.",
        "The product meets the basic requirements I had.",
    ],
    "label": [0] * 10 + [1] * 10 + [2] * 10,
}

# Human-readable label mapping
ID_TO_LABEL = {0: "positive", 1: "negative", 2: "neutral"}

# Convert to HuggingFace Dataset (required by SetFit)
train_dataset = Dataset.from_dict(training_data)

print(f"Training samples : {len(train_dataset)}")
print(f"Classes          : {list(ID_TO_LABEL.values())}")
print(f"Samples per class: {len(train_dataset) // len(ID_TO_LABEL)}")

---
## 2. Train the Model

We use `all-MiniLM-L6-v2` as the base sentence transformer — small (80 MB),
fast, and produces 384-dimensional embeddings.

**Key parameters:**
| Parameter | What it does |
|---|---|
| `num_iterations` | How many contrastive text pairs to generate per class combination. More = better boundaries. |
| `num_epochs` | Training passes over the generated pairs. 1 is usually enough for few-shot. |
| `batch_size` | Pairs processed per gradient step. |
| `loss_class` | `CosineSimilarityLoss` pulls same-class embeddings together and pushes different-class apart. |

In [None]:
# Load a pre-trained sentence transformer.
# SetFit will fine-tune its weights so that texts with the same
# sentiment produce similar embeddings.
model = SetFitModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

trainer = SetFitTrainer(
    model=model,
    train_dataset=train_dataset,
    loss_class=CosineSimilarityLoss,
    batch_size=16,
    num_iterations=20,   # 20 pairs per class combo -> 20 * C(3,2) = 60 pairs
    num_epochs=1,
)

# ~1-3 minutes on CPU
trainer.train()
print("Training complete!")

---
## 3. Test the Model

Evaluate on examples the model has **never** seen during training.

In [None]:
test_examples = [
    # Positive
    ("This restaurant has the best pasta I've ever tasted!", "positive"),
    ("I'm so happy with my new phone, it's incredible.", "positive"),
    ("The team did an amazing job on this project.", "positive"),
    # Negative
    ("The hotel room was dirty and the staff was rude.", "negative"),
    ("This software is full of bugs, very frustrating.", "negative"),
    ("I waited 3 hours and nobody helped me.", "negative"),
    # Neutral
    ("The meeting is scheduled for 3 PM tomorrow.", "neutral"),
    ("The store opens at 9 AM and closes at 6 PM.", "neutral"),
    ("I ordered the blue version of the product.", "neutral"),
]

correct = 0
print(f"{'Text':<55} | {'Predicted':<10} | {'Actual':<10} | OK?")
print("-" * 90)

for text, actual in test_examples:
    pred_id = model.predict([text])[0]
    predicted = ID_TO_LABEL[int(pred_id)]
    match = predicted == actual
    correct += match
    short = (text[:52] + "...") if len(text) > 52 else text
    print(f"{short:<55} | {predicted:<10} | {actual:<10} | {'Y' if match else 'N'}")

print(f"\nAccuracy: {correct}/{len(test_examples)} ({correct / len(test_examples) * 100:.0f}%)")

---
## 4. Save the Trained Model

In [None]:
MODEL_DIR = "sentiment_setfit_model"
model.save_pretrained(MODEL_DIR)
print(f"Model saved to '{MODEL_DIR}/'")

# Reload later with:
# model = SetFitModel.from_pretrained(MODEL_DIR)

---
---

## What is ONNX?

**ONNX** (Open Neural Network Exchange) is an **open standard format** for representing
machine learning models. Think of it as a *"PDF for ML models"* — a universal file that
any compatible runtime can execute.

### The Problem

You train in **PyTorch**, but need to deploy on:
- A C++ backend server
- A mobile app (iOS / Android)
- A cloud function with minimal dependencies

Each framework uses its own format (PyTorch `.pt`, TensorFlow `.pb`, etc.).
ONNX bridges them all with **one universal `.onnx` file**.

### What's Inside an `.onnx` File?

```
+---------------------------------------------+
|              model.onnx                      |
|                                              |
|  +----------+   +-----------+   +----------+ |
|  |  Embed   |-->| Attention |-->|  Linear  | |
|  |  Lookup  |   |  Layers   |   |  Output  | |
|  +----------+   +-----------+   +----------+ |
|                                              |
|  + Trained weights (learned parameters)      |
|  + Input / output specs (shapes, dtypes)     |
+---------------------------------------------+
```

The file stores the **computation graph** (every operation: matmul, softmax,
layer norm, etc.) plus the trained weights.

### Why Use ONNX?

| Benefit | Details |
|---------|---------|
| **Speed** | ONNX Runtime applies graph optimizations (operator fusion, constant folding) — often **2-5x faster** than raw PyTorch |
| **Portability** | Deploy anywhere: Linux, Windows, macOS, ARM, WebAssembly |
| **Small footprint** | Production only needs `onnxruntime` (~50 MB) instead of PyTorch (~2 GB) |
| **Hardware accel.** | Built-in CUDA, TensorRT, DirectML, OpenVINO, CoreML support |

### ONNX Runtime

**ONNX Runtime** is the inference engine that executes `.onnx` files. Developed
by Microsoft, it's battle-tested in production at Bing, Office, and Xbox.

---
## 5. Export to ONNX

A SetFit model has **two components**:

| Component | What it does | Size | Export strategy |
|-----------|-------------|------|----------------|
| **Body** (Sentence Transformer) | Text -> embedding vector | ~80 MB | Export to ONNX |
| **Head** (Logistic Regression) | Embedding -> label | ~1 KB | Pickle (sklearn) |

We export the **body** (where 99% of compute happens) to ONNX,
and pickle the lightweight head separately.

```
Text -> Tokenizer -> [ Transformer Body ] -> Mean Pool -> Normalize -> [ Head ] -> Label
                      ^^^^^^^^^^^^^^^^^^                                ^^^^^^
                      Export to ONNX                                    Pickle
```

In [None]:
from transformers import AutoModel, AutoTokenizer

# ===================================================================
# Step 1: Extract the raw transformer from the SetFit model
# ===================================================================
# SetFit's model_body is a SentenceTransformer containing:
#   [0] Transformer  — the HuggingFace model (does the heavy work)
#   [1] Pooling      — mean pooling over token embeddings
#   [2] Normalize    — L2 normalization (model-dependent)
#
# We need the raw transformer for ONNX export.

transformer_model = model.model_body[0].auto_model
tokenizer = model.model_body.tokenizer

# Save separately so we can reload cleanly
TRANSFORMER_DIR = "sentiment_transformer"
transformer_model.save_pretrained(TRANSFORMER_DIR)
tokenizer.save_pretrained(TRANSFORMER_DIR)

print(f"Transformer saved to '{TRANSFORMER_DIR}/'")

# ===================================================================
# Step 2: Export the transformer to ONNX
# ===================================================================
# torch.onnx.export works by:
#   1. Feeding a dummy input through the model
#   2. Tracing every operation (matmul, softmax, layernorm, ...)
#   3. Recording the traced graph + weights into a .onnx file

ONNX_PATH = "sentiment_model.onnx"

export_model = AutoModel.from_pretrained(TRANSFORMER_DIR)
export_tokenizer = AutoTokenizer.from_pretrained(TRANSFORMER_DIR)
export_model.eval()

# ONNX export requires tuple outputs, not dicts
export_model.config.return_dict = False

# Dummy input for tracing (actual values don't matter, only shapes do)
dummy = export_tokenizer(
    "Sample sentence for tracing.",
    return_tensors="pt",
    padding="max_length",
    truncation=True,
    max_length=128,
)

with torch.no_grad():
    torch.onnx.export(
        export_model,
        # Model inputs (positional args to forward())
        (dummy["input_ids"], dummy["attention_mask"]),
        # Output path
        ONNX_PATH,
        # Name the inputs/outputs so we can reference them later
        input_names=["input_ids", "attention_mask"],
        output_names=["last_hidden_state", "pooler_output"],
        # dynamic_axes lets the model accept variable-length inputs
        # at runtime (not locked to the dummy input's shape)
        dynamic_axes={
            "input_ids":        {0: "batch_size", 1: "seq_len"},
            "attention_mask":   {0: "batch_size", 1: "seq_len"},
            "last_hidden_state": {0: "batch_size", 1: "seq_len"},
            "pooler_output":    {0: "batch_size"},
        },
        opset_version=14,
        do_constant_folding=True,  # fold constant ops at export time
    )

onnx_mb = os.path.getsize(ONNX_PATH) / (1024 * 1024)
print(f"ONNX model exported: {ONNX_PATH} ({onnx_mb:.1f} MB)")

# ===================================================================
# Step 3: Save the classification head (sklearn LogisticRegression)
# ===================================================================

HEAD_PATH = "sentiment_head.pkl"
with open(HEAD_PATH, "wb") as f:
    pickle.dump(model.model_head, f)

print(f"Head saved: {HEAD_PATH}")
print(f"\nFiles needed for production inference (no PyTorch!):")
print(f"  1. {ONNX_PATH}")
print(f"  2. {HEAD_PATH}")
print(f"  3. {TRANSFORMER_DIR}/  (tokenizer files)")

---
## 6. Inference with ONNX Runtime

Now we can classify text **without PyTorch** — only `onnxruntime` and
`transformers` (for the tokenizer) are needed.

The pipeline reproduces what the SentenceTransformer does internally:

```
1. Tokenize          text -> input_ids, attention_mask   (AutoTokenizer)
2. Transformer       input_ids -> hidden_states          (ONNX Runtime)
3. Mean pooling      hidden_states -> sentence_embedding (NumPy)
4. L2 normalize      normalize the embedding             (NumPy)
5. Classify          embedding -> label                  (sklearn head)
```

In [None]:
import pickle
import numpy as np
import onnxruntime as ort
from transformers import AutoTokenizer

# -------------------------------------------------------------------
# Load the three inference components
# -------------------------------------------------------------------

# 1. Tokenizer (text -> token IDs)
tokenizer = AutoTokenizer.from_pretrained("sentiment_transformer")

# 2. ONNX session (token IDs -> hidden states)
session = ort.InferenceSession(
    "sentiment_model.onnx",
    providers=["CPUExecutionProvider"],  # swap to "CUDAExecutionProvider" for GPU
)

# 3. Classification head (embedding -> label)
with open("sentiment_head.pkl", "rb") as f:
    head = pickle.load(f)


# -------------------------------------------------------------------
# Full inference function
# -------------------------------------------------------------------

def predict_sentiment(texts):
    """Predict sentiment using the ONNX-exported model."""

    # 1) Tokenize — return NumPy arrays (no PyTorch needed)
    inputs = tokenizer(
        texts,
        padding=True,
        truncation=True,
        max_length=128,
        return_tensors="np",
    )

    # 2) Run the ONNX transformer
    outputs = session.run(
        None,  # None = return all outputs
        {
            "input_ids": inputs["input_ids"],
            "attention_mask": inputs["attention_mask"],
        },
    )
    hidden_states = outputs[0]  # (batch, seq_len, 384)

    # 3) Mean pooling — average token vectors, ignoring padding
    mask = np.expand_dims(inputs["attention_mask"], axis=-1).astype(np.float32)
    embeddings = (hidden_states * mask).sum(axis=1) / mask.sum(axis=1)  # (batch, 384)

    # 4) L2 normalize — same as the SentenceTransformer's Normalize module
    norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
    embeddings = embeddings / np.maximum(norms, 1e-12)

    # 5) Classify with the sklearn head
    predictions = head.predict(embeddings)
    return [ID_TO_LABEL[int(p)] for p in predictions]


# -------------------------------------------------------------------
# Test it
# -------------------------------------------------------------------

test_texts = [
    "This movie was absolutely fantastic, I loved every minute!",
    "The product broke on the first day, total garbage.",
    "The meeting has been moved to Tuesday at 2 PM.",
    "I'm so grateful for the amazing support team!",
    "Worst experience ever, never coming back.",
]

results = predict_sentiment(test_texts)

for text, sentiment in zip(test_texts, results):
    print(f"  {sentiment:>8}  |  {text}")

---
## 7. Benchmark: PyTorch vs ONNX Runtime

Let's measure the speed difference on a batch of 150 texts.

In [None]:
import time

benchmark_texts = [
    "The service was exceptional and the staff was very friendly.",
    "I regret purchasing this item, it doesn't work as advertised.",
    "The order arrived on the scheduled delivery date.",
] * 50  # 150 texts

# --- PyTorch (SetFit) ---
setfit_model = SetFitModel.from_pretrained("sentiment_setfit_model")

start = time.perf_counter()
_ = setfit_model.predict(benchmark_texts)
pytorch_time = time.perf_counter() - start

# --- ONNX Runtime ---
start = time.perf_counter()
_ = predict_sentiment(benchmark_texts)
onnx_time = time.perf_counter() - start

# --- Results ---
print(f"Texts processed : {len(benchmark_texts)}")
print(f"PyTorch (SetFit) : {pytorch_time:.3f}s  ({len(benchmark_texts) / pytorch_time:.0f} texts/sec)")
print(f"ONNX Runtime     : {onnx_time:.3f}s  ({len(benchmark_texts) / onnx_time:.0f} texts/sec)")
print(f"Speedup          : {pytorch_time / onnx_time:.1f}x")

---
## Summary

| Step | What we did |
|------|------------|
| **1. Data** | Created 30 labeled examples (10 per class) |
| **2. Train** | Fine-tuned `all-MiniLM-L6-v2` with SetFit in ~2 minutes on CPU |
| **3. Export** | Used `torch.onnx.export` to convert the transformer body to `.onnx` |
| **4. Inference** | Ran the ONNX model with `onnxruntime` — no PyTorch required |
| **5. Benchmark** | Measured the speedup from ONNX Runtime optimizations |

### Production deployment checklist

You only need **three files** + two lightweight packages:

```
Files:
  sentiment_model.onnx          # transformer body
  sentiment_head.pkl            # classification head
  sentiment_transformer/        # tokenizer files

Packages:
  pip install onnxruntime transformers
```

No PyTorch (~2 GB) needed in production.

### Alternative: `optimum` (one-liner export)

For a higher-level approach, Hugging Face's [optimum](https://huggingface.co/docs/optimum) library
can export and run ONNX models with less manual work:

```python
from optimum.onnxruntime import ORTModelForFeatureExtraction

ort_model = ORTModelForFeatureExtraction.from_pretrained(
    "sentiment_transformer", export=True
)
ort_model.save_pretrained("sentiment_onnx")
```

This handles the export, dynamic axes, and optimization automatically.