In [24]:
!pip install datasets
import datasets
conda create -n musicgen python=3.9
conda activate musicgen
pip install chord-extractor



ImportError: The pyarrow installation is not built with support for the Parquet file format (DLL load failed while importing _parquet: Belirtilen modül bulunamadı.)

SingSong-inspired vocal→accompaniment conditioning, and hooks for Jukebox-style stylization (reference-style embedding & temperature) and Cococo-inspired UI steering (instrument “lanes”, multiple alternatives, mood sliders). It keeps the chord-conditioned Transformer-VAE training intact and adds a second conditional path (vocal-to-accompaniment). It also renders graphs + instant audio in-notebook.

Notes:

No chord-extractor dependency (Colab issues). Chords come from a robust chroma+music21 fallback.

If Demucs is available we’ll use it for stems; otherwise we fall back to fast librosa HPSS to approximate vocal/accompaniment.

If you don’t have stems, the vocal→accomp training will auto-skip; you can still use chord-conditioned training and the UI steering preview.


 (Chord-conditioned Transformer-VAE) with mixed-precision (AMP) and grad clipping.

SingSong-style module: Vocal2AccompTVAE + dataset that uses stems (Demucs if available, else HPSS).

Jukebox-style stylization:

ReferenceStyleEncoder to pull a compact style vector from a short reference clip.

Temperature control to widen/narrow sampling from z.

Cococo-style steering:

Instrument lanes (bass/mids/highs) as mel-band masks so you can regenerate only selected parts.

Multiple alternatives (N candidates) in one call.

Mood/novelty sliders (simple knobs that scale z and decoder dropout).

Thesis graphs: Loss curves, PSNR/SSIM, latent PCA/UMAP, recon grids, difference maps.

Instant audio: IPython.display.Audio for input and generated outputs.

Backend-ready export: weights + audio_params.json in a neat output tree

SingSong-inspired (vocal→accompaniment)

StemSeparator produces (vocals, accompaniment) stems (Demucs → HPSS fallback).

Vocal2AccompDataset builds mel pairs for training.

Vocal2AccompTVAE encodes vocal mel (+ optional chord per frame) and decodes accompaniment mel.

Train with train_epoch_vocal2accomp, evaluate with PSNR/SSIM just like Option A.

Jukebox-style stylization

ReferenceStyleEncoder extracts a style vector from a short reference clip; it gates the latent z inside Vocal2AccompTVAE.reparam(...).

Temperature widens/narrows sampling noise.

Mood biases a few latent dims to push bright/dark coloration (simple but effective control).

Cococo-inspired UI steering

Instrument lanes via mel-band masks (mel_lane_masks + apply_lanes): regenerate only bass/mids/highs while keeping the rest of the original target (or overwrite entirely).

Multiple alternatives: num_alternatives generates N candidates in one go.

Sliders:

temperature → diversity,

mood → timbral/latent tilt,

novelty → boosts decoder dropout slightly to increase variety.

Where to tweak (quick pointers)

Datasets

Put your large pretrain WAVs in CONFIG["paths"]["big_data_dir"].

Put your personal recordings in CONFIG["paths"]["my_recordings_dir"].

Stems cache goes to CONFIG["paths"]["stems_cache"].

If you have a real stems dataset (e.g., MUSDB18-HQ), point big_data_dir to it—Vocal2AccompDataset will pick it up and skip separation when vocals.wav and accompaniment.wav already exist next to the originals.

Batch sizes

Start with batch_size: 12 (works on most 8–12GB GPUs).

If OOM: reduce to 8 or 6; or set n_mels: 96 and/or max_frames: 192.

AMP

Mixed precision is enabled when CUDA is present (CONFIG['training']['amp']=True). It’s already integrated into both training loops.

Backend export

Final weights and audio_params.json are written into ./musicai_runs/<session>/.

Now backend only needs to load final_cc_tvae.pth (for chord-conditioned tasks) and final_vocal2accomp.pth (for vocal→accompaniment) with the same model classes, and read audio_params.json to match preprocessing.

In [21]:
# =========================
# Option A ++ (Jupyter)
# Chord-Conditioned TVAE + SingSong-inspired Vocal→Accompaniment
# Jukebox-style stylization + Cococo-style steering (lanes, alternatives, sliders)
# Thesis graphs + AMP + instant audio
# =========================

