In [None]:
!pip install -q --upgrade torch torchvision transformers datasets


In [None]:
from datasets import load_dataset

# Load the emotions datas
dataset = load_dataset("SetFit/emotion")
print(dataset)

In [None]:
from transformers import AutoTokenizer
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)

def preprocess_function(examples):
    return tokenizer(examples['text'], truncation=True, padding='max_length')

tokenized_datasets = dataset.map(preprocess_function, batched=True)

In [None]:
from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=6)

In [None]:
from transformers import TrainingArguments, Trainer
training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",
    learning_rate=3e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=64,
    num_train_epochs=3,
    weight_decay=0.01,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
)

In [None]:
trainer.train()

In [None]:
trainer.evaluate()

In [None]:
from sklearn.metrics import classification_report
import numpy as np

predictions = trainer.predict(tokenized_datasets['test'])
predicted_labels = np.argmax(predictions.predictions, axis=1)
true_labels = tokenized_datasets['test']['label']

print(classification_report(true_labels, predicted_labels))


In [None]:
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F

model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
encoder = AutoModel.from_pretrained(model_name)
encoder.eval()
shots = {
    "joy": [
        "I absolutely loved this!",
        "This makes me so happy."
    ],
    "sadness": [
        "I'm feeling really down today.",
        "This is so depressing."
    ],
    "anger": [
        "This makes me furious!",
        "I can't stand this."
    ],
    "fear": [
        "I'm terrified of what's next.",
        "This scares me so much."
    ],
    "surprise": [
        "Wow, I did not see that coming!",
        "That's a shocking turn of events."
    ],
    "love": [
        "I adore you.",
        "My heart is full of love."
    ],
}


In [None]:
def embed_texts(texts):

    enc = tokenizer(texts, truncation=True, padding=True, return_tensors="pt")
    with torch.no_grad():
        out = encoder(**enc).last_hidden_state
    mask = enc.attention_mask.unsqueeze(-1)
    summed = (out * mask).sum(1)
    lengths = mask.sum(1)
    return summed / lengths


prototypes = {}
for label, examples in shots.items():
    emb = embed_texts(examples)
    prototypes[label] = emb.mean(0, keepdim=True)


In [None]:



text = "I can't believe what a wonderful surprise!"
pred_label, similarities = classify(text)
print(f"→ Predicted emotion: {pred_label}")
print(" Similarities:", similarities)


In [None]:


test_texts = dataset['test']['text'][:30]
true_labels = dataset['test']['label_text'][:30]

def classify(text):
    q_emb = embed_texts([text])
    sims = {
        label: F.cosine_similarity(q_emb, proto).item()
        for label, proto in prototypes.items()
    }

    return max(sims, key=sims.get), sims

correct = []
incorrect = []

for text, true_label in zip(test_texts, true_labels):
    pred_label, similarities = classify(text)
    entry = {
        "text": text,
        "true_label": true_label,
        "pred_label": pred_label,
        "similarities": similarities
    }
    if pred_label == true_label:
        correct.append(entry)
    else:
        incorrect.append(entry)

print("=== CORRECTLY CLASSIFIED ===")
for e in correct:
    print(f"Text: {e['text']}")
    print(f"True label: {e['true_label']}, Predicted: {e['pred_label']}")
    print(f"Similarities: {e['similarities']}\n")

print("=== MISCLASSIFIED ===")
for e in incorrect:
    print(f"Text: {e['text']}")
    print(f"True label: {e['true_label']}, Predicted: {e['pred_label']}")
    print(f"Similarities: {e['similarities']}\n")
