## Machine Learning Models

In [None]:
import os
import re
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, accuracy_score
import matplotlib.pyplot as plt

from transformers import ASTModel, ASTConfig
import torchvision.transforms as T
from PIL import Image

###############################################################################
# MPS DEVICE CHECK
###############################################################################
if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
    print("Using MPS device for GPU acceleration on Apple Silicon.")
else:
    DEVICE = torch.device("cpu")
    print("MPS not available. Falling back to CPU.")


###############################################################################
# HYPERPARAMETERS / GLOBAL SETTINGS
###############################################################################
HYPERPARAMS = {
    "OUTPUT_DIR": "data/trainingdataoutput",
    "DB_LIMIT": 100000,               # We'll read up to this many records (if needed)
    "BATCH_SIZE": 8,
    "EPOCHS": 10,                   # up to 10 or 20 epochs
    "LEARNING_RATE_MAIN": 5e-5,     # for AST layers (if we unfreeze them)
    "LEARNING_RATE_HEAD": 1e-4,     # for scalar/quantum/final layers
    "WEIGHT_DECAY": 1e-4,           # L2 regularization
    "DEVICE": DEVICE,
    "NUM_AST_LAYERS_UNFROZEN": 4,   # partial unfreezing
    "PATIENCE": 5,                  # early stopping patience
    "STOP_LIMIT": 5,                # must train at least this many epochs

    # AST model input shape for audio: (freq=128, time=1024) by default
    "AST_FREQ": 128,
    "AST_TIME": 1024
}


###############################################################################
# DATABASE FETCH LOGIC (FETCH ONLY)
###############################################################################
class QuantumMusicDBFetchOnly:
    """
    A minimal database utility class which fetches data from a Postgres table
    named 'audio_analysis'. This specialized version only supports reading.
    """
    def __init__(self, db_name="quantummusic_csef", host="localhost", user="postgres", password="postgres"):
        import psycopg2
        self.psycopg2 = psycopg2
        self.db_name = db_name
        self.host = host
        self.user = user
        self.password = password
        self.conn = None
        self.connect()

    def connect(self):
        try:
            import psycopg2
            self.conn = psycopg2.connect(
                dbname=self.db_name,
                host=self.host,
                user=self.user,
                password=self.password
            )
            print(f"Connected to database {self.db_name}. (fetch-only)")
        except Exception as e:
            print(f"Error connecting to database: {e}")

    def close(self):
        if self.conn:
            self.conn.close()
            print("Database connection closed.")

    def fetch_all_thaat_records(self):
        """
        Fetches all rows from 'audio_analysis' where file_name contains 'thaat'.
        Returns list of (id, analysis_data).
        """
        with self.conn.cursor() as cur:
            query = """
                SELECT id, analysis_data
                FROM audio_analysis
                WHERE file_name ILIKE '%%thaat%%' limit 1000
            """
            cur.execute(query)
            rows = cur.fetchall()
        return rows


###############################################################################
# HELPER FUNCTIONS
###############################################################################
def convert_counts_to_probs_feature(counts_dict, max_bits=10):
    """
    Convert a dictionary of quantum measurement counts into a probability
    distribution vector of length 2^max_bits.
    """
    total_counts = sum(counts_dict.values())
    if total_counts == 0:
        return np.zeros(2**max_bits, dtype=np.float32)

    feature_vec = np.zeros(2**max_bits, dtype=np.float32)
    for bitstring, c in counts_dict.items():
        if len(bitstring) > max_bits:
            truncated = bitstring[-max_bits:]
        else:
            truncated = bitstring.rjust(max_bits, '0')
        idx = int(truncated, 2)
        feature_vec[idx] += c / total_counts
    return feature_vec


def load_image_as_tensor(image_path: str):
    """
    Loads an image as a 1x128x128 tensor (grayscale). Returns None if not found.
    """
    transform_img = T.Compose([
        T.Resize((128, 128)),
        T.ToTensor()
    ])
    if not os.path.exists(image_path):
        return None
    with Image.open(image_path) as img:
        img = img.convert("L")
        return transform_img(img)  # => shape [1,128,128]


