## Machine Learning Models

In [None]:
import os
import re
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchaudio.transforms import MelSpectrogram
import librosa
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import random

from transformers import ASTModel, ASTConfig
import torchvision.transforms as T
from PIL import Image


# MPS DEVICE CHECK

if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
    print("Using MPS device for GPU acceleration on Apple Silicon.")
else:
    DEVICE = torch.device("cpu")
    print("MPS not available. Falling back to CPU.")



# HYPERPARAMETERS / GLOBAL SETTINGS

HYPERPARAMS = {
    "OUTPUT_DIR": "data/trainingdataoutput",
    "DB_LIMIT": 1000,               # We'll read up to 1000 records from DB
    "BATCH_SIZE": 8,
    "EPOCHS": 10,                   # up to 10 or 20 epochs
    "LEARNING_RATE_MAIN": 5e-5,     # for AST layers
    "LEARNING_RATE_HEAD": 1e-4,     # for scalar/quantum/final layers
    "WEIGHT_DECAY": 1e-4,           # L2 regularization
    "DEVICE": DEVICE,
    "NUM_AST_LAYERS_UNFROZEN": 4,   # partial unfreezing
    "PATIENCE": 5,                  # early stopping patience
    "STOP_LIMIT": 5,                # must train at least this many epochs
    "AST_FREQ": 128,
    "AST_TIME": 1024
}


# Minimal DB

class QuantumMusicDBFetchOnly:
    def __init__(self, db_name="quantummusic", host="localhost", user="postgres", password="postgres"):
        import psycopg2
        self.psycopg2 = psycopg2
        self.db_name = db_name
        self.host = host
        self.user = user
        self.password = password
        self.conn = None
        self.connect()

    def connect(self):
        try:
            import psycopg2
            self.conn = psycopg2.connect(
                dbname=self.db_name,
                host=self.host,
                user=self.user,
                password=self.password
            )
            print(f"Connected to database {self.db_name}. (fetch-only)")
        except Exception as e:
            print(f"Error connecting to database: {e}")

    def close(self):
        if self.conn:
            self.conn.close()
            print("Database connection closed.")

    def fetch_limited_analysis_data(self, limit=1000):
        """
        You can customize this query as needed.
        """
        with self.conn.cursor() as cur:
            query = """
                WITH cte AS (
                    SELECT
                        id,
                        analysis_data,
                        substring(file_name FROM '\\d+')::int AS file_num,
                        ROW_NUMBER() OVER (
                            PARTITION BY substring(file_name FROM '\\d+')::int
                            ORDER BY random()
                        ) AS rn
                    FROM audio_analysis
                    WHERE substring(file_name FROM '\\d+')::int IN (2, 3, 4, 5)
                )
                SELECT id, analysis_data
                FROM cte
                WHERE rn <= 42;
            """
            cur.execute(query)
            rows = cur.fetchall()
        return rows

    def fetch_single_record(self, record_id):
        with self.conn.cursor() as cur:
            query = "SELECT id, analysis_data FROM audio_analysis WHERE id = %s"
            cur.execute(query, (record_id,))
            row = cur.fetchone()
        return row

    def fetch_leftover_records(self, exclude_ids):
        """
        Fetch records not in exclude_ids, with rating >=4 based on parse_raga_and_quality.
        We'll fetch from the entire table or some subset.
        """
        with self.conn.cursor() as cur:
            # For simplicity, fetch everything. Then filter in Python
            # Adjust for large DB if needed
            q = "SELECT id, analysis_data FROM audio_analysis"
            cur.execute(q)
            rows = cur.fetchall()

        leftover = []
        for r_id, analysis_data in rows:
            if r_id in exclude_ids:
                continue
            fname = analysis_data["file_name"]
            base = fname.replace(".wav", "")
            m = re.match(r"^([A-Za-z]+)([1-5])(.*)", base)
            if m:
                rating = int(m.group(2))
                if rating >= 4:
                    leftover.append((r_id, analysis_data))
        return leftover


# Convert quantum measurement counts -> probability vector

def convert_counts_to_probs_feature(counts_dict, max_bits=10):
    total_counts = sum(counts_dict.values())
    if total_counts == 0:
        return np.zeros(2**max_bits, dtype=np.float32)

    feature_vec = np.zeros(2**max_bits, dtype=np.float32)
    for bitstring, c in counts_dict.items():
        if len(bitstring) > max_bits:
            truncated = bitstring[-max_bits:]
        else:
            truncated = bitstring.rjust(max_bits, '0')
        idx = int(truncated, 2)
        feature_vec[idx] += c / total_counts
    return feature_vec


# parse_raga_and_quality

def parse_raga_and_quality(fname: str):
    base = fname.replace(".wav", "")
    m = re.match(r"^([A-Za-z]+)([1-5])(.*)", base)
    if not m:
        return None, None
    raga = m.group(1)
    rating = int(m.group(2))
    return raga, rating


# load_image_as_tensor

transform_img = T.Compose([
    T.ToTensor()
])

def load_image_as_tensor(image_path: str):
    if not os.path.exists(image_path):
        return None
    with Image.open(image_path) as img:
        img = img.convert("L")
    return transform_img(img)


# fetch_training_data_v2