# ---- 0) Setup / installs ----
import sys, subprocess, importlib, os, json, math, random, shutil, glob, io, warnings
from pathlib import Path
from datetime import datetime

def pip_install(pkg: str) -> None:
    """Import or install a package."""
    try:
        importlib.import_module(pkg.split("==")[0].split("[")[0].replace("-", "_"))
    except Exception:
        print(f"[install] {pkg}")
        subprocess.run([sys.executable, "-m", "pip", "install", pkg, "-q"], check=False)

warnings.filterwarnings("ignore", category=UserWarning)

# Gerekli kütüphaneleri kur
for pkg in [
    "librosa",
    "soundfile",
    "matplotlib",
    "umap-learn",
    "scikit-image",
    "tqdm",
    "music21",
    "plotly",
]:
    pip_install(pkg)

# PyAudio ve diğer opsiyonel kütüphaneleri dene
try:
    import pyaudio  # type: ignore
except Exception:
    pip_install("pyaudio")

try:
    import demucs  # type: ignore
except Exception:
    pip_install("demucs")

# ---- 1) Imports ----
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import Audio, display, HTML  # noqa: F401
from tqdm import tqdm

import librosa, librosa.display
import soundfile as sf  # noqa: F401
from sklearn.decomposition import PCA  # noqa: F401
import umap  # noqa: F401
from skimage.metrics import peak_signal_noise_ratio as psnr  # noqa: F401
from skimage.metrics import structural_similarity as ssim  # noqa: F401

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# music21 for robust chord naming from chroma
from music21 import (
    chord as m21chord,
    pitch as m21pitch,
    key as m21key,  # noqa: F401
    note as m21note,  # noqa: F401
    stream as m21stream,  # noqa: F401
    harmony as m21harmony,  # noqa: F401
)

# ---- 2) CONFIG ----
CONFIG = {
    "seed": 42,
    "audio": {
        "sample_rate": 22050,
        "n_fft": 2048,
        "hop_length": 512,
        "win_length": 1024,
        "n_mels": 128,
        "fmin": 30,
        "fmax": 8000,
        "max_frames": 256,
    },
    "model": {
        "latent_dim": 256,
        "batch_size": 12,
        "init_lr": 3e-4,
        "num_epochs_pretrain": 6,
        "num_epochs_finetune": 4,
        "d_model": 256,
        "nhead": 4,
        "num_encoder_layers": 4,
        "num_decoder_layers": 4,
        "dim_feedforward": 1024,
        "dropout": 0.1,
        "grad_clip": 1.0,
    },
    "training": {
        "weight_decay": 1e-5,
        "checkpoint_interval": 2,
        "skip_nonfinite": True,
    },
    "data": {"wav_dir": "./wavs", "pretrain_data": "./pretrain_data"},
}

# Set seeds
random.seed(CONFIG["seed"])
np.random.seed(CONFIG["seed"])
torch.manual_seed(CONFIG["seed"])
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(CONFIG["seed"])

