In [1]:
#!/usr/bin/env python3
"""
Train a multi-label movie-genre classifier on IMDb overviews and save artifacts.

Inputs:
  - imdb_movies.csv  (must include columns: 'title', 'overview', 'genre')

Outputs (saved under ./artifacts):
  - vectorizer.joblib         (fitted TfidfVectorizer)
  - mlb.joblib                (fitted MultiLabelBinarizer)
  - model.joblib              (fitted OneVsRest LogisticRegression)
  - thresholds.npy            (per-class decision thresholds)
  - labeled_overviews.npz     (sparse TF-IDF matrix of labeled overviews)
  - labeled_titles.json       (list of titles aligned with rows in the matrix)
  - classes.json              (list of genre class names in mlb order)
"""

"\nTrain a multi-label movie-genre classifier on IMDb overviews and save artifacts.\n\nInputs:\n  - imdb_movies.csv  (must include columns: 'title', 'overview', 'genre')\n\nOutputs (saved under ./artifacts):\n  - vectorizer.joblib         (fitted TfidfVectorizer)\n  - mlb.joblib                (fitted MultiLabelBinarizer)\n  - model.joblib              (fitted OneVsRest LogisticRegression)\n  - thresholds.npy            (per-class decision thresholds)\n  - labeled_overviews.npz     (sparse TF-IDF matrix of labeled overviews)\n  - labeled_titles.json       (list of titles aligned with rows in the matrix)\n  - classes.json              (list of genre class names in mlb order)\n"

In [2]:
import os
import re
import json
import numpy as np
import pandas as pd
from pathlib import Path

# NLTK
import nltk
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from nltk.tokenize import word_tokenize

# Sklearn
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.multiclass import OneVsRestClassifier
from sklearn.metrics import f1_score, classification_report
from sklearn.utils import Bunch

# Persistence
import joblib
from scipy import sparse



DATA_PATH = "imdb_movies.csv"
ARTIFACT_DIR = Path("artifacts")
ARTIFACT_DIR.mkdir(parents=True, exist_ok=True)

RANDOM_STATE = 42
TEST_SIZE = 0.3

In [3]:
def ensure_nltk_data():
    """Download required NLTK packages if missing."""
    required = {
        'punkt': 'tokenizers/punkt',
        'stopwords': 'corpora/stopwords',
        'wordnet': 'corpora/wordnet',
        'omw-1.4': 'corpora/omw-1.4',
    }
    for pkg, res in required.items():
        try:
            nltk.data.find(res)
        except LookupError:
            print(f"Downloading NLTK resource: {pkg}")
            nltk.download(pkg, quiet=False)

ensure_nltk_data()
STOP_WORDS = set(stopwords.words("english"))
LEMM = WordNetLemmatizer()

Downloading NLTK resource: wordnet
Downloading NLTK resource: omw-1.4


[nltk_data] Downloading package wordnet to
[nltk_data]     /Users/manjunathpopuri/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to
[nltk_data]     /Users/manjunathpopuri/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


In [4]:
# Data loading & cleaning

def load_data(csv_path: str):
    df = pd.read_csv(csv_path)
    required_cols = {"names", "overview", "genre"}
    if not required_cols.issubset(df.columns):
        raise ValueError("CSV must contain columns: 'names', 'overview', 'genre'")

    unlabeled_mask = df["genre"].isna()
    labeled_df = df[~unlabeled_mask].copy()

    # Normalize genres: split/comma -> list
    labeled_df["genre"] = labeled_df["genre"].str.split(",").apply(label_preprocess)

    # Basic text cleaning on overviews
    labeled_df["overview"] = labeled_df["overview"].fillna("").apply(preprocess_text)
    return labeled_df


def label_preprocess(labels):
    out = []
    for lab in labels:
        lab = lab.strip().lower().replace("\u00A0", "")
        if lab:
            out.append(lab)
    return out


def preprocess_text(text: str) -> str:
    x = re.sub(r"[^\w\s]", " ", text or "")
    x = x.lower()
    words = word_tokenize(x)
    words = [w for w in words if w not in STOP_WORDS and w.isalpha()]
    words = [LEMM.lemmatize(w) for w in words]
    return " ".join(words)

