In [2]:
# IMPROVED SINGLE-CELL: Semantic text matching + better CNN + intelligent blending
# Paste & run in Google Colab. Uncomment pip install line on first run.

# !pip install -q faiss-cpu torchaudio soundfile pyloudnorm ipywidgets tqdm scikit-learn umap-learn seaborn librosa sentence-transformers
# IMPROVED SINGLE-CELL: Semantic text matching + better CNN + intelligent blending
# Paste & run in Google Colab. Uncomment pip install line on first run.
# ==============================================================
# FINAL: EMOTIFY BOOK-TO-MUSIC MATCHER (Interactive, Audio Play)
# ==============================================================

# NOTE: Run in Google Colab or Jupyter with audio support.
# Installs, dataset download, model loading, interactive UI, and
# 10 generated example paragraphs (>=150 words each) included.

!pip install -q sentence-transformers pandas numpy scikit-learn kagglehub ipywidgets tqdm
!pip install -q faiss-cpu torchaudio soundfile pyloudnorm ipywidgets tqdm scikit-learn umap-learn seaborn librosa sentence-transformers
import os
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer, util
from IPython.display import Audio, display, HTML, clear_output
import ipywidgets as widgets
import kagglehub
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")

from google.colab import files
files.upload()

# -----------------------------
# 3) (Optional) Download EMOTIFY dataset (Kaggle)
#    Uses kagglehub cache ‚Äî may be fast in Colab
# -----------------------------
print("üì¶ Downloading EMOTIFY dataset (may use cache)...")
dataset_path = kagglehub.dataset_download("yash9439/emotify-emotion-classificaiton-in-songs")
print("‚úÖ Dataset download reference:", dataset_path)
!mv /root/.cache/kagglehub/datasets/yash9439/emotify-emotion-classificaiton-in-songs/versions/1 /content/emotify_dataset

import os, math, random, time, warnings, sys
warnings.filterwarnings("ignore")
from pathlib import Path
import numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import torchaudio
import soundfile as sf
from tqdm.notebook import tqdm
import ipywidgets as widgets
from IPython.display import Audio, display, clear_output
import pyloudnorm as pyln
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import pandas as pd

try:
    import faiss
except Exception:
    faiss = None

import librosa

# NEW: Semantic text embedding
try:
    from sentence_transformers import SentenceTransformer
    semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
    SEMANTIC_AVAILABLE = True
except:
    print("sentence-transformers not available, using fallback")
    SEMANTIC_AVAILABLE = False
    semantic_model = None

# -------------------------- USER TOGGLES --------------------------
DEFAULT_USE_GPU_FAST = True
use_librosa_for_report_only = True
SAMPLE_REPORT_N = 200
FEATURE_CACHE_DIR = "/content/feature_cache"
os.makedirs(FEATURE_CACHE_DIR, exist_ok=True)

# -------------------------- EMOTION DESCRIPTIONS --------------------------
EMOTION_DESCRIPTIONS = {
    'amazement': 'This text expresses wonder, awe, astonishment, breathtaking discovery, miraculous revelation, epic grandeur, magnificent spectacle, inspiring triumph, jaw-dropping beauty, cosmic scale, overwhelming majesty.',
    'solemnity': 'This text expresses solemnity, reverence, dignity, sacred ceremony, spiritual depth, formal gravity, serious contemplation, religious devotion, ritual importance, profound respect, weighty significance.',
    'tenderness': 'This text expresses tenderness, warmth, gentleness, intimate affection, caring love, emotional closeness, soft vulnerability, protective nurturing, delicate sensitivity, heartfelt compassion, sweet devotion.',
    'nostalgia': 'This text expresses nostalgia, longing for the past, bittersweet memories, wistful reminiscence, yearning for old times, sentimental reflection, missing what was, remembering yesterday, looking back fondly, homesickness, lost innocence.',
    'calmness': 'This text expresses calmness, peace, tranquility, serenity, stillness, relaxation, quietness, meditation, gentle ease, untroubled mind, soothing comfort, restful contentment, peaceful silence.',
    'power': 'This text expresses power, strength, energy, boldness, heroism, triumph, victory, confidence, determination, dominance, intensity, explosive force, unstoppable drive, fierce courage, commanding presence.',
    'joyful_activation': 'This text expresses joy, happiness, excitement, playfulness, fun, celebration, cheerfulness, enthusiasm, lively energy, dancing spirit, positive vibes, exuberance, delightful pleasure, upbeat optimism.',
    'tension': 'This text expresses tension, stress, anxiety, nervousness, unease, worry, suspense, fear, restlessness, discomfort, apprehension, dread, agitation, uncertainty, foreboding danger, tightness, pressure.',
    'sadness': 'This text expresses sadness, sorrow, grief, melancholy, depression, heartbreak, loss, despair, loneliness, pain, crying, suffering, hopelessness, emptiness, mourning, tearful anguish, emotional hurt.'
}

EMO_COLS = [" amazement"," solemnity"," tenderness"," nostalgia"," calmness"," power"," joyful_activation"," tension"," sadness"]
EMO_KEYS = ['amazement', 'solemnity', 'tenderness', 'nostalgia', 'calmness', 'power', 'joyful_activation', 'tension', 'sadness']

# Build emotion description embeddings once
if SEMANTIC_AVAILABLE:
    EMOTION_DESC_EMBEDDINGS = {k: semantic_model.encode(v, convert_to_numpy=True) for k, v in EMOTION_DESCRIPTIONS.items()}
else:
    EMOTION_DESC_EMBEDDINGS = {}

# -------------------------- PATHS & CONFIG --------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

AUDIO_ROOT = "/content/emotify_dataset"
CSV_PATH = os.path.join(AUDIO_ROOT, "data.csv")
OUTPUT_DIR = "/content/outputs"; os.makedirs(OUTPUT_DIR, exist_ok=True)
CACHE_DIR = "/content/mel_cache"; os.makedirs(CACHE_DIR, exist_ok=True)

SAMPLE_RATE = 22050
AUDIO_DURATION = 20.0
SAMPLES = int(SAMPLE_RATE * AUDIO_DURATION)
N_MELS = 256
N_FFT = 2048
HOP_LENGTH = 512

BATCH_SIZE = 64
EPOCHS = 50
LR = 5e-4
EMBED_DIM = 512
EMOTION_DIM = 9
NUM_WORKERS = 4
AMP = True
EARLY_STOPPING_PATIENCE = 10
WARMUP_EPOCHS = 3
MIN_LR = 1e-6

CHECKPOINT_PATH = os.path.join(OUTPUT_DIR, "cnn_emotion_checkpoint.pth")
BEST_PATH = os.path.join(OUTPUT_DIR, "cnn_emotion_best.pth")

meter = pyln.Meter(SAMPLE_RATE)

# -------------------------- UTILS --------------------------
def safe_mkdir(p): os.makedirs(p, exist_ok=True)
def measure_loudness(audio):
    try:
        return meter.integrated_loudness(audio)
    except:
        rms = np.sqrt(np.mean(audio**2)); db = 20*math.log10(max(rms,1e-9)); return db - 20.0
def normalize_to_lufs(audio, target_lufs=-16.0):
    cur = measure_loudness(audio); db_diff = target_lufs - cur; gain = 10 ** (db_diff / 20.0); return audio * gain