# ---- 3) Data Loading / Preprocessing ----
class ChordedMelDataset(Dataset):
    def __init__(self, file_paths, chord_vocab=None, build_vocab=False):
        self.file_paths = [f for f in file_paths if os.path.exists(f)]
        print(f"Using {len(self.file_paths)} existing files out of {len(file_paths)} requested")

        self.data = []
        self.chord_vocab = chord_vocab.copy() if chord_vocab else {}

        if build_vocab or not self.chord_vocab:
            self._build_chord_vocab()
        else:
            if "N" not in self.chord_vocab:
                self.chord_vocab["N"] = len(self.chord_vocab)
            self._load_and_process()

    def _build_chord_vocab(self):
        """Build chord vocabulary from all files"""
        print("Building chord vocabulary...")
        chord_counter = {}

        for fpath in tqdm(self.file_paths, desc="Build chord vocab"):
            try:
                y, sr = librosa.load(fpath, sr=CONFIG["audio"]["sample_rate"])
                chroma = librosa.feature.chroma_cqt(y=y, sr=sr)
                chords = self._extract_chords(chroma)
                for chord in chords:
                    chord_counter[chord] = chord_counter.get(chord, 0) + 1
            except Exception as e:  # pragma: no cover - dataset errors
                print(f"Error processing {fpath}: {e}")
                continue

        # Create vocab with special tokens
        self.chord_vocab = {"<pad>": 0, "<unk>": 1, "N": 2}
        idx = 3
        for chord in sorted(ch for ch in chord_counter.keys() if ch != "N"):
            self.chord_vocab[chord] = idx
            idx += 1

        print(f"Chord vocab size: {len(self.chord_vocab)}")
        self._load_and_process()

    def _load_and_process(self):
        """Load and process all files"""
        print("Processing audio files...")
        for fpath in tqdm(self.file_paths, desc="Index chord/mel"):
            try:
                y, sr = librosa.load(fpath, sr=CONFIG["audio"]["sample_rate"])

                # Extract mel spectrogram
                mel = librosa.feature.melspectrogram(
                    y=y,
                    sr=sr,
                    n_fft=CONFIG["audio"]["n_fft"],
                    hop_length=CONFIG["audio"]["hop_length"],
                    n_mels=CONFIG["audio"]["n_mels"],
                    fmin=CONFIG["audio"]["fmin"],
                    fmax=CONFIG["audio"]["fmax"],
                )
                mel_db = librosa.power_to_db(mel, ref=np.max)

                # Extract chroma and chords
                chroma = librosa.feature.chroma_cqt(y=y, sr=sr)
                chords = self._extract_chords(chroma)
                chord_indices = [
                    self.chord_vocab.get(c, self.chord_vocab["<unk>"])
                    for c in chords
                ]

                # Split into segments
                seg_length = CONFIG["audio"]["max_frames"]
                mel_segments = self._segment_mel(mel_db, seg_length)
                chord_segments = self._segment_chords(
                    chord_indices, seg_length, len(mel_segments)
                )

                for mel_seg, chord_seg in zip(mel_segments, chord_segments):
                    self.data.append(
                        {
                            "mel": torch.FloatTensor(mel_seg),
                            "chord": torch.LongTensor(chord_seg),
                            "file_path": fpath,
                        }
                    )

            except Exception as e:  # pragma: no cover - dataset errors
                print(f"Error processing {fpath}: {e}")
                continue

    def _extract_chords(self, chroma):
        """Extract chord names from chroma using music21"""
        chords = []
        for i in range(chroma.shape[1]):
            frame = chroma[:, i]
            if np.max(frame) < 0.1:  # silence threshold
                chords.append("N")
                continue

            prominent = np.where(frame > 0.5 * np.max(frame))[0]
            if len(prominent) == 0:
                chords.append("N")
                continue

            try:
                pitches = [m21pitch.Pitch(pitchClass=p).name for p in prominent]
                m21_chord = m21chord.Chord(pitches)
                chord_name = m21_chord.root().name + m21_chord.quality
                chords.append(chord_name)
            except Exception:
                chords.append("N")

        return chords

    def _segment_mel(self, mel, seg_length):
        """Split mel into fixed-length segments and transpose to (frames, mels)"""
        segments = []
        n_frames = mel.shape[1]

        for i in range(0, n_frames, seg_length):
            seg = mel[:, i : i + seg_length]
            if seg.shape[1] < seg_length:
                pad_width = seg_length - seg.shape[1]
                seg = np.pad(seg, ((0, 0), (0, pad_width)), mode="constant")
            segments.append(seg.T)  # (frames, mels)

        return segments

    def _segment_chords(self, chords, seg_length, num_segments):
        """Split chords into segments aligned with mel segments"""
        segments = []
        for i in range(0, len(chords), seg_length):
            seg = chords[i : i + seg_length]
            if len(seg) < seg_length:
                seg += [self.chord_vocab["N"]] * (seg_length - len(seg))
            segments.append(seg)

        while len(segments) < num_segments:
            segments.append([self.chord_vocab["N"]] * seg_length)
        return segments[:num_segments]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


def chord_collate(batch):
    """Collate function for chord-conditioned batches"""
    mels = torch.stack([item["mel"] for item in batch])
    chords = torch.stack([item["chord"] for item in batch])
    file_paths = [item["file_path"] for item in batch]

    return {"mel": mels, "chord": chords, "file_path": file_paths}


