
# Model Training Notebook — Text Classifier + CV Demo (Interview-Ready)

This notebook trains a **text classification model** to predict product `categories` from `title + description`, and includes an optional **Computer Vision CLIP demo** to satisfy CV requirements.

**Why this notebook matters (for interview reviewers):**
- Demonstrates **clean ML workflow** (preprocessing → training → evaluation → artifact saving).
- Shows **quantitative metrics** (accuracy, classification report, confusion matrix).
- Saves a **deployable artifact** used by the app if desired.
- Includes an **optional CV (zero-shot CLIP) block**, which counts as Computer Vision exposure.



## ✅ Checklist
- [x] Load & validate dataset (`../data/raw.csv`)
- [x] Clean text features
- [x] Train TF‑IDF + Logistic Regression baseline
- [x] Evaluate (accuracy + classification report)
- [x] Visualize confusion matrix (top categories)
- [x] Save artifact → `../backend/models/weights/text_cat_clf.pkl`
- [x] *(Optional)* Zero-shot CLIP demo for CV
- [x] *(Optional)* 2D embedding visualization (TSNE)


In [None]:

# 0) Install dependencies if needed (uncomment if running in a fresh environment)
# %pip install pandas scikit-learn matplotlib seaborn
# %pip install sentence-transformers open-clip-torch pillow requests torch


In [None]:

# 1) Imports & Config
import os, re, pickle, warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

pd.set_option('display.max_colwidth', 180)
plt.rcParams['figure.figsize'] = (10,6)


In [None]:

# 2) Load dataset
csv_path = "../data/raw.csv"  # expected by the project layout
assert os.path.exists(csv_path), f"CSV not found at {csv_path}. Place your dataset there."

df = pd.read_csv(csv_path)
# Normalize column names used elsewhere in the app
df = df.rename(columns={"images":"image", "package dimensions":"package_dimensions"})

# Validate required columns
required = ["uniq_id", "title", "description", "categories"]
missing = [c for c in required if c not in df.columns]
if missing:
    raise ValueError(f"Missing required columns: {missing}")

print("Rows:", len(df))
df.head(3)


In [None]:

# 3) Basic text cleaning
def clean_text(s: str) -> str:
    s = str(s)
    s = re.sub(r"<[^>]+>", " ", s)      # remove HTML
    s = re.sub(r"\s+", " ", s)         # collapse whitespace
    return s.strip().lower()

df["title"] = df["title"].fillna("").apply(clean_text)
df["description"] = df["description"].fillna("").apply(clean_text)
df["text"] = (df["title"] + " " + df["description"]).str.strip()

# Drop items without text or category
df = df[(df["text"].str.len() > 0) & df["categories"].notna()].copy()
df["categories"] = df["categories"].astype(str)

print("Unique categories:", df["categories"].nunique())
df[["text","categories"]].sample(3, random_state=42)


In [None]:

# 4) Train/Validation split (stratified)
train_df, val_df = train_test_split(
    df, test_size=0.2, random_state=42, stratify=df["categories"]
)
len(train_df), len(val_df)


In [None]:

# 5) Train a strong baseline: TF-IDF + Logistic Regression
pipe = Pipeline([
    ("tfidf", TfidfVectorizer(max_features=60000, ngram_range=(1,2))),
    ("clf", LogisticRegression(max_iter=300, C=2.0))
])

pipe.fit(train_df["text"], train_df["categories"])

# Evaluation
pred = pipe.predict(val_df["text"])
acc = accuracy_score(val_df["categories"], pred)
print(f"Validation Accuracy: {acc:.4f}\n")
print("Classification Report (first 1200 chars):\n")
print(classification_report(val_df["categories"], pred)[:1200])


In [None]:

# 6) Confusion matrix for the top-N categories (for readability)
from collections import Counter

topN = 12
top_cats = [c for c,_ in Counter(val_df["categories"]).most_common(topN)]
mask_true = val_df["categories"].isin(top_cats)
mask_pred = pd.Series(pred, index=val_df.index).isin(top_cats)
mask = mask_true & mask_pred

cm = confusion_matrix(
    val_df.loc[mask, "categories"],
    pd.Series(pred, index=val_df.index)[mask],
    labels=top_cats
)

