<a href="https://colab.research.google.com/github/mdsenelen/Melody-Generation-VAE-for-Musician-s-Assistant/blob/main/vae_melody_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [15]:
# import torch
# torch.cuda.is_available(), torch.cuda.get_device_name(0)

In [16]:
import torch
print(torch.__version__)
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0))


2.6.0+cu124
True
Tesla T4


In [17]:
# !nvcc --version


In [18]:
# !pip install torch==2.0.0+cu118 torchvision --extra-index-url https://download.pytorch.org/whl/cu118


In [19]:
# !pip uninstall nvidia-cudnn-cu12


In [20]:
# !pip uninstall -y torch torchvision torchaudio
# !pip install torch==2.0.1+cu118 torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118


In [30]:
# -*- coding: utf-8 -*-
"""Musician's Assistant VAE Trainer with Thesis Visualizations

Combined notebook for:
1. Web-optimized VAE training with consistent 128x256 spectrogram outputs
2. Comprehensive visualizations for academic research
"""

# 1. IMPORT REQUIRED LIBRARIES
!pip install torchaudio librosa numpy matplotlib plotly scikit-learn soundfile tensorboard seaborn umap-learn fpdf

import os
import json
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.animation import FuncAnimation
from IPython.display import HTML, Audio, display
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import librosa
import librosa.display
import soundfile as sf
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import umap
from tqdm import tqdm, trange
import zipfile
import csv
from datetime import datetime
from google.colab import drive
from fpdf import FPDF
import glob

# 2. CONFIGURATION (Web-App Compatible)
CONFIG = {
    "audio": {
        "sample_rate": 22050,
        "n_fft": 2048,
        "hop_length": 512,
        "win_length": 1024,
        "n_mels": 128,
        "fmin": 30,
        "fmax": 8000,
        "max_frames": 256
    },
    "model": {
        "latent_dim": 256,
        "input_shape": [1, 128, 256],
        "batch_size": 32,
        "init_lr": 3e-4,
        "num_epochs": 100
    }
}

# 3. FIXED MODEL ARCHITECTURE
class WebVAE(nn.Module):
    def __init__(self):
        super().__init__()
        # Encoder with fixed output dimensions
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2)

        # Fixed latent space
        self.fc_mu = nn.Linear(256*16*32, CONFIG['model']['latent_dim'])
        self.fc_logvar = nn.Linear(256*16*32, CONFIG['model']['latent_dim'])

        # Decoder with fixed output (128x256)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(64, 1, kernel_size=(5,4), stride=(1,2), padding=(2,1)),
            nn.Tanh()
        )

    def encode(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), -1)
        return self.fc_mu(x), self.fc_logvar(x)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        z = z.view(-1, 256, 16, 32)
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# 4. DATA LOADING AND PROCESSING
class MelSpectrogramDataset(Dataset):
    def __init__(self, mel_array):
        self.mels = mel_array

    def __len__(self):
        return len(self.mels)

    def __getitem__(self, idx):
        mel = self.mels[idx]
        return torch.tensor(mel, dtype=torch.float32).unsqueeze(0)

def load_dataset(data_path):
    print("📁 Loading and processing audio files...")
    mel_list = []

    wav_files = [f for f in os.listdir(data_path) if f.lower().endswith('.wav')]

    for file in tqdm(wav_files, desc="Processing Audio Files"):
        try:
            filepath = os.path.join(data_path, file)
            y, _ = librosa.load(filepath, sr=CONFIG['audio']['sample_rate'])
            y = librosa.util.normalize(y) * 0.707  # -3dB normalizasyon

            mel = librosa.feature.melspectrogram(
                y=y, sr=CONFIG['audio']['sample_rate'],
                n_mels=CONFIG['audio']['n_mels'],
                n_fft=CONFIG['audio']['n_fft'],
                hop_length=CONFIG['audio']['hop_length'],
                win_length=CONFIG['audio']['win_length'],
                fmin=CONFIG['audio']['fmin'],
                fmax=CONFIG['audio']['fmax'],
                power=1.0)

            mel_db = librosa.amplitude_to_db(mel, ref=np.max, top_db=80)
            mel_db = np.clip(mel_db, -40, 0)
            mel_norm = (mel_db + 40) / 40

            if mel_norm.shape[1] < CONFIG['audio']['max_frames']:
                pad_width = CONFIG['audio']['max_frames'] - mel_norm.shape[1]
                mel_norm = np.pad(mel_norm, ((0, 0), (0, pad_width)),
                                 mode='constant', constant_values=-40/40+1)
            else:
                mel_norm = mel_norm[:, :CONFIG['audio']['max_frames']]

            mel_list.append(mel_norm)

        except Exception as e:
            print(f"\n⚠️ Error processing {file}: {str(e)}")

    return np.array(mel_list)

