# Test 

In [None]:
import sys
import torch
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm
from src.test.utils import get_prediction, load_and_preprocess, parse_ground_truth, display_formatted_results, run_inference

ROOT = Path.cwd()
while ROOT != ROOT.parent and not (ROOT / "src").exists():
    ROOT = ROOT.parent
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from src.models.CNN import CNN
from src.preprocessing import load_audio_stereo, ensure_duration, calc_fft_hop, mel_stereo2_from_stereo

# Define your paths here
MODEL_WEIGHTS     = ROOT / "src/models/saved_weights/CNN_v0/best_micro_f1.pt"

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")



## Run on A Touch of Zen Dataset

In [17]:
test_manifest_csv = ROOT / "data/test/a-touch-of-zen.csv"

preds_arr, gts_arr, sample_ids, audio_cfg, valid_labels, label_to_idx = run_inference(
    model_cls=CNN,
    model_kwargs={"in_ch": 2},
    model_weights_path=MODEL_WEIGHTS,
    device=DEVICE,
    test_manifest_csv=test_manifest_csv,
    root=ROOT,
)


Running inference on 85 samples against 11 classes...


100%|██████████| 85/85 [00:03<00:00, 22.87it/s]


## Tune classification threshold probability 

In [None]:
from src.test.utils import find_best_threshold
best_t = find_best_threshold(preds_arr, gts_arr, valid_labels)

Threshold  | Micro F1   | Macro F1   | Subset Acc
--------------------------------------------------
0.05       | 0.1633     | 0.1095     | 2.35%     
0.10       | 0.1413     | 0.0877     | 3.53%     
0.15       | 0.1339     | 0.0838     | 4.71%     
0.20       | 0.1333     | 0.0876     | 5.88%     
0.25       | 0.1327     | 0.0865     | 7.06%     
0.30       | 0.1116     | 0.0720     | 7.06%     
0.35       | 0.1053     | 0.0717     | 7.06%     
0.40       | 0.0891     | 0.0677     | 7.06%     
0.45       | 0.0800     | 0.0515     | 7.06%     
0.50       | 0.0804     | 0.0518     | 7.06%     
0.55       | 0.0513     | 0.0244     | 5.88%     
0.60       | 0.0532     | 0.0256     | 8.24%     
0.65       | 0.0535     | 0.0256     | 8.24%     
0.70       | 0.0543     | 0.0269     | 10.59%    
0.75       | 0.0556     | 0.0283     | 10.59%    
0.80       | 0.0565     | 0.0292     | 10.59%    
0.85       | 0.0465     | 0.0263     | 11.76%    
0.90       | 0.0357     | 0.0160     | 11.76%    

In [18]:

import numpy as np
from sklearn.metrics import classification_report, accuracy_score, hamming_loss

import numpy as np
from sklearn.metrics import classification_report, accuracy_score, hamming_loss


def evaluate_multilabel_performance(
    all_preds,
    all_gt,
    class_list,
    sample_ids=None,
    threshold=0.5,
    debug=False,
    zero_division=0,
):
    """
    Multi-label evaluation with sanity checks and sensible per-class metrics.
    Note: If ground-truth has no positives, metrics are meaningless and we warn loudly.
    """
    # ---- 1) Setup ----
    classes = [c.strip().lower() for c in class_list]
    probs = np.asarray(all_preds)
    gts = np.asarray(all_gt).astype(int)

    if probs.shape != gts.shape:
        raise ValueError(f"Shape mismatch: preds {probs.shape} vs gts {gts.shape}")

    # ---- 2) Threshold ----
    preds = (probs >= threshold).astype(int)

    # ---- 3) Sanity checks ----
    total_pos = int(gts.sum())
    pos_per_class = gts.sum(axis=0)

    if total_pos == 0:
        print("WARNING: Ground-truth contains ZERO positive labels across all samples.")
        print("This usually means your test labels are not being mapped into label_to_idx")
        print("(e.g. label mismatch or parse_ground_truth drops unknown labels).")
        print("Any 'accuracy' you see will be dominated by true negatives and is not meaningful.\n")

    # ---- 4) 'None' prediction rate (informational only) ----
    none_pred_mask = preds.sum(axis=1) == 0
    num_total = preds.shape[0]
    num_none = int(none_pred_mask.sum())
    num_non_none = num_total - num_none

    # ---- 5) Global metrics on FULL set ----
    exact_match = accuracy_score(gts, preds)                 # subset accuracy
    hacc = 1.0 - hamming_loss(gts, preds)                    # hamming "accuracy"

    report = classification_report(
        gts,
        preds,
        target_names=classes,
        output_dict=True,
        zero_division=zero_division,
    )

    print(f"Classification threshold probability: {threshold}")
    print(f"Total Samples: {num_total}")
    print(f"Predicted 'None' (all-zero): {num_none}  ({(num_none/num_total):.2%})")
    print("")
    print(f"Hamming accuracy: {hacc:.2%}")
    print(f"Subset accuracy (exact match): {exact_match:.2%}")
    print(f"Micro-Average F1: {report['micro avg']['f1-score']:.4f}")
    print(f"Macro-Average F1: {report['macro avg']['f1-score']:.4f}")
    print("")

    # ---- 6) Per-class: Precision / Recall / F1 / Support ----
    # (Recall is what will go to 0 if the class appears but you never predict it.)
    print(f"{'Instrument':<15} | {'Prec':>6} | {'Recall':>6} | {'F1':>6} | {'Support':>7}")
    print("-" * 55)

    for i, name in enumerate(classes):
        support = int(pos_per_class[i])

        if support == 0:
            # No positives in GT for this class -> can't interpret metrics
            print(f"{name:<15} | {'  n/a':>6} | {'  n/a':>6} | {'  n/a':>6} | {support:>7}")
            continue

        prec = report[name]["precision"]
        rec = report[name]["recall"]
        f1 = report[name]["f1-score"]
        print(f"{name:<15} | {prec:6.2f} | {rec:6.2f} | {f1:6.2f} | {support:>7}")

    # ---- 7) Optional debug: show some examples where GT had positives ----
    if debug and sample_ids is not None:
        print("\n--- DEBUG: Examples where GT has at least one label ---")
        gt_nonzero = np.where(gts.sum(axis=1) > 0)[0]
        for idx in gt_nonzero[:20]:
            pred_names = [classes[j] for j, v in enumerate(preds[idx]) if v]
            gt_names = [classes[j] for j, v in enumerate(gts[idx]) if v]
            print(f"ID: {sample_ids[idx]}")
            print(f"  Predicted: {pred_names if pred_names else '(none)'}")
            print(f"  Actual:    {gt_names if gt_names else '(none)'}")
            print("-" * 30)

    return report