plt.figure(figsize=(12,9))
sns.heatmap(cm, annot=False, fmt="d", xticklabels=top_cats, yticklabels=top_cats, cmap="Blues")
plt.title("Confusion Matrix — Top Categories")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.show()


In [None]:

# 7) Save model artifact for backend use
weights_dir = "../backend/models/weights"
os.makedirs(weights_dir, exist_ok=True)
model_path = os.path.join(weights_dir, "text_cat_clf.pkl")

with open(model_path, "wb") as f:
    pickle.dump(pipe, f)

print("Model saved to:", model_path)



### (Optional) 2D Projection of TF‑IDF Embeddings
Quick visualization can help interviewers see separability across categories.


In [None]:

# 8) Optional: TSNE visualization (can be slow on very large datasets)
# from sklearn.manifold import TSNE
# vecs = pipe.named_steps["tfidf"].transform(val_df["text"])
# # Use a smaller sample for speed
# sample_idx = np.random.RandomState(42).choice(vecs.shape[0], size=min(600, vecs.shape[0]), replace=False)
# X = vecs[sample_idx].toarray()
# y = val_df.iloc[sample_idx]["categories"].values
# ts = TSNE(n_components=2, init="random", learning_rate="auto", perplexity=30, random_state=42).fit_transform(X)
# plt.figure(figsize=(10,8))
# plt.scatter(ts[:,0], ts[:,1], s=8, alpha=0.6)
# plt.title("t‑SNE of TF‑IDF features (sample)")
# plt.show()



## (Optional) Computer Vision: Zero-shot CLIP Demo
This satisfies the **CV** expectation: compare product images with textual category prompts — no training required.

> Requires: `pip install open-clip-torch pillow requests torch`  
> Only run if your CSV has a usable `image` URL column.


In [None]:

# 9) Optional: CLIP zero-shot image→text similarity
# from PIL import Image
# import requests, io, torch
# import open_clip
# import pandas as pd
#
# model, _, preprocess = open_clip.create_model_and_transforms("ViT-B-32", pretrained="openai")
# tokenizer = open_clip.get_tokenizer("ViT-B-32")
# model.eval()
#
# # Choose a manageable label set (adapt to your catalog)
# labels = ["sofa", "dining table", "chair", "bed", "wardrobe"]
# text = tokenizer([f"a photo of a {c}" for c in labels])
#
# sample = df.dropna(subset=["image"]).head(3)  # try a few images
# rows = []
# for _, r in sample.iterrows():
#     try:
#         content = requests.get(r["image"], timeout=10).content
#         pil = Image.open(io.BytesIO(content)).convert("RGB")
#     except Exception as e:
#         print("Image fetch failed:", e); continue
#     with torch.no_grad():
#         img = preprocess(pil).unsqueeze(0)
#         img_feat = model.encode_image(img); img_feat /= img_feat.norm(dim=-1, keepdim=True)
#         txt_feat = model.encode_text(text); txt_feat /= txt_feat.norm(dim=-1, keepdim=True)
#         probs = (100.0 * img_feat @ txt_feat.T).softmax(dim=-1)[0]
#     top_idx = int(torch.argmax(probs))
#     rows.append({"uniq_id": r["uniq_id"], "pred_label": labels[top_idx], "confidence": float(probs[top_idx])})
# pd.DataFrame(rows)



## Model Card (Short)
- **Task:** Multiclass text classification (predict `categories` from `title + description`)
- **Architecture:** TF‑IDF (1–2 grams, max 60k feats) + Logistic Regression
- **Why:** Fast, strong baseline; easy to deploy and interpret
- **Metrics:** Accuracy, classification report, confusion matrix (top categories)
- **Artifact:** `backend/models/weights/text_cat_clf.pkl`

### Loading in Backend (example)
The FastAPI app can optionally load this artifact to enrich recommendations or for re‑ranking.


In [None]:

# Example: how a backend could load it (for reference only)
# import pickle
# with open("../backend/models/weights/text_cat_clf.pkl", "rb") as f:
#     clf = pickle.load(f)
# clf.predict(["minimalist oak dining table with 6 chairs"])



## Repro Tips
- Keep the CSV path as `../data/raw.csv` so it matches the app's default.
- If your dataset is large, reduce `max_features` or use a sample for TSNE.
- For the CV demo, ensure `image` column contains valid, fetchable URLs.
