Chinese Instrument (Mel+CQT)


In [1]:
import sys
from pathlib import Path
root = Path.cwd()
while root != root.parent and not (root / "src").exists():
    root = root.parent
if str(root) not in sys.path:
    sys.path.insert(0, str(root))
import torch
import yaml
import matplotlib.pyplot as plt
from src.train.utils_mel_cqt import multi_label_train_loop
print("Repo root:", root)


Repo root: d:\qingchaolaopian\Instrument Sound\GitHub\ml-based-analysis-of-sound


In [2]:
TRAIN_RUN = "Chinese_mel_cqt_v1"
WEIGHTS_DIR = Path(f"../models/saved_weights/{TRAIN_RUN}")
USE_CKPT = False  # True to resume from last.pt

# MANIFEST_CSV = "../../data/processed/train_mels.csv",
    
MANIFEST_CSV = [
    "../../data/processed/train_mels.csv",
    "../../data/processed/train_mels_mixed.csv",
]
LABELS_YAML = "../configs/labels.yaml"
AUDIO_CONFIG_YAML = "../configs/audio_params.yaml" 

CONFIG = {
    "batch_size": 64,
    "lr": 1e-3,
    "epochs": 300,
    "patience": 30,
    "weight_decay": 1e-4,
    "dropout": 0.5,
    "val_frac": 0.2,
    "seed": 1337,
    "threshold": 0.5
}





In [3]:
with open(AUDIO_CONFIG_YAML, 'r', encoding='utf-8') as f:
    audio_params = yaml.safe_load(f)
with open(LABELS_YAML, 'r', encoding='utf-8') as f:
    label_config = yaml.safe_load(f)
    classes = [c.strip().lower() for c in label_config.get('train_labels', [])]
print(f"Loaded {len(classes)} classes: {', '.join(classes)}")


resume_ckpt = WEIGHTS_DIR / "last.pt" if USE_CKPT else None
if resume_ckpt is None:
    print("Starting fresh (resume disabled).")
elif not resume_ckpt.exists():
    resume_ckpt = None
    print("Starting fresh. No previous weights found.")
else:
    print(f"Existing weights detected. Resuming from {resume_ckpt}")

results = multi_label_train_loop(
    manifest_csv=MANIFEST_CSV,
    classes=classes,
    ckpt_dir=WEIGHTS_DIR,
    epochs=CONFIG["epochs"],
    batch_size=CONFIG["batch_size"],
    lr=CONFIG["lr"],
    weight_decay=CONFIG["weight_decay"],
    val_frac=CONFIG["val_frac"],
    dropout=CONFIG["dropout"],
    patience=CONFIG["patience"],
    num_workers=0,
    threshold=CONFIG["threshold"],
    seed=CONFIG["seed"],
    audio_cfg=audio_params['audio'],
    resume_from=resume_ckpt,
    save_best_stamped=False,
)
    
# Run the training
history = results["history"]



Loaded 15 classes: strings, brass, percussion, woodwind, sheng, dizi, timpani, erhu, pipa, suona, guzheng, piano, guqin, xiao, yangqin
Starting fresh (resume disabled).
[1/300] Loss: 0.2802/0.2230 | Val MicroF1: 0.5571 | Time: 31.0s
[2/300] Loss: 0.2284/0.1883 | Val MicroF1: 0.6532 | Time: 31.7s
[3/300] Loss: 0.2081/0.1762 | Val MicroF1: 0.7018 | Time: 32.5s
[4/300] Loss: 0.1940/0.1548 | Val MicroF1: 0.7374 | Time: 35.0s
[5/300] Loss: 0.1840/0.1469 | Val MicroF1: 0.7547 | Time: 35.7s
[6/300] Loss: 0.1744/0.1410 | Val MicroF1: 0.7699 | Time: 36.3s
[7/300] Loss: 0.1677/0.1286 | Val MicroF1: 0.7899 | Time: 35.6s
[8/300] Loss: 0.1600/0.1252 | Val MicroF1: 0.7975 | Time: 35.7s
[9/300] Loss: 0.1567/0.1237 | Val MicroF1: 0.8025 | Time: 35.5s
[10/300] Loss: 0.1537/0.1236 | Val MicroF1: 0.7968 | Time: 35.7s
[11/300] Loss: 0.1513/0.1201 | Val MicroF1: 0.8067 | Time: 35.6s
[12/300] Loss: 0.1489/0.1207 | Val MicroF1: 0.8040 | Time: 35.5s
[13/300] Loss: 0.1473/0.1149 | Val MicroF1: 0.8117 | Time: 3

KeyboardInterrupt: 

In [None]:
from src.train.utils import plot_metrics

WEIGHTS_DIR = Path(f"../models/saved_weights/{TRAIN_RUN}")
MODEL_WEIGHTS = Path(WEIGHTS_DIR / "last.pt")
ckpt_loaded = torch.load(MODEL_WEIGHTS, map_location="cpu")
audio_params = ckpt_loaded['audio_config']
history = ckpt_loaded["history"]
plot_metrics(history)

print("Audio Config used during training:")
print(audio_params)