# -------------------------- TEXT TO EMOTION VECTOR --------------------------
def text_to_emotion_vector(text, top_k=3):
    """Convert text to 9D emotion vector using semantic similarity"""
    if not SEMANTIC_AVAILABLE or not text.strip():
        # Fallback: neutral vector
        return np.ones(EMOTION_DIM, dtype=np.float32) / np.sqrt(EMOTION_DIM)

    text_emb = semantic_model.encode(text, convert_to_numpy=True)

    # Compute similarity to each emotion description
    scores = []
    for key in EMO_KEYS:
        desc_emb = EMOTION_DESC_EMBEDDINGS[key]
        sim = np.dot(text_emb, desc_emb) / (np.linalg.norm(text_emb) * np.linalg.norm(desc_emb) + 1e-9)
        scores.append(max(0.0, sim))  # ReLU

    scores = np.array(scores, dtype=np.float32)

    # SPARSIFICATION: Keep only top-K emotions
    if top_k > 0 and top_k < len(scores):
        top_indices = np.argsort(scores)[-top_k:]
        mask = np.zeros_like(scores)
        mask[top_indices] = 1.0
        scores = scores * mask

    # Apply temperature to sharpen distribution (only on non-zero)
    if scores.sum() > 0:
        temp = 0.25  # Very sharp
        exp_scores = np.exp(scores / temp)
        emotion_vec = exp_scores / (exp_scores.sum() + 1e-9)
    else:
        emotion_vec = np.ones(EMOTION_DIM) / EMOTION_DIM

    # Apply power transformation to further emphasize peaks
    emotion_vec = emotion_vec ** 2.0  # More aggressive (was 1.5)

    # Normalize to unit vector
    emotion_vec = emotion_vec / (np.linalg.norm(emotion_vec) + 1e-9)
    return emotion_vec

# -------------------------- TRACK ID DECODER --------------------------
def decode_track_id(tid):
    """Convert global track ID to (genre, local_id) format like '21 rock' """
    if 1 <= tid <= 100:
        genre = "classical"; local_id = tid
    elif 101 <= tid <= 200:
        genre = "rock"; local_id = tid - 100
    elif 201 <= tid <= 300:
        genre = "electronic"; local_id = tid - 200
    elif 301 <= tid <= 400:
        genre = "pop"; local_id = tid - 300
    else:
        # Fallback to just showing the ID
        return f"#{tid}"
    return f"{local_id} {genre}"

# -------------------------- LOAD CSV EMOTION LABELS --------------------------
def load_emotion_vectors(csv_path):
    if not os.path.isfile(csv_path):
        print("Warning: CSV not found:", csv_path)
        return {}
    df = pd.read_csv(csv_path)
    required = ["track id"] + EMO_COLS
    for c in required:
        if c not in df.columns:
            print("CSV missing column:", c)
            return {}
    g = df.groupby("track id")[EMO_COLS].mean()
    vecs = g.values.astype(np.float32)
    norms = np.linalg.norm(vecs, axis=1, keepdims=True) + 1e-9
    vecs = vecs / norms
    track_to_vec = {}
    for tid, vec in zip(g.index.astype(int), vecs):
        track_to_vec[int(tid)] = vec
    print("Loaded emotion vectors for", len(track_to_vec), "tracks from CSV")
    return track_to_vec

emotion_vectors = load_emotion_vectors(CSV_PATH)

# -------------------------- DISCOVER AUDIO DB --------------------------
def discover_audio_db(audio_root=AUDIO_ROOT):
    db = {}
    found = []
    for root, dirs, files in os.walk(audio_root):
        for f in sorted(files):
            if f.lower().endswith(('.mp3','.wav','.flac','.ogg','.m4a')):
                found.append(os.path.join(root,f))
    if not found:
        print("No audio found under", audio_root)
        return {}
    tid = 1
    for p in found:
        g = os.path.basename(os.path.dirname(p))
        vec = emotion_vectors.get(tid, None)
        if vec is None:
            try:
                fname = os.path.splitext(os.path.basename(p))[0]
                fnum = int(fname)
                vec = emotion_vectors.get(fnum, None)
            except:
                vec = None
        if vec is None:
            vec = np.random.rand(EMOTION_DIM).astype(np.float32)
            vec = vec / (np.linalg.norm(vec) + 1e-12)
        db[tid] = {'audio_path': p, 'genre': g, 'emotion_vector': vec}
        tid += 1
    print(f"Discovered {len(db)} audio files under {audio_root}")
    return db

music_emotion_db = discover_audio_db(AUDIO_ROOT)
if len(music_emotion_db) == 0:
    raise RuntimeError("No audio files found. Put audio under AUDIO_ROOT and re-run.")

# -------------------------- MEL CACHE --------------------------
mel_transform = torchaudio.transforms.MelSpectrogram(
    sample_rate=SAMPLE_RATE, n_fft=N_FFT, hop_length=HOP_LENGTH, n_mels=N_MELS, center=True, power=2.0
)
to_db = torchaudio.transforms.AmplitudeToDB(stype='power', top_db=80.0)

def compute_and_cache_mels(db, cache_dir=CACHE_DIR, force=False):
    safe_mkdir(cache_dir)
    processed = 0
    for tid, entry in tqdm(db.items(), desc="Caching mels"):
        path = entry.get('audio_path')
        if not path or not os.path.isfile(path):
            continue
        outp = os.path.join(cache_dir, f"{tid}.npz")
        if os.path.exists(outp) and not force:
            continue
        try:
            y, sr = torchaudio.load(path)
            if sr != SAMPLE_RATE:
                y = torchaudio.functional.resample(y, orig_freq=sr, new_freq=SAMPLE_RATE)
            if y.ndim > 1:
                y = y.mean(dim=0)
            y = y.numpy().astype(np.float32)
            if len(y) < SAMPLES:
                y = np.pad(y, (0, SAMPLES - len(y)))
            else:
                y = y[:SAMPLES]
            mel = mel_transform(torch.from_numpy(y).unsqueeze(0)).squeeze(0)
            mel = to_db(mel).numpy()
            mel = (mel - mel.mean()) / (mel.std() + 1e-9)
            np.savez_compressed(outp, mel=mel, path=path, genre=entry.get('genre',''), emotion=entry.get('emotion_vector'))
            processed += 1
        except Exception as e:
            print("cache fail:", path, ":", str(e))
    print(f"Mel cache computed: {processed} files cached to {cache_dir}")

print("Precomputing mel cache (skip if done)...")
compute_and_cache_mels(music_emotion_db, CACHE_DIR, force=False)

# -------------------------- Dataset --------------------------
class MelDataset(Dataset):
    def __init__(self, db, cache_dir=CACHE_DIR, segment_seconds=5.0):
        self.items = []
        self.db = db
        self.cache_dir = cache_dir
        target_frames = int(segment_seconds * SAMPLE_RATE / HOP_LENGTH)
        self.target_frames = target_frames
        for tid, entry in db.items():
            cachep = os.path.join(cache_dir, f"{tid}.npz")
            if os.path.exists(cachep):
                self.items.append((tid, cachep))
        if not self.items:
            raise RuntimeError("No cached mels found. Run compute_and_cache_mels first.")
    def __len__(self): return len(self.items)
    def __getitem__(self, idx):
        tid, cachep = self.items[idx]
        data = np.load(cachep, allow_pickle=True)
        mel = data['mel'].astype(np.float32)
        emo = self.db[tid]['emotion_vector'].astype(np.float32)
        tdim = mel.shape[1]
        if tdim < self.target_frames:
            mel = np.pad(mel, ((0,0),(0,self.target_frames-tdim)))
        elif tdim > self.target_frames:
            start = (tdim - self.target_frames)//2
            mel = mel[:, start:start+self.target_frames]
        mel = torch.from_numpy(mel).unsqueeze(0)
        return mel, torch.from_numpy(emo), int(tid), cachep

# -------------------------- IMPROVED MODEL --------------------------
class ConvBN(nn.Module):
    def __init__(self, in_ch, out_ch, k=3, s=1, p=1):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, k, stride=s, padding=p, bias=False)
        self.bn = nn.BatchNorm2d(out_ch)
        self.act = nn.ReLU(inplace=True)
    def forward(self, x): return self.act(self.bn(self.conv(x)))

