## Machine Learning Models

In [None]:
import os
import re
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchaudio.transforms import MelSpectrogram
import librosa
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import random

from transformers import ASTModel, ASTConfig
import torchvision.transforms as T
from PIL import Image

################################################################################
# 1) MPS DEVICE CHECK
################################################################################
if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
    print("Using MPS device for GPU acceleration on Apple Silicon.")
else:
    DEVICE = torch.device("cpu")
    print("MPS not available. Falling back to CPU.")

################################################################################
# HYPERPARAMETERS / GLOBAL SETTINGS
################################################################################
HYPERPARAMS = {
    "OUTPUT_DIR": "data/trainingdataoutput",
    # We'll read up to 1000 samples from DB now
    "DB_LIMIT": 1000,

    # Training config
    "BATCH_SIZE": 8,
    "EPOCHS": 10,          # up to 20 epochs
    "LEARNING_RATE_MAIN": 5e-5,  # for AST layers
    "LEARNING_RATE_HEAD": 1e-4,  # for scalar/quantum/final layers
    "WEIGHT_DECAY": 1e-4,        # L2 regularization factor
    "DEVICE": DEVICE,
    "NUM_AST_LAYERS_UNFROZEN": 4, # partial unfreezing

    # Early Stopping
    "PATIENCE": 5,      # how many epochs of no improvement before stopping
    "STOP_LIMIT": 5,   # must train at least this many epochs before early stop

    # AST image sizing
    "AST_FREQ": 128,
    "AST_TIME": 1024,
}

################################################################################
# Minimal DB to fetch analysis_data
################################################################################
class QuantumMusicDBFetchOnly:
    """
    Minimal class to fetch analysis_data from the DB for ML.
    Now includes a LIMIT for reading a fixed number of rows.
    """
    def __init__(self,
                 db_name="quantummusic",
                 host="localhost",
                 user="postgres",
                 password="postgres"):
        import psycopg2
        self.psycopg2 = psycopg2
        self.db_name = db_name
        self.host = host
        self.user = user
        self.password = password
        self.conn = None
        self.connect()

    def connect(self):
        try:
            self.conn = self.psycopg2.connect(
                dbname=self.db_name,
                host=self.host,
                user=self.user,
                password=self.password
            )
            print(f"Connected to database {self.db_name}. (fetch-only)")
        except Exception as e:
            print(f"Error connecting to database: {e}")

    def close(self):
        if self.conn:
            self.conn.close()
            print("Database connection closed.")

    def fetch_limited_analysis_data(self, limit=42): #all records, given that we have only about 1700
        """
        Read a fixed 'limit' number of rows from 'audio_analysis' table.
        Default limit=1000 now.
        """
        with self.conn.cursor() as cur:
            #query = f"SELECT id, analysis_data FROM audio_analysis LIMIT {limit}"
            query = """
                WITH cte AS (
                    SELECT
                        id,
                        analysis_data,
                        -- Extract the digits as an integer
                        substring(file_name FROM '\\d+')::int AS file_num,
                        -- Assign a random order within each file_num group
                        ROW_NUMBER() OVER (
                            PARTITION BY substring(file_name FROM '\\d+')::int
                            ORDER BY random()
                        ) AS rn
                    FROM audio_analysis
                    WHERE substring(file_name FROM '\\d+')::int IN (2, 3, 4, 5)
                )
                SELECT id, analysis_data
                FROM cte
                WHERE rn <= 42;
                """
            cur.execute(query)
            rows = cur.fetchall()
        return rows

    def fetch_single_record(self, record_id):
        with self.conn.cursor() as cur:
            query = "SELECT id, analysis_data FROM audio_analysis WHERE id = %s"
            cur.execute(query, (record_id,))
            row = cur.fetchone()
        return row

################################################################################
# Convert quantum measurement counts -> probability vector
################################################################################
def convert_counts_to_probs_feature(counts_dict, max_bits=10):
    """
    Creates a probability distribution vector of length 2^max_bits.
    """
    total_counts = sum(counts_dict.values())
    if total_counts == 0:
        return np.zeros(2**max_bits, dtype=np.float32)

    feature_vec = np.zeros(2**max_bits, dtype=np.float32)
    for bitstring, c in counts_dict.items():
        if len(bitstring) > max_bits:
            truncated = bitstring[-max_bits:]
        else:
            truncated = bitstring.rjust(max_bits, '0')
        idx = int(truncated, 2)
        feature_vec[idx] += c / total_counts
    return feature_vec