In [5]:
# Train / Evaluate / Save
def tune_thresholds(y_true: np.ndarray, y_proba: np.ndarray) -> np.ndarray:
    """Choose per-class threshold that maximizes F1 on the PR curve."""
    from sklearn.metrics import precision_recall_curve

    n_labels = y_true.shape[1]
    best_thresholds = np.empty(n_labels, dtype=float)

    for i in range(n_labels):
        p, r, t = precision_recall_curve(y_true[:, i], y_proba[:, i])
        # The last precision/recall point has no threshold
        f1 = 2 * p * r / (p + r + 1e-12)
        f1 = f1[:-1]
        if t.size == 0 or np.all(np.isnan(f1)):
            best_thresholds[i] = 0.5
        else:
            j = np.nanargmax(f1)
            best_thresholds[i] = t[j]
    return best_thresholds

In [6]:
def main():
    df = load_data(DATA_PATH)
    X_text = df["overview"].astype(str).tolist()
    y_raw = df["genre"].tolist()
    titles = df["names"].astype(str).tolist()

    # Vectorize
    vectorizer = TfidfVectorizer(
        ngram_range=(1, 2),
        max_features=50_000,
        min_df=2
    )
    X = vectorizer.fit_transform(X_text)

    # Labels
    mlb = MultiLabelBinarizer()
    Y = mlb.fit_transform(y_raw)

    # Split
    X_train, X_test, y_train, y_test, titles_train, titles_test = train_test_split(
        X, Y, titles, test_size=TEST_SIZE, random_state=RANDOM_STATE
    )

    # Model
    lr = LogisticRegression(
        C=1.0,
        max_iter=1000,
        class_weight="balanced",
        n_jobs=-1,
        solver="lbfgs"
    )
    clf = OneVsRestClassifier(lr)
    clf.fit(X_train, y_train)

    # Raw probs on test
    y_proba = clf.predict_proba(X_test)

    # Per-class threshold tuning
    thresholds = tune_thresholds(y_test, y_proba)

    # Apply thresholds
    y_pred = (y_proba >= thresholds).astype(int)

    # Metrics
    micro = f1_score(y_test, y_pred, average="micro")
    macro = f1_score(y_test, y_pred, average="macro")
    print(f"\nMicro-F1: {micro:.4f}")
    print(f"Macro-F1: {macro:.4f}\n")
    print("Per-class report:")
    print(classification_report(y_test, y_pred, target_names=mlb.classes_))

    # Persist artifacts
    joblib.dump(vectorizer, ARTIFACT_DIR / "vectorizer.joblib")
    joblib.dump(mlb, ARTIFACT_DIR / "mlb.joblib")
    joblib.dump(clf, ARTIFACT_DIR / "model.joblib")
    np.save(ARTIFACT_DIR / "thresholds.npy", thresholds)

    # Save labeled TF-IDF matrix and aligned titles for similarity search
    sparse.save_npz(ARTIFACT_DIR / "labeled_overviews.npz", X)
    with open(ARTIFACT_DIR / "labeled_titles.json", "w", encoding="utf-8") as f:
        json.dump(titles, f)
    with open(ARTIFACT_DIR / "classes.json", "w", encoding="utf-8") as f:
        json.dump(list(mlb.classes_), f)

    print("\nArtifacts saved to ./artifacts")
    print("Done.")


if __name__ == "__main__":
    main()


Micro-F1: 0.6279
Macro-F1: 0.5762

Per-class report:
                 precision    recall  f1-score   support

         action       0.66      0.74      0.70       842
      adventure       0.51      0.70      0.59       564
      animation       0.61      0.72      0.66       471
         comedy       0.55      0.74      0.63       897
          crime       0.50      0.73      0.60       409
    documentary       0.65      0.52      0.57        60
          drama       0.58      0.76      0.66      1111
         family       0.63      0.80      0.70       433
        fantasy       0.61      0.58      0.59       408
        history       0.45      0.42      0.43       114
         horror       0.67      0.69      0.68       466
          music       0.55      0.54      0.55        94
        mystery       0.37      0.49      0.42       258
        romance       0.54      0.70      0.61       491
science fiction       0.66      0.64      0.65       397
       thriller       0.53      0

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Artifacts saved to ./artifacts
Done.


Exception ignored in: <function ResourceTracker.__del__ at 0x10497dbc0>
Traceback (most recent call last):
  File "/opt/anaconda3/lib/python3.13/multiprocessing/resource_tracker.py", line 82, in __del__
  File "/opt/anaconda3/lib/python3.13/multiprocessing/resource_tracker.py", line 91, in _stop
  File "/opt/anaconda3/lib/python3.13/multiprocessing/resource_tracker.py", line 116, in _stop_locked
ChildProcessError: [Errno 10] No child processes