class SEBlock(nn.Module):
    """Squeeze-and-Excitation block for channel attention"""
    def __init__(self, ch, reduction=16):
        super().__init__()
        self.fc = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(ch, ch//reduction, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch//reduction, ch, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return x * self.fc(x)

class ResidualBlock(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.net = nn.Sequential(
            ConvBN(ch, ch, 3,1,1),
            ConvBN(ch, ch, 3,1,1)
        )
        self.se = SEBlock(ch)
    def forward(self, x):
        return x + self.se(self.net(x))

class ImprovedResNet(nn.Module):
    """Deeper ResNet with attention"""
    def __init__(self, emb_dim=EMBED_DIM, emotion_dim=EMOTION_DIM):
        super().__init__()
        self.stem = nn.Sequential(
            ConvBN(1,64,3,1,1),
            ResidualBlock(64),
            ConvBN(64,128,3,2,1),
            ResidualBlock(128),
            ResidualBlock(128),
            ConvBN(128,256,3,2,1),
            ResidualBlock(256),
            ResidualBlock(256),
            ConvBN(256,512,3,2,1),
            ResidualBlock(512),
            ResidualBlock(512)
        )
        self.gap = nn.AdaptiveAvgPool2d((1,1))
        self.emb = nn.Sequential(
            nn.Linear(512, emb_dim),
            nn.BatchNorm1d(emb_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2)
        )
        self.head = nn.Linear(emb_dim, emotion_dim)

    def forward(self, x):
        h = self.stem(x)
        h = self.gap(h).view(h.size(0), -1)
        z = self.emb(h)
        z = F.normalize(z, dim=1)
        out = self.head(z)
        return z, out

# -------------------------- TRAIN --------------------------
def train_model(db, epochs=EPOCHS, batch_size=BATCH_SIZE, lr=LR, checkpoint_path=CHECKPOINT_PATH, best_path=BEST_PATH):
    ds = MelDataset(db)
    n = len(ds)
    val_n = max(16, int(0.08*n))
    train_n = n - val_n
    train_ds, val_ds = random_split(ds, [train_n, val_n])
    train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
    val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    model = ImprovedResNet().to(DEVICE)

    # Optimizer with weight decay
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4, betas=(0.9, 0.999))

    # Cosine annealing with warmup
    def lr_lambda(epoch):
        if epoch < WARMUP_EPOCHS:
            return (epoch + 1) / WARMUP_EPOCHS
        else:
            progress = (epoch - WARMUP_EPOCHS) / (epochs - WARMUP_EPOCHS)
            return 0.5 * (1 + math.cos(math.pi * progress))

    scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda)

    # ReduceLROnPlateau as backup
    plateau_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        opt, mode='min', factor=0.5, patience=4, min_lr=MIN_LR
    )

    scaler = torch.cuda.amp.GradScaler(enabled=AMP)
    start_ep = 0
    hist = {'train_loss':[], 'val_loss':[], 'lr':[]}
    best_val = 1e9
    no_improve = 0

    if os.path.exists(checkpoint_path):
        try:
            chk = torch.load(checkpoint_path, map_location=DEVICE)
            model.load_state_dict(chk['model_state_dict'])
            hist = chk.get('hist', hist)
            # Ensure 'lr' key exists for backward compatibility
            if 'lr' not in hist:
                hist['lr'] = []
            start_ep = len(hist.get('train_loss',[]))
            print("Resumed from checkpoint at epoch", start_ep)
        except Exception as e:
            print("Failed to resume:", e)

    for ep in range(start_ep, epochs):
        model.train()
        running = []
        current_lr = opt.param_groups[0]['lr']
        pbar = tqdm(train_dl, desc=f"Train ep{ep+1}/{epochs} lr={current_lr:.2e}")

        for mel, emo, tid, _ in pbar:
            mel = mel.to(DEVICE, non_blocking=True)
            emo = emo.to(DEVICE, non_blocking=True)

            with torch.cuda.amp.autocast(enabled=AMP):
                z, pred = model(mel)
                loss = F.mse_loss(pred, emo)

            opt.zero_grad()
            scaler.scale(loss).backward()

            # Gradient clipping
            scaler.unscale_(opt)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            scaler.step(opt)
            scaler.update()

            running.append(float(loss.item()))
            pbar.set_postfix({'loss':np.mean(running), 'lr': current_lr})

        train_loss = float(np.mean(running))

        # Validation
        model.eval()
        vrun = []
        with torch.no_grad():
            for mel, emo, tid, _ in val_dl:
                mel = mel.to(DEVICE)
                emo = emo.to(DEVICE)
                with torch.cuda.amp.autocast(enabled=AMP):
                    _, pred = model(mel)
                    l = F.mse_loss(pred, emo)
                vrun.append(float(l.item()))

        val_loss = float(np.mean(vrun)) if vrun else train_loss

        # Step schedulers
        scheduler.step()
        plateau_scheduler.step(val_loss)

        hist['train_loss'].append(train_loss)
        hist['val_loss'].append(val_loss)
        hist['lr'].append(current_lr)

        print(f"Epoch {ep+1}: train {train_loss:.4f} val {val_loss:.4f} lr {current_lr:.2e}")

        # Save checkpoint
        torch.save({
            'model_state_dict':model.state_dict(),
            'optimizer_state_dict': opt.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'hist':hist,
            'epoch': ep
        }, checkpoint_path)

        # Save best model
        if val_loss < best_val:
            best_val = val_loss
            torch.save({
                'model_state_dict':model.state_dict(),
                'hist':hist,
                'epoch': ep
            }, best_path)
            no_improve = 0
            print(f"  ‚≠ê New best model! val_loss: {best_val:.4f}")
        else:
            no_improve += 1

        if no_improve >= EARLY_STOPPING_PATIENCE:
            print(f"Early stopping: {EARLY_STOPPING_PATIENCE} epochs without improvement.")
            break

    print("Training done. Best val:", best_val)
    return model, hist

# -------------------------- EXTRACT EMBS --------------------------
@torch.no_grad()
def extract_embeddings(model, db, batch_size=64, apply_pca=True, pca_dim=128):
    model.eval()
    ds = MelDataset(db)
    dl = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=NUM_WORKERS)
    embs = []; ids = []; paths = []
    for mel, emo, tid, cachep in tqdm(dl, desc="Extract embs"):
        mel = mel.to(DEVICE)
        z, _ = model(mel)
        arr = z.detach().cpu().numpy(); embs.append(arr)
        if torch.is_tensor(tid):
            tid_list = tid.cpu().numpy().astype(int).tolist()
        elif isinstance(tid, (list,tuple,np.ndarray)):
            tid_list = [int(t) for t in tid]
        else:
            tid_list = [int(tid)]
        ids.extend(tid_list)
        if isinstance(cachep, (list,tuple)):
            cache_list = cachep
        else:
            cache_list = [cachep]
        for cp in cache_list:
            try:
                tid_here = int(Path(cp).stem); paths.append(db[tid_here]['audio_path'])
            except:
                paths.append(str(cp))
    if len(embs)==0:
        return np.zeros((0, EMBED_DIM)), [], []
    embs = np.vstack(embs)

    # Apply PCA to reduce dimensionality and fight curse of dimensionality
    if apply_pca and embs.shape[0] > pca_dim:
        print(f"Applying PCA: {embs.shape[1]}D ‚Üí {pca_dim}D to improve nearest neighbor quality")
        from sklearn.decomposition import PCA
        pca = PCA(n_components=pca_dim, whiten=True)  # whiten=True normalizes variance
        embs_reduced = pca.fit_transform(embs)
        variance_retained = pca.explained_variance_ratio_.sum()
        print(f"  Variance retained: {variance_retained*100:.1f}%")
        return embs_reduced, ids, paths

    return embs, ids, paths

def build_faiss_index(embs):
    if faiss is None:
        print("faiss not available")
        return None
    emb = embs.astype('float32'); faiss.normalize_L2(emb)
    d = emb.shape[1]; index = faiss.IndexFlatIP(d); index.add(emb); return index