def parse_thaat_and_raga(file_name: str):
    """
    Parse from a file_name that *contains* 'thaat'.
    Example: 'Kalyan_thaat_Bhoopali_xyz.wav'

    - 'thaat' = everything before '_thaat'
      => 'Kalyan'
    - 'raga'  = the substring after 'thaat_' up to the next underscore
      => 'Bhoopali'

    Returns (thaat, raga). If something fails, returns (None, None).
    """
    base = file_name.replace(".wav", "")
    # We expect something like: 'Kalyan_thaat_Bhoopali_morestuff'
    if "thaat_" not in base:
        return None, None

    parts = base.split("thaat_")
    if len(parts) != 2:
        return None, None

    # left_part => 'Kalyan'
    thaat_str = parts[0].rstrip("_")
    # right_part => 'Bhoopali_morestuff'
    right_part = parts[1]

    # The raga is up to the next underscore in right_part
    # e.g. 'Bhoopali_something' => 'Bhoopali'
    raga_parts = right_part.split("_", 1)
    raga_str = raga_parts[0]

    #print(f"Parsed thaat: {thaat_str}, raga: {raga_str} from file_name: {file_name}")

    return thaat_str, raga_str


###############################################################################
# MAIN DATA FETCH + PARSE
###############################################################################
def fetch_training_data_thaat_raga(limit=None):
    """
    1) Fetch all DB records whose file_name contains 'thaat'.
    2) Parse (thaat, raga) from file_name.
    3) Use ALL such records, with no exclusions.
    4) Build our data arrays:
        - mel spectrogram image
        - scalar features
        - quantum features
        - integer-encoded thaat, raga
    5) Return (data_dict, used_ids).

    The structure of `analysis_data` is expected to have:
        {
          "file_name": "...",
          "summary": { ... },
          "time_matrices": { ... },
          "quantum_analysis": { ... }
        }
    """
    db = QuantumMusicDBFetchOnly()
    rows = db.fetch_all_thaat_records()
    db.close()

    if not rows:
        print("No 'thaat' data found in DB.")
        return None, []

    included_records = []
    for (record_id, analysis_data) in rows:
        fname = analysis_data["file_name"]
        thaat, raga = parse_thaat_and_raga(fname)
        if thaat is not None and raga is not None:
            included_records.append((record_id, analysis_data))

    # If needed, we can limit how many total records we keep
    if limit is not None and len(included_records) > limit:
        included_records = included_records[:limit]

    if not included_records:
        print("No data left after filtering.")
        return None, []

    # Prepare arrays
    images_mel = []
    scalar_feats = []
    quantum_feats = []
    thaat_labels = []
    raga_labels = []
    used_ids = []

    # We'll collect distinct thaats, ragas
    all_thaats = set()
    all_ragas = set()

    print(f"Total included records: {len(included_records)}")

    for (record_id, analysis_data) in included_records:
        fname = analysis_data["file_name"]
        thaat, raga = parse_thaat_and_raga(fname)
        if thaat is None or raga is None:
            continue

        used_ids.append(record_id)
        all_thaats.add(thaat)
        all_ragas.add(raga)

        # Load mel-spectrogram
        base_no_ext = fname.replace(".wav", "")
        mel_path = os.path.join("data", "analysisoutput", f"{base_no_ext}_mel.png")
        mel_img = load_image_as_tensor(mel_path)
        if mel_img is None:
            # fallback
            mel_img = torch.zeros((1, 128, 128))

        images_mel.append(mel_img)

        # Extract scalar features
        summary_dict = analysis_data.get("summary", {})
        pitch_dev = summary_dict.get("pitch_deviation", {})
        tnr_dict  = summary_dict.get("tone_to_noise_ratio", {})
        praat_dict= summary_dict.get("praat", {})
        dynamics  = summary_dict.get("dynamics", {})

        avg_dev = pitch_dev.get("mean", 0.0)
        std_dev = pitch_dev.get("std", 0.0)
        avg_tnr = tnr_dict.get("mean", 0.0)
        avg_hnr = praat_dict.get("hnr_mean", 0.0)

        rms_db_stats = dynamics.get("rms_db", {})
        mean_rms_db  = rms_db_stats.get("mean", 0.0)
        lufs_stats   = dynamics.get("lufs", {})
        mean_lufs    = lufs_stats.get("mean", 0.0)

        time_matrices = analysis_data.get("time_matrices", {})
        time_matrix_small = time_matrices.get("time_matrix_small", [])

        jitter_vals = [x.get("jitter", 0.0) or 0.0 for x in time_matrix_small]
        shimmer_vals= [x.get("shimmer", 0.0) or 0.0 for x in time_matrix_small]
        vib_vals    = [x.get("vibrato_rate", 0.0) or 0.0 for x in time_matrix_small]
        f1_vals     = [x.get("formants", {}).get("F1", 0.0) or 0.0 for x in time_matrix_small]

        avg_jitter  = float(np.mean(jitter_vals))  if jitter_vals else 0.0
        avg_shimmer = float(np.mean(shimmer_vals)) if shimmer_vals else 0.0
        avg_vibrato = float(np.mean(vib_vals))      if vib_vals else 0.0
        avg_formant = float(np.mean(f1_vals))      if f1_vals else 0.0

        scalars = [
            avg_dev, std_dev,
            avg_hnr, avg_tnr,
            mean_rms_db, mean_lufs,
            avg_jitter, avg_shimmer,
            avg_vibrato, avg_formant
        ]
        scalar_feats.append(scalars)

        # Quantum features
        quantum_dict = analysis_data.get("quantum_analysis", {})
        angles = quantum_dict.get("scaled_angles", [])
        max_len = 10
        angle_arr = np.zeros(max_len, dtype=np.float32)
        for i in range(min(max_len, len(angles))):
            angle_arr[i] = angles[i]

        counts_d = quantum_dict.get("measurement_counts", {})
        dist_vec = convert_counts_to_probs_feature(counts_d, max_bits=10)
        combined_q = np.concatenate([angle_arr, dist_vec], axis=0)
        quantum_feats.append(combined_q)

        # We'll stash the raw thaat/raga strings for mapping
        thaat_labels.append(thaat)
        raga_labels.append(raga)

    # Build label maps
    unique_thaats = sorted(list(all_thaats))
    unique_ragas  = sorted(list(all_ragas))
    thaat_to_idx  = {t: i for i, t in enumerate(unique_thaats)}
    raga_to_idx   = {r: i for i, r in enumerate(unique_ragas)}

    # Encode
    label_thaat_idx = [thaat_to_idx[t] for t in thaat_labels]
    label_raga_idx  = [raga_to_idx[r]  for r in raga_labels]

    # Final data structures
    images_mel    = torch.stack(images_mel)  # shape => (N, 1, 128, 128)
    scalar_feats  = np.array(scalar_feats, dtype=np.float32)
    quantum_feats = np.array(quantum_feats, dtype=np.float32)

    data_dict = {
        "images_mel": images_mel,
        "scalar_feats": scalar_feats,
        "quantum_feats": quantum_feats,
        "label_thaat_idx": np.array(label_thaat_idx, dtype=np.int64),
        "label_raga_idx": np.array(label_raga_idx, dtype=np.int64),
        "thaat_to_idx": thaat_to_idx,
        "raga_to_idx": raga_to_idx,
        "unique_thaats": unique_thaats,
        "unique_ragas": unique_ragas
    }

    return data_dict, used_ids


