# 📘 Tutorial: Conformal Prediction for Text Classification with PUNCC

In this tutorial, you will discover how to use the **PUNCC** library for uncertainty quantification on an NLP classification task. The example uses the **AG News** dataset and the **DistilBERT** model.

By the end of this notebook, you will be able to transform your own NLP classification models into conformal predictors and evaluate their performance effectively.

⚡ If you are only interested in the [📏 conformal text classification](#cr-conformal) section, you can execute all cells up to that point and skip the details about data loading, preprocessing and model training.

-------

**Table of contents**

- [⚙️ Setup](#cr-setup)
- [📚 Dataset, Model, Tokenizer](#cr-data)
- [🧼 Preprocessing](#cr-preprocessing)
- [🚀 Training](#cr-training)
- [📏 Conformal Text Classification](#cr-conformal)

**Links**
- [<img src="https://github.githubassets.com/images/icons/emoji/octocat.png" width=20> Github](https://github.com/deel-ai/puncc)
- [📘 Documentation](https://deel-ai.github.io/puncc/index.html)

# ⚙️ Setup: Install and Import Libraries <a class="anchor" id="cr-setup"></a>
In addition to **PUNCC**, we will be using the following libraries from HuggingFace:
- **transformers**: Provides pre-trained NLP models and training utilities.
- **datasets**: For downloading benchmark datasets like AG News.

In [None]:
!pip install puncc datasets transformers[torch]



We import the general-purpose modules that will be used throughout the tutorial:

In [61]:
%load_ext autoreload
%autoreload 2

import warnings
import torch
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, classification_report
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset, DatasetDict

warnings.filterwarnings("ignore")

# 📚 Dataset, Model, Tokenizer <a class="anchor" id="cr-data"></a>
We load the **AG News** dataset, a 4-class dataset containing short news articles categorized as World, Sports, Business, Sci/Tech.
We also load a pretrained **DistilBERT** base model and tokenizer.

In [62]:
dataset = load_dataset("ag_news")
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=4)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


# 🧼 Preprocess the Data  <a class="anchor" id="cr-preprocessing"></a>
We split the training dataset into a proper training set and a calibration set, which will be use to conformalize the model.

In [63]:
# Split train dataset: 20k train, 5k calibration
split_dataset = dataset["train"].train_test_split(test_size=5_000, train_size=20_000, seed=42)

# Rename keys
split_dataset = DatasetDict({
    "train": split_dataset["train"],
    "calib": split_dataset["test"],
    "test": dataset["test"]
})

The following preprocessing function operates as follows:
- Tokenizes text with padding/truncation.
- Renames and reformats the dataset for PyTorch/Trainer use.

In [64]:
def preprocess(examples):
    return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=128)

encoded_dataset = split_dataset.map(preprocess, batched=True)
encoded_dataset = encoded_dataset.rename_column("label", "labels")
encoded_dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])

Map:   0%|          | 0/20000 [00:00<?, ? examples/s]

# 🚀 Training <a class="anchor" id="cr-training"></a>
We set up the training loop using Hugging Face's high-level `Trainer` API, and we train for 1 epoch.

In [65]:
# Define Trainer
training_args = TrainingArguments(
    output_dir="./news_results",
    report_to="none",
    eval_strategy="epoch", # Changed from evaluation_strategy
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=1,
    logging_dir="./logs",
    save_total_limit=1,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["test"],
)

trainer.train()

Epoch,Training Loss,Validation Loss
1,0.2682,0.246922


TrainOutput(global_step=1250, training_loss=0.3050960479736328, metrics={'train_runtime': 272.3826, 'train_samples_per_second': 73.426, 'train_steps_per_second': 4.589, 'total_flos': 662360616960000.0, 'train_loss': 0.3050960479736328, 'epoch': 1.0})

Next we evaluate the model using traditional accuracy and classification metrics.

In [66]:
preds_output = trainer.predict(encoded_dataset["test"])
preds = preds_output.predictions.argmax(axis=1)
labels = preds_output.label_ids

print("Accuracy:", accuracy_score(labels, preds))
print(classification_report(labels, preds))


Accuracy: 0.9181578947368421
              precision    recall  f1-score   support

           0       0.95      0.89      0.92      1900
           1       0.96      0.99      0.97      1900
           2       0.87      0.89      0.88      1900
           3       0.89      0.91      0.90      1900

    accuracy                           0.92      7600
   macro avg       0.92      0.92      0.92      7600
weighted avg       0.92      0.92      0.92      7600



We Apply the softmax to the raw logits to convert them to probabilities, These softmax scores are used for conformal calibration.

In [67]:
calib_preds = trainer.predict(encoded_dataset["calib"])
calib_logits, calib_labels = calib_preds.predictions, calib_preds.label_ids
calib_softmax = F.softmax(torch.tensor(calib_logits), dim=1).numpy()

test_preds = trainer.predict(encoded_dataset["test"])
test_logits, test_labels = test_preds.predictions, test_preds.label_ids
test_softmax = F.softmax(torch.tensor(test_logits), dim=1).numpy()

# 📏 9. Conformal Text Classification with PUNCC <a class="anchor" id="cr-conformal"></a>
We will conformalize our model using the  RAPS algorithm. We proceed as follows:
- **Instantiate the `IdPredictor` dummy wrapper**. The model is already trained, so we will not be training it using PUNCC, therefore we instantiate the `IdPredictor` dummy wrapper.
- **Instantiate the `LAC` conformalizer**. We instantiate the `RAPS` conformalizer with the dummy predictor.
- **Fit the conformal predictor.** `fit()` calibrates the conformal predictor, i.e., computes the nonconformity scores proper to the `RAPS` conformal method.
- **Predict**. `predict()` returns both the original (softmax) model predictions as well as the **prediction sets** obtained with the `RAPS` conformal method.

In [68]:
from deel.puncc.api.prediction import IdPredictor
from deel.puncc.classification import LAC

dummy_predictor = IdPredictor()
lac_cp = LAC(dummy_predictor, train=False)
lac_cp.fit(X_calib=calib_softmax, y_calib=calib_labels)

y_pred, y_pred_set = lac_cp.predict(X_test=test_softmax, alpha=0.01)

We finally compute the common conformal metrics on the test set: the average coverage of the prediction sets and the average cardinality of the prediction sets.

In [69]:
from deel.puncc import metrics

mean_coverage = metrics.classification_mean_coverage(test_labels, y_pred_set)
mean_size = metrics.classification_mean_size(y_pred_set)

print(f"Empirical coverage : {mean_coverage:.3f}")
print(f"Average set size : {mean_size:.3f}")

Empirical coverage : 0.992
Average set size : 1.581


Let us print an example of a test sample along with the model's point prediction and the conformalized prediction set.

In [77]:
# Get a random sample index from the test set
idx = 16

# Extract the sample text
sample_text = split_dataset["test"][idx]['text']  # or adjust key as per your dataset

# Extract point prediction
point_label = test_labels[idx]

# Extract conformal prediction set
conformal_labels = y_pred_set[idx]

# Map label indexes to class labels
label_mapping = dataset['train'].features['label'].names
point_label = label_mapping[point_label]
conformal_labels = [label_mapping[label] for label in conformal_labels]

# Print results
print("Sample text from test set:")
print(sample_text)
print("Point prediction of the model:")
print(point_label)
print("Conformal prediction set:")
print(conformal_labels)
print("True label:")
print(label_mapping[test_labels[idx]])

Sample text from test set:
Scientists Discover Ganymede has a Lumpy Interior Jet Propulsion Lab -- Scientists have discovered irregular lumps beneath the icy surface of Jupiter's largest moon, Ganymede. These irregular masses may be rock formations, supported by Ganymede's icy shell for billions of years...
Point prediction of the model:
Sci/Tech
Conformal prediction set:
['World', 'Sci/Tech']
True label:
Sci/Tech