# ---- 4) Model Architecture ----
class TransformerVAE(nn.Module):
    def __init__(self, input_dim, chord_vocab_size, latent_dim=256):
        super().__init__()
        self.latent_dim = latent_dim

        # Chord embedding
        self.chord_embedding = nn.Embedding(chord_vocab_size, latent_dim)

        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim + latent_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 2 * latent_dim),
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim + latent_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, input_dim),
        )

    def encode(self, x, chord_emb):
        conditioned_x = torch.cat([x, chord_emb], dim=-1)
        params = self.encoder(conditioned_x)
        mu, logvar = params.chunk(2, dim=-1)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, chord_emb):
        conditioned_z = torch.cat([z, chord_emb], dim=-1)
        return self.decoder(conditioned_z)

    def forward(self, x, chord):
        chord_emb = self.chord_embedding(chord)

        batch_size, seq_len, mel_dim = x.shape
        x_flat = x.view(batch_size * seq_len, mel_dim)
        chord_emb_flat = chord_emb.view(batch_size * seq_len, -1)

        mu, logvar = self.encode(x_flat, chord_emb_flat)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z, chord_emb_flat)

        recon = recon.view(batch_size, seq_len, mel_dim)
        mu = mu.view(batch_size, seq_len, self.latent_dim)
        logvar = logvar.view(batch_size, seq_len, self.latent_dim)
        return recon, mu, logvar
# ---- 5) Training Functions ----
def vae_loss(recon_x, x, mu, logvar, beta=1.0):
    recon_loss = F.mse_loss(recon_x, x, reduction="sum")
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + beta * kl_loss, recon_loss, kl_loss


def train_epoch(model, dataloader, optimizer, device, beta=1.0):
    model.train()
    total_loss = 0
    total_recon = 0
    total_kl = 0

    for batch in tqdm(dataloader, desc="Training"):
        mel = batch["mel"].to(device)
        chord = batch["chord"].to(device)

        optimizer.zero_grad()
        recon, mu, logvar = model(mel, chord)

        loss, recon_loss, kl_loss = vae_loss(recon, mel, mu, logvar, beta)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG["model"]["grad_clip"])
        optimizer.step()

        total_loss += loss.item()
        total_recon += recon_loss.item()
        total_kl += kl_loss.item()

    n = len(dataloader.dataset)
    return total_loss / n, total_recon / n, total_kl / n


def validate_epoch(model, dataloader, device, beta=1.0):
    model.eval()
    total_loss = 0
    total_recon = 0
    total_kl = 0

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validation"):
            mel = batch["mel"].to(device)
            chord = batch["chord"].to(device)

            recon, mu, logvar = model(mel, chord)
            loss, recon_loss, kl_loss = vae_loss(recon, mel, mu, logvar, beta)

            total_loss += loss.item()
            total_recon += recon_loss.item()
            total_kl += kl_loss.item()

    n = len(dataloader.dataset)
    return total_loss / n, total_recon / n, total_kl / n


