## Machine Learning Models

In [1]:
import os
import re
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, accuracy_score
import matplotlib.pyplot as plt

from transformers import ASTModel, ASTConfig
import torchvision.transforms as T
from PIL import Image

###############################################################################
# MPS DEVICE CHECK
###############################################################################
if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
    print("Using MPS device for GPU acceleration on Apple Silicon.")
else:
    DEVICE = torch.device("cpu")
    print("MPS not available. Falling back to CPU.")


###############################################################################
# HYPERPARAMETERS / GLOBAL SETTINGS
###############################################################################
HYPERPARAMS = {
    "OUTPUT_DIR": "data/trainingdataoutput",
    "DB_LIMIT": 1000,               # We'll read up to this many records
    "BATCH_SIZE": 8,
    "EPOCHS": 10,                   # up to 10 or 20 epochs
    "LEARNING_RATE_MAIN": 5e-5,     # for AST layers (if we unfreeze them)
    "LEARNING_RATE_HEAD": 1e-4,     # for scalar/quantum/final layers
    "WEIGHT_DECAY": 1e-4,           # L2 regularization
    "DEVICE": DEVICE,
    "NUM_AST_LAYERS_UNFROZEN": 4,   # partial unfreezing
    "PATIENCE": 5,                  # early stopping patience
    "STOP_LIMIT": 5,                # must train at least this many epochs

    # AST model input shape for audio: (freq=128, time=1024) by default
    "AST_FREQ": 128,
    "AST_TIME": 1024
}


###############################################################################
# DATABASE FETCH LOGIC (FETCH ONLY)
###############################################################################
class QuantumMusicDBFetchOnly:
    """
    A minimal database utility class which fetches data from a Postgres table
    named 'audio_analysis'. This specialized version only supports reading.
    """
    def __init__(self, db_name="quantummusic_csef", host="localhost", user="postgres", password="postgres"):
        import psycopg2
        self.psycopg2 = psycopg2
        self.db_name = db_name
        self.host = host
        self.user = user
        self.password = password
        self.conn = None
        self.connect()

    def connect(self):
        try:
            import psycopg2
            self.conn = psycopg2.connect(
                dbname=self.db_name,
                host=self.host,
                user=self.user,
                password=self.password
            )
            print(f"Connected to database {self.db_name}. (fetch-only)")
        except Exception as e:
            print(f"Error connecting to database: {e}")

    def close(self):
        if self.conn:
            self.conn.close()
            print("Database connection closed.")

    def fetch_limited_analysis_data(self, limit=70):
        """
        Pull a balanced subset:
        - Labels 2, 3, 4: all records for valid ragas
        - Label 5: limit by singer for each raga, then overall limit
        Returns only (id, analysis_data).
        """
        with self.conn.cursor() as cur:
            query = """
                WITH parsed AS (
                    -------------------------------------------------------------------
                    -- 1) Parse label, raga, and singer from file_name
                    -------------------------------------------------------------------
                    SELECT 
                        id,
                        analysis_data,
                        file_name,
                        
                        -- label = first run of digits
                        CAST(substring(file_name FROM '\\d+') AS INT) AS label,
                        
                        -- raga = everything before first digit
                        substring(file_name FROM '^[^0-9]+') AS raga,
                        
                        -- singer = substring after those digits, up to next digit or end
                        substring(file_name FROM '^[^0-9]+\\d+([^0-9]+)') AS singer
                    FROM audio_analysis
                ),
                valid_ragas AS (
                    -------------------------------------------------------------------
                    -- 2) Identify all ragas used by labels 2, 3, or 4
                    -------------------------------------------------------------------
                    SELECT DISTINCT raga
                    FROM parsed
                    WHERE label IN (2, 3, 4)
                ),
                small_labels AS (
                    -------------------------------------------------------------------
                    -- 3) For labels 2, 3, 4, include ALL rows for valid ragas
                    -------------------------------------------------------------------
                    SELECT 
                        id,
                        analysis_data,
                        label,
                        raga,
                        singer
                    FROM parsed
                    WHERE label IN (2, 3, 4)
                      AND raga IN (SELECT raga FROM valid_ragas)
                ),
                label_5_all AS (
                    -------------------------------------------------------------------
                    -- 4) For label=5, only rows with ragas in valid_ragas
                    -------------------------------------------------------------------
                    SELECT 
                        id,
                        analysis_data,
                        label,
                        raga,
                        singer
                    FROM parsed
                    WHERE label = 5
                      AND raga IN (SELECT raga FROM valid_ragas)
                ),
                label_5_singer_limited AS (
                    -------------------------------------------------------------------
                    -- 5) Limit how many recordings we take per (raga, singer)
                    --    e.g. keep up to 5 per (raga, singer)
                    -------------------------------------------------------------------
                    SELECT
                        id,
                        analysis_data,
                        label,
                        raga,
                        singer,
                        ROW_NUMBER() OVER (
                            PARTITION BY raga, singer
                            ORDER BY RANDOM()
                        ) AS rn_singer
                    FROM label_5_all
                ),
                label_5_filtered AS (
                    SELECT 
                        id,
                        analysis_data,
                        label,
                        raga,
                        singer
                    FROM label_5_singer_limited
                    WHERE rn_singer <= 5  -- e.g., keep up to 5 per (raga, singer)
                ),
                label_5_final AS (
                    -------------------------------------------------------------------
                    -- 6) Of the remaining label=5 rows, pick up to 'limit' total
                    -------------------------------------------------------------------
                    SELECT
                        id,
                        analysis_data,
                        label,
                        raga,
                        singer,
                        ROW_NUMBER() OVER (ORDER BY RANDOM()) AS rn
                    FROM label_5_filtered
                )
                -----------------------------------------------------------------------
                -- 7) Combine:
                --    - all from small_labels (labels 2, 3, 4)
                --    - up to 'limit' from label_5_final
                -----------------------------------------------------------------------
                SELECT id, analysis_data
                FROM small_labels

                UNION ALL

                SELECT id, analysis_data
                FROM label_5_final
                WHERE rn <= %(limit_param)s
            """
            
            # Execute the query with the provided limit
            cur.execute(query, {"limit_param": limit})
            
            # Fetch all rows
            rows = cur.fetchall()
        
        return rows
    def fetch_single_record(self, record_id):
        with self.conn.cursor() as cur:
            query = "SELECT id, analysis_data FROM audio_analysis WHERE id = %s"
            cur.execute(query, (record_id,))
            row = cur.fetchone()
        return row