################################################################################
# parse_raga_and_quality
################################################################################
def parse_raga_and_quality(fname: str):
    base = fname.replace(".wav", "")
    match = re.match(r"^([A-Za-z]+)([1-5])(.*)", base)
    if not match:
        return None, None
    raga = match.group(1)
    quality_str = match.group(2)
    quality = int(quality_str)
    return raga, quality

################################################################################
# load_image_as_tensor
################################################################################
transform_img = T.Compose([
    T.ToTensor()
])

def load_image_as_tensor(image_path: str):
    if not os.path.exists(image_path):
        return None
    with Image.open(image_path) as img:
        img = img.convert("L")  # single-channel grayscale
    return transform_img(img)

################################################################################
# fetch_training_data_v2
################################################################################
def fetch_training_data_v2(limit=None):
    db = QuantumMusicDBFetchOnly()
    limit_to_use = limit if limit is not None else HYPERPARAMS["DB_LIMIT"]
    rows = db.fetch_limited_analysis_data(limit=limit_to_use)
    db.close()

    if not rows:
        print("No data found in DB.")
        return None

    images_mfcc = []
    images_log = []
    scalar_feats = []
    quantum_feats = []
    labels_raga = []
    labels_quality = []

    for (record_id, analysis_data) in rows:
        wav_fname = analysis_data["file_name"]
        base_no_ext = wav_fname.replace(".wav", "")

        # parse raga & quality
        raga, quality = parse_raga_and_quality(wav_fname)
        if raga is None or quality is None:
            continue
        labels_raga.append(raga)
        labels_quality.append(quality)

        # load images
        mfcc_path = os.path.join("data", "analysisoutput", f"{base_no_ext}_mfcc.png")
        log_path  = os.path.join("data", "analysisoutput", f"{base_no_ext}_log_spectrogram.png")

        mfcc_img = load_image_as_tensor(mfcc_path)
        log_img  = load_image_as_tensor(log_path)

        if mfcc_img is None:
            mfcc_img = torch.zeros((1,1))
        if log_img is None:
            log_img = torch.zeros((1,1))

        images_mfcc.append(mfcc_img)
        images_log.append(log_img)

        # scalar
        res = analysis_data.get("results", {})
        dyn = analysis_data.get("dynamics_summary", {})
        adv = analysis_data.get("quantum_analysis", {}).get("advanced_stats", {})

        avg_dev = res.get("average_dev_cents", 0.0)
        std_dev = res.get("std_dev_cents", 0.0)
        avg_hnr = res.get("avg_praat_hnr", 0.0)
        avg_tnr = res.get("avg_tnr", 0.0)

        rms_db_stats = dyn.get("rms_db", {})
        mean_rms_db  = rms_db_stats.get("mean", 0.0)
        lufs_stats   = dyn.get("lufs", {})
        mean_lufs    = lufs_stats.get("mean", 0.0)

        avg_jitter   = adv.get("avg_jitter", 0.0)
        avg_shimmer  = adv.get("avg_shimmer", 0.0)
        avg_vibrato  = adv.get("avg_vibrato_rate", 0.0)
        avg_formant  = adv.get("avg_F1", 0.0)

        scalars = [
            avg_dev, std_dev, avg_hnr, avg_tnr,
            mean_rms_db, mean_lufs,
            avg_jitter, avg_shimmer, avg_vibrato, avg_formant
        ]
        scalar_feats.append(scalars)

        # quantum
        quantum_dict = analysis_data.get("quantum_analysis", {})
        angles = quantum_dict.get("scaled_angles", [])
        max_len = 10
        angle_arr = np.zeros(max_len, dtype=np.float32)
        for i in range(min(max_len, len(angles))):
            angle_arr[i] = angles[i]

        counts_d = quantum_dict.get("measurement_counts", {})
        dist_vec = convert_counts_to_probs_feature(counts_d, max_bits=10)
        combined_q = np.concatenate([angle_arr, dist_vec], axis=0)
        quantum_feats.append(combined_q)

    images_mfcc = torch.stack(images_mfcc)
    images_log  = torch.stack(images_log)
    scalar_feats = np.array(scalar_feats, dtype=np.float32)
    quantum_feats = np.array(quantum_feats, dtype=np.float32)

    # label encodings
    unique_ragas = sorted(list(set(labels_raga)))
    raga_to_idx = {r: i for i, r in enumerate(unique_ragas)}
    labels_raga_idx = [raga_to_idx[r] for r in labels_raga]

    labels_quality_arr = np.array([q - 1 for q in labels_quality], dtype=np.int64)

    data_dict = {
        "images_mfcc": images_mfcc,
        "images_log": images_log,
        "scalar_feats": scalar_feats,
        "quantum_feats": quantum_feats,
        "label_raga_idx": np.array(labels_raga_idx, dtype=np.int64),
        "label_quality": labels_quality_arr,
        "raga_to_idx": raga_to_idx,
        "unique_ragas": unique_ragas
    }
    return data_dict