# ---- 6) Main Function ----
def main():
    os.makedirs(CONFIG["data"]["wav_dir"], exist_ok=True)
    os.makedirs(CONFIG["data"]["pretrain_data"], exist_ok=True)

    print(f"Looking for audio files in: {CONFIG['data']['wav_dir']}")

    wav_patterns = ["*.wav", "*.WAV", "*.mp3", "*.MP3", "*.flac", "*.FLAC"]
    my_wavs = []
    for pattern in wav_patterns:
        my_wavs.extend(Path(CONFIG["data"]["wav_dir"]).glob(f"**/{pattern}"))

    pretrain_files = []
    for pattern in wav_patterns:
        pretrain_files.extend(Path(CONFIG["data"]["pretrain_data"]).glob(f"**/{pattern}"))

    print(f"Found {len(my_wavs)} files in wav_dir, {len(pretrain_files)} files in pretrain_data")

    if not my_wavs and not pretrain_files:
        print("ERROR: No audio files found!")
        print("Please place audio files in the directories.")
        return None, None

    chord_vocab = {}
    pre_ds = None

    if pretrain_files:
        print("Building pretrain dataset...")
        pre_ds = ChordedMelDataset([str(p) for p in pretrain_files], build_vocab=True)
        chord_vocab = pre_ds.chord_vocab
        print(f"Pretrain dataset size: {len(pre_ds)}")

    print("Building fine-tuning dataset...")
    ft_ds = ChordedMelDataset(
        [str(p) for p in my_wavs], chord_vocab=chord_vocab, build_vocab=not chord_vocab
    )
    print(f"Fine-tuning dataset size: {len(ft_ds)}")

    if (pre_ds is None or len(pre_ds) == 0) and len(ft_ds) == 0:
        print("ERROR: No valid training data!")
        return None, None

    bs = CONFIG["model"]["batch_size"]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    dataloaders = {}

    if pre_ds and len(pre_ds) > 0:
        pre_train = DataLoader(
            pre_ds,
            batch_size=min(bs, len(pre_ds)),
            shuffle=True,
            num_workers=0,
            collate_fn=chord_collate,
        )
        pre_val = DataLoader(
            pre_ds,
            batch_size=min(bs, len(pre_ds)),
            shuffle=False,
            num_workers=0,
            collate_fn=chord_collate,
        )
        dataloaders["pretrain"] = (pre_train, pre_val)

    if len(ft_ds) > 0:
        ft_train = DataLoader(
            ft_ds,
            batch_size=min(bs, len(ft_ds)),
            shuffle=True,
            num_workers=0,
            collate_fn=chord_collate,
        )
        ft_val = DataLoader(
            ft_ds,
            batch_size=min(bs, len(ft_ds)),
            shuffle=False,
            num_workers=0,
            collate_fn=chord_collate,
        )
        dataloaders["finetune"] = (ft_train, ft_val)

    input_dim = CONFIG["audio"]["n_mels"]
    model = TransformerVAE(input_dim, len(chord_vocab), CONFIG["model"]["latent_dim"])
    model = model.to(device)

    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=CONFIG["model"]["init_lr"],
        weight_decay=CONFIG["training"]["weight_decay"],
    )

    for phase, (train_loader, val_loader) in dataloaders.items():
        print(f"\n=== Starting {phase} phase ===")
        num_epochs = (
            CONFIG["model"]["num_epochs_pretrain"]
            if phase == "pretrain"
            else CONFIG["model"]["num_epochs_finetune"]
        )

        for epoch in range(num_epochs):
            print(f"\nEpoch {epoch + 1}/{num_epochs}")
            train_loss, train_recon, train_kl = train_epoch(
                model, train_loader, optimizer, device
            )
            val_loss, val_recon, val_kl = validate_epoch(model, val_loader, device)

            print(
                f"Train Loss: {train_loss:.4f} (Recon: {train_recon:.4f}, KL: {train_kl:.4f})"
            )
            print(
                f"Val Loss: {val_loss:.4f} (Recon: {val_recon:.4f}, KL: {val_kl:.4f})"
            )

    print("\nTraining completed!")
    return model, chord_vocab


# ---- 7) Run Main ----
if __name__ == "__main__":
    try:
        model, chord_vocab = main()
        if model is not None:
            print("Model training successful!")
        else:
            print("Training failed - no data available")
    except Exception as e:  # pragma: no cover - runtime errors
        print(f"Error during training: {e}")
        import traceback

        traceback.print_exc()
        print("Processing audio files...")
        for fpath in tqdm(self.file_paths, desc="Index chord/mel"):
            try:
                y, sr = librosa.load(fpath, sr=CONFIG["audio"]["sample_rate"])

                # Extract mel spectrogram
                mel = librosa.feature.melspectrogram(
                    y=y,
                    sr=sr,
                    n_fft=CONFIG["audio"]["n_fft"],
                    hop_length=CONFIG["audio"]["hop_length"],
                    n_mels=CONFIG["audio"]["n_mels"],
                    fmin=CONFIG["audio"]["fmin"],
                    fmax=CONFIG["audio"]["fmax"],
                )
                mel_db = librosa.power_to_db(mel, ref=np.max)

                # Extract chroma and chords
                chroma = librosa.feature.chroma_cqt(y=y, sr=sr)
                chords = self._extract_chords(chroma)
                chord_indices = [
                    self.chord_vocab.get(c, self.chord_vocab["<unk>"])
                    for c in chords
                ]

                seg_length = CONFIG["audio"]["max_frames"]
                mel_segments = self._segment_mel(mel_db, seg_length)
                chord_segments = self._segment_chords(
                    chord_indices, seg_length, len(mel_segments)
                )

                for mel_seg, chord_seg in zip(mel_segments, chord_segments):
                    self.data.append(
                        {
                            "mel": torch.FloatTensor(mel_seg),
                            "chord": torch.LongTensor(chord_seg),
                            "file_path": fpath,
                        }
                    )

            except Exception as e:  # pragma: no cover - dataset errors
                print(f"Error processing {fpath}: {e}")
                continue

    def _extract_chords(self, chroma):
        """Extract chord names from chroma using music21"""
        chords = []
        for i in range(chroma.shape[1]):
            frame = chroma[:, i]
            if np.max(frame) < 0.1:  # silence threshold
                chords.append("N")
                continue

            prominent = np.where(frame > 0.5 * np.max(frame))[0]
            if len(prominent) == 0:
                chords.append("N")
                continue

            try:
                pitches = [m21pitch.Pitch(pitchClass=p).name for p in prominent]
                m21_chord = m21chord.Chord(pitches)
                chord_name = m21_chord.root().name + m21_chord.quality
                chords.append(chord_name)
            except Exception:
                chords.append("N")

        return chords

    def _segment_mel(self, mel, seg_length):
        """Split mel into fixed-length segments and transpose to (frames, mels)"""
        segments = []
        n_frames = mel.shape[1]

        for i in range(0, n_frames, seg_length):
            seg = mel[:, i : i + seg_length]
            if seg.shape[1] < seg_length:
                pad_width = seg_length - seg.shape[1]
                seg = np.pad(seg, ((0, 0), (0, pad_width)), mode="constant")
            segments.append(seg.T)  # (frames, mels)

        return segments

    def _segment_chords(self, chords, seg_length, num_segments):
        """Split chords into segments aligned with mel segments"""
        segments = []
        for i in range(0, len(chords), seg_length):
            seg = chords[i : i + seg_length]
            if len(seg) < seg_length:
                seg += [self.chord_vocab["N"]] * (seg_length - len(seg))
            segments.append(seg)

        while len(segments) < num_segments:
            segments.append([self.chord_vocab["N"]] * seg_length)
        return segments[:num_segments]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