###############################################################################
# HELPER FUNCTIONS
###############################################################################
def convert_counts_to_probs_feature(counts_dict, max_bits=10):
    """
    Convert a dictionary of quantum measurement counts into a probability
    distribution vector of length 2^max_bits.
    """
    total_counts = sum(counts_dict.values())
    if total_counts == 0:
        return np.zeros(2**max_bits, dtype=np.float32)

    feature_vec = np.zeros(2**max_bits, dtype=np.float32)
    for bitstring, c in counts_dict.items():
        if len(bitstring) > max_bits:
            truncated = bitstring[-max_bits:]
        else:
            truncated = bitstring.rjust(max_bits, '0')
        idx = int(truncated, 2)
        feature_vec[idx] += c / total_counts
    return feature_vec


def load_image_as_tensor(image_path: str):
    """
    Loads an image as a 1x128x128 tensor (grayscale). Returns None if not found.
    """
    transform_img = T.Compose([
        T.Resize((128, 128)),
        T.ToTensor()
    ])
    if not os.path.exists(image_path):
        return None
    with Image.open(image_path) as img:
        img = img.convert("L")
        return transform_img(img)  # => shape [1,128,128]


def parse_quality(fname: str):
    """
    Using filename format: <Something><rating>[optional extras].wav
    e.g. Bhairavi4_run2.wav => rating=4

    Returns an integer rating in [1..5]. If not found, returns None.
    """
    base = fname.replace(".wav", "")
    m = re.match(r"^([A-Za-z]+)([1-5])(.*)", base)
    if not m:
        return None
    rating = int(m.group(2))
    return rating