################################################################################
# MultiLabelASTDataset
################################################################################
class MultiLabelASTDataset(Dataset):
    def __init__(self, data_dict):
        self.mfcc_imgs = data_dict["images_mfcc"]
        self.log_imgs  = data_dict["images_log"]
        self.scalars   = data_dict["scalar_feats"]
        self.quants    = data_dict["quantum_feats"]
        self.raga_lbl  = data_dict["label_raga_idx"]
        self.qual_lbl  = data_dict["label_quality"]

    def __len__(self):
        return len(self.raga_lbl)

    def __getitem__(self, idx):
        scal = torch.tensor(self.scalars[idx], dtype=torch.float32)
        qua  = torch.tensor(self.quants[idx], dtype=torch.float32)
        rag  = torch.tensor(self.raga_lbl[idx], dtype=torch.long)
        qua_lbl = torch.tensor(self.qual_lbl[idx], dtype=torch.long)
        return (
            self.mfcc_imgs[idx],
            self.log_imgs[idx],
            scal,
            qua,
            rag,
            qua_lbl
        )

################################################################################
# HybridASTModelV2
################################################################################
class HybridASTModelV2(nn.Module):
    def __init__(self, num_ragas, scalar_dim=10, quantum_dim=42,
                 num_quality=5, num_unfrozen_layers=0):
        super().__init__()
        self.config = ASTConfig.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
        self.ast_model = ASTModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593",
                                                  config=self.config)

        # partial unfreeze
        for param in self.ast_model.parameters():
            param.requires_grad = False
        if num_unfrozen_layers > 0:
            total_layers = 12
            start_layer = max(0, total_layers - num_unfrozen_layers)
            for layer_idx in range(start_layer, total_layers):
                for param in self.ast_model.encoder.layer[layer_idx].parameters():
                    param.requires_grad = True

        # Add dropout to final heads for more regularization (optional)
        self.scalar_fc = nn.Sequential(
            nn.Linear(scalar_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
        )
        self.quantum_fc = nn.Sequential(
            nn.Linear(quantum_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
        )

        combined_dim = 768 + 64 + 64
        self.raga_head = nn.Sequential(
            nn.Linear(combined_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, num_ragas)
        )
        self.quality_head = nn.Sequential(
            nn.Linear(combined_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, num_quality)
        )

    def forward(self, mfcc_img, log_img, scal_in, quan_in):
        B, C, H, W = log_img.shape
        freq = HYPERPARAMS["AST_FREQ"]
        time = HYPERPARAMS["AST_TIME"]

        log_resized = F.interpolate(
            log_img,
            size=(freq, time),
            mode='bilinear',
            align_corners=False
        )
        log_for_ast = log_resized.squeeze(1)  # (B,128,1024)

        outputs = self.ast_model(input_values=log_for_ast, output_hidden_states=True)
        hidden = outputs.last_hidden_state  # (B, seq_len, 768)
        ast_embedding = hidden[:, 0, :]     # (B,768)

        emb_scal = self.scalar_fc(scal_in)
        emb_quan = self.quantum_fc(quan_in)
        fused = torch.cat([ast_embedding, emb_scal, emb_quan], dim=1)  # (B,896)

        logits_raga = self.raga_head(fused)
        logits_quality = self.quality_head(fused)
        return logits_raga, logits_quality

################################################################################
# train_model_v2 with Minimum EPOCHS + Early Stopping, L2, Confusion Plots
################################################################################
def train_model_v2():
    data_dict = fetch_training_data_v2(limit=HYPERPARAMS["DB_LIMIT"])
    if data_dict is None:
        print("No data to train.")
        return None, None

    scalar_feats = data_dict["scalar_feats"]
    quantum_feats = data_dict["quantum_feats"]
    raga_to_idx = data_dict["raga_to_idx"]
    num_ragas = len(raga_to_idx)

    dataset_full = MultiLabelASTDataset(data_dict)
    n_samples = len(dataset_full)
    test_size = int(0.2 * n_samples)
    train_size = n_samples - test_size
    train_ds, test_ds = torch.utils.data.random_split(dataset_full, [train_size, test_size])

    train_dl = DataLoader(train_ds, batch_size=HYPERPARAMS["BATCH_SIZE"], shuffle=True)
    test_dl  = DataLoader(test_ds, batch_size=HYPERPARAMS["BATCH_SIZE"], shuffle=False)

    model = HybridASTModelV2(
        num_ragas=num_ragas,
        scalar_dim=scalar_feats.shape[1],
        quantum_dim=quantum_feats.shape[1],
        num_quality=5,
        num_unfrozen_layers=HYPERPARAMS["NUM_AST_LAYERS_UNFROZEN"]
    ).to(HYPERPARAMS["DEVICE"])

    # L2 => weight_decay
    optimizer = optim.Adam(
        model.parameters(),
        lr=HYPERPARAMS["LEARNING_RATE_HEAD"],
        weight_decay=HYPERPARAMS["WEIGHT_DECAY"]
    )

    crit_ce = nn.CrossEntropyLoss()
    EPOCHS = HYPERPARAMS["EPOCHS"]
    train_losses = []
    test_losses = []

    best_test_loss = float("inf")
    best_model_state = None
    patience = HYPERPARAMS["PATIENCE"]
    epochs_no_improve = 0
    stop_limit = HYPERPARAMS["STOP_LIMIT"]  # must do at least this many epochs

    for epoch in range(EPOCHS):
        model.train()
        total_train_loss = 0.0

        for mfcc_img, log_img, scal, quan, rag_lbl, qual_lbl in train_dl:
            mfcc_img = mfcc_img.to(HYPERPARAMS["DEVICE"])
            log_img  = log_img.to(HYPERPARAMS["DEVICE"])
            scal     = scal.to(HYPERPARAMS["DEVICE"])
            quan     = quan.to(HYPERPARAMS["DEVICE"])
            rag_lbl  = rag_lbl.to(HYPERPARAMS["DEVICE"])
            qual_lbl = qual_lbl.to(HYPERPARAMS["DEVICE"])

            optimizer.zero_grad()
            logits_raga, logits_quality = model(mfcc_img, log_img, scal, quan)
            loss_raga = crit_ce(logits_raga, rag_lbl)
            loss_qual = crit_ce(logits_quality, qual_lbl)
            loss = loss_raga + loss_qual
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()

        avg_train_loss = total_train_loss / len(train_dl)
        train_losses.append(avg_train_loss)

        # Evaluate
        model.eval()
        total_test_loss = 0.0
        with torch.no_grad():
            for mfcc_img, log_img, scal, quan, rag_lbl, qual_lbl in test_dl:
                mfcc_img = mfcc_img.to(HYPERPARAMS["DEVICE"])
                log_img  = log_img.to(HYPERPARAMS["DEVICE"])
                scal     = scal.to(HYPERPARAMS["DEVICE"])
                quan     = quan.to(HYPERPARAMS["DEVICE"])
                rag_lbl  = rag_lbl.to(HYPERPARAMS["DEVICE"])
                qual_lbl = qual_lbl.to(HYPERPARAMS["DEVICE"])

                logits_raga, logits_quality = model(mfcc_img, log_img, scal, quan)
                loss_raga = crit_ce(logits_raga, rag_lbl)
                loss_qual = crit_ce(logits_quality, qual_lbl)
                total_test_loss += (loss_raga + loss_qual).item()

        avg_test_loss = total_test_loss / len(test_dl)
        test_losses.append(avg_test_loss)

        print(f"Epoch {epoch+1}/{EPOCHS}: train_loss={avg_train_loss:.4f}, test_loss={avg_test_loss:.4f}")

        # Early Stopping Check
        if avg_test_loss < best_test_loss:
            best_test_loss = avg_test_loss
            best_model_state = model.state_dict()
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            # Only consider stopping if we've run at least stop_limit epochs
            if (epoch + 1) >= stop_limit and epochs_no_improve >= patience:
                print(f"No improvement for {patience} epochs after epoch {stop_limit}, "
                      f"stopping early at epoch {epoch+1}.")
                break

    print("Training complete or early stopped.")

    # If we found a best_model_state, load it back
    if best_model_state is not None:
        print("Loading best model state (lowest test loss)...")
        model.load_state_dict(best_model_state)

    # Final pass for confusion matrices
    all_raga_preds, all_raga_truth = [], []
    all_qual_preds, all_qual_truth = [], []

    model.eval()
    with torch.no_grad():
        for mfcc_img, log_img, scal, quan, rag_lbl, qual_lbl in test_dl:
            mfcc_img = mfcc_img.to(HYPERPARAMS["DEVICE"])
            log_img  = log_img.to(HYPERPARAMS["DEVICE"])
            scal     = scal.to(HYPERPARAMS["DEVICE"])
            quan     = quan.to(HYPERPARAMS["DEVICE"])
            rag_lbl  = rag_lbl.to(HYPERPARAMS["DEVICE"])
            qual_lbl = qual_lbl.to(HYPERPARAMS["DEVICE"])

            logits_raga, logits_quality = model(mfcc_img, log_img, scal, quan)
            pred_raga = logits_raga.argmax(dim=1).cpu().numpy()
            pred_qual = logits_quality.argmax(dim=1).cpu().numpy()

            all_raga_preds.extend(pred_raga)
            all_raga_truth.extend(rag_lbl.cpu().numpy())
            all_qual_preds.extend(pred_qual)
            all_qual_truth.extend(qual_lbl.cpu().numpy())

    from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

    cm_raga = confusion_matrix(all_raga_truth, all_raga_preds)
    used_raga_indices = sorted(set(all_raga_truth) | set(all_raga_preds))
    idx_to_raga = {v: k for k, v in data_dict["raga_to_idx"].items()}
    used_raga_labels = [idx_to_raga[i] for i in used_raga_indices]

    disp_raga = ConfusionMatrixDisplay(cm_raga, display_labels=used_raga_labels)
    disp_raga.plot(cmap=plt.cm.Blues)
    plt.title("Raga Confusion Matrix (Best Model)")
    plt.show()

    cm_qual = confusion_matrix(all_qual_truth, all_qual_preds)
    used_qual_indices = sorted(set(all_qual_truth) | set(all_qual_preds))
    used_qual_labels  = [q+1 for q in used_qual_indices]
    disp_qual = ConfusionMatrixDisplay(cm_qual, display_labels=used_qual_labels)
    disp_qual.plot(cmap=plt.cm.Blues)
    plt.title("Quality Confusion Matrix (Best Model)")
    plt.show()

    # Plot training curves
    plt.figure()
    plt.plot(train_losses, label="Train Loss")
    plt.plot(test_losses, label="Test Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training Curves (Min. EPOCHS + Early Stopping + L2)")
    plt.legend()
    plt.show()

    # Save best model + metadata => data/modeloutput/trained_model.pt
    os.makedirs("data/modeloutput", exist_ok=True)
    checkpoint_path = os.path.join("data", "modeloutput", "trained_model.pt")

    checkpoint_dict = {
        "model_state": model.state_dict(),
        "raga_to_idx": data_dict["raga_to_idx"],
        "scalar_dim": scalar_feats.shape[1],
        "quantum_dim": quantum_feats.shape[1],
        "num_ragas": len(data_dict["raga_to_idx"])
    }
    torch.save(checkpoint_dict, checkpoint_path)
    print(f"Best model + metadata saved to: {checkpoint_path}")

    return model, data_dict["raga_to_idx"]

################################################################################
# run_inference_on_persisted_model
################################################################################
def run_inference_on_persisted_model(record_id, model_path="data/modeloutput/trained_model.pt"):
    """
    Loads the persisted model+metadata from 'model_path', re-creates the model,
    and runs inference on DB record {record_id}. Prints predicted raga + quality.
    """
    checkpoint = torch.load(model_path, map_location=HYPERPARAMS["DEVICE"])
    model_state = checkpoint["model_state"]
    raga_to_idx = checkpoint["raga_to_idx"]
    scalar_dim = checkpoint["scalar_dim"]
    quantum_dim = checkpoint["quantum_dim"]
    num_ragas = checkpoint["num_ragas"]

    model = HybridASTModelV2(
        num_ragas=num_ragas,
        scalar_dim=scalar_dim,
        quantum_dim=quantum_dim,
        num_quality=5,
        num_unfrozen_layers=HYPERPARAMS["NUM_AST_LAYERS_UNFROZEN"]
    )
    model.load_state_dict(model_state)
    model.to(HYPERPARAMS["DEVICE"])
    model.eval()

    db = QuantumMusicDBFetchOnly()
    row = db.fetch_single_record(record_id)
    db.close()

    if not row:
        print(f"No record found with ID={record_id}")
        return None, None

    _, analysis_data = row
    wav_fname = analysis_data["file_name"]
    base_no_ext = wav_fname.replace(".wav", "")

    raga_true, quality_true = parse_raga_and_quality(wav_fname)

    mfcc_path = os.path.join("data", "analysisoutput", f"{base_no_ext}_mfcc.png")
    log_path  = os.path.join("data", "analysisoutput", f"{base_no_ext}_log_spectrogram.png")
    mfcc_img = load_image_as_tensor(mfcc_path)
    if mfcc_img is None:
        mfcc_img = torch.zeros((1,1))

    log_img = load_image_as_tensor(log_path)
    if log_img is None:
        log_img = torch.zeros((1,1))
    
    res = analysis_data.get("results", {})
    dyn = analysis_data.get("dynamics_summary", {})
    adv = analysis_data.get("quantum_analysis", {}).get("advanced_stats", {})

    avg_dev = res.get("average_dev_cents", 0.0)
    std_dev = res.get("std_dev_cents", 0.0)
    avg_hnr = res.get("avg_praat_hnr", 0.0)
    avg_tnr = res.get("avg_tnr", 0.0)

    rms_db_stats = dyn.get("rms_db", {})
    mean_rms_db  = rms_db_stats.get("mean", 0.0)
    lufs_stats   = dyn.get("lufs", {})
    mean_lufs    = lufs_stats.get("mean", 0.0)

    avg_jitter   = adv.get("avg_jitter", 0.0)
    avg_shimmer  = adv.get("avg_shimmer", 0.0)
    avg_vibrato  = adv.get("avg_vibrato_rate", 0.0)
    avg_formant  = adv.get("avg_F1", 0.0)

    scalars = [
        avg_dev, std_dev, avg_hnr, avg_tnr,
        mean_rms_db, mean_lufs,
        avg_jitter, avg_shimmer, avg_vibrato, avg_formant
    ]

    quantum_dict = analysis_data.get("quantum_analysis", {})
    angles = quantum_dict.get("scaled_angles", [])
    max_len = 10
    angle_arr = np.zeros(max_len, dtype=np.float32)
    for i in range(min(max_len, len(angles))):
        angle_arr[i] = angles[i]

    counts_d = quantum_dict.get("measurement_counts", {})
    dist_vec = convert_counts_to_probs_feature(counts_d, max_bits=10)
    combined_q = np.concatenate([angle_arr, dist_vec], axis=0)

    device = HYPERPARAMS["DEVICE"]
    mfcc_img = mfcc_img.unsqueeze(0).to(device)
    log_img  = log_img.unsqueeze(0).to(device)
    scal_ten = torch.tensor(scalars, dtype=torch.float32).unsqueeze(0).to(device)
    quan_ten = torch.tensor(combined_q, dtype=torch.float32).unsqueeze(0).to(device)

    with torch.no_grad():
        logits_raga, logits_quality = model(mfcc_img, log_img, scal_ten, quan_ten)
        pred_raga_idx = torch.argmax(logits_raga, dim=1).item()
        pred_qual_idx = torch.argmax(logits_quality, dim=1).item()

    inv_map = {v: k for k, v in raga_to_idx.items()}
    pred_raga_str = inv_map.get(pred_raga_idx, "UnknownRaga")
    pred_quality = pred_qual_idx + 1

    print(f"Inference for Record ID={record_id}:")
    print(f"  True raga={raga_true}, True quality={quality_true}")
    print(f"  Predicted raga={pred_raga_str}, predicted quality={pred_quality}")
    return pred_raga_str, pred_quality


################################################################################
# Example usage
################################################################################
if __name__ == "__main__":
    # 1) Train the model (with at least STOP_LIMIT=10 epochs) + Early Stopping
    #model, raga_to_idx = train_model_v2()

    # 2) Later or in a separate session:
    run_inference_on_persisted_model(361, "data/modeloutput/trained_model_42_balanced.pt")