# 5. VISUALIZATION TOOLS
class Visualizer:
    def __init__(self, log_dir):
        self.log_dir = log_dir
        os.makedirs(os.path.join(log_dir, "visualizations"), exist_ok=True)

    def plot_training_curves(self, train_losses, val_losses, metrics):
        plt.figure(figsize=(18, 12))

        plt.subplot(2, 2, 1)
        plt.plot(train_losses, label='Train Loss', alpha=0.8)
        plt.plot(val_losses, label='Val Loss', alpha=0.8)
        plt.title('Training and Validation Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True)

        plt.subplot(2, 2, 2)
        plt.plot(metrics['psnr'], label='PSNR', color='green')
        plt.title('PSNR Over Time')
        plt.xlabel('Epoch')
        plt.ylabel('PSNR (dB)')
        plt.grid(True)

        plt.subplot(2, 2, 3)
        plt.plot(metrics['ssim'], label='SSIM', color='purple')
        plt.title('SSIM Over Time')
        plt.xlabel('Epoch')
        plt.ylabel('SSIM')
        plt.grid(True)

        plt.subplot(2, 2, 4)
        ax = plt.gca()
        ax.plot(metrics['psnr'], color='green', label='PSNR')
        ax.set_ylabel('PSNR (dB)', color='green')
        ax2 = ax.twinx()
        ax2.plot(metrics['ssim'], color='purple', label='SSIM')
        ax2.set_ylabel('SSIM', color='purple')
        plt.title('Quality Metrics Correlation')
        plt.grid(True)

        plt.tight_layout()
        plt.savefig(os.path.join(self.log_dir, "visualizations", "training_metrics.png"))
        plt.close()

    def plot_latent_space(self, latent_vectors):
        pca = PCA(n_components=3).fit_transform(latent_vectors)
        tsne = TSNE(n_components=3).fit_transform(latent_vectors)
        umap_emb = umap.UMAP(n_components=3).fit_transform(latent_vectors)

        fig = make_subplots(rows=1, cols=3,
                           specs=[[{'type': 'scatter3d'}, {'type': 'scatter3d'}, {'type': 'scatter3d'}]],
                           subplot_titles=('PCA', 't-SNE', 'UMAP'))

        for i, (emb, name) in enumerate(zip([pca, tsne, umap_emb], ['PCA', 't-SNE', 'UMAP'])):
            fig.add_trace(
                go.Scatter3d(
                    x=emb[:, 0], y=emb[:, 1], z=emb[:, 2],
                    mode='markers',
                    marker=dict(size=4, opacity=0.8),
                    name=name
                ),
                row=1, col=i+1
            )

        fig.update_layout(height=600, width=1800, title_text="3D Latent Space Projections")
        fig.write_html(os.path.join(self.log_dir, "visualizations", "3d_latent.html"))

        plt.figure(figsize=(18, 5))
        for i, (emb, name) in enumerate(zip([pca[:,:2], tsne[:,:2], umap_emb[:,:2]], ['PCA', 't-SNE', 'UMAP'])):
            plt.subplot(1, 3, i+1)
            sns.kdeplot(x=emb[:, 0], y=emb[:, 1], cmap="viridis", fill=True, thresh=0.1)
            plt.title(f'{name} Projection')
        plt.savefig(os.path.join(self.log_dir, "visualizations", "2d_latent.png"))
        plt.close()

    def plot_spectrogram_comparison(self, original, reconstructed, epoch=None):
        fig = plt.figure(figsize=(20, 10))

        plt.subplot(3, 1, 1)
        librosa.display.specshow(original, sr=CONFIG['audio']['sample_rate'],
                               hop_length=CONFIG['audio']['hop_length'],
                               x_axis='time', y_axis='mel', cmap='magma')
        plt.colorbar(format='%+2.0f dB')
        plt.title('Original Spectrogram')

        plt.subplot(3, 1, 2)
        librosa.display.specshow(reconstructed, sr=CONFIG['audio']['sample_rate'],
                               hop_length=CONFIG['audio']['hop_length'],
                               x_axis='time', y_axis='mel', cmap='magma')
        plt.colorbar(format='%+2.0f dB')
        plt.title('Reconstructed Spectrogram')

        plt.subplot(3, 1, 3)
        diff = np.abs(original - reconstructed)
        librosa.display.specshow(diff, sr=CONFIG['audio']['sample_rate'],
                               hop_length=CONFIG['audio']['hop_length'],
                               x_axis='time', y_axis='mel', cmap='coolwarm')
        plt.colorbar(format='%+2.0f dB')
        plt.title('Absolute Difference')

        plt.tight_layout()
        fname = f"spectrogram_comparison{'_epoch'+str(epoch) if epoch else ''}.png"
        plt.savefig(os.path.join(self.log_dir, "visualizations", fname))
        plt.close()

    def create_spectrogram_animation(self, spectrograms, filename="spectrogram_evolution.gif"):
        fig = plt.figure(figsize=(10, 6))
        ax = plt.gca()

        def update(frame):
            ax.clear()
            librosa.display.specshow(spectrograms[frame],
                                   sr=CONFIG['audio']['sample_rate'],
                                   hop_length=CONFIG['audio']['hop_length'],
                                   x_axis='time', y_axis='mel', cmap='magma')
            plt.title(f'Epoch {frame+1}')

        anim = FuncAnimation(fig, update, frames=len(spectrograms), interval=500)
        anim.save(os.path.join(self.log_dir, "visualizations", filename), writer='pillow')
        return HTML(anim.to_jshtml())

# 6. TRAINING AND EVALUATION
def compute_metrics(target, recon):
    target_np = target.squeeze().cpu().numpy()
    recon_np = recon.squeeze().cpu().numpy()
    psnr_score = psnr(target_np, recon_np, data_range=1.0)
    ssim_score = ssim(target_np, recon_np, data_range=1.0, channel_axis=0)
    return psnr_score, ssim_score

def spectral_convergence_loss(input, target):
    return torch.norm(target - input, p='fro') / (torch.norm(target, p='fro') + 1e-6)

def spectral_magnitude_loss(input, target):
    return F.l1_loss(torch.log1p(input), torch.log1p(target))

def compute_loss(recon, target, mu, logvar):
    mse_loss = F.mse_loss(recon, target)
    sc_loss = spectral_convergence_loss(recon, target)
    sm_loss = spectral_magnitude_loss(recon, target)
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return {
        'loss': mse_loss + 0.5*sc_loss + 0.5*sm_loss + 0.001*kl_loss,
        'mse': mse_loss,
        'sc': sc_loss,
        'sm': sm_loss,
        'kl': kl_loss
    }

# 7. MAIN TRAINING FUNCTION
def train_for_web_and_thesis(data_path, output_dir):
    # Initialize
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Load data
    mels = load_dataset(data_path)
    split_idx = int(0.8 * len(mels))
    train_dataset = MelSpectrogramDataset(mels[:split_idx])
    val_dataset = MelSpectrogramDataset(mels[split_idx:])

    train_loader = DataLoader(train_dataset,
                            batch_size=CONFIG['model']['batch_size'],
                            shuffle=True)
    val_loader = DataLoader(val_dataset,
                          batch_size=CONFIG['model']['batch_size'])

    # Model and optimizer
    model = WebVAE().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG['model']['init_lr'])
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5)

    # Setup logging
    log_dir = os.path.join(output_dir, datetime.now().strftime("%Y%m%d_%H%M%S"))
    os.makedirs(log_dir, exist_ok=True)
    visualizer = Visualizer(log_dir)

    # Training loop
    train_losses = []
    val_losses = []
    val_psnrs = []
    val_ssims = []
    best_samples = []

    for epoch in trange(CONFIG['model']['num_epochs'], desc="Training"):
        model.train()
        epoch_train_losses = []

        for batch in train_loader:
            batch = batch.to(device)
            if batch.shape[-2:] != (128, 256):
                batch = F.interpolate(batch, size=(128,256))

            optimizer.zero_grad()
            recon, mu, logvar = model(batch)
            loss_dict = compute_loss(recon, batch, mu, logvar)
            loss_dict['loss'].backward()
            optimizer.step()

            epoch_train_losses.append(loss_dict['loss'].item())

        # Validation
        model.eval()
        epoch_val_losses = []
        epoch_val_recons = []
        epoch_val_targets = []
        latent_vectors = []

        with torch.no_grad():
            for batch in val_loader:
                batch = batch.to(device)
                if batch.shape[-2:] != (128, 256):
                    batch = F.interpolate(batch, size=(128,256))

                recon, mu, logvar = model(batch)
                loss_dict = compute_loss(recon, batch, mu, logvar)
                epoch_val_losses.append(loss_dict['loss'].item())
                epoch_val_recons.append(recon.cpu())
                epoch_val_targets.append(batch.cpu())
                latent_vectors.append(mu.cpu().numpy())

        # Calculate metrics
        avg_train_loss = np.mean(epoch_train_losses)
        avg_val_loss = np.mean(epoch_val_losses)
        train_losses.append(avg_train_loss)
        val_losses.append(avg_val_loss)

        if epoch_val_recons:
            recon_all = torch.cat(epoch_val_recons)
            target_all = torch.cat(epoch_val_targets)
            psnr_score, ssim_score = compute_metrics(target_all, recon_all)
            val_psnrs.append(psnr_score)
            val_ssims.append(ssim_score)

            # Store best sample
            best_idx = np.argmin([F.mse_loss(r, t) for r, t in zip(epoch_val_recons, epoch_val_targets)])
            best_samples.append(epoch_val_recons[best_idx][0].squeeze().numpy())

        # Visualizations
        if epoch % 5 == 0 and epoch_val_recons:
            sample_idx = 0
            original = epoch_val_targets[0][sample_idx].squeeze().numpy()
            recon = epoch_val_recons[0][sample_idx].squeeze().numpy()
            visualizer.plot_spectrogram_comparison(original, recon, epoch)

            if latent_vectors:
                visualizer.plot_latent_space(np.concatenate(latent_vectors))

        scheduler.step(avg_val_loss)

    # Final visualizations
    visualizer.create_spectrogram_animation(best_samples)
    visualizer.plot_training_curves(
        train_losses,
        val_losses,
        {'psnr': val_psnrs, 'ssim': val_ssims}
    )

    # Save final model and artifacts
    torch.save({
        'state_dict': model.state_dict(),
        'config': CONFIG,
        'class_name': 'WebVAE'
    }, os.path.join(log_dir, "web_model.pt"))

    with open(os.path.join(log_dir, "web_audio_params.json"), "w") as f:
        json.dump(CONFIG['audio'], f)

    # Generate PDF report
    generate_report(log_dir)

    print(f"✅ Training completed! Results saved to: {log_dir}")