###############################################################################
# DATA FETCH + PARSE FOR QUALITY ONLY
###############################################################################
def fetch_training_data_quality_only(limit=None):
    """
    1) Fetch data from the DB (limit # of rows).
    2) Parse the filename for a quality rating in [1..5].
    3) Build arrays for:
        - mel spectrogram image
        - scalar features
        - quantum features
        - quality label (0..4)
    4) Return (data_dict, used_ids).
    """
    db = QuantumMusicDBFetchOnly()
    rows = db.fetch_limited_analysis_data(limit=limit or HYPERPARAMS["DB_LIMIT"])
    db.close()

    if not rows:
        print("No data found in DB.")
        return None, []

    images_mel = []
    scalar_feats = []
    quantum_feats = []
    labels_quality = []
    used_ids = []

    for (record_id, analysis_data) in rows:
        fname = analysis_data["file_name"]
        rating = parse_quality(fname)
        if rating is None:
            # skip if we can't parse a valid rating
            continue

        used_ids.append(record_id)

        # Load mel-spectrogram => shape (1, 128, 128)
        base_no_ext = fname.replace(".wav", "")
        mel_path = os.path.join("data", "analysisoutput", f"{base_no_ext}_mel.png")
        mel_img = load_image_as_tensor(mel_path)
        if mel_img is None:
            mel_img = torch.zeros((1, 128, 128))

        images_mel.append(mel_img)

        # Extract scalar features
        summary_dict = analysis_data.get("summary", {})
        pitch_dev = summary_dict.get("pitch_deviation", {})
        tnr_dict  = summary_dict.get("tone_to_noise_ratio", {})
        praat_dict= summary_dict.get("praat", {})
        dynamics  = summary_dict.get("dynamics", {})

        avg_dev = pitch_dev.get("mean", 0.0)
        std_dev = pitch_dev.get("std", 0.0)
        avg_tnr = tnr_dict.get("mean", 0.0)
        avg_hnr = praat_dict.get("hnr_mean", 0.0)

        rms_db_stats = dynamics.get("rms_db", {})
        mean_rms_db  = rms_db_stats.get("mean", 0.0)
        lufs_stats   = dynamics.get("lufs", {})
        mean_lufs    = lufs_stats.get("mean", 0.0)

        time_matrices = analysis_data.get("time_matrices", {})
        time_matrix_small = time_matrices.get("time_matrix_small", [])

        jitter_vals = [x.get("jitter", 0.0) or 0.0 for x in time_matrix_small]
        shimmer_vals= [x.get("shimmer", 0.0) or 0.0 for x in time_matrix_small]
        vib_vals    = [x.get("vibrato_rate", 0.0) or 0.0 for x in time_matrix_small]
        f1_vals     = [x.get("formants", {}).get("F1", 0.0) or 0.0 for x in time_matrix_small]

        avg_jitter  = float(np.mean(jitter_vals))  if jitter_vals else 0.0
        avg_shimmer = float(np.mean(shimmer_vals)) if shimmer_vals else 0.0
        avg_vibrato = float(np.mean(vib_vals))      if vib_vals else 0.0
        avg_formant = float(np.mean(f1_vals))      if f1_vals else 0.0

        scalars = [
            avg_dev, std_dev,
            avg_hnr, avg_tnr,
            mean_rms_db, mean_lufs,
            avg_jitter, avg_shimmer,
            avg_vibrato, avg_formant
        ]

        # Quantum features
        quantum_dict = analysis_data.get("quantum_analysis", {})
        angles = quantum_dict.get("scaled_angles", [])
        max_len = 10
        angle_arr = np.zeros(max_len, dtype=np.float32)
        for i in range(min(max_len, len(angles))):
            angle_arr[i] = angles[i]

        counts_d = quantum_dict.get("measurement_counts", {})
        dist_vec = convert_counts_to_probs_feature(counts_d, max_bits=10)
        combined_q = np.concatenate([angle_arr, dist_vec], axis=0)

        scalar_feats.append(scalars)
        quantum_feats.append(combined_q)

        # rating in [1..5] => label in [0..4]
        labels_quality.append(rating - 1)

    if not scalar_feats:
        print("No valid data after filtering for quality-only training.")
        return None, []

    images_mel    = torch.stack(images_mel)
    scalar_feats  = np.array(scalar_feats, dtype=np.float32)
    quantum_feats = np.array(quantum_feats, dtype=np.float32)
    labels_quality= np.array(labels_quality, dtype=np.int64)

    data_dict = {
        "images_mel": images_mel,
        "scalar_feats": scalar_feats,
        "quantum_feats": quantum_feats,
        "label_quality": labels_quality
    }

    return data_dict, used_ids


###############################################################################
# DATASET CLASSES
###############################################################################
class QualityOnlyASTDataset(Dataset):
    """
    Each sample has:
      1) mel_img
      2) scalar feats
      3) quantum feats
      4) quality label (0â€“4)
    """
    def __init__(self, data_dict):
        self.mel_imgs   = data_dict["images_mel"]
        self.scalars    = data_dict["scalar_feats"]
        self.quants     = data_dict["quantum_feats"]
        self.labels_qual= data_dict["label_quality"]

    def __len__(self):
        return len(self.labels_qual)

    def __getitem__(self, idx):
        mel = self.mel_imgs[idx]
        scal= torch.tensor(self.scalars[idx], dtype=torch.float32)
        quan= torch.tensor(self.quants[idx], dtype=torch.float32)
        lbl = torch.tensor(self.labels_qual[idx], dtype=torch.long)
        return (mel, scal, quan, lbl)