def fetch_training_data_v2(limit=None):
    """
    We'll also store the DB record IDs used in 'used_ids' so we can exclude them later.
    """
    db = QuantumMusicDBFetchOnly()
    rows = db.fetch_limited_analysis_data(limit=limit or HYPERPARAMS["DB_LIMIT"])
    db.close()

    if not rows:
        print("No data found in DB.")
        return None, []

    images_mfcc = []
    images_log = []
    scalar_feats = []
    quantum_feats = []
    labels_raga = []
    labels_quality = []
    used_ids = []

    for (record_id, analysis_data) in rows:
        used_ids.append(record_id)

        fname = analysis_data["file_name"]
        base_no_ext = fname.replace(".wav", "")
        raga, quality = parse_raga_and_quality(fname)
        if raga is None or quality is None:
            continue

        mfcc_path = os.path.join("data", "analysisoutput", f"{base_no_ext}_mfcc.png")
        log_path  = os.path.join("data", "analysisoutput", f"{base_no_ext}_log_spectrogram.png")

        mfcc_img = load_image_as_tensor(mfcc_path)
        if mfcc_img is None:
            mfcc_img = torch.zeros((1,1))
        log_img = load_image_as_tensor(log_path)
        if log_img is None:
            log_img = torch.zeros((1,1))

        images_mfcc.append(mfcc_img)
        images_log.append(log_img)

        # scalar
        res = analysis_data.get("results", {})
        dyn = analysis_data.get("dynamics_summary", {})
        adv = analysis_data.get("quantum_analysis", {}).get("advanced_stats", {})

        avg_dev = res.get("average_dev_cents", 0.0)
        std_dev = res.get("std_dev_cents", 0.0)
        avg_hnr = res.get("avg_praat_hnr", 0.0)
        avg_tnr = res.get("avg_tnr", 0.0)

        rms_db_stats = dyn.get("rms_db", {})
        mean_rms_db  = rms_db_stats.get("mean", 0.0)
        lufs_stats   = dyn.get("lufs", {})
        mean_lufs    = lufs_stats.get("mean", 0.0)

        avg_jitter   = adv.get("avg_jitter", 0.0)
        avg_shimmer  = adv.get("avg_shimmer", 0.0)
        avg_vibrato  = adv.get("avg_vibrato_rate", 0.0)
        avg_formant  = adv.get("avg_F1", 0.0)

        scalars = [
            avg_dev, std_dev, avg_hnr, avg_tnr,
            mean_rms_db, mean_lufs,
            avg_jitter, avg_shimmer, avg_vibrato, avg_formant
        ]

        quantum_dict = analysis_data.get("quantum_analysis", {})
        angles = quantum_dict.get("scaled_angles", [])
        max_len = 10
        angle_arr = np.zeros(max_len, dtype=np.float32)
        for i in range(min(max_len, len(angles))):
            angle_arr[i] = angles[i]

        counts_d = quantum_dict.get("measurement_counts", {})
        dist_vec = convert_counts_to_probs_feature(counts_d, max_bits=10)
        combined_q = np.concatenate([angle_arr, dist_vec], axis=0)

        scalar_feats.append(scalars)
        quantum_feats.append(combined_q)
        labels_raga.append(raga)
        labels_quality.append(quality)

    images_mfcc = torch.stack(images_mfcc)
    images_log  = torch.stack(images_log)
    scalar_feats = np.array(scalar_feats, dtype=np.float32)
    quantum_feats = np.array(quantum_feats, dtype=np.float32)

    unique_ragas = sorted(list(set(labels_raga)))
    raga_to_idx = {r: i for i, r in enumerate(unique_ragas)}
    label_raga_idx = [raga_to_idx[r] for r in labels_raga]
    label_quality_idx = np.array([q - 1 for q in labels_quality], dtype=np.int64)

    data_dict = {
        "images_mfcc": images_mfcc,
        "images_log": images_log,
        "scalar_feats": scalar_feats,
        "quantum_feats": quantum_feats,
        "label_raga_idx": np.array(label_raga_idx, dtype=np.int64),
        "label_quality": label_quality_idx,
        "raga_to_idx": raga_to_idx,
        "unique_ragas": unique_ragas
    }

    return data_dict, used_ids


# Datasets / Model classes (unchanged)

class MultiLabelASTDataset(Dataset):
    def __init__(self, data_dict):
        self.mfcc_imgs = data_dict["images_mfcc"]
        self.log_imgs  = data_dict["images_log"]
        self.scalars   = data_dict["scalar_feats"]
        self.quants    = data_dict["quantum_feats"]
        self.raga_lbl  = data_dict["label_raga_idx"]
        self.qual_lbl  = data_dict["label_quality"]

    def __len__(self):
        return len(self.raga_lbl)

    def __getitem__(self, idx):
        scal = torch.tensor(self.scalars[idx], dtype=torch.float32)
        qua  = torch.tensor(self.quants[idx], dtype=torch.float32)
        rag  = torch.tensor(self.raga_lbl[idx], dtype=torch.long)
        qua_lbl = torch.tensor(self.qual_lbl[idx], dtype=torch.long)
        return (
            self.mfcc_imgs[idx],
            self.log_imgs[idx],
            scal,
            qua,
            rag,
            qua_lbl
        )