def generate_report(log_dir):
    pdf = FPDF()
    pdf.set_auto_page_break(auto=True, margin=15)
    pdf.add_page()
    pdf.set_font("Arial", 'B', 16)
    pdf.cell(0, 10, "VAE Melody Generation - Thesis Results", ln=1, align='C')

    vis_dir = os.path.join(log_dir, "visualizations")
    images = glob.glob(os.path.join(vis_dir, "*.png")) + glob.glob(os.path.join(vis_dir, "*.jpg"))

    for img in sorted(images):
        pdf.add_page()
        pdf.set_font("Arial", size=12)
        pdf.cell(0, 10, os.path.basename(img).replace('_', ' ').replace('.png', ''), ln=1)
        pdf.image(img, x=10, y=30, w=180)

    pdf.output(os.path.join(log_dir, "thesis_results_report.pdf"))

# 8. WEB-COMPATIBLE AUDIO RECONSTRUCTION
def web_reconstruct(mel):
    """Optimized for your web app backend"""
    mel = (mel * 40) - 40  # Denormalize
    mel = librosa.db_to_amplitude(mel)
    audio = librosa.griffinlim(
        librosa.feature.inverse.mel_to_stft(mel),
        n_iter=32,  # Faster for web
        hop_length=CONFIG['audio']['hop_length'],
        win_length=CONFIG['audio']['win_length']
    )
    return audio