# Default is 0.5; could optionally adjust based on precision-recall tradeoff desired
threshold_probability = 0.5

# Then, call the evaluation function using that best threshold
results = evaluate_multilabel_performance(
    all_preds=preds_arr, 
    all_gt=gts_arr, 
    class_list=valid_labels, 
    threshold=threshold_probability,
    debug=False
)


Classification threshold probability: 0.5
Total Samples: 85
Predicted 'None' (all-zero): 38  (44.71%)

Hamming accuracy: 80.43%
Subset accuracy (exact match): 7.06%
Micro-Average F1: 0.0804
Macro-Average F1: 0.0518

Instrument      |   Prec | Recall |     F1 | Support
-------------------------------------------------------
pipa            |   0.00 |   0.00 |   0.00 |      29
erhu            |   0.00 |   0.00 |   0.00 |      23
sheng           |   0.00 |   0.00 |   0.00 |      30
dizi            |   0.00 |   0.00 |   0.00 |      12
xiao            |   0.00 |   0.00 |   0.00 |       8
piano           |   0.00 |   0.00 |   0.00 |       2
guqin           |   0.00 |   0.00 |   0.00 |      10
suona           |   0.20 |   0.25 |   0.22 |       4
guzheng         |   0.00 |   0.00 |   0.00 |       2
percussion      |   0.32 |   0.22 |   0.26 |      27
voice           |   0.05 |   0.25 |   0.09 |       4


### Classification Report


In [None]:
display_formatted_results(results)


--- Detailed Classification Report ---


Unnamed: 0,precision,recall,f1-score,support
pipa,0.0,0.0,0.0,29
erhu,0.0,0.0,0.0,23
sheng,0.0,0.0,0.0,30
dizi,0.0,0.0,0.0,12
xiao,0.0,0.0,0.0,8
piano,0.0,0.0,0.0,2
guqin,0.0,0.0,0.0,10
suona,0.2,0.25,0.2222,4
guzheng,0.0,0.0,0.0,2
percussion,0.3158,0.2222,0.2609,27


In [None]:
# ckpt = torch.load(MODEL_WEIGHTS, map_location=DEVICE)
# audio_cfg = ckpt["audio_config"]
# valid_labels = [c.strip().lower() for c in ckpt["classes"]]
# label_to_idx = {name: i for i, name in enumerate(valid_labels)}
# print(audio_cfg)

# model = CNN(in_ch=2, num_classes=len(valid_labels)).to(DEVICE)
# model.load_state_dict(ckpt['model_state'])

# df = pd.read_csv(TEST_MANIFEST_CSV)

# def _resolve_path(p):
#     p = Path(p)
#     return p if p.is_absolute() else (ROOT / p).resolve()

# df["wav_path"] = df["wav_path"].apply(lambda p: str(_resolve_path(p)))
# df["txt_path"] = df["txt_path"].apply(lambda p: str(_resolve_path(p)))
# all_preds, all_gt, sample_ids = [], [], []

# print(f"Evaluating {len(df)} samples against {len(valid_labels)} classes...")

# for _, row in tqdm(df.iterrows(), total=len(df)):
#     gt_vec = parse_ground_truth(row['txt_path'], label_to_idx)
#     mel = load_and_preprocess(row['wav_path'], audio_cfg)
#     probs = get_prediction(model, mel, DEVICE)

#     all_preds.append(probs)
#     all_gt.append(gt_vec)
#     sample_ids.append(Path(row['wav_path']).stem)


# preds_arr = np.array(all_preds)
# gts_arr = np.array(all_gt)