class HybridASTModelV2(nn.Module):
    def __init__(self, num_ragas, scalar_dim=10, quantum_dim=42,
                 num_quality=5, num_unfrozen_layers=0):
        super().__init__()
        self.config = ASTConfig.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
        self.ast_model = ASTModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593", config=self.config)

        for param in self.ast_model.parameters():
            param.requires_grad = False
        if num_unfrozen_layers > 0:
            total_layers = 12
            start_layer = max(0, total_layers - num_unfrozen_layers)
            for layer_idx in range(start_layer, total_layers):
                for param in self.ast_model.encoder.layer[layer_idx].parameters():
                    param.requires_grad = True

        self.scalar_fc = nn.Sequential(
            nn.Linear(scalar_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
        )
        self.quantum_fc = nn.Sequential(
            nn.Linear(quantum_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
        )

        combined_dim = 768 + 64 + 64
        self.raga_head = nn.Sequential(
            nn.Linear(combined_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, num_ragas)
        )
        self.quality_head = nn.Sequential(
            nn.Linear(combined_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, num_quality)
        )

    def forward(self, mfcc_img, log_img, scal_in, quan_in):
        B, C, H, W = log_img.shape
        freq = HYPERPARAMS["AST_FREQ"]
        time = HYPERPARAMS["AST_TIME"]

        log_resized = F.interpolate(log_img, size=(freq, time), mode='bilinear', align_corners=False)
        log_for_ast = log_resized.squeeze(1)  # shape (B,128,1024)

        outputs = self.ast_model(input_values=log_for_ast, output_hidden_states=True)
        hidden = outputs.last_hidden_state
        ast_embedding = hidden[:, 0, :]  # (B,768)

        emb_scal = self.scalar_fc(scal_in)
        emb_quan = self.quantum_fc(quan_in)
        fused = torch.cat([ast_embedding, emb_scal, emb_quan], dim=1)  # (B,896)

        logits_raga = self.raga_head(fused)
        logits_quality = self.quality_head(fused)
        return logits_raga, logits_quality



# Train Model, Return used_ids in checkpoint

def train_model_v2():
    data_dict, used_ids = fetch_training_data_v2(limit=HYPERPARAMS["DB_LIMIT"])
    if data_dict is None:
        print("No data to train.")
        return None, None

    scalar_feats = data_dict["scalar_feats"]
    quantum_feats = data_dict["quantum_feats"]
    raga_to_idx = data_dict["raga_to_idx"]
    num_ragas = len(raga_to_idx)

    dataset_full = MultiLabelASTDataset(data_dict)
    n_samples = len(dataset_full)
    test_size = int(0.2 * n_samples)
    train_size = n_samples - test_size
    train_ds, test_ds = torch.utils.data.random_split(dataset_full, [train_size, test_size])

    train_dl = DataLoader(train_ds, batch_size=HYPERPARAMS["BATCH_SIZE"], shuffle=True)
    test_dl  = DataLoader(test_ds, batch_size=HYPERPARAMS["BATCH_SIZE"], shuffle=False)

    model = HybridASTModelV2(
        num_ragas=num_ragas,
        scalar_dim=scalar_feats.shape[1],
        quantum_dim=quantum_feats.shape[1],
        num_quality=5,
        num_unfrozen_layers=HYPERPARAMS["NUM_AST_LAYERS_UNFROZEN"]
    ).to(HYPERPARAMS["DEVICE"])

    optimizer = optim.Adam(
        model.parameters(),
        lr=HYPERPARAMS["LEARNING_RATE_HEAD"],
        weight_decay=HYPERPARAMS["WEIGHT_DECAY"]
    )
    crit_ce = nn.CrossEntropyLoss()

    EPOCHS = HYPERPARAMS["EPOCHS"]
    train_losses = []
    test_losses = []
    best_test_loss = float("inf")
    best_model_state = None
    patience = HYPERPARAMS["PATIENCE"]
    epochs_no_improve = 0
    stop_limit = HYPERPARAMS["STOP_LIMIT"]

    for epoch in range(EPOCHS):
        model.train()
        total_train_loss = 0.0

        for mfcc_img, log_img, scal, quan, rag_lbl, qual_lbl in train_dl:
            mfcc_img = mfcc_img.to(HYPERPARAMS["DEVICE"])
            log_img  = log_img.to(HYPERPARAMS["DEVICE"])
            scal     = scal.to(HYPERPARAMS["DEVICE"])
            quan     = quan.to(HYPERPARAMS["DEVICE"])
            rag_lbl  = rag_lbl.to(HYPERPARAMS["DEVICE"])
            qual_lbl = qual_lbl.to(HYPERPARAMS["DEVICE"])

            optimizer.zero_grad()
            logits_raga, logits_quality = model(mfcc_img, log_img, scal, quan)
            loss_raga = crit_ce(logits_raga, rag_lbl)
            loss_qual = crit_ce(logits_quality, qual_lbl)
            loss = loss_raga + loss_qual
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()

        avg_train_loss = total_train_loss / len(train_dl)
        train_losses.append(avg_train_loss)

        # Evaluate
        model.eval()
        total_test_loss = 0.0

        with torch.no_grad():
            for mfcc_img, log_img, scal, quan, rag_lbl, qual_lbl in test_dl:
                mfcc_img = mfcc_img.to(HYPERPARAMS["DEVICE"])
                log_img  = log_img.to(HYPERPARAMS["DEVICE"])
                scal     = scal.to(HYPERPARAMS["DEVICE"])
                quan     = quan.to(HYPERPARAMS["DEVICE"])
                rag_lbl  = rag_lbl.to(HYPERPARAMS["DEVICE"])
                qual_lbl = qual_lbl.to(HYPERPARAMS["DEVICE"])

                logits_raga, logits_quality = model(mfcc_img, log_img, scal, quan)
                loss_raga = crit_ce(logits_raga, rag_lbl)
                loss_qual = crit_ce(logits_quality, qual_lbl)
                total_test_loss += (loss_raga + loss_qual).item()

        avg_test_loss = total_test_loss / len(test_dl)
        test_losses.append(avg_test_loss)

        print(f"Epoch {epoch+1}/{EPOCHS}: train_loss={avg_train_loss:.4f}, test_loss={avg_test_loss:.4f}")

        # Early stopping logic
        if avg_test_loss < best_test_loss:
            best_test_loss = avg_test_loss
            best_model_state = model.state_dict()
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if (epoch + 1) >= stop_limit and epochs_no_improve >= patience:
                print(f"No improvement for {patience} epochs after min epoch {stop_limit}. Stopping at epoch {epoch+1}.")
                break

    print("Training complete or early stopped.")

    if best_model_state is not None:
        print("Loading best model state (lowest test loss).")
        model.load_state_dict(best_model_state)

    # Evaluate final => confusion matrix
    all_raga_preds, all_raga_truth = [], []
    all_qual_preds, all_qual_truth = [], []

    model.eval()
    with torch.no_grad():
        for mfcc_img, log_img, scal, quan, rag_lbl, qual_lbl in test_dl:
            mfcc_img = mfcc_img.to(HYPERPARAMS["DEVICE"])
            log_img  = log_img.to(HYPERPARAMS["DEVICE"])
            scal     = scal.to(HYPERPARAMS["DEVICE"])
            quan     = quan.to(HYPERPARAMS["DEVICE"])
            rag_lbl  = rag_lbl.to(HYPERPARAMS["DEVICE"])
            qual_lbl = qual_lbl.to(HYPERPARAMS["DEVICE"])

            logits_raga, logits_quality = model(mfcc_img, log_img, scal, quan)
            pred_raga = logits_raga.argmax(dim=1).cpu().numpy()
            pred_qual = logits_quality.argmax(dim=1).cpu().numpy()

            all_raga_preds.extend(pred_raga)
            all_raga_truth.extend(rag_lbl.cpu().numpy())
            all_qual_preds.extend(pred_qual)
            all_qual_truth.extend(qual_lbl.cpu().numpy())

    from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

    cm_raga = confusion_matrix(all_raga_truth, all_raga_preds)
    used_raga_indices = sorted(set(all_raga_truth) | set(all_raga_preds))
    idx_to_raga = {v: k for k, v in data_dict["raga_to_idx"].items()}
    used_raga_labels = [idx_to_raga[i] for i in used_raga_indices]

    disp_raga = ConfusionMatrixDisplay(cm_raga, display_labels=used_raga_labels)
    disp_raga.plot(cmap=plt.cm.Blues)
    plt.title("Raga Confusion Matrix (Best Model)")
    plt.show()

    cm_qual = confusion_matrix(all_qual_truth, all_qual_preds)
    used_qual_indices = sorted(set(all_qual_truth) | set(all_qual_preds))
    used_qual_labels  = [q+1 for q in used_qual_indices]
    disp_qual = ConfusionMatrixDisplay(cm_qual, display_labels=used_qual_labels)
    disp_qual.plot(cmap=plt.cm.Blues)
    plt.title("Quality Confusion Matrix (Best Model)")
    plt.show()

    plt.figure()
    plt.plot(train_losses, label="Train Loss")
    plt.plot(test_losses, label="Test Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training Curves")
    plt.legend()
    plt.show()

    # Save best model + used_ids
    os.makedirs("data/modeloutput", exist_ok=True)
    checkpoint_path = os.path.join("data", "modeloutput", "trained_model.pt")
    checkpoint_dict = {
        "model_state": model.state_dict(),
        "raga_to_idx": data_dict["raga_to_idx"],
        "scalar_dim": scalar_feats.shape[1],
        "quantum_dim": quantum_feats.shape[1],
        "num_ragas": len(data_dict["raga_to_idx"]),
        "used_ids": used_ids  # new addition
    }
    torch.save(checkpoint_dict, checkpoint_path)
    print(f"Best model + metadata + used_ids saved to: {checkpoint_path}")

    return model, data_dict["raga_to_idx"]




###############################################################################
# 1) NO-QUANTUM MODEL CLASS
###############################################################################
class HybridASTModelV2NoQuantum(nn.Module):
    """
    Identical to your HybridASTModelV2, but it omits quantum features.
    We only fuse the AST embedding + scalar_fc output (no quantum_fc).
    """
    def __init__(self, num_ragas, scalar_dim=10, num_quality=5, num_unfrozen_layers=0):
        super().__init__()
        # Load AST config + model
        self.config = ASTConfig.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
        self.ast_model = ASTModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593", config=self.config)

        # Partial unfreeze
        for param in self.ast_model.parameters():
            param.requires_grad = False
        if num_unfrozen_layers > 0:
            total_layers = 12
            start_layer = max(0, total_layers - num_unfrozen_layers)
            for layer_idx in range(start_layer, total_layers):
                for param in self.ast_model.encoder.layer[layer_idx].parameters():
                    param.requires_grad = True

        # Only scalar MLP; no quantum MLP
        self.scalar_fc = nn.Sequential(
            nn.Linear(scalar_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
        )

        # AST embedding => 768 dims, plus 64 => 832 total
        combined_dim = 768 + 64

        self.raga_head = nn.Sequential(
            nn.Linear(combined_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, num_ragas)
        )
        self.quality_head = nn.Sequential(
            nn.Linear(combined_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, num_quality)
        )

    def forward(self, mfcc_img, log_img, scalar_input):
        """
        Omits the quantum features entirely.
        """
        import torch.nn.functional as F
        B, C, H, W = log_img.shape
        freq = HYPERPARAMS["AST_FREQ"]
        time = HYPERPARAMS["AST_TIME"]

        # Interpolate log_img => (B,1,128,1024)
        log_resized = F.interpolate(
            log_img, size=(freq, time),
            mode='bilinear', align_corners=False
        )
        # Squeeze channel => (B,128,1024)
        log_for_ast = log_resized.squeeze(1)

        # Pass to AST
        outputs = self.ast_model(input_values=log_for_ast, output_hidden_states=True)
        hidden = outputs.last_hidden_state  # (B, seq_len, 768)
        ast_embedding = hidden[:, 0, :]     # [CLS] => (B,768)

        # scalar
        emb_scal = self.scalar_fc(scalar_input)  # (B,64)
        fused = torch.cat([ast_embedding, emb_scal], dim=1)  # (B,832)

        logits_raga = self.raga_head(fused)
        logits_quality = self.quality_head(fused)
        return logits_raga, logits_quality


###############################################################################
# 2) TRAINING FUNCTION FOR NO-QUANTUM MODEL
###############################################################################
def train_model_v2_noquantum():
    """
    Identical logic to train_model_v2, but using HybridASTModelV2NoQuantum.
    Prints confusion matrices and training curves at the end.
    """
    # Fetch data & used_ids
    data_dict, used_ids = fetch_training_data_v2(limit=HYPERPARAMS["DB_LIMIT"])
    if data_dict is None:
        print("[NoQuantum] No data found in DB.")
        return None, None

    # We'll ignore quantum features. We only use scalar_feats in the model.
    scalar_feats = data_dict["scalar_feats"]
    raga_to_idx = data_dict["raga_to_idx"]
    num_ragas = len(raga_to_idx)

    # Build dataset, but in forward we'll skip 'quan'
    dataset_full = MultiLabelASTDataset(data_dict)
    n_samples = len(dataset_full)
    test_size = int(0.2 * n_samples)
    train_size = n_samples - test_size
    train_ds, test_ds = torch.utils.data.random_split(dataset_full, [train_size, test_size])

    train_dl = DataLoader(train_ds, batch_size=HYPERPARAMS["BATCH_SIZE"], shuffle=True)
    test_dl  = DataLoader(test_ds, batch_size=HYPERPARAMS["BATCH_SIZE"], shuffle=False)

    # Build the no-quantum model
    model_nq = HybridASTModelV2NoQuantum(
        num_ragas=num_ragas,
        scalar_dim=scalar_feats.shape[1],
        num_quality=5,
        num_unfrozen_layers=HYPERPARAMS["NUM_AST_LAYERS_UNFROZEN"]
    ).to(HYPERPARAMS["DEVICE"])

    # Use L2 regularization
    optimizer = optim.Adam(
        model_nq.parameters(),
        lr=HYPERPARAMS["LEARNING_RATE_HEAD"],
        weight_decay=HYPERPARAMS["WEIGHT_DECAY"]
    )
    crit_ce = nn.CrossEntropyLoss()

    EPOCHS = HYPERPARAMS["EPOCHS"]
    train_losses = []
    test_losses = []
    best_test_loss = float("inf")
    best_model_state = None
    patience = HYPERPARAMS["PATIENCE"]
    epochs_no_improve = 0
    stop_limit = HYPERPARAMS["STOP_LIMIT"]

    for epoch in range(EPOCHS):
        model_nq.train()
        total_train_loss = 0.0

        for mfcc_img, log_img, scal, quan, rag_lbl, qual_lbl in train_dl:
            mfcc_img = mfcc_img.to(HYPERPARAMS["DEVICE"])
            log_img  = log_img.to(HYPERPARAMS["DEVICE"])
            scal     = scal.to(HYPERPARAMS["DEVICE"])
            rag_lbl  = rag_lbl.to(HYPERPARAMS["DEVICE"])
            qual_lbl = qual_lbl.to(HYPERPARAMS["DEVICE"])

            optimizer.zero_grad()
            # forward => ignore quantum feats
            logits_raga, logits_quality = model_nq(mfcc_img, log_img, scal)
            loss_raga = crit_ce(logits_raga, rag_lbl)
            loss_qual = crit_ce(logits_quality, qual_lbl)
            loss = loss_raga + loss_qual
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()

        avg_train_loss = total_train_loss / len(train_dl)
        train_losses.append(avg_train_loss)

        # Evaluate
        model_nq.eval()
        total_test_loss = 0.0

        with torch.no_grad():
            for mfcc_img, log_img, scal, quan, rag_lbl, qual_lbl in test_dl:
                mfcc_img = mfcc_img.to(HYPERPARAMS["DEVICE"])
                log_img  = log_img.to(HYPERPARAMS["DEVICE"])
                scal     = scal.to(HYPERPARAMS["DEVICE"])
                rag_lbl  = rag_lbl.to(HYPERPARAMS["DEVICE"])
                qual_lbl = qual_lbl.to(HYPERPARAMS["DEVICE"])

                logits_raga, logits_quality = model_nq(mfcc_img, log_img, scal)
                loss_raga = crit_ce(logits_raga, rag_lbl)
                loss_qual = crit_ce(logits_quality, qual_lbl)
                total_test_loss += (loss_raga + loss_qual).item()

        avg_test_loss = total_test_loss / len(test_dl)
        test_losses.append(avg_test_loss)

        print(f"[NoQuantum] Epoch {epoch+1}/{EPOCHS}: train_loss={avg_train_loss:.4f}, test_loss={avg_test_loss:.4f}")

        # Early Stopping
        if avg_test_loss < best_test_loss:
            best_test_loss = avg_test_loss
            best_model_state = model_nq.state_dict()
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if (epoch + 1) >= stop_limit and epochs_no_improve >= patience:
                print(f"[NoQuantum] No improvement for {patience} epochs after min epoch {stop_limit}. Stopping.")
                break

    print("[NoQuantum] Training complete or early stopped.")

    # Load best model
    if best_model_state is not None:
        model_nq.load_state_dict(best_model_state)

    # Final confusion matrices for no-quantum approach
    all_raga_preds, all_raga_truth = [], []
    all_qual_preds, all_qual_truth = [], []

    model_nq.eval()
    with torch.no_grad():
        for mfcc_img, log_img, scal, quan, rag_lbl, qual_lbl in test_dl:
            mfcc_img = mfcc_img.to(HYPERPARAMS["DEVICE"])
            log_img  = log_img.to(HYPERPARAMS["DEVICE"])
            scal     = scal.to(HYPERPARAMS["DEVICE"])
            rag_lbl  = rag_lbl.to(HYPERPARAMS["DEVICE"])
            qual_lbl = qual_lbl.to(HYPERPARAMS["DEVICE"])

            logits_raga, logits_quality = model_nq(mfcc_img, log_img, scal)
            pred_raga = logits_raga.argmax(dim=1).cpu().numpy()
            pred_qual = logits_quality.argmax(dim=1).cpu().numpy()

            all_raga_preds.extend(pred_raga)
            all_raga_truth.extend(rag_lbl.cpu().numpy())
            all_qual_preds.extend(pred_qual)
            all_qual_truth.extend(qual_lbl.cpu().numpy())

    from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

    # Raga Confusion Matrix
    cm_raga = confusion_matrix(all_raga_truth, all_raga_preds)
    used_raga_indices = sorted(set(all_raga_truth) | set(all_raga_preds))

    idx_to_raga = data_dict["raga_to_idx"]
    inv_raga = {v:k for k,v in idx_to_raga.items()}
    used_raga_labels = [inv_raga[i] for i in used_raga_indices]

    disp_raga = ConfusionMatrixDisplay(cm_raga, display_labels=used_raga_labels)
    disp_raga.plot(cmap=plt.cm.Blues)
    plt.title("Raga Confusion Matrix (No-Quantum)")
    plt.show()

    # Quality Confusion Matrix
    cm_qual = confusion_matrix(all_qual_truth, all_qual_preds)
    used_qual_indices = sorted(set(all_qual_truth) | set(all_qual_preds))
    used_qual_labels = [q+1 for q in used_qual_indices]

    disp_qual = ConfusionMatrixDisplay(cm_qual, display_labels=used_qual_labels)
    disp_qual.plot(cmap=plt.cm.Blues)
    plt.title("Quality Confusion Matrix (No-Quantum)")
    plt.show()

    # Plot training curves
    plt.figure()
    plt.plot(train_losses, label="Train Loss (NoQuantum)")
    plt.plot(test_losses, label="Test Loss (NoQuantum)")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("No-Quantum Model Training Curves")
    plt.legend()
    plt.show()

    # Save best model + used_ids
    os.makedirs("data/modeloutput", exist_ok=True)
    checkpoint_path = os.path.join("data", "modeloutput", "trained_model_noquantum.pt")
    checkpoint_dict = {
        "model_state": model_nq.state_dict(),
        "raga_to_idx": data_dict["raga_to_idx"],
        "scalar_dim": scalar_feats.shape[1],
        "num_ragas": num_ragas,
        "used_ids": used_ids
    }
    torch.save(checkpoint_dict, checkpoint_path)
    print(f"[NoQuantum] Best model + metadata (no quantum) saved to: {checkpoint_path}")

    return model_nq, data_dict["raga_to_idx"]

###############################################################################
# 3) COMPARISON FUNCTION: QUANTUM VS. NO-QUANTUM
###############################################################################
def compare_classical_vs_quantum_models():
    """
    1) Re-fetch the dataset (same DB_LIMIT).
    2) Build the same test set split (20%).
    3) Load 'trained_model.pt' (quantum) and 'trained_model_noquantum.pt' (no quantum).
    4) Evaluate both on the same test set, compare accuracy side-by-side.
    """
    data_dict, used_ids = fetch_training_data_v2(limit=HYPERPARAMS["DB_LIMIT"])
    if data_dict is None:
        print("[Compare] No data for comparison.")
        return

    from sklearn.metrics import accuracy_score

    dataset_full = MultiLabelASTDataset(data_dict)
    n_samples = len(dataset_full)
    test_size = int(0.2 * n_samples)
    train_size = n_samples - test_size
    _, test_ds = torch.utils.data.random_split(dataset_full, [train_size, test_size])
    test_dl = DataLoader(test_ds, batch_size=HYPERPARAMS["BATCH_SIZE"], shuffle=False)

    #--- Load quantum model checkpoint
    q_ckpt_path = "data/modeloutput/trained_model.pt"
    q_ckpt = torch.load(q_ckpt_path, map_location=HYPERPARAMS["DEVICE"])
    q_model_state = q_ckpt["model_state"]
    q_raga_to_idx = q_ckpt["raga_to_idx"]
    q_scalar_dim  = q_ckpt["scalar_dim"]
    q_quantum_dim = q_ckpt["quantum_dim"]
    q_num_ragas   = q_ckpt["num_ragas"]

    # Rebuild quantum model
    model_q = HybridASTModelV2(
        num_ragas=q_num_ragas,
        scalar_dim=q_scalar_dim,
        quantum_dim=q_quantum_dim,
        num_quality=5,
        num_unfrozen_layers=HYPERPARAMS["NUM_AST_LAYERS_UNFROZEN"]
    )
    model_q.load_state_dict(q_model_state)
    model_q.to(HYPERPARAMS["DEVICE"])
    model_q.eval()

    #--- Load no-quantum model checkpoint
    nq_ckpt_path = "data/modeloutput/trained_model_noquantum.pt"
    nq_ckpt = torch.load(nq_ckpt_path, map_location=HYPERPARAMS["DEVICE"])
    nq_model_state = nq_ckpt["model_state"]
    nq_raga_to_idx = nq_ckpt["raga_to_idx"]
    nq_scalar_dim  = nq_ckpt["scalar_dim"]
    nq_num_ragas   = nq_ckpt["num_ragas"]

    model_nq = HybridASTModelV2NoQuantum(
        num_ragas=nq_num_ragas,
        scalar_dim=nq_scalar_dim,
        num_quality=5,
        num_unfrozen_layers=HYPERPARAMS["NUM_AST_LAYERS_UNFROZEN"]
    )
    model_nq.load_state_dict(nq_model_state)
    model_nq.to(HYPERPARAMS["DEVICE"])
    model_nq.eval()

    #--- Evaluate both on test
    all_true_ragas = []
    all_true_quals = []

    preds_raga_q, preds_qual_q = [], []
    preds_raga_nq, preds_qual_nq = [], []

    with torch.no_grad():
        for mfcc_img, log_img, scal, quan, rag_lbl, qual_lbl in test_dl:
            mfcc_img = mfcc_img.to(HYPERPARAMS["DEVICE"])
            log_img  = log_img.to(HYPERPARAMS["DEVICE"])
            scal     = scal.to(HYPERPARAMS["DEVICE"])
            quan     = quan.to(HYPERPARAMS["DEVICE"])
            rag_lbl  = rag_lbl.to(HYPERPARAMS["DEVICE"])
            qual_lbl = qual_lbl.to(HYPERPARAMS["DEVICE"])

            # quantum model
            logits_raga_q, logits_quality_q = model_q(mfcc_img, log_img, scal, quan)
            pred_r_q = logits_raga_q.argmax(dim=1).cpu().numpy()
            pred_q_q = logits_quality_q.argmax(dim=1).cpu().numpy()

            # no-quantum model
            logits_raga_nq, logits_quality_nq = model_nq(mfcc_img, log_img, scal)
            pred_r_nq = logits_raga_nq.argmax(dim=1).cpu().numpy()
            pred_q_nq = logits_quality_nq.argmax(dim=1).cpu().numpy()

            all_true_ragas.extend(rag_lbl.cpu().numpy())
            all_true_quals.extend(qual_lbl.cpu().numpy())
            preds_raga_q.extend(pred_r_q)
            preds_qual_q.extend(pred_q_q)
            preds_raga_nq.extend(pred_r_nq)
            preds_qual_nq.extend(pred_q_nq)

    rag_acc_q  = accuracy_score(all_true_ragas, preds_raga_q)
    rag_acc_nq = accuracy_score(all_true_ragas, preds_raga_nq)
    qual_acc_q  = accuracy_score(all_true_quals, preds_qual_q)
    qual_acc_nq = accuracy_score(all_true_quals, preds_qual_nq)

    print("\n=== Comparison: Quantum Model vs. No-Quantum Model ===")
    print(f"Raga Accuracy (Quantum):     {rag_acc_q:.4f}")
    print(f"Raga Accuracy (No-Quantum):  {rag_acc_nq:.4f}")
    print(f"Quality Accuracy (Quantum):  {qual_acc_q:.4f}")
    print(f"Quality Accuracy (No-Quant): {qual_acc_nq:.4f}")

    d_raga = (rag_acc_q - rag_acc_nq)*100
    d_qual = (qual_acc_q - qual_acc_nq)*100
    print(f"\nDifference in Raga Accuracy = {d_raga:.2f} percentage points")
    print(f"Difference in Quality Acc.  = {d_qual:.2f} percentage points\n")







#  SINGLE-RECORD INFERENCE (unchanged)

def run_inference_on_persisted_model(record_id, model_path="data/modeloutput/trained_model.pt"):
    """
    Single-record inference for quick demos, etc.
    """
    checkpoint = torch.load(model_path, map_location=HYPERPARAMS["DEVICE"])
    model_state = checkpoint["model_state"]
    raga_to_idx = checkpoint["raga_to_idx"]
    scalar_dim = checkpoint["scalar_dim"]
    quantum_dim = checkpoint["quantum_dim"]
    num_ragas = checkpoint["num_ragas"]

    model = HybridASTModelV2(
        num_ragas=num_ragas,
        scalar_dim=scalar_dim,
        quantum_dim=quantum_dim,
        num_quality=5,
        num_unfrozen_layers=HYPERPARAMS["NUM_AST_LAYERS_UNFROZEN"]
    )
    model.load_state_dict(model_state)
    model.to(HYPERPARAMS["DEVICE"])
    model.eval()

    db = QuantumMusicDBFetchOnly()
    row = db.fetch_single_record(record_id)
    db.close()

    if not row:
        print(f"No record found with ID={record_id}")
        return None, None

    _, analysis_data = row
    fname = analysis_data["file_name"]
    raga_true, quality_true = parse_raga_and_quality(fname)
    base_no_ext = fname.replace(".wav", "")

    mfcc_path = os.path.join("data", "analysisoutput", f"{base_no_ext}_mfcc.png")
    log_path  = os.path.join("data", "analysisoutput", f"{base_no_ext}_log_spectrogram.png")

    mfcc_img = load_image_as_tensor(mfcc_path)
    if mfcc_img is None:
        mfcc_img = torch.zeros((1,1))

    log_img = load_image_as_tensor(log_path)
    if log_img is None:
        log_img = torch.zeros((1,1))
    
    res = analysis_data.get("results", {})
    dyn = analysis_data.get("dynamics_summary", {})
    adv = analysis_data.get("quantum_analysis", {}).get("advanced_stats", {})

    avg_dev = res.get("average_dev_cents", 0.0)
    std_dev = res.get("std_dev_cents", 0.0)
    avg_hnr = res.get("avg_praat_hnr", 0.0)
    avg_tnr = res.get("avg_tnr", 0.0)

    rms_db_stats = dyn.get("rms_db", {})
    mean_rms_db  = rms_db_stats.get("mean", 0.0)
    lufs_stats   = dyn.get("lufs", {})
    mean_lufs    = lufs_stats.get("mean", 0.0)

    avg_jitter   = adv.get("avg_jitter", 0.0)
    avg_shimmer  = adv.get("avg_shimmer", 0.0)
    avg_vibrato  = adv.get("avg_vibrato_rate", 0.0)
    avg_formant  = adv.get("avg_F1", 0.0)

    scalars = [
        avg_dev, std_dev, avg_hnr, avg_tnr,
        mean_rms_db, mean_lufs,
        avg_jitter, avg_shimmer, avg_vibrato, avg_formant
    ]

    quantum_dict = analysis_data.get("quantum_analysis", {})
    angles = quantum_dict.get("scaled_angles", [])
    max_len = 10
    angle_arr = np.zeros(max_len, dtype=np.float32)
    for i in range(min(max_len, len(angles))):
        angle_arr[i] = angles[i]

    counts_d = quantum_dict.get("measurement_counts", {})
    dist_vec = convert_counts_to_probs_feature(counts_d, max_bits=10)
    combined_q = np.concatenate([angle_arr, dist_vec], axis=0)

    device = HYPERPARAMS["DEVICE"]
    mfcc_img = mfcc_img.unsqueeze(0).to(device)
    log_img  = log_img.unsqueeze(0).to(device)
    scal_ten = torch.tensor(scalars, dtype=torch.float32).unsqueeze(0).to(device)
    quan_ten = torch.tensor(combined_q, dtype=torch.float32).unsqueeze(0).to(device)

    with torch.no_grad():
        logits_raga, logits_quality = model(mfcc_img, log_img, scal_ten, quan_ten)
        pred_raga_idx = torch.argmax(logits_raga, dim=1).item()
        pred_qual_idx = torch.argmax(logits_quality, dim=1).item()

    inv_map = {v: k for k, v in raga_to_idx.items()}
    pred_raga_str = inv_map.get(pred_raga_idx, "UnknownRaga")
    pred_quality = pred_qual_idx + 1

    print(f"Inference for Record ID={record_id}:")
    print(f"  True raga={raga_true}, True quality={quality_true}")
    print(f"  Predicted raga={pred_raga_str}, predicted quality={pred_quality}")
    return pred_raga_str, pred_quality



# NEW: run_inference_on_persisted_model_v2

def run_inference_on_persisted_model_v2(model_path="data/modeloutput/trained_model.pt"):
    """
    1) Loads the best model + used_ids from the checkpoint.
    2) Fetch leftover records with rating >= 4, excluding used_ids.
    3) Runs inference, prints (ID, file_name, actual raga/quality, predicted raga/quality).
    """
    # 1) Load checkpoint
    checkpoint = torch.load(model_path, map_location=HYPERPARAMS["DEVICE"])
    model_state = checkpoint["model_state"]
    raga_to_idx = checkpoint["raga_to_idx"]
    scalar_dim  = checkpoint["scalar_dim"]
    quantum_dim = checkpoint["quantum_dim"]
    num_ragas   = checkpoint["num_ragas"]
    used_ids    = checkpoint.get("used_ids", [])

    # 2) Rebuild model
    model = HybridASTModelV2(
        num_ragas=num_ragas,
        scalar_dim=scalar_dim,
        quantum_dim=quantum_dim,
        num_quality=5,
        num_unfrozen_layers=HYPERPARAMS["NUM_AST_LAYERS_UNFROZEN"]
    )
    model.load_state_dict(model_state)
    model.to(HYPERPARAMS["DEVICE"])
    model.eval()

    # 3) Fetch leftover records
    db = QuantumMusicDBFetchOnly()
    leftover = db.fetch_leftover_records(exclude_ids=set(used_ids))
    db.close()

    if not leftover:
        print("No leftover records found with rating >=4 that weren't used in training.")
        return

    # For mapping predicted raga index -> string
    inv_map = {v: k for k, v in raga_to_idx.items()}

    # 4) Inference loop
    for (rec_id, analysis_data) in leftover:
        fname = analysis_data["file_name"]
        raga_true, qual_true = parse_raga_and_quality(fname)
        base_no_ext = fname.replace(".wav", "")

        # images
        mfcc_path = os.path.join("data", "analysisoutput", f"{base_no_ext}_mfcc.png")
        log_path  = os.path.join("data", "analysisoutput", f"{base_no_ext}_log_spectrogram.png")
        mfcc_img = load_image_as_tensor(mfcc_path)
        if mfcc_img is None:
            mfcc_img = torch.zeros((1,1))
        log_img = load_image_as_tensor(log_path)
        if log_img is None:
            log_img = torch.zeros((1,1))

        # scalar
        res = analysis_data.get("results", {})
        dyn = analysis_data.get("dynamics_summary", {})
        adv = analysis_data.get("quantum_analysis", {}).get("advanced_stats", {})

        avg_dev = res.get("average_dev_cents", 0.0)
        std_dev = res.get("std_dev_cents", 0.0)
        avg_hnr = res.get("avg_praat_hnr", 0.0)
        avg_tnr = res.get("avg_tnr", 0.0)

        rms_db_stats = dyn.get("rms_db", {})
        mean_rms_db  = rms_db_stats.get("mean", 0.0)
        lufs_stats   = dyn.get("lufs", {})
        mean_lufs    = lufs_stats.get("mean", 0.0)

        avg_jitter   = adv.get("avg_jitter", 0.0)
        avg_shimmer  = adv.get("avg_shimmer", 0.0)
        avg_vibrato  = adv.get("avg_vibrato_rate", 0.0)
        avg_formant  = adv.get("avg_F1", 0.0)

        scalars = [
            avg_dev, std_dev, avg_hnr, avg_tnr,
            mean_rms_db, mean_lufs,
            avg_jitter, avg_shimmer, avg_vibrato, avg_formant
        ]

        # quantum
        quantum_dict = analysis_data.get("quantum_analysis", {})
        angles = quantum_dict.get("scaled_angles", [])
        max_len = 10
        angle_arr = np.zeros(max_len, dtype=np.float32)
        for i in range(min(max_len, len(angles))):
            angle_arr[i] = angles[i]

        counts_d = quantum_dict.get("measurement_counts", {})
        dist_vec = convert_counts_to_probs_feature(counts_d, max_bits=10)
        combined_q = np.concatenate([angle_arr, dist_vec], axis=0)

        # 5) Model forward pass
        device = HYPERPARAMS["DEVICE"]
        mfcc_img = mfcc_img.unsqueeze(0).to(device)
        log_img  = log_img.unsqueeze(0).to(device)
        scal_ten = torch.tensor(scalars, dtype=torch.float32).unsqueeze(0).to(device)
        quan_ten = torch.tensor(combined_q, dtype=torch.float32).unsqueeze(0).to(device)

        with torch.no_grad():
            logits_raga, logits_quality = model(mfcc_img, log_img, scal_ten, quan_ten)
            pred_raga_idx = torch.argmax(logits_raga, dim=1).item()
            pred_qual_idx = torch.argmax(logits_quality, dim=1).item()

        pred_raga_str = inv_map.get(pred_raga_idx, "UnknownRaga")
        pred_quality = pred_qual_idx + 1

        # 6) Print
        print(f"\nLeftover Inference => ID={rec_id}, file_name={fname}")
        print(f"   Actual Raga={raga_true}, Quality={qual_true}")
        print(f"   Predicted Raga={pred_raga_str}, Quality={pred_quality}")

    print("\n--- Finished leftover inference (rating>=4) ---")



# Example usage

if __name__ == "__main__":
    # 1) TRAIN
    #model, raga_to_idx = train_model_v2()
    #model_nq, raga_nq = train_model_v2_noquantum()  
    # saves to data/modeloutput/trained_model_noquantum.pt

    #compare_classical_vs_quantum_models()

    # 2) SINGLE RECORD
    run_inference_on_persisted_model(2150, "data/modeloutput/trained_model_rohantst1.pt")

    # 3) LEFTOVER RATING >= 4
    #run_inference_on_persisted_model_v2("data/modeloutput/trained_model.pt")

Using MPS device for GPU acceleration on Apple Silicon.


  checkpoint = torch.load(model_path, map_location=HYPERPARAMS["DEVICE"])


Connected to database quantummusic. (fetch-only)
Database connection closed.
Inference for Record ID=2150:
  True raga=Bhairav, True quality=4
  Predicted raga=Bhairav, predicted quality=4