# 9. RUN TRAINING
if __name__ == "__main__":
    # Mount Google Drive if needed
    drive.mount('/content/drive')

    # Set your paths
    DATA_DIR = "/content/drive/MyDrive/vae_project/recordings"
    OUTPUT_DIR = "/content/drive/MyDrive/vae_project/logs"

    # Start training
    train_for_web_and_thesis(DATA_DIR, OUTPUT_DIR)

SyntaxError: '(' was never closed (ipython-input-30-3919162377.py, line 66)

In [34]:
# -*- coding: utf-8 -*-
"""vae_melody_training.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1ufRp8OqWN5n9VRDYzxSC6uxmJv6KigAS
"""

# 1. REQUIRED LIBRARIES
!pip install torchaudio librosa numpy matplotlib plotly scikit-learn soundfile

from torch.utils.data import Dataset, DataLoader
from IPython.display import Audio, display
import plotly.express as px
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import librosa.display # Added import for librosa.display
import unittest
from tqdm import tqdm, trange
from google.colab import drive
from datetime import datetime
import zipfile
import soundfile as sf
import librosa
import numpy as np
import torch.nn.functional as F
import torch.nn as nn
import torch
import torchaudio  # Added missing import
import csv
import os

# Configuration
USE_DRIVE = True
DATA_DIR = "/content/drive/MyDrive/vae_project/recordings"
OUTPUT_DIR = "/content/drive/MyDrive/vae_project/logs"