def chord_collate(batch):
    """Collate function for chord-conditioned batches"""
    mels = torch.stack([item["mel"] for item in batch])
    chords = torch.stack([item["chord"] for item in batch])
    file_paths = [item["file_path"] for item in batch]

    return {"mel": mels, "chord": chords, "file_path": file_paths}
# ---- 4) Model Architecture ----
class TransformerVAE(nn.Module):
    def __init__(self, input_dim, chord_vocab_size, latent_dim=256):
        super().__init__()
        self.latent_dim = latent_dim

        self.chord_embedding = nn.Embedding(chord_vocab_size, latent_dim)

        self.encoder = nn.Sequential(
            nn.Linear(input_dim + latent_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 2 * latent_dim),
        )

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim + latent_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, input_dim),
        )

    def encode(self, x, chord_emb):
        conditioned_x = torch.cat([x, chord_emb], dim=-1)
        params = self.encoder(conditioned_x)
        mu, logvar = params.chunk(2, dim=-1)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, chord_emb):
        conditioned_z = torch.cat([z, chord_emb], dim=-1)
        return self.decoder(conditioned_z)

    def forward(self, x, chord):
        chord_emb = self.chord_embedding(chord)

        batch_size, seq_len, mel_dim = x.shape
        x_flat = x.view(batch_size * seq_len, mel_dim)
        chord_emb_flat = chord_emb.view(batch_size * seq_len, -1)

        mu, logvar = self.encode(x_flat, chord_emb_flat)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z, chord_emb_flat)

        recon = recon.view(batch_size, seq_len, mel_dim)
        mu = mu.view(batch_size, seq_len, self.latent_dim)
        logvar = logvar.view(batch_size, seq_len, self.latent_dim)
        return recon, mu, logvar
# ---- 5) Training Functions ----
def vae_loss(recon_x, x, mu, logvar, beta=1.0):
    recon_loss = F.mse_loss(recon_x, x, reduction="sum")
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + beta * kl_loss, recon_loss, kl_loss


def train_epoch(model, dataloader, optimizer, device, beta=1.0):
    model.train()
    total_loss = total_recon = total_kl = 0

    for batch in tqdm(dataloader, desc="Training"):
        mel = batch["mel"].to(device)
        chord = batch["chord"].to(device)

        optimizer.zero_grad()
        recon, mu, logvar = model(mel, chord)

        loss, recon_loss, kl_loss = vae_loss(recon, mel, mu, logvar, beta)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG["model"]["grad_clip"])
        optimizer.step()

        total_loss += loss.item()
        total_recon += recon_loss.item()
        total_kl += kl_loss.item()

    n = len(dataloader.dataset)
    return total_loss / n, total_recon / n, total_kl / n


