In [1]:
%load_ext lab_black

In [2]:
import pandas as pd
import torch
from datasets import Dataset, load_dataset
from sentence_transformers.losses import CosineSimilarityLoss
from sentence_transformers.SentenceTransformer import SentenceTransformer
from setfit import SetFitModel, SetFitTrainer
from tqdm.auto import tqdm

from setfit_ig.html_text_colorizer import WordImportanceColorsSetFit
from setfit_ig.integrated_gradients import integrated_gradients_on_text
from setfit_ig.model_head import BinaryLogisticRegressionModel
from setfit_ig.setfit_extensions import SetFitGrad, SetFitModelWithTorchHead

from IPython.display import HTML

In [3]:
data = load_dataset("rotten_tomatoes")
data = data["train"].train_test_split(
    train_size=20, test_size=300, stratify_by_column="label", shuffle=True
)

train = data["train"]
test = data["test"]


# Warning: this is a big model
model_name = "sentence-transformers/all-MiniLM-L6-v2"

model = SetFitModelWithTorchHead(
    model_body=SentenceTransformer(model_name),
    model_head=BinaryLogisticRegressionModel(
        input_dimension=384, lr=0.01, number_of_epochs=10000, device="cpu"
    ),
)

trainer = SetFitTrainer(
    model=model,
    train_dataset=train,
    eval_dataset=test,
    loss_class=CosineSimilarityLoss,
    batch_size=10,
    num_epochs=2,
    num_iterations=20,
)

Using custom data configuration default
Reusing dataset rotten_tomatoes (/Users/kostis/.cache/huggingface/datasets/rotten_tomatoes/default/1.0.0/40d411e45a6ce3484deed7cc15b82a53dad9a72aafd9f86f8f227134bec5ca46)


  0%|          | 0/3 [00:00<?, ?it/s]

In [4]:
trainer.train()

***** Running training *****
  Num examples = 800
  Num epochs = 1
  Total optimization steps = 80
  Total train batch size = 10


Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/80 [00:00<?, ?it/s]

In [5]:
grd = SetFitGrad(model)
m = WordImportanceColorsSetFit(grd)


N = 0
test_text, test_label = test["text"][N], test["label"][N]
colors, df, prob, _ = m.show_colors_for_sentence(test_text, integration_steps=100)
print(test_label)
print(f"class probability: {prob:1.2f}")
HTML(colors)

Remember to use:
from IPython.display import HTML
HTML(colored_text)


100%|███████████████████████████████████████████████████████████████████| 100/100 [00:22<00:00,  4.44it/s]


0
class probability: 0.51


In [6]:
N = 1
test_text, test_label = test["text"][N], test["label"][N]
colors, df, prob, _ = m.show_colors_for_sentence(test_text, integration_steps=100)
print(test_label)
print(f"class probability: {prob:1.2f}")
HTML(colors)

100%|███████████████████████████████████████████████████████████████████| 100/100 [00:07<00:00, 14.04it/s]

0
class probability: 0.50