# Mount Google Drive
if USE_DRIVE:
    drive.mount('/content/drive')

# CALLBACK BASE
class Callback:
    def on_epoch_end(self, epoch, logs):
        pass

class EarlyStopping(Callback):
    def __init__(self, patience=3):
        self.patience = patience
        self.best_loss = float('inf')
        self.counter = 0
        self.stopped_epoch = None

    def on_epoch_end(self, epoch, logs):
        val_loss = logs['val_loss']
        if val_loss < self.best_loss:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.stopped_epoch = epoch
                logs['stop_training'] = True

class ModelCheckpoint(Callback):
    def __init__(self, model, filepath):
        self.model = model
        self.filepath = filepath
        self.best_loss = float('inf')

    def on_epoch_end(self, epoch, logs):
        val_loss = logs['val_loss']
        if val_loss < self.best_loss:
            self.best_loss = val_loss
            torch.save(self.model.state_dict(), self.filepath)
            print(f"📌 Best model saved (val_loss={val_loss:.4f})")

class CSVLogger(Callback):
    def __init__(self, filename):
        self.filename = filename
        with open(self.filename, 'w', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(['epoch', 'train_loss', 'val_loss'])

    def on_epoch_end(self, epoch, logs):
        with open(self.filename, 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow([epoch + 1, logs['train_loss'], logs['val_loss']])

def load_dataset(data_path, sr=22050, n_mels=64, max_len=256):
    """
    Loads audio files and converts them to normalized, fixed-size mel spectrograms.
    """
    print("📁 Loading data...")
    mel_list = []

    if not os.path.exists(data_path):
        raise FileNotFoundError(f"Data path not found: {data_path}")

    wav_files = [f for f in os.listdir(data_path) if f.lower().endswith('.wav')]
    if not wav_files:
        raise ValueError("No .wav files found in the specified directory")

    for file in tqdm(wav_files, desc="Processing Audio Files"):
        try:
            filepath = os.path.join(data_path, file)
            y, _ = librosa.load(filepath, sr=sr)
            mel = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=n_mels)
            mel_db = librosa.power_to_db(mel, ref=np.max)

            # Normalize to [0, 1]
            mel_db = (mel_db - mel_db.min()) / (mel_db.max() - mel_db.min())

            # Pad or truncate to fixed time steps
            if mel_db.shape[1] < max_len:
                pad_width = max_len - mel_db.shape[1]
                mel_db = np.pad(mel_db, ((0, 0), (0, pad_width)), mode='constant')
            else:
                mel_db = mel_db[:, :max_len]

            mel_list.append(mel_db)
        except Exception as e:
            print(f"\n⚠️ Error processing {file}: {str(e)}")

    if not mel_list:
        raise ValueError("No audio files could be processed")

    return np.array(mel_list)  # shape: (N, 64, max_len)

class MelSpectrogramDataset(Dataset):
    def __init__(self, mel_array):
        self.mels = mel_array

    def __len__(self):
        return len(self.mels)

    def __getitem__(self, idx):
        mel = self.mels[idx]
        mel_tensor = torch.tensor(mel, dtype=torch.float32).unsqueeze(0)  # shape: (1, 64, time)
        return mel_tensor

class EnhancedVAE(nn.Module):
    def __init__(self, input_shape=(1, 64, 256), latent_dim=128):
        super().__init__()
        self.input_shape = input_shape
        self.latent_dim = latent_dim

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),  # → (32, H/2, W/2)
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),  # → (64, H/4, W/4)
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),  # → (128, H/8, W/8)
            nn.ReLU(),
        )

        # Compute conv output shape dynamically
        self._conv_output_shape = self._get_conv_output_shape()
        conv_output_dim = int(np.prod(self._conv_output_shape))

        # Latent space
        self.fc_mu = nn.Linear(conv_output_dim, latent_dim)
        self.fc_logvar = nn.Linear(conv_output_dim, latent_dim)

        # Decoder input and reshape
        self.decoder_input = nn.Linear(latent_dim, conv_output_dim)

        # Decoder (mirror of encoder)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # ×2
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),   # ×2
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1),    # ×2
            nn.Sigmoid()
        )

    def _get_conv_output_shape(self):
        with torch.no_grad():
            dummy_input = torch.zeros(1, *self.input_shape)
            output = self.encoder(dummy_input)
            return output.shape[1:]  # (channels, H', W')

    def encode(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), -1)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        x = self.decoder_input(z)
        x = x.view(-1, *self._conv_output_shape)
        return self.decoder(x)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z)
        return recon, mu, logvar