def validate_epoch(model, dataloader, device, beta=1.0):
    model.eval()
    total_loss = total_recon = total_kl = 0

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validation"):
            mel = batch["mel"].to(device)
            chord = batch["chord"].to(device)

            recon, mu, logvar = model(mel, chord)
            loss, recon_loss, kl_loss = vae_loss(recon, mel, mu, logvar, beta)

            total_loss += loss.item()
            total_recon += recon_loss.item()
            total_kl += kl_loss.item()

    n = len(dataloader.dataset)
    return total_loss / n, total_recon / n, total_kl / n
# ---- 6) Main Function ----
def main():
    os.makedirs(CONFIG["data"]["wav_dir"], exist_ok=True)
    os.makedirs(CONFIG["data"]["pretrain_data"], exist_ok=True)

    print(f"Looking for audio files in: {CONFIG['data']['wav_dir']}")

    wav_patterns = ["*.wav", "*.WAV", "*.mp3", "*.MP3", "*.flac", "*.FLAC"]
    my_wavs = []
    for pattern in wav_patterns:
        my_wavs.extend(Path(CONFIG["data"]["wav_dir"]).glob(f"**/{pattern}"))

    pretrain_files = []
    for pattern in wav_patterns:
        pretrain_files.extend(Path(CONFIG["data"]["pretrain_data"]).glob(f"**/{pattern}"))

    print(f"Found {len(my_wavs)} files in wav_dir, {len(pretrain_files)} files in pretrain_data")

    if not my_wavs and not pretrain_files:
        print("ERROR: No audio files found!")
        print("Please place audio files in the directories.")
        return None, None

    chord_vocab = {}
    pre_ds = None

    if pretrain_files:
        print("Building pretrain dataset...")
        pre_ds = ChordedMelDataset([str(p) for p in pretrain_files], build_vocab=True)
        chord_vocab = pre_ds.chord_vocab
        print(f"Pretrain dataset size: {len(pre_ds)}")

    print("Building fine-tuning dataset...")
    ft_ds = ChordedMelDataset(
        [str(p) for p in my_wavs], chord_vocab=chord_vocab, build_vocab=not chord_vocab
    )
    print(f"Fine-tuning dataset size: {len(ft_ds)}")

    if (pre_ds is None or len(pre_ds) == 0) and len(ft_ds) == 0:
        print("ERROR: No valid training data!")
        return None, None

    bs = CONFIG["model"]["batch_size"]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    dataloaders = {}

    if pre_ds and len(pre_ds) > 0:
        pre_train = DataLoader(
            pre_ds,
            batch_size=min(bs, len(pre_ds)),
            shuffle=True,
            num_workers=0,
            collate_fn=chord_collate,
        )
        pre_val = DataLoader(
            pre_ds,
            batch_size=min(bs, len(pre_ds)),
            shuffle=False,
            num_workers=0,
            collate_fn=chord_collate,
        )
        dataloaders["pretrain"] = (pre_train, pre_val)

    if len(ft_ds) > 0:
        ft_train = DataLoader(
            ft_ds,
            batch_size=min(bs, len(ft_ds)),
            shuffle=True,
            num_workers=0,
            collate_fn=chord_collate,
        )
        ft_val = DataLoader(
            ft_ds,
            batch_size=min(bs, len(ft_ds)),
            shuffle=False,
            num_workers=0,
            collate_fn=chord_collate,
        )
        dataloaders["finetune"] = (ft_train, ft_val)

    input_dim = CONFIG["audio"]["n_mels"]
    model = TransformerVAE(input_dim, len(chord_vocab), CONFIG["model"]["latent_dim"])
    model = model.to(device)

    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=CONFIG["model"]["init_lr"],
        weight_decay=CONFIG["training"]["weight_decay"],
    )

    for phase, (train_loader, val_loader) in dataloaders.items():
        print(f"\n=== Starting {phase} phase ===")
        num_epochs = (
            CONFIG["model"]["num_epochs_pretrain"]
            if phase == "pretrain"
            else CONFIG["model"]["num_epochs_finetune"]
        )

        for epoch in range(num_epochs):
            print(f"\nEpoch {epoch + 1}/{num_epochs}")
            train_loss, train_recon, train_kl = train_epoch(
                model, train_loader, optimizer, device
            )
            val_loss, val_recon, val_kl = validate_epoch(model, val_loader, device)

            print(
                f"Train Loss: {train_loss:.4f} (Recon: {train_recon:.4f}, KL: {train_kl:.4f})"
            )
            print(
                f"Val Loss: {val_loss:.4f} (Recon: {val_recon:.4f}, KL: {val_kl:.4f})"
            )

    print("\nTraining completed!")
    return model, chord_vocab