# -------------------------- RETRIEVAL WITH SCORE PRINTING --------------------------
def retrieve_candidates_for_text(text, faiss_index, ids, embs, top_k=20, print_scores=False):
    target_em = text_to_emotion_vector(text)

    if print_scores:
        print(f"\nüìù Text: '{text[:100]}...'")
        print(f"üé≠ Emotion vector: {' '.join([f'{EMO_KEYS[i]}={target_em[i]:.2f}' for i in range(len(EMO_KEYS))])}")

    scores=[]
    for tid, entry in music_emotion_db.items():
        path = entry.get('audio_path')
        if not path or not os.path.isfile(path): continue
        em = entry.get('emotion_vector', np.ones(EMOTION_DIM)/EMOTION_DIM)
        cos = float(np.dot(em, target_em)/(np.linalg.norm(em)*np.linalg.norm(target_em)+1e-12))

        # Penalize tracks with very flat emotion profiles (generic tracks)
        em_variance = float(np.var(em))
        if em_variance < 0.01:  # Very flat profile
            cos *= 0.7  # Penalize by 30%

        scores.append((tid,cos))

    if not scores: return []
    scores.sort(key=lambda x:x[1], reverse=True)

    # Remove duplicate track IDs if they appear multiple times
    seen = set()
    unique_scores = []
    for tid, score in scores:
        if tid not in seen:
            unique_scores.append((tid, score))
            seen.add(tid)

    if print_scores:
        print(f"üéµ Top 10 emotion matches:")
        for i, (tid, score) in enumerate(unique_scores[:10]):
            track_label = decode_track_id(tid)
            em = music_emotion_db[tid]['emotion_vector']
            # Show top 3 emotions for this track
            top_emo_idx = np.argsort(em)[-3:][::-1]
            em_str = ' '.join([f'{EMO_KEYS[top_emo_idx[j]]}={em[top_emo_idx[j]]:.2f}' for j in range(3)])
            print(f"  {i+1}. [{track_label:18s}] score: {score:.3f} | {em_str}")

    candidates = [t for t,_ in unique_scores[:max(200, top_k)]]
    if faiss_index is None or embs is None or len(ids)==0:
        return candidates[:top_k]
    anchor = next((c for c in candidates if music_emotion_db.get(c) and os.path.isfile(music_emotion_db[c]['audio_path'])), None)
    if anchor is None: return candidates[:top_k]
    try:
        idx = ids.index(anchor)
        anchor_emb = embs[idx:idx+1].astype('float32')
        faiss.normalize_L2(anchor_emb)
        D,I = faiss_index.search(anchor_emb, min(300, len(ids)))
        ranked=[]
        for ii in I[0]:
            tid = ids[ii]
            if tid in candidates: ranked.append(tid)
        if ranked: return ranked[:top_k]
    except Exception:
        pass
    return candidates[:top_k]

# -------------------------- IMPROVED AUDIO ANALYSIS & BLENDING --------------------------
def compute_chroma(y, sr=SAMPLE_RATE):
    try:
        chroma = librosa.feature.chroma_stft(y=y, sr=sr, n_fft=N_FFT, hop_length=HOP_LENGTH)
        mean = np.mean(chroma, axis=1)
        if np.linalg.norm(mean) < 1e-9: return np.ones(12)/12.0
        return mean/(np.linalg.norm(mean)+1e-12)
    except:
        return np.ones(12)/12.0

def detect_key(chroma_mean):
    major_profile = np.array([6.35,2.23,3.48,2.33,4.38,4.09,2.52,5.19,2.39,3.66,2.29,2.88])
    minor_profile = np.array([6.33,2.68,3.52,5.38,2.60,3.53,2.54,4.75,3.98,2.69,3.34,3.17])
    best=(0,-10.0,'major'); chroma=np.asarray(chroma_mean)
    for r in range(12):
        maj=np.roll(major_profile,r); minr=np.roll(minor_profile,r)
        cmaj=np.corrcoef(chroma,maj)[0,1]; cmin=np.corrcoef(chroma,minr)[0,1]
        if cmaj>best[1]: best=(r,cmaj,'major')
        if cmin>best[1]: best=(r,cmin,'minor')
    return best[0], best[2]

def semitone_diff(a,b):
    diff = (a-b) % 12
    if diff>6: diff-=12
    return int(diff)

def pitch_shift(y, sr, n_steps):
    if n_steps==0: return y
    try: return librosa.effects.pitch_shift(y, sr=sr, n_steps=n_steps)
    except: return y

def time_stretch(y, rate):
    rate = max(0.55, min(1.8, rate))
    try: return librosa.effects.time_stretch(y, rate)
    except: return y

def harmonic_percussive(y):
    try: return librosa.effects.hpss(y, margin=3.0)
    except: return y, np.zeros_like(y)

def dynamic_eq_match(source, target, sr=SAMPLE_RATE):
    """Match spectral envelope of source to target"""
    try:
        source_spec = np.abs(librosa.stft(source))
        target_spec = np.abs(librosa.stft(target))

        source_env = np.mean(source_spec, axis=1, keepdims=True)
        target_env = np.mean(target_spec, axis=1, keepdims=True)

        gain_curve = target_env / (source_env + 1e-6)
        gain_curve = np.clip(gain_curve, 0.3, 3.0)

        adjusted_spec = source_spec * gain_curve
        result = librosa.istft(adjusted_spec * np.exp(1j * np.angle(librosa.stft(source))))

        if len(result) != len(source):
            result = librosa.util.fix_length(result, size=len(source))

        return result
    except:
        return source