def main():
    # 4. MODEL AND DATASET PREPARATION
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Data loading
    try:
        mels = load_dataset(DATA_DIR, max_len=256)
    except Exception as e:
        print(f"Failed to load dataset: {e}")
        return

    # Split and wrap with Dataset
    split_idx = int(0.8 * len(mels))
    train_dataset = MelSpectrogramDataset(mels[:split_idx])
    val_dataset = MelSpectrogramDataset(mels[split_idx:])

    batch_size = 16
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    # Dynamically infer input shape from one sample
    example_input = next(iter(train_loader))
    input_shape = tuple(example_input.shape[1:])  # (1, 64, 256)

    # Model
    model = EnhancedVAE(input_shape=input_shape).to(device)
    print(f"Model initialized with input shape: {input_shape}")

    num_epochs = 10

    # 5. TRAINING LOOP
    train_losses = []
    val_losses = []
    val_psnrs = []
    val_ssims = []
    latent_vectors = []

    # Initialize callbacks
    log_dir = os.path.join(OUTPUT_DIR, datetime.now().strftime("%Y%m%d_%H%M%S"))
    os.makedirs(log_dir, exist_ok=True)
    model_filepath = os.path.join(log_dir, "best_model.pt")
    csv_filepath = os.path.join(log_dir, "training_log.csv")
    latent_path = os.path.join(log_dir, "latent_vectors.npy")

    callbacks = [
        EarlyStopping(patience=3),
        ModelCheckpoint(model, model_filepath),
        CSVLogger(csv_filepath)
    ]

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    def smooth_curve(points, factor=0.8):
        smoothed_points = []
        for point in points:
            if smoothed_points:
                previous = smoothed_points[-1]
                smoothed_points.append(previous * factor + point * (1 - factor))
            else:
                smoothed_points.append(point)
        return smoothed_points

    def compute_metrics(target, recon):
        target_np = target.squeeze().cpu().numpy()
        recon_np = recon.squeeze().cpu().numpy()

        psnr_score = 0
        ssim_score = 0
        count = 0
        for i in range(target_np.shape[0]):
            img1 = target_np[i]
            img2 = recon_np[i]
            max_val = 1.0
            psnr_score += psnr(img1, img2, data_range=max_val)
            ssim_score += ssim(img1, img2, data_range=max_val)
            count += 1

        return psnr_score / count, ssim_score / count

    # Training loop
    if len(train_dataset) > 0:
        for epoch in trange(num_epochs, desc="Training Epochs"):
            model.train()
            train_loss = 0
            num_batches = 0

            for batch in train_loader:
                batch = batch.to(device)
                optimizer.zero_grad()

                recon, mu, logvar = model(batch)
                loss = F.mse_loss(recon, batch) + 0.5 * torch.sum(logvar.exp() + mu.pow(2) - 1 - logvar)
                loss.backward()
                optimizer.step()

                train_loss += loss.item()
                num_batches += 1

            avg_train_loss = train_loss / max(1, num_batches)
            train_losses.append(avg_train_loss)

            # Validation
            model.eval()
            val_loss_accum = 0
            val_batches = 0
            val_recons = []
            val_targets = []
            val_mus = []

            with torch.no_grad():
                for val_batch in val_loader:
                    val_batch = val_batch.to(device)
                    recon, mu, logvar = model(val_batch)
                    loss = F.mse_loss(recon, val_batch) + 0.5 * torch.sum(logvar.exp() + mu.pow(2) - 1 - logvar)
                    val_loss_accum += loss.item()
                    val_batches += 1
                    val_recons.append(recon.cpu())
                    val_targets.append(val_batch.cpu())
                    val_mus.append(mu.cpu())

            avg_val_loss = val_loss_accum / max(1, val_batches)
            val_losses.append(avg_val_loss)

            # PSNR / SSIM
            if val_targets:
                recon_all = torch.cat(val_recons, dim=0)
                target_all = torch.cat(val_targets, dim=0)
                psnr_score, ssim_score = compute_metrics(target_all, recon_all)
                val_psnrs.append(psnr_score)
                val_ssims.append(ssim_score)
                current_epoch_latents = torch.cat(val_mus, dim=0).numpy()
                latent_vectors.append(current_epoch_latents)

            logs = {
                'train_loss': avg_train_loss,
                'val_loss': avg_val_loss,
                'stop_training': False
            }

            print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, "
                  f"PSNR: {val_psnrs[-1] if val_psnrs else 0:.2f}, SSIM: {val_ssims[-1] if val_ssims else 0:.3f}")

            for cb in callbacks:
                cb.on_epoch_end(epoch, logs)

            if logs['stop_training']:
                print("Early stopping triggered.")
                break

            # Plotting with smoothing
            plt.figure(figsize=(8, 5))
            plt.plot(smooth_curve(train_losses), label='Smoothed Train Loss', marker='o')
            plt.plot(smooth_curve(val_losses), label='Smoothed Val Loss', marker='x')
            plt.title('Training and Validation Loss (Smoothed)')
            plt.xlabel('Epoch')
            plt.ylabel('Loss')
            plt.legend()
            plt.grid(True)
            plt.savefig(os.path.join(log_dir, "loss_plot.png"))
            plt.close()

        # After training loop
        if latent_vectors:
            np.save(latent_path, latent_vectors[-1])
            print(f"Latent vectors saved to: {latent_path}")

            # Visualization
            latents = latent_vectors[-1]
            if latents.shape[0] >= 2:
                pca = PCA(n_components=2).fit_transform(latents)
                tsne = TSNE(n_components=2, perplexity=min(30, latents.shape[0]-1)).fit_transform(latents)

                fig_pca = px.scatter(x=pca[:, 0], y=pca[:, 1], title="PCA of Latent Space")
                fig_tsne = px.scatter(x=tsne[:, 0], y=tsne[:, 1], title="t-SNE of Latent Space")
                fig_pca.write_html(os.path.join(log_dir, "pca_latent.html"))
                fig_tsne.write_html(os.path.join(log_dir, "tsne_latent.html"))
                fig_pca.show()
                fig_tsne.show()

        # Spectrogram Reconstruction Visuals
        if val_recons and val_targets:
            n = min(3, val_recons[-1].shape[0])
            fig, axes = plt.subplots(nrows=n, ncols=2, figsize=(8, n*3))
            samples = val_targets[-1][:n].cpu().squeeze(dim=1).numpy() # Squeeze channel dimension here
            recons = val_recons[-1][:n].cpu().squeeze(dim=1).numpy() # Squeeze channel dimension here

            print(f"Shape of samples before plotting: {samples.shape}") # Debug print
            print(f"Shape of recons before plotting: {recons.shape}")   # Debug print

            for i in range(n):
                # Ensure data is 2D (64, 256) for specshow
                sample_to_plot = samples[i]
                recon_to_plot = recons[i]

                if n > 1:
                    librosa.display.specshow(sample_to_plot, ax=axes[i, 0])
                    axes[i, 0].set_title(f"Original {i}")
                    librosa.display.specshow(recon_to_plot, ax=axes[i, 1])
                    axes[i, 1].set_title(f"Reconstruction {i}")
                else:
                    # Handle the case when n=1, axes is a 1D array of Axes objects
                    librosa.display.specshow(sample_to_plot, ax=axes[0])
                    axes[0].set_title(f"Original {i}")
                    librosa.display.specshow(recon_to_plot, ax=axes[1])
                    axes[1].set_title(f"Reconstruction {i}")


            plt.tight_layout()
            plt.savefig(os.path.join(log_dir, "spectrogram_reconstructions.png"))
            plt.close()

        # Audio Reconstruction
        if val_recons:
            audio_files = []
            n = min(3, val_recons[-1].shape[0])
            recons_last_batch = val_recons[-1][:n].cpu().squeeze(dim=1).numpy()

            for i in range(n):
                mel = recons_last_batch[i]
                if mel.ndim == 1:
                    mel = mel.reshape(-1, 1)
                mel_db = librosa.power_to_db(mel, ref=np.max)
                audio = librosa.feature.inverse.mel_to_audio(
                    librosa.db_to_power(mel_db), sr=22050, n_iter=32)
                out_path = os.path.join(log_dir, f"reconstructed_audio_{i}.wav")
                sf.write(out_path, audio, 22050)
                audio_files.append(out_path)
                print(f"Reconstructed audio saved: {out_path}")
                display(Audio(out_path))

            if audio_files:
                zip_path = os.path.join(log_dir, "reconstructed_audios.zip")
                with zipfile.ZipFile(zip_path, 'w') as zipf:
                    for file in audio_files:
                        zipf.write(file, arcname=os.path.basename(file))
                print(f"Zipped audio files: {zip_path}")

        # Save the final model
        save_model(model, train_dataset, val_dataset, {
            'epochs': num_epochs,
            'notes': 'VAE training',
            'loss': 'MSE + KL Divergence',
            'device': str(device)
        })

    else:
        print("⛔ Training could not start: No valid training data")