# ---- 7) Run Main ----
if __name__ == "__main__":
    try:
        model, chord_vocab = main()
        if model is not None:
            print("Model training successful!")
        else:
            print("Training failed - no data available")
    except Exception as e:  # pragma: no cover - runtime errors
        print(f"Error during training: {e}")
        import traceback

        traceback.print_exc()

[install] umap-learn
[install] scikit-image
[install] music21
[install] fpdf
[install] demucs
Device: cpu

== Build chord dataset ==


Build chord vocab: 0it [00:00, ?it/s]
Index chord/mel: 0it [00:00, ?it/s]


[vocab] chords: 1


Index chord/mel: 100%|███████████████████████████████████████████████████████████████████| 1/1 [00:12<00:00, 12.23s/it]


ValueError: num_samples should be a positive integer value, but got num_samples=0

In [None]:
…/checkpoints/
  final_cc_tvae.pth
  chord_vocab.json
  audio_params.json

In [None]:
# =====  A) DATA STREAMING  =====
def stream_wavs_from_urls(url_list, cache_dir, max_files=None):
    """
    Yalnızca URL ile indirir. Her WAV'ı geçici dosyaya yazar,
    yield eder; istersen iş bitince silersin ya da LRU cache yaparsın.
    """
    ...

def build_manifest_from_hf_or_repo(url_roots) -> list[str]:
    """
    Hugging Face 'resolve' URL'leri veya GitHub raw yollarından
    .wav listesi çıkarır (hazır verilecek URL'ler de olabilir).
    """
    ...

def iter_dataset_from_stream(manifest, tokenizer, mode, max_files=None):
    """
    mode: 'pretrain' or 'finetune'
    Her wav -> mel, per_frame_chords, chord_seq, event_seq
    VAE ve BC/IfO için gerekli objeleri üretir (yield).
    """
    ...

# =====  B) NEXT-CHORD BC =====
class BCChordModel(nn.Module):
    ...  # mevcut modeli kullan
def train_bc_chords(model, seqs, ...):
    ...  # mevcut train_bc ile aynı mantık
def predict_next_chord(model, context_ids, topk=1):
    ...

# =====  C) RL / PPO =====
class SequenceEnv(gym.Env):
    """
    mode: 'events' veya 'chords'
    - events: action = event token id
    - chords: action = chord id (next-chord policy)
    reward = w1*novelty_ngram + w2*consonance + w3*style_bc (+ w4*critic) - w5*KL_ngram
    """
    ...

def compute_ngram_stats(seqs, n=4):
    """ gerçek korpus n-gram dağılımı """
    ...

def ngram_kl_penalty(gen_counts, ref_counts, eps=1e-8):
    """ KL(gen||ref) """
    ...

class TinyCritic(nn.Module):
    """
    Hafif MLP: girdi -> (ngram histogram / token embed avg / basit özellikler)
    çıktı: realism score [0,1]
    """
    ...

def train_tiny_critic(real_seqs, fake_seqs):
    ...

# =====  D) INFERENCE ENDPOINT HELPERS =====
def vae_generate_from_chord_seq(model, chord_seq_ids, audio_params):
    """
    Akor sekansı ver, frame'lere yay (equal split) ve VAE ile üret.
    """
    ...


In [None]:
# %% Module Imports
from src.utils.endpoints import generate_accompaniment
from src.rl.train import train_vfo_value
from src.utils.mode_router import route_generation
import numpy as np


In [None]:
# %% Mini Tests
# Generate accompaniment
out_path = generate_accompaniment('dummy.wav')
print('accompaniment saved to', out_path)

# VFO overfit test
states = np.array([0.,1.,2.,3.])
labels = np.array([0.,1.,1.,0.])
model = train_vfo_value(states, labels)
values = model(states)
print('values', values)
print('route', route_generation(True, None, {'modes':{'fallback_if_vocal_missing':True}})[0])
