## Machine Learning Models

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchaudio.transforms import MelSpectrogram
import librosa
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import random
from transformers import ASTModel

# Settings
OUTPUT_DIR = "data/trainingdataoutput"
BATCH_SIZE = 8
EPOCHS = 10
LEARNING_RATE_MAIN = 5e-5
LEARNING_RATE_HEAD = 1e-4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

###########################################################################
# Minimal DB to fetch analysis_data
###########################################################################
class QuantumMusicDBFetchOnly:
    """Minimal class to fetch analysis_data from the DB for ML."""
    def __init__(self, db_name="quantummusic", host="localhost", user="postgres", password="postgres"):
        import psycopg2
        self.psycopg2 = psycopg2
        self.db_name = db_name
        self.host = host
        self.user = user
        self.password = password
        self.conn = None
        self.connect()

    def connect(self):
        try:
            self.conn = self.psycopg2.connect(
                dbname=self.db_name,
                host=self.host,
                user=self.user,
                password=self.password
            )
            print(f"Connected to database {self.db_name}. (fetch-only)")
        except Exception as e:
            print(f"Error connecting to database: {e}")

    def close(self):
        if self.conn:
            self.conn.close()
            print("Database connection closed.")

    def fetch_all_analysis_data(self):
        with self.conn.cursor() as cur:
            cur.execute("SELECT analysis_data FROM audio_analysis")
            rows = cur.fetchall()
        return rows


###########################################################################
# Convert quantum measurement counts -> probability vector
###########################################################################
def convert_counts_to_probs_feature(counts_dict, max_bits=5):
    total_counts = sum(counts_dict.values())
    if total_counts == 0:
        return np.zeros(2**max_bits, dtype=np.float32)

    feature_vec = np.zeros(2**max_bits, dtype=np.float32)
    for bitstring, c in counts_dict.items():
        if len(bitstring) > max_bits:
            truncated = bitstring[-max_bits:]
        else:
            truncated = bitstring.rjust(max_bits, '0')
        idx = int(truncated, 2)
        feature_vec[idx] += c / total_counts
    return feature_vec


###########################################################################
# Fetch Training Data
###########################################################################
def fetch_training_data():
    db = QuantumMusicDBFetchOnly()
    rows = db.fetch_all_analysis_data()
    db.close()
    if not rows:
        print("No data found in DB.")
        return None, None, None, None

    audio_feats, scalar_feats, quantum_feats, labels = [], [], [], []
    # We'll also store distribution-based quantum features
    dist_feats = []

    for (analysis_data,) in rows:
        # file path
        fpath = os.path.join(OUTPUT_DIR, analysis_data["file_name"])
        if not os.path.exists(fpath):
            continue

        # load audio => mel spec
        y, sr = librosa.load(fpath, sr=None)
        meltr = MelSpectrogram(n_mels=128)
        mel = meltr(torch.tensor(y).float()).numpy()  # shape(128, frames)
        mel = np.expand_dims(mel, axis=0)             # shape(1, 128, frames)
        audio_feats.append(mel)

        # gather scalar features
        res = analysis_data.get("results", {})
        avg_dev = res.get("average_dev_cents", 0.0)
        std_dev = res.get("std_dev_cents", 0.0)
        avg_hnr = res.get("avg_praat_hnr", 0.0)
        avg_tnr = res.get("avg_tnr", 0.0)
        # placeholder
        scal = [avg_dev, std_dev, avg_hnr, avg_tnr, 0.0, 0.0]
        scalar_feats.append(scal)

        # gather quantum angles
        quantum_dict = analysis_data.get("quantum_analysis", {})
        angles = quantum_dict.get("scaled_angles", [])
        max_len = 10
        angle_arr = np.zeros(max_len, dtype=np.float32)
        for i in range(min(max_len, len(angles))):
            angle_arr[i] = angles[i]

        # gather quantum distribution
        qv = analysis_data.get("quantum_analysis_variational", {})
        counts_d = qv.get("counts", {})
        dist_vec = convert_counts_to_probs_feature(counts_d, max_bits=5)  # shape(32,)

        quantum_feats.append(angle_arr)
        dist_feats.append(dist_vec)

        # label
        lab = random.randint(0, 1)
        labels.append(lab)

    # combine quantum data
    final_q_feats = []
    for i in range(len(quantum_feats)):
        concat_ = np.concatenate([quantum_feats[i], dist_feats[i]], axis=0)
        final_q_feats.append(concat_)
    final_q_feats = np.array(final_q_feats, dtype=np.float32)

    return (
        np.array(audio_feats, dtype=np.float32),
        np.array(scalar_feats, dtype=np.float32),
        final_q_feats,
        np.array(labels, dtype=np.int64),
    )