def save_model(model, train_data=None, val_data=None, metadata={}):
    try:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M")
        model_name = f"melody_vae_{timestamp}.pt"
        save_path = os.path.join(OUTPUT_DIR, model_name)

        torch.save({
            'state_dict': model.state_dict(),
            'config': {
                'input_shape': (1, 64, 256),
                'latent_dim': model.latent_dim
            },
            'metadata': {
                'train_samples': len(train_data) if train_data else 0,
                'val_samples': len(val_data) if val_data else 0,
                'saved_at': timestamp,
                **metadata
            }
        }, save_path)

        print(f"✅ Model successfully saved: {save_path}")
        return save_path
    except Exception as e:
        print(f"❌ Save error: {str(e)}")
        return None

if __name__ == "__main__":
    main()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Using device: cuda
📁 Loading data...


Processing Audio Files: 100%|██████████| 5/5 [00:02<00:00,  2.03it/s]


Model initialized with input shape: (1, 64, 256)


Training Epochs:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 1/10, Train Loss: 0.2980, Val Loss: 13.7451, PSNR: 14.72, SSIM: 0.446
📌 Best model saved (val_loss=13.7451)


Training Epochs:  10%|█         | 1/10 [00:00<00:02,  3.36it/s]

Epoch 2/10, Train Loss: 54.5382, Val Loss: 0.1457, PSNR: 14.86, SSIM: 0.431
📌 Best model saved (val_loss=0.1457)


