# Test Chinese (Mel+CQT)


In [1]:
import sys
import torch
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm

ROOT = Path.cwd()
while ROOT != ROOT.parent and not (ROOT / "src").exists():
    ROOT = ROOT.parent
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from src.models.CNN import CNN
from src.test.utils import  display_formatted_results, find_best_threshold
from src.test.utils_mel_cqt import run_inference

MODEL_WEIGHTS     = ROOT / "src/models/saved_weights/Chinese_mel_cqt_v1/best_val.pt"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")



## Run on A Touch of Zen Dataset (Mel+CQT)


In [2]:
test_manifest_csv = ROOT / "data/test/a-touch-of-zen.csv"

preds_arr, gts_arr, sample_ids, audio_cfg, valid_labels, label_to_idx = run_inference(
    model_cls=CNN,
    model_kwargs={"in_ch": 4},
    model_weights_path=MODEL_WEIGHTS,
    device=DEVICE,
    test_manifest_csv=test_manifest_csv,
    root=ROOT,
)



Running inference on 85 samples against 15 classes...


100%|██████████| 85/85 [01:13<00:00,  1.16it/s]


## Tune classification threshold probability 

Optimising macro F1 means:
- You care equally about rare and common instruments

Optimising micro F1 means:
- You care about overall instrument detection performance
- Every correctly detected instrument occurrence matters equally
- Common instruments dominate

In [3]:
from src.test.utils import find_best_threshold
best_t = find_best_threshold(preds_arr, gts_arr, valid_labels)
print(f"Best threshold found: {best_t:.2f}")
# print(valid_labels)


Threshold  | Micro F1   | Macro F1   | Subset Acc
--------------------------------------------------
0.05       | 0.2812     | 0.1517     | 0.00%     
0.10       | 0.2748     | 0.1488     | 0.00%     
0.15       | 0.2541     | 0.1331     | 1.18%     
0.20       | 0.2479     | 0.1258     | 1.18%     
0.25       | 0.2423     | 0.1199     | 1.18%     
0.30       | 0.2262     | 0.1068     | 1.18%     
0.35       | 0.2243     | 0.1070     | 1.18%     
0.40       | 0.2217     | 0.1041     | 1.18%     
0.45       | 0.2201     | 0.1030     | 1.18%     
0.50       | 0.2165     | 0.0991     | 1.18%     
0.55       | 0.2100     | 0.0971     | 1.18%     
0.60       | 0.2110     | 0.0971     | 1.18%     
0.65       | 0.1863     | 0.0796     | 1.18%     
0.70       | 0.1750     | 0.0761     | 1.18%     
0.75       | 0.1667     | 0.0604     | 1.18%     
0.80       | 0.1590     | 0.0580     | 0.00%     
0.85       | 0.1530     | 0.0560     | 0.00%     
0.90       | 0.1344     | 0.0480     | 0.00%     

In [4]:

import numpy as np
from sklearn.metrics import classification_report, accuracy_score, hamming_loss


def evaluate_multilabel_performance(
    all_preds,
    all_gt,
    class_list,
    sample_ids=None,
    threshold=0.5,
    debug=False,
    zero_division=0,
):
    classes = [c.strip().lower() for c in class_list]
    probs = np.asarray(all_preds)
    gts = np.asarray(all_gt).astype(int)

    if probs.shape != gts.shape:
        raise ValueError(f"Shape mismatch: preds {probs.shape} vs gts {gts.shape}")

    preds = (probs >= threshold).astype(int)

    num_samples, num_classes = preds.shape

    # ---- Key counts ----
    total_pos_gt = int(gts.sum())
    total_pos_pred = int(preds.sum())
    total_entries = int(num_samples * num_classes)

    # Confusion totals across ALL labels (micro)
    tp = int(((preds == 1) & (gts == 1)).sum())
    fp = int(((preds == 1) & (gts == 0)).sum())
    fn = int(((preds == 0) & (gts == 1)).sum())
    tn = int(((preds == 0) & (gts == 0)).sum())

    # ---- None prediction rate ----
    none_pred_mask = preds.sum(axis=1) == 0
    num_none = int(none_pred_mask.sum())

    # ---- Metrics ----
    subset_acc = accuracy_score(gts, preds)                      # exact match
    hamming_acc = 1.0 - hamming_loss(gts, preds)                 # label-wise accuracy

    report = classification_report(
        gts,
        preds,
        target_names=classes,
        output_dict=True,
        zero_division=zero_division,
    )

    # ---- Print summary ----
    print(f"Classification threshold probability: {threshold}")
    print(f"Samples: {num_samples} | Classes: {num_classes} | Decisions: {total_entries}")
    print(f"GT positives: {total_pos_gt} ({total_pos_gt/total_entries:.2%} of all decisions)")
    print(f"Pred positives: {total_pos_pred} ({total_pos_pred/total_entries:.2%} of all decisions)")
    print(f"Predicted 'None' (all-zero): {num_none} ({num_none/num_samples:.2%})")
    print("")
    print(f"Hamming accuracy (label-wise): {hamming_acc:.2%}")
    print(f"Subset accuracy (exact match): {subset_acc:.2%}")
    print(f"Micro F1:  {report['micro avg']['f1-score']:.4f}")
    print(f"Macro F1:  {report['macro avg']['f1-score']:.4f}")
    # print(f"Micro Prec:{report['micro avg']['precision']:.4f} | Micro Rec:{report['micro avg']['recall']:.4f}")
    print("")

    print(f"TP={tp} FP={fp} FN={fn} TN={tn}")
    if total_pos_pred < max(5, 0.02 * total_pos_gt):
        print("WARNING: Very few positive predictions relative to GT positives.")
        print("         Your threshold is likely too high, or logits are miscalibrated.\n")

    # ---- Per-class table ----
    pos_per_class = gts.sum(axis=0)
    print(f"{'Instrument':<15} | {'Prec':>6} | {'Recall':>6} | {'F1':>6} | {'Support':>7} | {'Pred':>5}")
    print("-" * 70)

    for i, name in enumerate(classes):
        support = int(pos_per_class[i])
        pred_count = int(preds[:, i].sum())

        if support == 0:
            print(f"{name:<15} | {'  n/a':>6} | {'  n/a':>6} | {'  n/a':>6} | {support:>7} | {pred_count:>5}")
            continue

        prec = report[name]["precision"]
        rec = report[name]["recall"]
        f1 = report[name]["f1-score"]
        print(f"{name:<15} | {prec:6.2f} | {rec:6.2f} | {f1:6.2f} | {support:>7} | {pred_count:>5}")

    # ---- Debug examples where GT had positives ----
    if debug and sample_ids is not None:
        print("\n--- DEBUG: Examples where GT has at least one label ---")
        gt_nonzero = np.where(gts.sum(axis=1) > 0)[0]
        for idx in gt_nonzero[:]:
            pred_names = [classes[j] for j, v in enumerate(preds[idx]) if v]
            gt_names = [classes[j] for j, v in enumerate(gts[idx]) if v]
            print(f"ID: {sample_ids[idx]}")
            print(f"  Predicted: {pred_names if pred_names else '(none)'}")
            print(f"  Actual:    {gt_names if gt_names else '(none)'}")
            print("-" * 30)

    return report