def intelligent_crossfade(seg1, seg2, sr=SAMPLE_RATE, fade_sec=2.5):
    """Smart crossfade that preserves transients and matches energy"""
    fade_samps = int(fade_sec * sr)
    fade_samps = min(fade_samps, len(seg1), len(seg2))

    if fade_samps <= 0:
        return np.concatenate([seg1, seg2])

    # Extract overlap regions
    seg1_end = seg1[-fade_samps:]
    seg2_start = seg2[:fade_samps]

    # Detect transients to preserve attacks
    try:
        onset1 = librosa.onset.onset_detect(y=seg1_end, sr=sr, units='samples')
        onset2 = librosa.onset.onset_detect(y=seg2_start, sr=sr, units='samples')
    except:
        onset1, onset2 = [], []

    # Create adaptive fade curves
    fade_out = np.linspace(1, 0, fade_samps) ** 1.5
    fade_in = np.linspace(0, 1, fade_samps) ** 1.5

    # Preserve transients by reducing fade at onset points
    for o in onset1:
        if 0 <= o < fade_samps:
            window = int(0.05 * sr)
            start = max(0, o - window//2)
            end = min(fade_samps, o + window//2)
            fade_out[start:end] = np.maximum(fade_out[start:end], 0.7)

    for o in onset2:
        if 0 <= o < fade_samps:
            window = int(0.05 * sr)
            start = max(0, o - window//2)
            end = min(fade_samps, o + window//2)
            fade_in[start:end] = np.maximum(fade_in[start:end], 0.7)

    # Match spectral characteristics
    seg2_matched = dynamic_eq_match(seg2_start, seg1_end, sr=sr)

    # Crossfade
    mixed = seg1_end * fade_out + seg2_matched * fade_in

    # Normalize to prevent clipping
    peak = np.max(np.abs(mixed))
    if peak > 0:
        mixed = mixed / peak * 0.95

    result = np.concatenate([seg1[:-fade_samps], mixed, seg2[fade_samps:]])
    return result

def overlap_add_intelligent(segments, sr=SAMPLE_RATE, crossfade_sec=2.5):
    """Improved overlap-add with intelligent crossfading"""
    if not segments: return None
    proc = []
    for s in segments:
        arr = s
        if arr.ndim == 1: arr = arr.reshape(-1,1)
        if arr.shape[1] == 1: arr = np.column_stack([arr[:,0], arr[:,0]])
        proc.append(arr)

    if len(proc) == 1:
        return proc[0]

    # Process left and right channels separately for better stereo
    result_L = proc[0][:,0]
    result_R = proc[0][:,1]

    for i in range(1, len(proc)):
        seg_L = proc[i][:,0]
        seg_R = proc[i][:,1]

        result_L = intelligent_crossfade(result_L, seg_L, sr=sr, fade_sec=crossfade_sec)
        result_R = intelligent_crossfade(result_R, seg_R, sr=sr, fade_sec=crossfade_sec)

    result = np.column_stack([result_L, result_R])
    peak = np.max(np.abs(result)) + 1e-9
    return result / peak * 0.95

def fade_in_out(n, sr=SAMPLE_RATE, fi=0.12, fo=0.5):
    env = np.ones(n, dtype=np.float32)
    fi_samps = int(min(n, fi*sr)); fo_samps = int(min(n, fo*sr))
    if fi_samps > 0: env[:fi_samps] = np.linspace(0.0, 1.0, fi_samps) ** 2
    if fo_samps > 0: env[-fo_samps:] = np.linspace(1.0, 0.0, fo_samps) ** 2
    return env

def stereo_widen(audio, amount=0.22):
    if audio.ndim == 2 and audio.shape[1] == 2: return audio
    L = audio
    # Haas effect with slight delay
    delay = int(0.015 * SAMPLE_RATE)
    if delay >= len(audio):
        R = audio * (1.0 - amount)
    else:
        R = np.concatenate([np.zeros(delay), audio[:-delay]]) * (1.0 - amount)
    stereo = np.column_stack([L, R])
    maxv = np.max(np.abs(stereo)) + 1e-9
    stereo = stereo / maxv * 0.95
    return stereo

def simple_master(stereo, sr=SAMPLE_RATE):
    if stereo is None: return None
    y = stereo.copy()
    mono = 0.5*(y[:,0] + y[:,1])

    # Soft saturation
    y[:,0] = np.tanh(y[:,0] * 1.05)
    y[:,1] = np.tanh(y[:,1] * 1.05)

    # Dynamic compression
    window = int(0.05*sr)
    window = max(1, window)
    env = np.sqrt(np.convolve(mono**2, np.ones(window)/window, mode='same') + 1e-12)
    threshold = 0.15
    ratio = 2.0
    over = env - threshold
    gain = np.where(over > 0, 1.0 / (1.0 + (ratio-1.0)*(over/(threshold+1e-9))), 1.0)

    for ch in range(2):
        y[:,ch] *= gain

    # Final loudness normalization
    final_mono = normalize_to_lufs(0.5*(y[:,0]+y[:,1]), target_lufs=-16.0)
    side = (y[:,0] - y[:,1]) * 0.5
    stereo_final = np.column_stack([final_mono + side, final_mono - side])

    peak = np.max(np.abs(stereo_final)) + 1e-9
    stereo_final = stereo_final / peak * 0.95
    return stereo_final

# -------------------------- GPU FEATURE EXTRACTION --------------------------
GPU_CHROMA_NFFT = N_FFT
GPU_CHROMA_HOP = HOP_LENGTH
GPU_CHROMA_WINDOW_SEC = 6.0
GPU_CHROMA_HOP_SEC = 2.0
GPU_MAX_WINDOWS = 12
GPU_USE_CACHE = True
GPU_CHROMA_BINS = 12
_torch_device = torch.device(DEVICE if torch.cuda.is_available() else "cpu")
_stft_window = torch.hann_window(GPU_CHROMA_NFFT, device=_torch_device)

def build_chroma_filterbank(sr=SAMPLE_RATE, n_fft=GPU_CHROMA_NFFT, n_chroma=12):
    freqs = np.fft.rfftfreq(n_fft, 1.0/sr)
    C0 = 16.35
    chroma_mat = np.zeros((n_chroma, len(freqs)), dtype=np.float32)
    for i,f in enumerate(freqs):
        if f < 30: continue
        s = 12.0 * np.log2(f / C0)
        k = int(np.round(s)) % n_chroma
        w = np.exp(-0.5 * (s - np.round(s))**2 / (0.6**2))
        chroma_mat[k, i] = w
    chroma_mat = chroma_mat / (chroma_mat.sum(axis=1, keepdims=True)+1e-12)
    return torch.from_numpy(chroma_mat).to(_torch_device)

_chroma_filterbank = build_chroma_filterbank(sr=SAMPLE_RATE, n_fft=GPU_CHROMA_NFFT, n_chroma=GPU_CHROMA_BINS)

def waveform_to_power(y_np, sr=SAMPLE_RATE):
    y = torch.from_numpy(y_np).to(_torch_device)
    if y.dim() == 1: y = y.unsqueeze(0)
    spec = torch.stft(y, n_fft=GPU_CHROMA_NFFT, hop_length=GPU_CHROMA_HOP, win_length=GPU_CHROMA_NFFT, window=_stft_window, center=True, return_complex=True)
    power = (spec.abs() ** 2).squeeze(0)
    return power

def power_to_chroma(power_spec):
    chroma = torch.matmul(_chroma_filterbank, power_spec)
    chroma = chroma / (chroma.norm(dim=0, keepdim=True) + 1e-9)
    return chroma

def power_to_centroid(power_spec, sr=SAMPLE_RATE, n_fft=GPU_CHROMA_NFFT):
    freqs = torch.fft.rfftfreq(n_fft, 1.0/sr).to(_torch_device)
    num = (freqs.unsqueeze(1) * power_spec).sum(dim=0)
    den = power_spec.sum(dim=0) + 1e-12
    centroid = (num/den).cpu().numpy()
    return centroid

def rms_frames(y_np, frame_length=GPU_CHROMA_NFFT, hop_length=GPU_CHROMA_HOP):
    y = torch.from_numpy(y_np).to(_torch_device)
    spec = torch.stft(y, n_fft=frame_length, hop_length=hop_length, win_length=frame_length, window=_stft_window, center=True, return_complex=True)
    mag = spec.abs()
    rms = torch.sqrt((mag**2).mean(dim=1)).squeeze(0).cpu().numpy()
    return rms

_gpu_track_cache = {}
def compute_track_gpu_features(tid, force=False):
    if (not GPU_USE_CACHE) or (not force and tid in _gpu_track_cache):
        return _gpu_track_cache.get(tid, None)
    entry = music_emotion_db.get(tid)
    path = entry.get('audio_path')
    if not path or not os.path.isfile(path):
        return None
    y, sr = sf.read(path, dtype='float32')
    if y.ndim > 1: y = y.mean(axis=1)
    if sr != SAMPLE_RATE:
        y_t = torch.from_numpy(y).unsqueeze(0)
        y_t = torchaudio.functional.resample(y_t, orig_freq=sr, new_freq=SAMPLE_RATE).squeeze(0).cpu().numpy()
        y = y_t; sr = SAMPLE_RATE
    if len(y) < int(GPU_CHROMA_WINDOW_SEC * sr):
        padlen = int(GPU_CHROMA_WINDOW_SEC * sr) - len(y)
        y = np.pad(y, (0, padlen))
    power = waveform_to_power(y, sr=sr)
    chroma = power_to_chroma(power)
    centroid = power_to_centroid(power, sr=sr)
    rms = rms_frames(y, frame_length=GPU_CHROMA_NFFT, hop_length=GPU_CHROMA_HOP)
    cache = {"power": power, "chroma": chroma, "centroid": centroid, "rms": rms, "length": len(y), "sr": sr}
    _gpu_track_cache[tid] = cache
    return cache

def fast_pick_best_segments(anchor_tid, candidate_tid, top_k=2, segment_duration=20.0, max_windows=GPU_MAX_WINDOWS):
    acache = compute_track_gpu_features(anchor_tid)
    ccache = compute_track_gpu_features(candidate_tid)
    if acache is None or ccache is None: return []
    anchor_chroma = acache["chroma"].mean(dim=1).cpu().numpy()
    anchor_centroid = float(np.mean(acache["centroid"]))
    anchor_rms = float(np.mean(acache["rms"]))
    cand_chroma = ccache["chroma"]
    cand_centroid = ccache["centroid"]
    cand_rms = ccache["rms"]
    frames_per_sec = SAMPLE_RATE / GPU_CHROMA_HOP
    win_frames = int(segment_duration * frames_per_sec)
    total_frames = cand_chroma.shape[1]
    if total_frames <= win_frames:
        y = sf.read(music_emotion_db[candidate_tid]["audio_path"], dtype='float32')[0]
        if y.ndim>1: y = y.mean(axis=1)
        return [(0.0, min(len(y)/SAMPLE_RATE, segment_duration), y[:int(segment_duration*SAMPLE_RATE)], 0.5)]
    max_possible = max(1, total_frames - win_frames + 1)
    if max_possible <= max_windows:
        starts = list(range(0, max_possible))
    else:
        idxs = np.linspace(0, max_possible-1, num=max_windows, dtype=int)
        starts = idxs.tolist()
    scores = []
    for s in starts:
        e = s + win_frames
        if e > total_frames:
            e = total_frames
            s = e - win_frames
        window_chroma = cand_chroma[:, s:e]
        wmean = window_chroma.mean(dim=1)
        wmean_np = wmean.cpu().numpy()
        chroma_sim = float(np.dot(anchor_chroma, wmean_np) / (np.linalg.norm(anchor_chroma) * (np.linalg.norm(wmean_np)+1e-12)))
        chroma_sim = (chroma_sim + 1.0) / 2.0
        seg_cent = float(np.mean(cand_centroid[max(0, s):min(len(cand_centroid), e)]))
        cent_sim = math.exp(-abs(anchor_centroid - seg_cent) / max(1.0, anchor_centroid) * 1.2)
        seg_rms = float(np.mean(cand_rms[max(0, int(s/frames_per_sec)): min(len(cand_rms), int(e/frames_per_sec))]))
        rms_sim = math.exp(-abs(anchor_rms - seg_rms) / max(1e-6, anchor_rms) * 1.2)
        score = 0.6*chroma_sim + 0.25*cent_sim + 0.15*rms_sim
        start_sec = (s * GPU_CHROMA_HOP) / SAMPLE_RATE
        end_sec = min((e * GPU_CHROMA_HOP) / SAMPLE_RATE, ccache["length"]/SAMPLE_RATE)
        scores.append((score, start_sec, end_sec))
    scores.sort(key=lambda x:x[0], reverse=True)
    selected = []
    used_ranges = []
    for sc, ssec, esec in scores:
        overlap = False
        for us, ue in used_ranges:
            if not (esec <= us or ssec >= ue):
                overlap = True
                break
        if not overlap:
            y_full = sf.read(music_emotion_db[candidate_tid]["audio_path"], dtype='float32')[0]
            if y_full.ndim>1: y_full = y_full.mean(axis=1)
            start_sample = int(ssec * SAMPLE_RATE)
            end_sample = int(esec * SAMPLE_RATE)
            segy = y_full[start_sample:end_sample]
            selected.append((ssec, esec, segy, sc))
            used_ranges.append((ssec, esec))
        if len(selected) >= top_k:
            break
    return selected

# -------------------------- IMPROVED GENERATION --------------------------
def generate_improved_soundtrack(text_segments, candidate_ids_per_segment, db, sr=SAMPLE_RATE,
                                 segment_duration=15.0, blend_k=4, target_lufs=-16.0,
                                 enable_key_shift=True, enable_tempo=True, enable_hpss=True,
                                 use_gpu_fast=DEFAULT_USE_GPU_FAST, show_progress=True):
    out_segments = []
    for seg_idx, (text, cand_ids) in enumerate(zip(text_segments, candidate_ids_per_segment)):
        print(f"\nüé¨ Segment {seg_idx+1}/{len(text_segments)} ‚Äî '{text[:120]}...'")
        if not cand_ids:
            print(" no candidates, skip")
            continue
        anchor_id = cand_ids[0]
        anchor_entry = db.get(anchor_id)
        if not anchor_entry:
            print(" anchor missing, skip")
            continue
        anchor_path = anchor_entry['audio_path']
        anchor_y = sf.read(anchor_path, dtype='float32')[0]
        if anchor_y.ndim>1: anchor_y = anchor_y.mean(axis=1)
        if anchor_y is None:
            print(" anchor load fail")
            continue
        anchor_tempo = 90.0
        try:
            onset = librosa.onset.onset_strength(y=anchor_y, sr=sr, hop_length=HOP_LENGTH)
            tt = librosa.beat.tempo(onset_envelope=onset, sr=sr)
            anchor_tempo = float(tt[0]) if hasattr(tt, "__len__") else float(tt)
        except:
            pass
        anchor_chroma = compute_chroma(anchor_y, sr=sr)
        anchor_key, anchor_mode = detect_key(anchor_chroma)
        anchor_label = decode_track_id(anchor_id)
        print(f" üéµ Anchor: [{anchor_label}] | tempo {anchor_tempo:.1f} BPM | key {anchor_key} {anchor_mode}")

        candidate_segments = []
        iterator = cand_ids[:min(len(cand_ids), max(blend_k, 12))]
        if show_progress:
            iterator = tqdm(iterator, desc="Loading candidates")
        for tid in iterator:
            entry = db.get(tid)
            path = entry.get('audio_path') if entry else None
            if path is None or not os.path.isfile(path): continue
            y = sf.read(path, dtype='float32')[0]
            if y.ndim>1: y = y.mean(axis=1)

            if use_gpu_fast and isinstance(tid, int):
                try:
                    segs = fast_pick_best_segments(anchor_id, tid, top_k=2, segment_duration=min(segment_duration,30.0))
                    if segs:
                        segs = [(s,e,seg,score) for (s,e,seg,score) in segs]
                    else:
                        segs = [(0.0, min(len(y)/sr, segment_duration), y[:int(min(len(y), int(segment_duration*sr)))], 0.15)]
                except Exception:
                    segs = [(0.0, min(len(y)/sr, segment_duration), y[:int(min(len(y), int(segment_duration*sr)))], 0.15)]
            else:
                segs = [(0.0, min(len(y)/sr, segment_duration), y[:int(min(len(y), int(segment_duration*sr)))], 0.15)]

            if not segs:
                segs = [(0.0, min(len(y)/sr, segment_duration), y[:int(min(len(y), int(segment_duration*sr)))], 0.15)]
            candidate_segments.append({'tid': tid, 'path': path, 'full_y': y, 'segs': segs})

        if not candidate_segments:
            print(" no loaded candidate segments")
            continue

        chosen = []
        scores = []
        for c in candidate_segments:
            best = max(c['segs'], key=lambda x: x[3])
            sstart, send, seg_y, score = best
            scores.append(score)
            chosen.append({'tid': c['tid'], 'seg_y': seg_y, 'score': score, 'path': c['path']})

        # Weighted selection
        arr = np.array(scores, dtype=np.float64)
        arr = np.maximum(arr, 1e-9)
        ex = np.exp(arr / max(1.0, np.std(arr)))
        weights = ex / (ex.sum() + 1e-12)

        for i, c in enumerate(chosen):
            c['weight'] = float(weights[i])

        for c in chosen:
            if c['tid'] == anchor_id:
                c['weight'] += 0.12

        ssum = sum(c['weight'] for c in chosen)
        if ssum <= 0:
            for c in chosen: c['weight'] = 1.0/len(chosen)
        else:
            for c in chosen: c['weight'] = float(c['weight'] / (ssum + 1e-12))

        # Sort by weight and show top choices
        chosen_sorted = sorted(chosen, key=lambda x: x['weight'], reverse=True)
        print(f" üéØ Blending top {blend_k} tracks (weighted):")
        for i, c in enumerate(chosen_sorted[:blend_k]):
            track_label = decode_track_id(c['tid'])
            print(f"    {i+1}. [{track_label:18s}] weight: {c['weight']:.3f} | score: {c['score']:.3f}")

        processed = []
        for c in chosen_sorted[:blend_k]:
            seg = c['seg_y']

            if enable_key_shift:
                seg_chroma = compute_chroma(seg, sr=sr)
                seg_key, seg_mode = detect_key(seg_chroma)
                n_steps = semitone_diff(seg_key, anchor_key)
                if n_steps != 0:
                    try:
                        seg = pitch_shift(seg, sr, -n_steps)
                    except:
                        pass

            if enable_tempo:
                seg_tempo = 90.0
                try:
                    onset = librosa.onset.onset_strength(y=seg, sr=sr, hop_length=HOP_LENGTH)
                    tt = librosa.beat.tempo(onset_envelope=onset, sr=sr)
                    seg_tempo = float(tt[0]) if hasattr(tt, "__len__") else float(tt)
                except:
                    seg_tempo = 90.0
                if seg_tempo > 0 and anchor_tempo > 0:
                    rate = anchor_tempo / seg_tempo
                    rate = np.clip(rate, 0.85, 1.15)  # More conservative stretching
                    try:
                        seg = time_stretch(seg, rate)
                    except:
                        pass

            if enable_hpss:
                try:
                    h, p = harmonic_percussive(seg)
                except:
                    h, p = seg, np.zeros_like(seg)
            else:
                h, p = seg, np.zeros_like(seg)

            try:
                h = normalize_to_lufs(h, target_lufs=target_lufs - 2.0)
            except:
                h = h
            try:
                p = normalize_to_lufs(p, target_lufs=target_lufs - 8.0)
            except:
                p = p

            processed.append((h, p, c['weight']))

        if not processed: continue

        min_len = min(len(h) for h,_,_ in processed)
        min_len = max(min_len, int(6.0*sr))
        trimmed = []
        for h,p,w in processed:
            if len(h) < min_len:
                h = np.pad(h, (0, min_len - len(h)))
                p = np.pad(p, (0, min_len - len(p)))
            else:
                h = h[:min_len]
                p = p[:min_len]
            trimmed.append((h,p,w))

        harmonic_mix = np.zeros(min_len, dtype=np.float32)
        percussive_mix = np.zeros(min_len, dtype=np.float32)
        for h,p,w in trimmed:
            harmonic_mix += h * float(w)
            percussive_mix += p * (0.15 * float(w))

        if np.max(np.abs(harmonic_mix)) > 0:
            harmonic_mix = harmonic_mix / (np.max(np.abs(harmonic_mix)) + 1e-9) * 0.92
        if np.max(np.abs(percussive_mix)) > 0:
            percussive_mix = percussive_mix / (np.max(np.abs(percussive_mix)) + 1e-9) * 0.92

        combined = harmonic_mix * 0.88 + percussive_mix * 0.18

        env = fade_in_out(len(combined), sr=sr, fi=0.15, fo=0.6)
        combined = combined * env

        stereo = stereo_widen(combined, amount=0.22)

        mono = 0.5*(stereo[:,0] + stereo[:,1])
        mono = normalize_to_lufs(mono, target_lufs=target_lufs)
        stereo = np.column_stack([mono, mono])

        out_segments.append(stereo)
        print(f" ‚úÖ Segment {seg_idx+1} generated: {stereo.shape[0]/sr:.2f}s")

    if not out_segments: return None

    print(f"\nüéöÔ∏è  Joining {len(out_segments)} segments with intelligent crossfades...")
    joined = overlap_add_intelligent(out_segments, sr=SAMPLE_RATE, crossfade_sec=3.0)

    print("üéõÔ∏è  Applying final mastering...")
    mastered = simple_master(joined, sr=SAMPLE_RATE)

    return mastered

# -------------------------- LIBROSA CACHE FOR REPORT --------------------------
def compute_and_cache_librosa_features(paths, cache_dir=FEATURE_CACHE_DIR, force=False):
    safe_mkdir(cache_dir)
    results = {}
    for i,p in tqdm(enumerate(paths), desc="librosa feature cache", total=len(paths)):
        key = str(Path(p).resolve())
        cache_file = os.path.join(cache_dir, f"{abs(hash(key))}.npz")
        if os.path.exists(cache_file) and not force:
            try:
                d = np.load(cache_file)
                results[p] = dict(d)
                continue
            except:
                pass
        try:
            y, sr = librosa.load(p, sr=SAMPLE_RATE, mono=True, duration=30.0)
            centroid = float(np.mean(librosa.feature.spectral_centroid(y=y, sr=sr)))
            contrast = float(np.mean(librosa.feature.spectral_contrast(y=y, sr=sr)))
            chroma = np.mean(librosa.feature.chroma_stft(y=y, sr=sr, n_fft=N_FFT, hop_length=HOP_LENGTH), axis=1)
            tempo = float(librosa.beat.tempo(y=y, sr=sr).mean()) if hasattr(librosa.beat.tempo(y=y, sr=sr), "__len__") else float(librosa.beat.tempo(y=y, sr=sr))
            results[p] = {"centroid": centroid, "contrast": contrast, "tempo": tempo, "chroma": chroma}
            np.savez_compressed(cache_file, centroid=centroid, contrast=contrast, tempo=tempo, chroma=chroma)
        except Exception as e:
            results[p] = {"centroid": 0.0, "contrast": 0.0, "tempo":0.0, "chroma": np.ones(12)/12.0}
            print("librosa feature fail for", p, e)
    return results

# -------------------------- REPORT --------------------------
def extended_sanity_and_report(state, sample_n=SAMPLE_REPORT_N, use_cache=True):
    embs = state.get('embs')
    ids = state.get('ids')
    paths = state.get('paths')
    hist = state.get('history')
    if embs is None:
        print("No embeddings yet. Run extraction first.")
        return
    plt.figure(figsize=(8,4))
    if hist:
        plt.plot(hist.get('train_loss',[]), label='train')
        plt.plot(hist.get('val_loss',[]), label='val')
        plt.title("Loss")
        plt.legend()
        plt.grid()
        plt.show()
    p = PCA(n_components=min(50, embs.shape[1])).fit(embs)
    plt.figure(figsize=(6,4))
    plt.plot(np.cumsum(p.explained_variance_ratio_))
    plt.title("Cumulative PCA variance")
    plt.grid()
    plt.show()
    sample_idx = np.random.choice(len(embs), size=min(sample_n, len(embs)), replace=False)
    ts = TSNE(n_components=2, perplexity=30, learning_rate=200, init="pca").fit_transform(embs[sample_idx])
    plt.figure(figsize=(7,6))
    plt.scatter(ts[:,0], ts[:,1], s=12, alpha=0.7)
    plt.title("t-SNE of embeddings (sample)")
    plt.show()
    norms = np.linalg.norm(embs, axis=1)
    plt.figure(figsize=(6,4))
    plt.hist(norms, bins=50, density=True)
    plt.title("Embedding L2 norm distribution")
    plt.show()
    from sklearn.metrics.pairwise import cosine_similarity
    sim = cosine_similarity(embs[:80])
    plt.figure(figsize=(8,6))
    sns.heatmap(sim, cmap="viridis")
    plt.title("Cosine sim (first 80)")
    plt.show()
    sample_paths = [paths[i] for i in sample_idx]
    if use_cache:
        features = compute_and_cache_librosa_features(sample_paths)
    else:
        features = compute_and_cache_librosa_features(sample_paths, force=True)
    feats = []
    for p in sample_paths:
        d = features.get(p, {"centroid":0.0,"contrast":0.0,"tempo":0.0})
        feats.append([d["centroid"], d["contrast"], d["tempo"]])
    feats = np.array(feats)
    if feats.shape[0] > 2:
        from sklearn.cluster import KMeans
        kmeans = KMeans(n_clusters=min(6, max(2, feats.shape[0]//30)), random_state=0).fit(feats)
        plt.figure(figsize=(7,5))
        plt.scatter(feats[:,0], feats[:,1], c=kmeans.labels_, cmap="tab10", s=12)
        plt.title("Librosa feature clusters (centroid vs contrast)")
        plt.xlabel("centroid")
        plt.ylabel("contrast")
        plt.show()

    # KNN examples
    def show_knn(track_idx, k=5):
        q = embs[track_idx:track_idx+1].astype('float32')
        if faiss is not None and state.get('index') is not None:
            faiss.normalize_L2(q)
            D,I = state['index'].search(q, k+1)
            neigh = I[0][1:]
            scores = D[0][1:]
        else:
            sims = embs @ q.T / (np.linalg.norm(embs,axis=1)*(np.linalg.norm(q)+1e-12))
            order = sims.flatten().argsort()[::-1]
            neigh = order[1:k+1]
            scores = sims.flatten()[neigh]
        print(f"\nNeighbors for {ids[track_idx]} ({os.path.basename(paths[track_idx])}):")
        for r,(nid,sc) in enumerate(zip(neigh,scores)):
            print(f" {r+1}. ID {ids[int(nid)]} - {os.path.basename(paths[int(nid)])} (score {float(sc):.3f})")

    for i in np.random.choice(len(embs), size=min(6,len(embs)), replace=False):
        show_knn(i,5)

    print("\nReport generation complete. (librosa-cached features used)" if use_cache else "\nReport generation complete.")

# -------------------------- UI --------------------------
_state = {'backbone': None, 'embs': None, 'ids': None, 'paths': None, 'index': None, 'history': None}
text_input = widgets.Textarea(value="", placeholder="Paste text here, paragraphs separated by newlines", description="Text:", layout=widgets.Layout(width='95%', height='120px'))
seg_dur_slider = widgets.IntSlider(value=15, min=5, max=60, step=5, description='Seg sec')
blend_slider = widgets.IntSlider(value=4, min=1, max=8, step=1, description='Blend tracks')
lufs_slider = widgets.IntSlider(value=-16, min=-30, max=-6, step=1, description='Target LUFS')
train_button = widgets.Button(description='Train CNN', button_style='primary')
extract_button = widgets.Button(description='Extract Embs', button_style='')
generate_button = widgets.Button(description='Generate Soundtrack', button_style='success')
report_button = widgets.Button(description='Run Sanity & Report', button_style='')
output = widgets.Output(layout={'border': '1px solid black'})

# Auto-load best model if present
if os.path.exists(BEST_PATH):
    try:
        chk = torch.load(BEST_PATH, map_location=DEVICE)
        model = ImprovedResNet().to(DEVICE)
        model.load_state_dict(chk['model_state_dict'])
        _state['backbone'] = model
        _state['history'] = chk.get('hist', None)
        print("‚úÖ Auto-loaded best model from", BEST_PATH)
    except Exception as e:
        print("Failed to auto-load best model:", e)

def on_train(b):
    with output:
        clear_output()
        print("üöÄ Starting training with improved ResNet architecture...")
        model, history = train_model(music_emotion_db, epochs=EPOCHS, batch_size=BATCH_SIZE, lr=LR, checkpoint_path=CHECKPOINT_PATH, best_path=BEST_PATH)
        _state['backbone'] = model
        _state['history'] = history
        print("‚úÖ Training finished and saved.")

def on_extract(b):
    with output:
        clear_output()
        if _state.get('backbone') is None:
            print("‚ö†Ô∏è  No trained model in memory. Loading best checkpoint if available...")
            if os.path.exists(BEST_PATH):
                chk = torch.load(BEST_PATH, map_location=DEVICE)
                model = ImprovedResNet().to(DEVICE)
                model.load_state_dict(chk['model_state_dict'])
                _state['backbone'] = model
                print("‚úÖ Loaded best model")
            else:
                print("‚ùå No model available. Train first.")
                return
        model = _state['backbone']
        print("üîç Extracting embeddings with PCA dimensionality reduction...")
        embs, ids, paths = extract_embeddings(model, music_emotion_db, batch_size=64, apply_pca=True, pca_dim=128)
        _state['embs'] = embs
        _state['ids'] = ids
        _state['paths'] = paths
        if faiss is not None and embs.shape[0]>0:
            idx = build_faiss_index(embs)
            _state['index'] = idx
            print(f"‚úÖ FAISS index built with {idx.ntotal} tracks in {embs.shape[1]}D space")
        else:
            print("‚ö†Ô∏è  FAISS not installed or empty embeddings.")
        print(f"‚úÖ Extraction complete. {len(ids)} tracks embedded.")

def on_generate(b):
    with output:
        clear_output()
        text = text_input.value.strip()
        if not text:
            print("‚ùå Enter text first.")
            return
        segments = [p.strip() for p in text.split('\n') if len(p.strip())>3]
        if not segments:
            segments = [text]
        print(f"üìù Processing {len(segments)} text segments...")

        candidate_lists = []
        for seg in segments:
            print(f"\n{'='*80}")
            cand_ids = retrieve_candidates_for_text(seg, _state.get('index'), _state.get('ids'), _state.get('embs'), top_k=40, print_scores=True)
            candidate_lists.append(cand_ids or [])

        print(f"\n{'='*80}")
        print("üéµ Starting soundtrack generation...")
        final = generate_improved_soundtrack(
            segments,
            candidate_lists,
            music_emotion_db,
            sr=SAMPLE_RATE,
            segment_duration=seg_dur_slider.value,
            blend_k=blend_slider.value,
            target_lufs=lufs_slider.value,
            use_gpu_fast=DEFAULT_USE_GPU_FAST and use_librosa_for_report_only,
            show_progress=True
        )

        if final is None:
            print("‚ùå Failed to generate soundtrack.")
            return

        out_path = os.path.join(OUTPUT_DIR, "final_soundtrack_stereo.wav")
        sf.write(out_path, final, SAMPLE_RATE)
        print(f"\n{'='*80}")
        print(f"‚úÖ Soundtrack saved to: {out_path}")
        print(f"   Duration: {final.shape[0]/SAMPLE_RATE:.2f}s")
        print(f"   Channels: {final.shape[1]}")
        display(Audio(out_path, rate=SAMPLE_RATE))

def on_report(b):
    with output:
        clear_output()
        print("üìä Running extended sanity & report plots...")
        print("   (This may take a minute depending on sample size)")
        extended_sanity_and_report(_state, sample_n=SAMPLE_REPORT_N, use_cache=True)
        print("‚úÖ Report complete.")

train_button.on_click(on_train)
extract_button.on_click(on_extract)
generate_button.on_click(on_generate)
report_button.on_click(on_report)

ui = widgets.VBox([
    text_input,
    widgets.HBox([seg_dur_slider, blend_slider, lufs_slider]),
    widgets.HBox([train_button, extract_button, generate_button, report_button]),
    output
])
display(ui)

print("="*80)
print("üéµ IMPROVED MUSIC GENERATION PIPELINE READY")
print("="*80)
print("üìã Workflow:")
print("   1. Train CNN      ‚Üí Train emotion recognition model")
print("   2. Extract Embs   ‚Üí Extract embeddings for all tracks")
print("   3. Generate       ‚Üí Create soundtrack from text")
print("   4. Report         ‚Üí View analysis & visualizations")
print("="*80)
if SEMANTIC_AVAILABLE:
    print("‚úÖ Semantic text-to-emotion mapping enabled")
else:
    print("‚ö†Ô∏è  sentence-transformers not available, using fallback")
print("="*80)

Saving kaggle.json to kaggle (1).json
üì¶ Downloading EMOTIFY dataset (may use cache)...
Using Colab cache for faster access to the 'emotify-emotion-classificaiton-in-songs' dataset.
‚úÖ Dataset download reference: /kaggle/input/emotify-emotion-classificaiton-in-songs
mv: cannot stat '/root/.cache/kagglehub/datasets/yash9439/emotify-emotion-classificaiton-in-songs/versions/1': No such file or directory
Device: cuda
Loaded emotion vectors for 400 tracks from CSV
Discovered 400 audio files under /content/emotify_dataset
Precomputing mel cache (skip if done)...


Caching mels:   0%|          | 0/400 [00:00<?, ?it/s]

Mel cache computed: 400 files cached to /content/mel_cache


VBox(children=(Textarea(value='', description='Text:', layout=Layout(height='120px', width='95%'), placeholder‚Ä¶

üéµ IMPROVED MUSIC GENERATION PIPELINE READY
üìã Workflow:
   1. Train CNN      ‚Üí Train emotion recognition model
   2. Extract Embs   ‚Üí Extract embeddings for all tracks
   3. Generate       ‚Üí Create soundtrack from text
   4. Report         ‚Üí View analysis & visualizations
‚úÖ Semantic text-to-emotion mapping enabled