###########################################################################
# PyTorch Dataset
###########################################################################
class AudioDataset(Dataset):
    def __init__(self, aud, scal, quan, labs):
        self.aud = torch.tensor(aud, dtype=torch.float32)
        self.scal = torch.tensor(scal, dtype=torch.float32)
        self.quan = torch.tensor(quan, dtype=torch.float32)
        self.labs = torch.tensor(labs, dtype=torch.long)

    def __len__(self):
        return len(self.labs)

    def __getitem__(self, idx):
        return (
            self.aud[idx],
            self.scal[idx],
            self.quan[idx],
            self.labs[idx]
        )


###########################################################################
# Hybrid AST Model
###########################################################################
class HybridASTModel(nn.Module):
    def __init__(self, scalar_dim, quantum_dim, output_dim=2, freeze_ast=True):
        super().__init__()
        self.ast = ASTModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")

        if freeze_ast:
            for p in self.ast.parameters():
                p.requires_grad = False

        self.scalar_fc = nn.Sequential(
            nn.Linear(scalar_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
        )
        self.quantum_fc = nn.Sequential(
            nn.Linear(quantum_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
        )

        # pooler_output => 768 dims
        self.final_fc = nn.Linear(768 + 64 + 64, output_dim)

    def forward(self, audio_input, scalar_input, quantum_input):
        ast_out = self.ast(audio_input).pooler_output
        s = self.scalar_fc(scalar_input)
        q = self.quantum_fc(quantum_input)
        fused = torch.cat([ast_out, s, q], dim=1)
        return self.final_fc(fused)


###########################################################################
# Training Pipeline
###########################################################################
def train_model():
    data = fetch_training_data()
    if data[0] is None:
        print("No data to train.")
        return None
    audio_features, scalar_features, quantum_features, labels = data

    from sklearn.preprocessing import StandardScaler
    sc1 = StandardScaler()
    scalar_features = sc1.fit_transform(scalar_features)

    sc2 = StandardScaler()
    quantum_features = sc2.fit_transform(quantum_features)

    from sklearn.model_selection import train_test_split
    Xaud_tr, Xaud_te, Xs_tr, Xs_te, Xq_tr, Xq_te, y_tr, y_te = train_test_split(
        audio_features, scalar_features, quantum_features, labels,
        test_size=0.2, random_state=42
    )

    train_ds = AudioDataset(Xaud_tr, Xs_tr, Xq_tr, y_tr)
    test_ds  = AudioDataset(Xaud_te, Xs_te, Xq_te, y_te)

    train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
    test_dl  = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)

    model = HybridASTModel(
        scalar_dim=scalar_features.shape[1],
        quantum_dim=quantum_features.shape[1],
        output_dim=2,
        freeze_ast=False
    ).to(DEVICE)

    # separate param groups
    ast_params, head_params = [], []
    for name, param in model.named_parameters():
        if "ast." in name:
            ast_params.append(param)
        else:
            head_params.append(param)

    optimizer = optim.Adam([
        {"params": ast_params,  "lr": LEARNING_RATE_MAIN},
        {"params": head_params, "lr": LEARNING_RATE_HEAD},
    ])
    crit = nn.CrossEntropyLoss()

    train_losses = []
    test_losses = []

    for epoch in range(EPOCHS):
        model.train()
        run_loss = 0
        for aud, sc, qu, lab in train_dl:
            aud = aud.to(DEVICE)
            sc  = sc.to(DEVICE)
            qu  = qu.to(DEVICE)
            lab = lab.to(DEVICE)
            optimizer.zero_grad()
            out = model(aud, sc, qu)
            loss = crit(out, lab)
            loss.backward()
            optimizer.step()
            run_loss += loss.item()
        avg_tr_loss = run_loss / len(train_dl)
        train_losses.append(avg_tr_loss)

        # test
        model.eval()
        run_loss_te = 0
        with torch.no_grad():
            for aud, sc, qu, lab in test_dl:
                aud, sc, qu = aud.to(DEVICE), sc.to(DEVICE), qu.to(DEVICE)
                lab = lab.to(DEVICE)
                out = model(aud, sc, qu)
                loss = crit(out, lab)
                run_loss_te += loss.item()
        avg_te_loss = run_loss_te / len(test_dl)
        test_losses.append(avg_te_loss)

        print(f"Epoch {epoch+1}/{EPOCHS}, train_loss={avg_tr_loss:.4f}, test_loss={avg_te_loss:.4f}")

    print("Training complete.")

    # Confusion matrix
    from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
    all_preds, all_labels = [], []
    model.eval()
    with torch.no_grad():
        for aud, sc, qu, lab in test_dl:
            aud, sc, qu = aud.to(DEVICE), sc.to(DEVICE), qu.to(DEVICE)
            preds = model(aud, sc, qu).argmax(dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(lab.numpy())

    cm = confusion_matrix(all_labels, all_preds)
    disp = ConfusionMatrixDisplay(cm, display_labels=[0,1])
    disp.plot(cmap=plt.cm.Blues)
    plt.title("Confusion Matrix")
    plt.show()

    # Plot training curves
    plt.figure()
    plt.plot(train_losses, label="Train Loss")
    plt.plot(test_losses, label="Test Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.title("Training Curves")
    plt.show()

    return model

# Example usage if you want to run it directly:
# if __name__ == "__main__":
#     trained_model = train_model()


import re

def parse_raga_and_quality(fname: str):
    """
    Example filename: 'Bhairavi5VenkateshKumar_500s_520s_Unknown.wav'
    Returns raga='Bhairavi', quality=5
    """
    # Strip any extension
    base = fname.replace(".wav", "")
    # Regex: one or more letters, followed by a single digit (1-5)
    match = re.match(r"^([A-Za-z]+)([1-5])(.*)", base)
    if not match:
        # fallback
        return None, None

    raga = match.group(1)
    quality_str = match.group(2)
    quality = int(quality_str)  # convert "5" -> 5
    return raga, quality



import os
from PIL import Image
import torch
import torchvision.transforms as T

transform_img = T.Compose([
    T.Resize((224, 224)),  # example size
    T.ToTensor()
])

def load_image_as_tensor(image_path: str):
    if not os.path.exists(image_path):
        # Could return None or a zero-tensor placeholder
        return None
    img = Image.open(image_path).convert("RGB")
    return transform_img(img)



def fetch_training_data_v2():
    db = QuantumMusicDBFetchOnly()
    rows = db.fetch_all_analysis_data()
    db.close()
    if not rows:
        print("No data found in DB.")
        return None

    # We’ll gather the following lists for the dataset:
    #  - images_mfcc, images_log, ...
    #  - scalar_feats (from results + dynamics_summary + advanced_vocal_stats)
    #  - quantum_feats
    #  - labels_raga, labels_quality (or you might combine them)

    images_mfcc = []
    images_log = []
    # Similarly for jitter, shimmer, formant, vibrato if desired
    # images_jitter = []
    # images_shimmer = []
    # ...
    
    scalar_feats = []
    quantum_feats = []
    labels_raga = []
    labels_quality = []

    for (analysis_data,) in rows:
        # Grab filename from the DB
        wav_fname = analysis_data["file_name"]  # e.g. 'Bhairavi5VenkateshKumar_500s_520s_Unknown.wav'
        base_no_ext = wav_fname.replace(".wav", "")
        
        # 1) Parse raga and quality from file name
        raga, quality = parse_raga_and_quality(wav_fname)
        if raga is None or quality is None:
            # If not parseable, skip or handle somehow
            continue
        
        labels_raga.append(raga)
        labels_quality.append(quality)

        # 2) Precomputed images
        #    data/analysisoutput/<base_no_ext>_mfcc.png
        mfcc_path = os.path.join("data", "analysisoutput", f"{base_no_ext}_mfcc.png")
        log_path  = os.path.join("data", "analysisoutput", f"{base_no_ext}_log_spectrogram.png")
        
        mfcc_img = load_image_as_tensor(mfcc_path)
        log_img  = load_image_as_tensor(log_path)
        # etc. for jitter, shimmer, ...
        
        images_mfcc.append(mfcc_img if mfcc_img is not None else torch.zeros((3,224,224)))
        images_log.append(log_img  if log_img  is not None else torch.zeros((3,224,224)))

        # 3) Scalar features from "results" + "dynamics_summary" + possibly "advanced_vocal_stats"
        #    You can pick whichever ones you want. For example:
        res = analysis_data.get("results", {})
        dyn = analysis_data.get("dynamics_summary", {})
        adv = analysis_data.get("quantum_analysis", {}).get("advanced_stats", {})
        
        avg_dev = res.get("average_dev_cents", 0.0)
        std_dev = res.get("std_dev_cents", 0.0)
        avg_hnr = res.get("avg_praat_hnr", 0.0)
        avg_tnr = res.get("avg_tnr", 0.0)

        # from dynamics_summary
        rms_db_stats = dyn.get("rms_db", {})
        mean_rms_db  = rms_db_stats.get("mean", 0.0)
        lufs_stats   = dyn.get("lufs", {})
        mean_lufs    = lufs_stats.get("mean", 0.0)

        # from advanced_vocal_stats (inside advanced_stats in quantum_analysis or similar)
        avg_jitter   = adv.get("avg_jitter", 0.0)
        avg_shimmer  = adv.get("avg_shimmer", 0.0)
        avg_vibrato  = adv.get("avg_vibrato_rate", 0.0)
        avg_formant  = adv.get("avg_F1", 0.0)  # or you can use multiple formants

        # Combine them into one scalar vector
        scalars = [
            avg_dev, std_dev, avg_hnr, avg_tnr,
            mean_rms_db, mean_lufs,
            avg_jitter, avg_shimmer, avg_vibrato, avg_formant
        ]
        scalar_feats.append(scalars)

        # 4) Quantum features
        #    e.g. angles + measurement_counts => distribution
        quantum_dict = analysis_data.get("quantum_analysis", {})
        angles = quantum_dict.get("scaled_angles", [])
        
        # example: put up to 10 angles
        max_len = 10
        angle_arr = np.zeros(max_len, dtype=np.float32)
        for i in range(min(max_len, len(angles))):
            angle_arr[i] = angles[i]

        # If you need distribution from measurement_counts
        counts_d = quantum_dict.get("measurement_counts", {})
        dist_vec = convert_counts_to_probs_feature(counts_d, max_bits=5)  # shape(32,)

        # combine angle_arr + dist_vec => shape (42,) if you do so
        combined_q = np.concatenate([angle_arr, dist_vec], axis=0)
        quantum_feats.append(combined_q)

    # Now convert your lists into arrays or Tensors as desired
    # For images, you might store them as a list of Tensors, or keep them on disk
    images_mfcc = torch.stack(images_mfcc)  # shape(N, 3, 224, 224)
    images_log  = torch.stack(images_log)

    scalar_feats = np.array(scalar_feats, dtype=np.float32)
    quantum_feats = np.array(quantum_feats, dtype=np.float32)

    # If you want to encode raga as an integer label:
    # e.g. map each unique raga to an integer
    all_ragas = list(set(labels_raga))
    raga_to_idx = {r: i for i, r in enumerate(sorted(all_ragas))}
    labels_raga_idx = [raga_to_idx[r] for r in labels_raga]

    # ‘quality’ is already numeric (1..5). You might keep it as is or shift to 0..4
    labels_quality_arr = np.array(labels_quality, dtype=np.int64)  # shape(N,)

    return {
        "images_mfcc": images_mfcc,
        "images_log": images_log,
        "scalar_feats": scalar_feats,
        "quantum_feats": quantum_feats,
        "label_raga_str": labels_raga,      # string labels
        "label_raga_idx": np.array(labels_raga_idx, dtype=np.int64),
        "label_quality": labels_quality_arr,
        "raga_to_idx": raga_to_idx
    }