# Default is 0.5; could optionally adjust based on precision-recall tradeoff desired
# threshold_probability = 0.5

threshold_probability = best_t
# threshold_probability = 0.9

results = evaluate_multilabel_performance(
    all_preds=preds_arr, 
    all_gt=gts_arr, 
    class_list=valid_labels, 
    sample_ids=sample_ids,
    threshold=threshold_probability,
    debug=True
)


Classification threshold probability: 0.05
Samples: 85 | Classes: 15 | Decisions: 1275
GT positives: 284 (22.27% of all decisions)
Pred positives: 356 (27.92% of all decisions)
Predicted 'None' (all-zero): 0 (0.00%)

Hamming accuracy (label-wise): 63.92%
Subset accuracy (exact match): 0.00%
Micro F1:  0.2812
Macro F1:  0.1517

TP=90 FP=266 FN=194 TN=725
Instrument      |   Prec | Recall |     F1 | Support |  Pred
----------------------------------------------------------------------
strings         |   0.00 |   0.00 |   0.00 |      46 |     0
brass           |   0.39 |   0.91 |   0.54 |      32 |    75
percussion      |   0.33 |   0.83 |   0.47 |      29 |    73
woodwind        |   0.00 |   0.00 |   0.00 |      30 |     0
sheng           |   0.36 |   0.77 |   0.49 |      30 |    64
dizi            |   0.14 |   0.17 |   0.15 |      12 |    14
timpani         |   0.00 |   0.00 |   0.00 |      26 |     0
erhu            |   0.27 |   0.30 |   0.29 |      23 |    26
pipa            |   0.33

### Classification Report


In [5]:
display_formatted_results(results)


--- Detailed Classification Report ---


Unnamed: 0,precision,recall,f1-score,support
strings,0.0,0.0,0.0,46
brass,0.3867,0.9062,0.5421,32
percussion,0.3288,0.8276,0.4706,29
woodwind,0.0,0.0,0.0,30
sheng,0.3594,0.7667,0.4894,30
dizi,0.1429,0.1667,0.1538,12
timpani,0.0,0.0,0.0,26
erhu,0.2692,0.3043,0.2857,23
pipa,0.3333,0.0345,0.0625,29
suona,0.0,0.0,0.0,4


In [6]:
# ckpt = torch.load(MODEL_WEIGHTS, map_location=DEVICE)
# audio_cfg = ckpt["audio_config"]
# valid_labels = [c.strip().lower() for c in ckpt["classes"]]
# label_to_idx = {name: i for i, name in enumerate(valid_labels)}
# print(audio_cfg)

# model = CNN(in_ch=2, num_classes=len(valid_labels)).to(DEVICE)
# model.load_state_dict(ckpt['model_state'])

# df = pd.read_csv(TEST_MANIFEST_CSV)

# def _resolve_path(p):
#     p = Path(p)
#     return p if p.is_absolute() else (ROOT / p).resolve()

# df["wav_path"] = df["wav_path"].apply(lambda p: str(_resolve_path(p)))
# df["txt_path"] = df["txt_path"].apply(lambda p: str(_resolve_path(p)))
# all_preds, all_gt, sample_ids = [], [], []

# print(f"Evaluating {len(df)} samples against {len(valid_labels)} classes...")

# for _, row in tqdm(df.iterrows(), total=len(df)):
#     gt_vec = parse_ground_truth(row['txt_path'], label_to_idx)
#     mel = load_and_preprocess(row['wav_path'], audio_cfg)
#     probs = get_prediction(model, mel, DEVICE)

#     all_preds.append(probs)
#     all_gt.append(gt_vec)
#     sample_ids.append(Path(row['wav_path']).stem)


# preds_arr = np.array(all_preds)
# gts_arr = np.array(all_gt)