Training Epochs:  20%|██        | 2/10 [00:00<00:02,  3.06it/s]

Epoch 3/10, Train Loss: 0.4664, Val Loss: 0.1171, PSNR: 14.98, SSIM: 0.435
📌 Best model saved (val_loss=0.1171)


Training Epochs:  30%|███       | 3/10 [00:01<00:02,  2.87it/s]

Epoch 4/10, Train Loss: 0.3414, Val Loss: 0.1024, PSNR: 15.06, SSIM: 0.427
📌 Best model saved (val_loss=0.1024)


Training Epochs:  40%|████      | 4/10 [00:01<00:02,  2.72it/s]

Epoch 5/10, Train Loss: 0.2925, Val Loss: 0.0564, PSNR: 15.23, SSIM: 0.419
📌 Best model saved (val_loss=0.0564)


Training Epochs:  50%|█████     | 5/10 [00:01<00:01,  2.67it/s]

Epoch 6/10, Train Loss: 0.1127, Val Loss: 0.0440, PSNR: 15.29, SSIM: 0.425
📌 Best model saved (val_loss=0.0440)


Training Epochs:  60%|██████    | 6/10 [00:02<00:01,  2.49it/s]

Epoch 7/10, Train Loss: 0.0575, Val Loss: 0.0437, PSNR: 15.33, SSIM: 0.430
📌 Best model saved (val_loss=0.0437)


Training Epochs:  70%|███████   | 7/10 [00:02<00:01,  2.36it/s]

Epoch 8/10, Train Loss: 0.0473, Val Loss: 0.0450, PSNR: 15.39, SSIM: 0.428


Training Epochs:  80%|████████  | 8/10 [00:02<00:00,  2.70it/s]

Epoch 9/10, Train Loss: 0.0492, Val Loss: 0.0454, PSNR: 15.22, SSIM: 0.423


Training Epochs:  90%|█████████ | 9/10 [00:03<00:00,  2.94it/s]

Epoch 10/10, Train Loss: 0.0454, Val Loss: 0.0426, PSNR: 15.35, SSIM: 0.432
📌 Best model saved (val_loss=0.0426)


Training Epochs: 100%|██████████| 10/10 [00:03<00:00,  2.67it/s]


Latent vectors saved to: /content/drive/MyDrive/vae_project/logs/20250727_152915/latent_vectors.npy
Shape of samples before plotting: (1, 64, 256)
Shape of recons before plotting: (1, 64, 256)
Reconstructed audio saved: /content/drive/MyDrive/vae_project/logs/20250727_152915/reconstructed_audio_0.wav


Zipped audio files: /content/drive/MyDrive/vae_project/logs/20250727_152915/reconstructed_audios.zip
✅ Model successfully saved: /content/drive/MyDrive/vae_project/logs/melody_vae_20250727_1529.pt