###############################################################################
# DATASET & MODELS
###############################################################################
class MultiLabelThaatRagaDataset(Dataset):
    """
    Each sample has:
       1) mel_imgs
       2) scalar features
       3) quantum features
       4) thaat label
       5) raga label
    """
    def __init__(self, data_dict):
        self.mel_imgs   = data_dict["images_mel"]
        self.scalars    = data_dict["scalar_feats"]
        self.quants     = data_dict["quantum_feats"]
        self.thaat_lbl  = data_dict["label_thaat_idx"]
        self.raga_lbl   = data_dict["label_raga_idx"]

    def __len__(self):
        return len(self.thaat_lbl)

    def __getitem__(self, idx):
        scal  = torch.tensor(self.scalars[idx], dtype=torch.float32)
        quan  = torch.tensor(self.quants[idx], dtype=torch.float32)
        t_lbl = torch.tensor(self.thaat_lbl[idx], dtype=torch.long)
        r_lbl = torch.tensor(self.raga_lbl[idx], dtype=torch.long)
        return (
            self.mel_imgs[idx],
            scal,
            quan,
            t_lbl,
            r_lbl
        )


class HybridASTModelThaatRaga(nn.Module):
    """
    Similar to a "hybrid" AST model, but predicting:
      1) Thaat
      2) Raga
    (no quality label).

    The AST portion is partially unfrozen. We then fuse scalar + quantum MLPs
    with the AST embedding. Finally, we have two heads: (thaat_head, raga_head).
    """
    def __init__(self, num_thaats, num_ragas, scalar_dim=10, quantum_dim=(10 + 2**10), num_unfrozen_layers=0):
        super().__init__()
        # Load AST model
        self.config = ASTConfig.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
        self.ast_model = ASTModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593", config=self.config)

        # Freeze the AST
        for param in self.ast_model.parameters():
            param.requires_grad = False

        # Unfreeze the last N encoder layers
        if num_unfrozen_layers > 0:
            total_layers = 12
            start_layer = max(0, total_layers - num_unfrozen_layers)
            for layer_idx in range(start_layer, total_layers):
                for param in self.ast_model.encoder.layer[layer_idx].parameters():
                    param.requires_grad = True

        # MLP for scalar features
        self.scalar_fc = nn.Sequential(
            nn.Linear(scalar_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
        )

        # MLP for quantum features
        self.quantum_fc = nn.Sequential(
            nn.Linear(quantum_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
        )

        # The AST base output is 768 dims
        combined_dim = 768 + 64 + 64

        # HEADS
        self.thaat_head = nn.Sequential(
            nn.Linear(combined_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, num_thaats)
        )
        self.raga_head = nn.Sequential(
            nn.Linear(combined_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, num_ragas)
        )

    def forward(self, mel_img, scal_in, quan_in):
        # mel_img => (B,1,H,W)
        B, C, H, W = mel_img.shape
        freq = HYPERPARAMS["AST_FREQ"]
        time = HYPERPARAMS["AST_TIME"]

        # Resize to (B,1,128,1024)
        mel_resized = F.interpolate(
            mel_img, size=(freq, time),
            mode='bilinear', align_corners=False
        )
        # AST wants shape (B, freq, time) => remove channel dim
        mel_for_ast = mel_resized.squeeze(1)

        outputs = self.ast_model(input_values=mel_for_ast, output_hidden_states=True)
        hidden = outputs.last_hidden_state  # (B, seq_len, 768)
        ast_embedding = hidden[:, 0, :]     # [CLS] => (B,768)

        emb_scal = self.scalar_fc(scal_in)    # (B,64)
        emb_quan = self.quantum_fc(quan_in)   # (B,64)

        fused = torch.cat([ast_embedding, emb_scal, emb_quan], dim=1)  # (B, 768+64+64=896)

        logits_thaat = self.thaat_head(fused)  # (B, num_thaats)
        logits_raga  = self.raga_head(fused)   # (B, num_ragas)
        return logits_thaat, logits_raga


class HybridASTModelThaatRagaNoQuantum(nn.Module):
    """
    A variant of HybridASTModelThaatRaga with NO quantum features fused in.
    We only combine AST + scalar features => 768 + 64 => 832 dims,
    then have two classification heads (thaat, raga).
    """
    def __init__(self, num_thaats, num_ragas, scalar_dim=10, num_unfrozen_layers=0):
        super().__init__()
        self.config = ASTConfig.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
        self.ast_model = ASTModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593", config=self.config)

        # Freeze the AST
        for param in self.ast_model.parameters():
            param.requires_grad = False

        # Unfreeze the last N encoder layers if requested
        if num_unfrozen_layers > 0:
            total_layers = 12
            start_layer = max(0, total_layers - num_unfrozen_layers)
            for layer_idx in range(start_layer, total_layers):
                for param in self.ast_model.encoder.layer[layer_idx].parameters():
                    param.requires_grad = True

        # Only scalar MLP
        self.scalar_fc = nn.Sequential(
            nn.Linear(scalar_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
        )

        combined_dim = 768 + 64

        self.thaat_head = nn.Sequential(
            nn.Linear(combined_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, num_thaats)
        )
        self.raga_head = nn.Sequential(
            nn.Linear(combined_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, num_ragas)
        )

    def forward(self, mel_img, scalar_input):
        B, C, H, W = mel_img.shape
        freq = HYPERPARAMS["AST_FREQ"]
        time = HYPERPARAMS["AST_TIME"]

        mel_resized = F.interpolate(
            mel_img, size=(freq, time),
            mode='bilinear', align_corners=False
        )
        mel_for_ast = mel_resized.squeeze(1)

        outputs = self.ast_model(input_values=mel_for_ast, output_hidden_states=True)
        hidden = outputs.last_hidden_state  # (B, seq_len, 768)
        ast_embedding = hidden[:, 0, :]     # (B,768)

        emb_scal = self.scalar_fc(scalar_input)  # (B,64)
        fused = torch.cat([ast_embedding, emb_scal], dim=1)  # (B,832)

        logits_thaat = self.thaat_head(fused)
        logits_raga  = self.raga_head(fused)
        return logits_thaat, logits_raga


###############################################################################
# TRAINING FUNCTIONS
###############################################################################
def train_thaat_raga_model(include_quantum=True):
    """
    Main training function for the new Thaat+Raga classification.

    If `include_quantum=True`, we use HybridASTModelThaatRaga,
    otherwise we use HybridASTModelThaatRagaNoQuantum.

    Returns (model, data_dict).
    """
    # Fetch data
    data_dict, used_ids = fetch_training_data_thaat_raga(limit=HYPERPARAMS["DB_LIMIT"])
    if data_dict is None:
        print("No data to train.")
        return None, None

    # Create dataset
    dataset_full = MultiLabelThaatRagaDataset(data_dict)
    n_samples = len(dataset_full)
    test_size = int(0.2 * n_samples)
    train_size = n_samples - test_size
    train_ds, test_ds = torch.utils.data.random_split(dataset_full, [train_size, test_size])
    train_dl = DataLoader(train_ds, batch_size=HYPERPARAMS["BATCH_SIZE"], shuffle=True)
    test_dl  = DataLoader(test_ds, batch_size=HYPERPARAMS["BATCH_SIZE"], shuffle=False)

    scalar_feats = data_dict["scalar_feats"]
    quantum_feats= data_dict["quantum_feats"]
    num_thaats   = len(data_dict["thaat_to_idx"])
    num_ragas    = len(data_dict["raga_to_idx"])

    # Build model
    if include_quantum:
        model = HybridASTModelThaatRaga(
            num_thaats=num_thaats,
            num_ragas=num_ragas,
            scalar_dim=scalar_feats.shape[1],
            quantum_dim=quantum_feats.shape[1],
            num_unfrozen_layers=HYPERPARAMS["NUM_AST_LAYERS_UNFROZEN"]
        ).to(HYPERPARAMS["DEVICE"])
    else:
        model = HybridASTModelThaatRagaNoQuantum(
            num_thaats=num_thaats,
            num_ragas=num_ragas,
            scalar_dim=scalar_feats.shape[1],
            num_unfrozen_layers=HYPERPARAMS["NUM_AST_LAYERS_UNFROZEN"]
        ).to(HYPERPARAMS["DEVICE"])

    optimizer = optim.Adam(
        model.parameters(),
        lr=HYPERPARAMS["LEARNING_RATE_HEAD"],
        weight_decay=HYPERPARAMS["WEIGHT_DECAY"]
    )
    crit_ce = nn.CrossEntropyLoss()

    EPOCHS = HYPERPARAMS["EPOCHS"]
    train_losses = []
    test_losses = []
    best_test_loss = float("inf")
    best_model_state = None
    patience = HYPERPARAMS["PATIENCE"]
    epochs_no_improve = 0
    stop_limit = HYPERPARAMS["STOP_LIMIT"]

    for epoch in range(EPOCHS):
        # TRAIN
        model.train()
        total_train_loss = 0.0

        for batch in train_dl:
            if include_quantum:
                mel_img, scal, quan, t_lbl, r_lbl = batch
                mel_img = mel_img.to(HYPERPARAMS["DEVICE"])
                scal    = scal.to(HYPERPARAMS["DEVICE"])
                quan    = quan.to(HYPERPARAMS["DEVICE"])
                t_lbl   = t_lbl.to(HYPERPARAMS["DEVICE"])
                r_lbl   = r_lbl.to(HYPERPARAMS["DEVICE"])

                optimizer.zero_grad()
                logits_thaat, logits_raga = model(mel_img, scal, quan)
            else:
                mel_img, scal, quan, t_lbl, r_lbl = batch
                mel_img = mel_img.to(HYPERPARAMS["DEVICE"])
                scal    = scal.to(HYPERPARAMS["DEVICE"])
                t_lbl   = t_lbl.to(HYPERPARAMS["DEVICE"])
                r_lbl   = r_lbl.to(HYPERPARAMS["DEVICE"])

                optimizer.zero_grad()
                logits_thaat, logits_raga = model(mel_img, scal)

            loss_thaat = crit_ce(logits_thaat, t_lbl)
            loss_raga  = crit_ce(logits_raga, r_lbl)
            loss = loss_thaat + loss_raga
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()

        avg_train_loss = total_train_loss / len(train_dl)
        train_losses.append(avg_train_loss)

        # EVAL
        model.eval()
        total_test_loss = 0.0
        with torch.no_grad():
            for batch in test_dl:
                if include_quantum:
                    mel_img, scal, quan, t_lbl, r_lbl = batch
                    mel_img = mel_img.to(HYPERPARAMS["DEVICE"])
                    scal    = scal.to(HYPERPARAMS["DEVICE"])
                    quan    = quan.to(HYPERPARAMS["DEVICE"])
                    t_lbl   = t_lbl.to(HYPERPARAMS["DEVICE"])
                    r_lbl   = r_lbl.to(HYPERPARAMS["DEVICE"])

                    logits_thaat, logits_raga = model(mel_img, scal, quan)
                else:
                    mel_img, scal, quan, t_lbl, r_lbl = batch
                    mel_img = mel_img.to(HYPERPARAMS["DEVICE"])
                    scal    = scal.to(HYPERPARAMS["DEVICE"])
                    t_lbl   = t_lbl.to(HYPERPARAMS["DEVICE"])
                    r_lbl   = r_lbl.to(HYPERPARAMS["DEVICE"])

                    logits_thaat, logits_raga = model(mel_img, scal)

                loss_thaat = crit_ce(logits_thaat, t_lbl)
                loss_raga  = crit_ce(logits_raga, r_lbl)
                total_test_loss += (loss_thaat + loss_raga).item()

        avg_test_loss = total_test_loss / len(test_dl)
        test_losses.append(avg_test_loss)

        prefix = "[Quantum]" if include_quantum else "[NoQuantum]"
        print(f"{prefix} Epoch {epoch+1}/{EPOCHS}: train_loss={avg_train_loss:.4f}, test_loss={avg_test_loss:.4f}")

        # Early Stopping
        if avg_test_loss < best_test_loss:
            best_test_loss = avg_test_loss
            best_model_state = model.state_dict()
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if (epoch + 1) >= stop_limit and epochs_no_improve >= patience:
                print(f"{prefix} No improvement for {patience} epochs after epoch {stop_limit}. Stopping.")
                break

    print(f"{prefix} Training complete or early stopped.")

    # Load best model
    if best_model_state is not None:
        model.load_state_dict(best_model_state)

    # Final confusion matrices + accuracies
    all_thaat_preds, all_thaat_truth = [], []
    all_raga_preds, all_raga_truth   = [], []

    model.eval()
    with torch.no_grad():
        for batch in test_dl:
            if include_quantum:
                mel_img, scal, quan, t_lbl, r_lbl = batch
                mel_img = mel_img.to(HYPERPARAMS["DEVICE"])
                scal    = scal.to(HYPERPARAMS["DEVICE"])
                quan    = quan.to(HYPERPARAMS["DEVICE"])
                t_lbl   = t_lbl.to(HYPERPARAMS["DEVICE"])
                r_lbl   = r_lbl.to(HYPERPARAMS["DEVICE"])

                logits_thaat, logits_raga = model(mel_img, scal, quan)
            else:
                mel_img, scal, quan, t_lbl, r_lbl = batch
                mel_img = mel_img.to(HYPERPARAMS["DEVICE"])
                scal    = scal.to(HYPERPARAMS["DEVICE"])
                t_lbl   = t_lbl.to(HYPERPARAMS["DEVICE"])
                r_lbl   = r_lbl.to(HYPERPARAMS["DEVICE"])

                logits_thaat, logits_raga = model(mel_img, scal)

            pred_thaat = logits_thaat.argmax(dim=1).cpu().numpy()
            pred_raga  = logits_raga.argmax(dim=1).cpu().numpy()

            all_thaat_preds.extend(pred_thaat)
            all_thaat_truth.extend(t_lbl.cpu().numpy())
            all_raga_preds.extend(pred_raga)
            all_raga_truth.extend(r_lbl.cpu().numpy())

    # Compute confusion matrix for thaat
    cm_thaat = confusion_matrix(all_thaat_truth, all_thaat_preds)
    used_thaat_indices = sorted(set(all_thaat_truth) | set(all_thaat_preds))
    idx_to_thaat = {v: k for k, v in data_dict["thaat_to_idx"].items()}
    used_thaat_labels = [idx_to_thaat[i] for i in used_thaat_indices]

    disp_thaat = ConfusionMatrixDisplay(cm_thaat, display_labels=used_thaat_labels)
    disp_thaat.plot(cmap=plt.cm.Blues)
    plt.title(f"Thaat Confusion Matrix {prefix}")
    plt.show()

    # Compute confusion matrix for raga
    cm_raga = confusion_matrix(all_raga_truth, all_raga_preds)
    used_raga_indices = sorted(set(all_raga_truth) | set(all_raga_preds))
    idx_to_raga = {v: k for k, v in data_dict["raga_to_idx"].items()}
    used_raga_labels = [idx_to_raga[i] for i in used_raga_indices]

    disp_raga = ConfusionMatrixDisplay(cm_raga, display_labels=used_raga_labels)
    disp_raga.plot(cmap=plt.cm.Blues)
    plt.title(f"Raga Confusion Matrix {prefix}")
    plt.show()

    # Compute accuracies
    thaat_acc = accuracy_score(all_thaat_truth, all_thaat_preds)
    raga_acc  = accuracy_score(all_raga_truth, all_raga_preds)
    print(f"{prefix} Thaat Accuracy: {thaat_acc:.4f}")
    print(f"{prefix} Raga Accuracy:  {raga_acc:.4f}")

    # Plot training curves
    plt.figure()
    plt.plot(train_losses, label="Train Loss")
    plt.plot(test_losses, label="Test Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title(f"Training Curves {prefix}")
    plt.legend()
    plt.show()

    # Save model
    os.makedirs("data/modeloutput", exist_ok=True)
    if include_quantum:
        checkpoint_path = os.path.join("data", "modeloutput", "trained_model_thaat_raga_quantum.pt")
    else:
        checkpoint_path = os.path.join("data", "modeloutput", "trained_model_thaat_raga_noquantum.pt")

    # Save relevant metadata
    checkpoint_dict = {
        "model_state": model.state_dict(),
        "thaat_to_idx": data_dict["thaat_to_idx"],
        "raga_to_idx": data_dict["raga_to_idx"],
        "scalar_dim": scalar_feats.shape[1],
        "quantum_dim": quantum_feats.shape[1],
        "num_thaats": num_thaats,
        "num_ragas": num_ragas,
        "used_ids": used_ids
    }
    torch.save(checkpoint_dict, checkpoint_path)
    print(f"{prefix} Model + metadata saved to: {checkpoint_path}")

    return model, data_dict


###############################################################################
# COMPARISON FUNCTION: QUANTUM VS. NO-QUANTUM
###############################################################################
def compare_thaat_raga_models():
    """
    1) Re-fetch the same dataset (using fetch_training_data_thaat_raga).
    2) Build the same test set split (20%).
    3) Load the quantum model checkpoint and the no-quantum model checkpoint.
    4) Evaluate both on the *same test set*, compare thaat & raga accuracies.
    """
    data_dict, used_ids = fetch_training_data_thaat_raga(limit=HYPERPARAMS["DB_LIMIT"])
    if data_dict is None:
        print("[Compare] No data for comparison.")
        return

    dataset_full = MultiLabelThaatRagaDataset(data_dict)
    n_samples = len(dataset_full)
    test_size = int(0.2 * n_samples)
    train_size = n_samples - test_size
    # We only want to evaluate on the test set. We do not re-train here.
    _, test_ds = torch.utils.data.random_split(dataset_full, [train_size, test_size])
    test_dl = DataLoader(test_ds, batch_size=HYPERPARAMS["BATCH_SIZE"], shuffle=False)

    # Load quantum model
    q_ckpt_path = "data/modeloutput/trained_model_thaat_raga_quantum.pt"
    if not os.path.exists(q_ckpt_path):
        print(f"[Compare] Quantum model not found at {q_ckpt_path}")
        return
    q_ckpt = torch.load(q_ckpt_path, map_location=HYPERPARAMS["DEVICE"])

    q_model = HybridASTModelThaatRaga(
        num_thaats=q_ckpt["num_thaats"],
        num_ragas=q_ckpt["num_ragas"],
        scalar_dim=q_ckpt["scalar_dim"],
        quantum_dim=q_ckpt["quantum_dim"],
        num_unfrozen_layers=HYPERPARAMS["NUM_AST_LAYERS_UNFROZEN"]
    )
    q_model.load_state_dict(q_ckpt["model_state"])
    q_model.to(HYPERPARAMS["DEVICE"])
    q_model.eval()

    # Load no-quantum model
    nq_ckpt_path = "data/modeloutput/trained_model_thaat_raga_noquantum.pt"
    if not os.path.exists(nq_ckpt_path):
        print(f"[Compare] No-Quantum model not found at {nq_ckpt_path}")
        return
    nq_ckpt = torch.load(nq_ckpt_path, map_location=HYPERPARAMS["DEVICE"])

    nq_model = HybridASTModelThaatRagaNoQuantum(
        num_thaats=nq_ckpt["num_thaats"],
        num_ragas=nq_ckpt["num_ragas"],
        scalar_dim=nq_ckpt["scalar_dim"],
        num_unfrozen_layers=HYPERPARAMS["NUM_AST_LAYERS_UNFROZEN"]
    )
    nq_model.load_state_dict(nq_ckpt["model_state"])
    nq_model.to(HYPERPARAMS["DEVICE"])
    nq_model.eval()

    # Evaluate both
    all_true_thaats, all_true_ragas = [], []
    preds_thaat_q, preds_raga_q = [], []
    preds_thaat_nq, preds_raga_nq = [], []

    with torch.no_grad():
        for mel_img, scal, quan, t_lbl, r_lbl in test_dl:
            mel_img = mel_img.to(HYPERPARAMS["DEVICE"])
            scal    = scal.to(HYPERPARAMS["DEVICE"])
            quan    = quan.to(HYPERPARAMS["DEVICE"])
            t_lbl   = t_lbl.to(HYPERPARAMS["DEVICE"])
            r_lbl   = r_lbl.to(HYPERPARAMS["DEVICE"])

            # Quantum
            logits_thaat_q, logits_raga_q = q_model(mel_img, scal, quan)
            pred_t_q = logits_thaat_q.argmax(dim=1).cpu().numpy()
            pred_r_q = logits_raga_q.argmax(dim=1).cpu().numpy()

            # NoQuantum
            logits_thaat_nq, logits_raga_nq = nq_model(mel_img, scal)
            pred_t_nq = logits_thaat_nq.argmax(dim=1).cpu().numpy()
            pred_r_nq = logits_raga_nq.argmax(dim=1).cpu().numpy()

            all_true_thaats.extend(t_lbl.cpu().numpy())
            all_true_ragas.extend(r_lbl.cpu().numpy())
            preds_thaat_q.extend(pred_t_q)
            preds_raga_q.extend(pred_r_q)
            preds_thaat_nq.extend(pred_t_nq)
            preds_raga_nq.extend(pred_r_nq)

    # Accuracy
    thaat_acc_q  = accuracy_score(all_true_thaats, preds_thaat_q)
    thaat_acc_nq = accuracy_score(all_true_thaats, preds_thaat_nq)
    raga_acc_q   = accuracy_score(all_true_ragas, preds_raga_q)
    raga_acc_nq  = accuracy_score(all_true_ragas, preds_raga_nq)

    print("\n=== Thaat-Raga: Quantum vs. No-Quantum ===")
    print(f"Thaat Accuracy (Quantum):     {thaat_acc_q:.4f}")
    print(f"Thaat Accuracy (No-Quantum):  {thaat_acc_nq:.4f}")
    print(f"Raga  Accuracy (Quantum):     {raga_acc_q:.4f}")
    print(f"Raga  Accuracy (No-Quantum):  {raga_acc_nq:.4f}")

    d_thaat = (thaat_acc_q - thaat_acc_nq) * 100
    d_raga  = (raga_acc_q - raga_acc_nq) * 100
    print(f"Difference in Thaat Accuracy = {d_thaat:.2f} % points")
    print(f"Difference in Raga Accuracy  = {d_raga:.2f} % points")


###############################################################################
# EXAMPLE MAIN
###############################################################################
if __name__ == "__main__":
    # 1) Train with quantum
    model_q, data_q = train_thaat_raga_model(include_quantum=True)

    # 2) Train no-quantum
    model_nq, data_nq = train_thaat_raga_model(include_quantum=False)

    # 3) Compare
    compare_thaat_raga_models()