# Sentiment Analysis with DistilBERT
Fine-tune DistilBERT on the IMDB dataset for binary sentiment classification (positive/negative).


Install & Import Libraries

In [None]:
!pip install transformers datasets scikit-learn torch matplotlib seaborn -q

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix


Load Dataset & Quick EDA

In [None]:
dataset = load_dataset("imdb")
print(dataset)

# Check class distribution
sns.countplot(x="label", data=pd.DataFrame(dataset["train"]))
plt.title("IMDB Train Dataset Class Distribution")
plt.show()


Tokenization Example

In [None]:
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
sample_text = "This movie was amazing!"
tokens = tokenizer(sample_text)
print(tokens)


Load Trained Model

In [None]:
model = AutoModelForSequenceClassification.from_pretrained("../src/model")  # path to saved model
model.eval()


Inference Demo

In [None]:
from torch.nn.functional import softmax

def predict(text):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    outputs = model(**inputs)
    probs = softmax(outputs.logits, dim=-1)
    pred_class = torch.argmax(probs, dim=-1).item()
    return "positive" if pred_class == 1 else "negative", probs[0][pred_class].item()

# Test some examples
for review in ["I loved this movie!", "It was terrible and boring"]:
    label, confidence = predict(review)
    print(f"Review: {review}\nPredicted: {label} (Confidence: {confidence:.2f})\n")


Evaluation on Test Subset

In [None]:
test_subset = dataset["test"].shuffle(seed=42).select(range(500))
all_preds, all_labels = [], []

for item in test_subset:
    label, _ = predict(item["text"])
    all_preds.append(1 if label=="positive" else 0)
    all_labels.append(item["label"])

acc = accuracy_score(all_labels, all_preds)
f1 = f1_score(all_labels, all_preds)
print(f"Accuracy: {acc:.4f}, F1 Score: {f1:.4f}")


Confusion Matrix Visualization

In [None]:
cm = confusion_matrix(all_labels, all_preds)
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.title("Confusion Matrix")
plt.show()


: 