###############################################################################
# MODEL CLASSES
###############################################################################
class QualityOnlyASTModel(nn.Module):
    """
    A single-output model (5 classes) for "quality".
    We fuse AST(768) + scalar(64) + quantum(64) => 896 => single head (5 outputs).
    """
    def __init__(self, scalar_dim=10, quantum_dim=(10 + 2**10), num_quality=5, num_unfrozen_layers=0):
        super().__init__()
        self.config = ASTConfig.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
        self.ast_model = ASTModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593", config=self.config)

        # Freeze AST
        for param in self.ast_model.parameters():
            param.requires_grad = False

        # Unfreeze last N layers if requested
        if num_unfrozen_layers > 0:
            total_layers = 12
            start_layer = max(0, total_layers - num_unfrozen_layers)
            for layer_idx in range(start_layer, total_layers):
                for param in self.ast_model.encoder.layer[layer_idx].parameters():
                    param.requires_grad = True

        # Scalar MLP
        self.scalar_fc = nn.Sequential(
            nn.Linear(scalar_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
        )

        # Quantum MLP
        self.quantum_fc = nn.Sequential(
            nn.Linear(quantum_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
        )

        combined_dim = 768 + 64 + 64  # AST(768) + Scalar(64) + Quantum(64)

        self.quality_head = nn.Sequential(
            nn.Linear(combined_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, num_quality)
        )

    def forward(self, mel_img, scal_in, quan_in):
        B, C, H, W = mel_img.shape
        freq = HYPERPARAMS["AST_FREQ"]
        time = HYPERPARAMS["AST_TIME"]

        # Resize mel => (B,1,128,1024)
        mel_resized = F.interpolate(
            mel_img, size=(freq, time),
            mode='bilinear', align_corners=False
        )
        # (B,128,1024)
        mel_for_ast = mel_resized.squeeze(1)

        outputs = self.ast_model(input_values=mel_for_ast, output_hidden_states=True)
        hidden = outputs.last_hidden_state  # (B, seq_len, 768)
        ast_embedding = hidden[:, 0, :]     # [CLS] => (B,768)

        emb_scal = self.scalar_fc(scal_in)   # (B,64)
        emb_quan = self.quantum_fc(quan_in)  # (B,64)

        fused = torch.cat([ast_embedding, emb_scal, emb_quan], dim=1)  # (B,896)

        logits_quality = self.quality_head(fused)
        return logits_quality


class QualityOnlyASTModelNoQuantum(nn.Module):
    """
    A single-output model (5 classes) for "quality", but **no** quantum features.
    We fuse AST(768) + scalar(64) => 832 => single head (5 outputs).
    """
    def __init__(self, scalar_dim=10, num_quality=5, num_unfrozen_layers=0):
        super().__init__()
        self.config = ASTConfig.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
        self.ast_model = ASTModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593", config=self.config)

        # Freeze AST
        for param in self.ast_model.parameters():
            param.requires_grad = False

        # Unfreeze last N layers if requested
        if num_unfrozen_layers > 0:
            total_layers = 12
            start_layer = max(0, total_layers - num_unfrozen_layers)
            for layer_idx in range(start_layer, total_layers):
                for param in self.ast_model.encoder.layer[layer_idx].parameters():
                    param.requires_grad = True

        # Scalar MLP
        self.scalar_fc = nn.Sequential(
            nn.Linear(scalar_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
        )

        combined_dim = 768 + 64  # AST(768) + Scalar(64)

        self.quality_head = nn.Sequential(
            nn.Linear(combined_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, num_quality)
        )

    def forward(self, mel_img, scal_in):
        B, C, H, W = mel_img.shape
        freq = HYPERPARAMS["AST_FREQ"]
        time = HYPERPARAMS["AST_TIME"]

        # Resize mel => (B,1,128,1024)
        mel_resized = F.interpolate(
            mel_img, size=(freq, time),
            mode='bilinear', align_corners=False
        )
        mel_for_ast = mel_resized.squeeze(1)  # (B,128,1024)

        outputs = self.ast_model(input_values=mel_for_ast, output_hidden_states=True)
        hidden = outputs.last_hidden_state  # (B, seq_len, 768)
        ast_embedding = hidden[:, 0, :]     # [CLS] => (B,768)

        emb_scal = self.scalar_fc(scal_in)  # (B,64)

        fused = torch.cat([ast_embedding, emb_scal], dim=1)  # (B,832)

        logits_quality = self.quality_head(fused)
        return logits_quality


###############################################################################
# TRAINING FUNCTION (Quality Only, Quantum vs. NoQuantum)
###############################################################################
def train_quality_model(include_quantum=True):
    """
    If `include_quantum=True`, build QualityOnlyASTModel (with quantum).
    Otherwise, build QualityOnlyASTModelNoQuantum.
    """
    data_dict, used_ids = fetch_training_data_quality_only(limit=HYPERPARAMS["DB_LIMIT"])
    if data_dict is None:
        print("No data for training.")
        return None

    # Create dataset
    dataset_full = QualityOnlyASTDataset(data_dict)
    n_samples = len(dataset_full)
    test_size = int(0.2 * n_samples)
    train_size = n_samples - test_size
    train_ds, test_ds = torch.utils.data.random_split(dataset_full, [train_size, test_size])
    train_dl = DataLoader(train_ds, batch_size=HYPERPARAMS["BATCH_SIZE"], shuffle=True)
    test_dl  = DataLoader(test_ds, batch_size=HYPERPARAMS["BATCH_SIZE"], shuffle=False)

    scalar_dim  = data_dict["scalar_feats"].shape[1]
    quantum_dim = data_dict["quantum_feats"].shape[1]  # 10 + 2**10 = 10 + 1024 = 1034

    # Build model
    if include_quantum:
        model = QualityOnlyASTModel(
            scalar_dim=scalar_dim,
            quantum_dim=quantum_dim,
            num_quality=5,
            num_unfrozen_layers=HYPERPARAMS["NUM_AST_LAYERS_UNFROZEN"]
        ).to(HYPERPARAMS["DEVICE"])
    else:
        model = QualityOnlyASTModelNoQuantum(
            scalar_dim=scalar_dim,
            num_quality=5,
            num_unfrozen_layers=HYPERPARAMS["NUM_AST_LAYERS_UNFROZEN"]
        ).to(HYPERPARAMS["DEVICE"])

    optimizer = optim.Adam(
        model.parameters(),
        lr=HYPERPARAMS["LEARNING_RATE_HEAD"],
        weight_decay=HYPERPARAMS["WEIGHT_DECAY"]
    )
    crit_ce = nn.CrossEntropyLoss()

    EPOCHS = HYPERPARAMS["EPOCHS"]
    train_losses = []
    test_losses = []
    best_test_loss = float("inf")
    best_model_state = None
    patience = HYPERPARAMS["PATIENCE"]
    epochs_no_improve = 0
    stop_limit = HYPERPARAMS["STOP_LIMIT"]

    label_str = "Quantum" if include_quantum else "NoQuantum"

    for epoch in range(EPOCHS):
        # TRAIN
        model.train()
        total_train_loss = 0.0
        for mel_img, scal, quan, qual_lbl in train_dl:
            mel_img = mel_img.to(HYPERPARAMS["DEVICE"])
            scal    = scal.to(HYPERPARAMS["DEVICE"])
            quan    = quan.to(HYPERPARAMS["DEVICE"])
            qual_lbl= qual_lbl.to(HYPERPARAMS["DEVICE"])

            optimizer.zero_grad()

            if include_quantum:
                logits_quality = model(mel_img, scal, quan)
            else:
                logits_quality = model(mel_img, scal)

            loss = crit_ce(logits_quality, qual_lbl)
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()

        avg_train_loss = total_train_loss / len(train_dl)
        train_losses.append(avg_train_loss)

        # EVAL
        model.eval()
        total_test_loss = 0.0
        with torch.no_grad():
            for mel_img, scal, quan, qual_lbl in test_dl:
                mel_img = mel_img.to(HYPERPARAMS["DEVICE"])
                scal    = scal.to(HYPERPARAMS["DEVICE"])
                quan    = quan.to(HYPERPARAMS["DEVICE"])
                qual_lbl= qual_lbl.to(HYPERPARAMS["DEVICE"])

                if include_quantum:
                    logits_quality = model(mel_img, scal, quan)
                else:
                    logits_quality = model(mel_img, scal)

                loss = crit_ce(logits_quality, qual_lbl)
                total_test_loss += loss.item()

        avg_test_loss = total_test_loss / len(test_dl)
        test_losses.append(avg_test_loss)

        print(f"[{label_str}] Epoch {epoch+1}/{EPOCHS}: train_loss={avg_train_loss:.4f}, test_loss={avg_test_loss:.4f}")

        # Early Stopping
        if avg_test_loss < best_test_loss:
            best_test_loss = avg_test_loss
            best_model_state = model.state_dict()
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if (epoch + 1) >= stop_limit and epochs_no_improve >= patience:
                print(f"[{label_str}] No improvement for {patience} epochs after epoch {stop_limit}. Stopping.")
                break

    print(f"[{label_str}] Training complete or early stopped.")

    # Load best model
    if best_model_state is not None:
        model.load_state_dict(best_model_state)

    # Evaluate final => confusion matrix
    all_qual_preds, all_qual_truth = [], []

    model.eval()
    with torch.no_grad():
        for mel_img, scal, quan, qual_lbl in test_dl:
            mel_img = mel_img.to(HYPERPARAMS["DEVICE"])
            scal    = scal.to(HYPERPARAMS["DEVICE"])
            quan    = quan.to(HYPERPARAMS["DEVICE"])
            qual_lbl= qual_lbl.to(HYPERPARAMS["DEVICE"])

            if include_quantum:
                logits_quality = model(mel_img, scal, quan)
            else:
                logits_quality = model(mel_img, scal)

            pred_qual = logits_quality.argmax(dim=1).cpu().numpy()

            all_qual_preds.extend(pred_qual)
            all_qual_truth.extend(qual_lbl.cpu().numpy())

    cm_qual = confusion_matrix(all_qual_truth, all_qual_preds)
    used_qual_indices = sorted(set(all_qual_truth) | set(all_qual_preds))
    # Our labels go 0..4 internally => display as 1..5
    used_qual_labels  = [q+1 for q in used_qual_indices]
    disp_qual = ConfusionMatrixDisplay(cm_qual, display_labels=used_qual_labels)
    disp_qual.plot(cmap=plt.cm.Blues)
    plt.title(f"[{label_str}] Quality Confusion Matrix")
    plt.show()

    # Accuracy
    qual_acc = accuracy_score(all_qual_truth, all_qual_preds)
    print(f"[{label_str}] Final test accuracy: {qual_acc:.4f}")

    # Training curves
    plt.figure()
    plt.plot(train_losses, label="Train Loss")
    plt.plot(test_losses, label="Test Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title(f"[{label_str}] Training Curves")
    plt.legend()
    plt.show()

    # Save model + used_ids
    os.makedirs("data/modeloutput", exist_ok=True)
    if include_quantum:
        model_fname = "trained_quality_model_quantum.pt"
    else:
        model_fname = "trained_quality_model_noquantum.pt"

    checkpoint_path = os.path.join("data", "modeloutput", model_fname)

    checkpoint_dict = {
        "model_state": model.state_dict(),
        "scalar_dim": scalar_dim,
        "quantum_dim": quantum_dim,
        "used_ids": used_ids
    }
    torch.save(checkpoint_dict, checkpoint_path)
    print(f"[{label_str}] Model + metadata saved to: {checkpoint_path}")

    return model


###############################################################################
# COMPARISON FUNCTION: QUANTUM VS. NO-QUANTUM
###############################################################################
def compare_quality_models():
    """
    1) Fetch the same dataset (quality-only).
    2) Split 20% test.
    3) Load quantum model + no-quantum model from disk.
    4) Evaluate both on the same test set -> confusion matrix, accuracy.
    5) Print comparison.
    """
    data_dict, used_ids = fetch_training_data_quality_only(limit=HYPERPARAMS["DB_LIMIT"])
    if data_dict is None:
        print("[Compare] No data available for comparison.")
        return

    dataset_full = QualityOnlyASTDataset(data_dict)
    n_samples = len(dataset_full)
    test_size = int(0.2 * n_samples)
    train_size = n_samples - test_size
    # We only want to evaluate on the test set. We do not re-train here.
    _, test_ds = torch.utils.data.random_split(dataset_full, [train_size, test_size])
    test_dl = DataLoader(test_ds, batch_size=HYPERPARAMS["BATCH_SIZE"], shuffle=False)

    scalar_dim = data_dict["scalar_feats"].shape[1]
    quantum_dim= data_dict["quantum_feats"].shape[1]

    # Load quantum model
    q_ckpt_path = "data/modeloutput/trained_quality_model_quantum.pt"
    if not os.path.exists(q_ckpt_path):
        print(f"[Compare] Quantum model not found at {q_ckpt_path}")
        return
    q_ckpt = torch.load(q_ckpt_path, map_location=HYPERPARAMS["DEVICE"])

    model_q = QualityOnlyASTModel(
        scalar_dim=scalar_dim,
        quantum_dim=quantum_dim,
        num_quality=5,
        num_unfrozen_layers=HYPERPARAMS["NUM_AST_LAYERS_UNFROZEN"]
    )
    model_q.load_state_dict(q_ckpt["model_state"])
    model_q.to(HYPERPARAMS["DEVICE"])
    model_q.eval()

    # Load no-quantum model
    nq_ckpt_path = "data/modeloutput/trained_quality_model_noquantum.pt"
    if not os.path.exists(nq_ckpt_path):
        print(f"[Compare] No-Quantum model not found at {nq_ckpt_path}")
        return
    nq_ckpt = torch.load(nq_ckpt_path, map_location=HYPERPARAMS["DEVICE"])

    model_nq = QualityOnlyASTModelNoQuantum(
        scalar_dim=scalar_dim,
        num_quality=5,
        num_unfrozen_layers=HYPERPARAMS["NUM_AST_LAYERS_UNFROZEN"]
    )
    model_nq.load_state_dict(nq_ckpt["model_state"])
    model_nq.to(HYPERPARAMS["DEVICE"])
    model_nq.eval()

    all_qual_truth = []
    preds_qual_q, preds_qual_nq = [], []

    with torch.no_grad():
        for mel_img, scal, quan, qual_lbl in test_dl:
            mel_img = mel_img.to(HYPERPARAMS["DEVICE"])
            scal    = scal.to(HYPERPARAMS["DEVICE"])
            quan    = quan.to(HYPERPARAMS["DEVICE"])
            qual_lbl= qual_lbl.to(HYPERPARAMS["DEVICE"])

            # Quantum model
            logits_q = model_q(mel_img, scal, quan)
            pred_q   = logits_q.argmax(dim=1).cpu().numpy()

            # No-Quantum model
            logits_nq = model_nq(mel_img, scal)
            pred_nq   = logits_nq.argmax(dim=1).cpu().numpy()

            all_qual_truth.extend(qual_lbl.cpu().numpy())
            preds_qual_q.extend(pred_q)
            preds_qual_nq.extend(pred_nq)

    # Accuracy
    acc_q  = accuracy_score(all_qual_truth, preds_qual_q)
    acc_nq = accuracy_score(all_qual_truth, preds_qual_nq)

    print("\n=== Quality-Only: Quantum vs. No-Quantum ===")
    print(f"Quality Accuracy (Quantum):    {acc_q:.4f}")
    print(f"Quality Accuracy (No-Quantum): {acc_nq:.4f}")
    diff = (acc_q - acc_nq)*100
    print(f"Difference in Accuracy = {diff:.2f} % points")


def run_quality_inference(record_id, model_path="data/modeloutput/trained_quality_model_quantum.pt", include_quantum=True):
    """
    Single-record inference using persisted Quality-only model.
    Set include_quantum=True for quantum-enabled models.
    """
    # Load checkpoint
    checkpoint = torch.load(model_path, map_location=HYPERPARAMS["DEVICE"])
    model_state = checkpoint["model_state"]
    scalar_dim  = checkpoint["scalar_dim"]
    quantum_dim = checkpoint.get("quantum_dim", 0)

    # Build the appropriate model
    if include_quantum:
        model = QualityOnlyASTModel(
            scalar_dim=scalar_dim,
            quantum_dim=quantum_dim,
            num_quality=5,
            num_unfrozen_layers=HYPERPARAMS["NUM_AST_LAYERS_UNFROZEN"]
        )
    else:
        model = QualityOnlyASTModelNoQuantum(
            scalar_dim=scalar_dim,
            num_quality=5,
            num_unfrozen_layers=HYPERPARAMS["NUM_AST_LAYERS_UNFROZEN"]
        )

    model.load_state_dict(model_state)
    model.to(HYPERPARAMS["DEVICE"])
    model.eval()

    # Fetch single record from DB
    db = QuantumMusicDBFetchOnly()
    row = db.fetch_single_record(record_id)
    db.close()

    if not row:
        print(f"No record found with ID={record_id}")
        return None

    _, analysis_data = row
    fname = analysis_data["file_name"]
    true_quality = parse_quality(fname)

    # Load mel-spectrogram
    base_no_ext = fname.replace(".wav", "")
    mel_path = os.path.join("data", "analysisoutput", f"{base_no_ext}_mel.png")
    mel_img = load_image_as_tensor(mel_path)
    if mel_img is None:
        mel_img = torch.zeros((1, 128, 128))

    # Extract scalar features
    summary_dict = analysis_data.get("summary", {})
    pitch_dev = summary_dict.get("pitch_deviation", {})
    tnr_dict  = summary_dict.get("tone_to_noise_ratio", {})
    praat_dict= summary_dict.get("praat", {})
    dynamics  = summary_dict.get("dynamics", {})

    avg_dev = pitch_dev.get("mean", 0.0)
    std_dev = pitch_dev.get("std", 0.0)
    avg_tnr = tnr_dict.get("mean", 0.0)
    avg_hnr = praat_dict.get("hnr_mean", 0.0)

    rms_db_stats = dynamics.get("rms_db", {})
    mean_rms_db  = rms_db_stats.get("mean", 0.0)
    lufs_stats   = dynamics.get("lufs", {})
    mean_lufs    = lufs_stats.get("mean", 0.0)

    time_matrices = analysis_data.get("time_matrices", {})
    time_matrix_small = time_matrices.get("time_matrix_small", [])

    jitter_vals = [x.get("jitter", 0.0) or 0.0 for x in time_matrix_small]
    shimmer_vals= [x.get("shimmer", 0.0) or 0.0 for x in time_matrix_small]
    vib_vals    = [x.get("vibrato_rate", 0.0) or 0.0 for x in time_matrix_small]
    f1_vals     = [x.get("formants", {}).get("F1", 0.0) or 0.0 for x in time_matrix_small]

    avg_jitter  = float(np.mean(jitter_vals))  if jitter_vals else 0.0
    avg_shimmer = float(np.mean(shimmer_vals)) if shimmer_vals else 0.0
    avg_vibrato = float(np.mean(vib_vals))     if vib_vals else 0.0
    avg_formant = float(np.mean(f1_vals))      if f1_vals else 0.0

    scalar_features = [
        avg_dev, std_dev, avg_hnr, avg_tnr,
        mean_rms_db, mean_lufs,
        avg_jitter, avg_shimmer,
        avg_vibrato, avg_formant
    ]

    # Quantum features (if applicable)
    combined_quantum = None
    if include_quantum:
        quantum_dict = analysis_data.get("quantum_analysis", {})
        angles = quantum_dict.get("scaled_angles", [])
        max_len = 10
        angle_arr = np.zeros(max_len, dtype=np.float32)
        for i in range(min(max_len, len(angles))):
            angle_arr[i] = angles[i]

        counts_d = quantum_dict.get("measurement_counts", {})
        dist_vec = convert_counts_to_probs_feature(counts_d, max_bits=10)
        combined_quantum = np.concatenate([angle_arr, dist_vec], axis=0)

    # Prepare inputs
    device = HYPERPARAMS["DEVICE"]
    mel_img_tensor = mel_img.unsqueeze(0).to(device)  # [1,1,128,128]
    scalar_tensor  = torch.tensor(scalar_features, dtype=torch.float32).unsqueeze(0).to(device)

    with torch.no_grad():
        if include_quantum:
            quantum_tensor = torch.tensor(combined_quantum, dtype=torch.float32).unsqueeze(0).to(device)
            logits_quality = model(mel_img_tensor, scalar_tensor, quantum_tensor)
        else:
            logits_quality = model(mel_img_tensor, scalar_tensor)

        predicted_quality_idx = logits_quality.argmax(dim=1).item()
        predicted_quality = predicted_quality_idx + 1  # label (0..4) => rating (1..5)

    print(f"Inference for Record ID={record_id}:")
    print(f"  True Quality Rating: {true_quality}")
    print(f"  Predicted Quality Rating: {predicted_quality}")

    return predicted_quality




###############################################################################
# EXAMPLE MAIN
###############################################################################
if __name__ == "__main__":
    # 1) Train quantum
    #model_q = train_quality_model(include_quantum=True)

    # 2) Train no-quantum
    #model_nq = train_quality_model(include_quantum=False)

    # 3) Compare
    #compare_quality_models()


    pred_quality = run_quality_inference(
    record_id=2661, 
    model_path="data/modeloutput/trained_quality_model_quantum_final.pt",
    include_quantum=True
    )
    print(f"Predicted Quality Rating: {pred_quality}")


Using MPS device for GPU acceleration on Apple Silicon.


  checkpoint = torch.load(model_path, map_location=HYPERPARAMS["DEVICE"])


Connected to database quantummusic_csef. (fetch-only)
Database connection closed.
Inference for Record ID=2661:
  True Quality Rating: 2
  Predicted Quality Rating: 2
Predicted Quality Rating: 2
