In [None]:
"""
03_error_analysis.py

Notebook/script for detailed error analysis of the meter classifier.

It:
- loads the enriched dataset
- loads one of the trained models (baseline or MLP)
- computes predictions
- uses src.eval_tools to print confusion matrix & top confusions
- surfaces concrete misclassified mantras for manual inspection
"""

In [None]:
import os

import pandas as pd

from src.eval_tools import print_confusion_matrix, print_top_confusions
from src.model_utils import load_model, features_to_model_input, BASELINE_MODEL_NAME, MLP_MODEL_NAME


In [None]:
BASE_DIR = os.path.dirname(os.path.dirname(__file__))
DATA_DIR = os.path.join(BASE_DIR, "data", "processed")
MODEL_DIR = os.path.join(BASE_DIR, "models")

dataset_path = os.path.join(DATA_DIR, "dataset_enriched.csv")
print("Dataset path:", dataset_path)

df = pd.read_csv(dataset_path)
print("Rows:", len(df))
print("Columns:", df.columns.tolist())

In [None]:
# Filter down to rows with gold meter labels
df = df.dropna(subset=["meter_gold_base"])
print("Rows with gold base meter:", len(df))

In [None]:
# Choose which model to analyze
model_name = BASELINE_MODEL_NAME  # or MLP_MODEL_NAME
model = load_model(model_name)

if model is None:
    raise RuntimeError(f"Model {model_name} not found in {MODEL_DIR}")

print("Loaded model:", model_name)

In [None]:
# Prepare feature matrix consistent with training
X = df[["L_G_sequence", "source_veda", "has_pluti", "has_stobha"]]
y = df["meter_gold_base"].astype(str)

y_pred = model.predict(X)

In [None]:
# Overall classification report
from sklearn.metrics import classification_report

print("=== Classification Report ===")
print(classification_report(y, y_pred))

In [None]:
# Confusion matrix
print_confusion_matrix(y, y_pred)

In [None]:
# Convert to DataFrame with predictions
df_pred = df.copy()
df_pred["meter_pred"] = y_pred

In [None]:
# Top confusions list
print_top_confusions(df_pred, top_k=15)

In [None]:
# Inspect some concrete errors per meter pair
def inspect_confusion(true_meter: str, pred_meter: str, n: int = 5):
    subset = df_pred[
        (df_pred["meter_gold_base"] == true_meter)
        & (df_pred["meter_pred"] == pred_meter)
    ]
    print(f"\n=== Examples where true={true_meter}, pred={pred_meter} (showing {min(n, len(subset))}) ===")
    for _, row in subset.head(n).iterrows():
        print("ID:", row["id"])
        print("source_veda:", row["source_veda"])
        print("Chanda raw:", row["meter_gold_raw"])
        print("Text:", row["text_dev_original"])
        print("L/G:", row["L_G_sequence"])
        print("syllable_count_per_pada:", row["syllable_count_per_pada"])
        print("-" * 60)


# Example: investigate common confusion triṣṭubh ↔ jagatī
inspect_confusion("trishtubh", "jagati", n=5)
inspect_confusion("jagati", "trishtubh", n=5)

# %%
# You can also export misclassified examples for manual spreadsheet review
errors = df_pred[df_pred["meter_pred"] != df_pred["meter_gold_base"]]
OUT_ERRORS = os.path.join(BASE_DIR, "data", "interim", "meter_errors.csv")
os.makedirs(os.path.dirname(OUT_ERRORS), exist_ok=True)
errors.to_csv(OUT_ERRORS, index=False)
print("Wrote misclassified examples to", OUT_ERRORS)