# Edge Fluency Classifier - Inference Demo

This notebook demonstrates how to load a trained model and perform inference on speech clips from the Speechocean762 dataset.

**Prerequisites:**
1.  Ensure you have run `make setup` to install dependencies.
2.  Ensure you have run `make download` and `make features` to prepare data.
3.  Ensure you have run `make train_teacher` (or other training targets) to generate checkpoints.

In [None]:
import sys
import os
from pathlib import Path
import json
import torch
import numpy as np
import pandas as pd
import IPython.display as ipd

# Add src to path
project_root = Path(os.getcwd()).parent
if str(project_root / "src") not in sys.path:
    sys.path.append(str(project_root / "src"))

from utils.config import load_config
from models.mlp import build_mlp
from features.extractor import extract_features
from audio.processing import load_audio, preprocess

## 1. Load Configuration and Model
We load the global configuration and the trained MLP model.

In [None]:
# Load config
cfg = load_config("config/default.yaml")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load label map
label_map_path = Path("experiments/exports/label_map.json")
if not label_map_path.exists():
    # Fallback if export hasn't run
    label_map = {"poor": 0, "moderate": 1, "good": 2}
else:
    with open(label_map_path) as f:
        label_map = json.load(f)

id2label = {v: k for k, v in label_map.items()}
print(f"Label Map: {label_map}")

# Determine input dimension (from config or dummy)
# Typically 30000 for 1-5s clips with default feature settings
input_dim = 30000 

# Initialize Model
model_name = cfg["training"].get("model", "mlp_small")
model = build_mlp(model_name, input_dim, len(label_map), cfg["models"])
model.to(device)

# Load Checkpoint
ckpt_path = Path("experiments/checkpoints") / f"{model_name}.pt"
if ckpt_path.exists():
    state = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(state["state_dict"])
    print(f"Loaded checkpoint from {ckpt_path}")
    model.eval()
else:
    print(f"WARNING: Checkpoint not found at {ckpt_path}. Using random weights.")

## 2. Load Sample Data
We pick a sample from the test set to evaluate.

In [None]:
test_manifest = Path("experiments/manifests/test.csv")
if test_manifest.exists():
    df = pd.read_csv(test_manifest)
    sample = df.iloc[0]
    audio_path = sample["path"]
    print(f"Selected sample: {audio_path}")
    print(f"True Label: {sample['label']}")
    print(f"Text: {sample.get('text', 'N/A')}")
    
    # Play Audio
    ipd.display(ipd.Audio(audio_path))
else:
    print("Test manifest not found. Please run `make download` and `make features`.")
    audio_path = None

## 3. Run Inference
Preprocess the audio, extract features, and predict the fluency score.

In [None]:
def predict(path, model, cfg):
    if not path:
        return
        
    # 1. Preprocess (load, resample, trim)
    wav = preprocess(
        path,
        sample_rate=cfg["data"]["sample_rate"],
        min_sec=cfg["data"]["clip_seconds"][0],
        max_sec=cfg["data"]["clip_seconds"][1]
    )
    
    # 2. Extract Features
    feats = extract_features(wav, cfg["data"]["sample_rate"], cfg["features"])
    
    # 3. Flatten and Normalize (Simple CMVN if available, else raw)
    # Note: In production, load global CMVN stats. Here we just flatten.
    x = feats.flatten()
    
    # Pad/Truncate to expected input_dim
    if len(x) < input_dim:
        x = np.pad(x, (0, input_dim - len(x)))
    else:
        x = x[:input_dim]
        
    # 4. Inference
    x_tensor = torch.from_numpy(x).float().unsqueeze(0).to(device)
    with torch.no_grad():
        logits = model(x_tensor)
        probs = torch.softmax(logits, dim=1)
        pred_idx = torch.argmax(probs).item()
        
    return id2label[pred_idx], probs[0].cpu().numpy()

if audio_path:
    pred_label, probabilities = predict(audio_path, model, cfg)
    print(f"\nPredicted Label: {pred_label.upper()}")
    print(f"Confidence: {probabilities[label_map[pred_label]]:.4f}")
    print(f"Class Probabilities: {probabilities}")