In [None]:
import sys
sys.path.append("/home/luke/VIU/09MIAR/euterpe")

import training.main as main

datasets_path = '/home/luke/VIU/09MIAR/datasets'
valid_files_csv_path = '/home/luke/valid_files.csv'

main.train_gan(datasets_path,valid_files_csv_path)

In [None]:
import sys
sys.path.append("/home/luke/VIU/09MIAR/euterpe")

import os
import json
from datetime import datetime

import torch
import librosa
import librosa.display
import soundfile as sf
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

from scipy.ndimage import gaussian_filter, median_filter
from scipy.signal import lfilter

from training.config import Config
from training.gan.gan import GAN
from training.gan.GAN_AI_model_wrapper import GANAIModelWrapper
from training.main import get_dataloader


def listar_generos_disponibles(valid_files_csv_path):
    df = pd.read_csv(valid_files_csv_path, sep=";")
    for col in ["genre_id", "genre", "label"]:
        if col in df.columns:
            return sorted(df[col].unique().tolist())
    raise ValueError("No se encontró ninguna columna de género válida en el CSV.")


def normalizar_audio(audio, target_dbfs=-20.0):
    rms = np.sqrt(np.mean(audio**2))
    scalar = 10 ** (target_dbfs / 20) / (rms + 1e-9)
    return audio * scalar


def reconstruir_y_visualizar(
    model_path: str,
    genre_id: int,
    num_samples: int = 1,
    output_dir: str = "samples",
    datasets_path: str = "/home/luke/VIU/09MIAR/datasets",
    valid_files_csv_path: str = "/home/luke/valid_files.csv",
    max_retries_per_sample: int = 10
) -> None:
    os.makedirs(output_dir, exist_ok=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    config_path = os.path.join(os.path.dirname(os.path.dirname(model_path)), "config.json")
    with open(config_path) as f:
        saved_cfg = json.load(f)
    cfg = Config()
    cfg.__dict__.update(saved_cfg)

    generos_disponibles = listar_generos_disponibles(valid_files_csv_path)
    if genre_id not in generos_disponibles:
        genre_id = generos_disponibles[0]

    model = GANAIModelWrapper(GAN(), is_eval=True).to(device)
    checkpoint = torch.load(model_path, map_location=device)
    filtered = {k: v for k, v in checkpoint.items() if k in model.state_dict() and v.shape == model.state_dict()[k].shape}
    model.load_state_dict(filtered, strict=False)
    model.eval()

    _, dataset = get_dataloader(datasets_path, valid_files_csv_path)
    from torch.utils.data import DataLoader
    dataloader = DataLoader(dataset, batch_size=cfg.TRAIN_BATCH_SIZE, shuffle=True)
    model.trainer = type("DummyTrainer", (), {"datamodule": type("DummyDM", (), {"train_dataloader": lambda self: dataloader})()})()

    z_list = []
    for i in range(num_samples):
        single_genre_tensor = torch.tensor([genre_id], device=device)
        for _ in range(max_retries_per_sample):
            try:
                model.spectrogram_cache = None
                # zi = GANAIModelWrapper.generate_z_for_genre(single_genre_tensor, cfg)
                zi = model.generate_z_from_real_sample(single_genre_tensor)
                z_list.append(zi)
                break
            except RuntimeError:
                continue
        else:
            zi = GANAIModelWrapper.generate_z_for_genre(single_genre_tensor, cfg)
            z_list.append(zi)

    z = torch.cat(z_list, dim=0)
    genre_tensor = torch.full((num_samples,), genre_id, device=device, dtype=torch.long)

    with torch.no_grad():
        specs = model(z, genre_tensor).cpu()

    timestamp = datetime.now().strftime("%Y%m%d_%H%M")
    for i, spec in enumerate(specs):
        spec = spec[0].numpy()
        spec_db = ((spec + 1.0) / 2.0) * (cfg.DB_MAX - cfg.DB_MIN) * 0.95 + cfg.DB_MIN
        spec_db = np.clip(spec_db, cfg.DB_MIN, cfg.DB_MAX)

        # === Filtro gaussiano + mediana horizontal
        spec_db = gaussian_filter(spec_db, sigma=(1, 1))
        spec_db = median_filter(spec_db, size=(1, 3))

        # === Supresión por energía
        frame_energy = (spec_db ** 2).mean(axis=0)
        mask = frame_energy > np.percentile(frame_energy, 5)
        spec_db = spec_db[:, mask]

        # === Mostrar espectrograma
        plt.figure(figsize=(10, 4))
        librosa.display.specshow(
            spec_db,
            sr=cfg.SAMPLE_RATE,
            hop_length=cfg.HOP_LENGTH,
            y_axis='mel' if cfg.KIND_OF_SPECTROGRAM == 'MEL' else 'linear',
            x_axis='time'
        )
        plt.colorbar(format="%+2.0f dB")
        plt.title(f"Espectrograma generado - sample {i}")
        plt.tight_layout()
        save_prefix = f"{timestamp}_GAN_{cfg.KIND_OF_SPECTROGRAM}_g{genre_id}_s{i}"
        img_path = os.path.join(output_dir, f"{save_prefix}.png")
        plt.savefig(img_path)
        plt.close()

        # === Reconstrucción de audio
        magnitude = librosa.db_to_amplitude(spec_db)

        if cfg.KIND_OF_SPECTROGRAM == 'MEL':
            audio = librosa.feature.inverse.mel_to_audio(
                magnitude,
                sr=cfg.SAMPLE_RATE,
                n_fft=cfg.N_FFT,
                hop_length=cfg.HOP_LENGTH,
                win_length=cfg.N_FFT,
                window="hann",
                center=True,
                n_iter=512
            )
        else:
            audio = librosa.griffinlim(
                magnitude,
                n_iter=512,
                hop_length=cfg.HOP_LENGTH,
                win_length=cfg.N_FFT,
                window="hann",
                center=True
            )

        # === Normalización, suavizado, post-énfasis y estéreo
        audio = audio / np.max(np.abs(audio))
        audio = normalizar_audio(audio, -20)
        audio = lfilter(np.ones(5) / 5.0, 1, audio)
        audio = lfilter([1, -0.97], [1], audio)
        audio_stereo = np.stack([audio, audio], axis=0).T

        wav_path = os.path.join(output_dir, f"{save_prefix}.wav")
        sf.write(wav_path, audio_stereo, samplerate=cfg.SAMPLE_RATE)


# === Generación por género
for genre_id in range(5):
    reconstruir_y_visualizar(
        model_path="/home/luke/logs/20250422_0828_GAN_STFT/checkpoints/gan_last_11.pt",
        genre_id=genre_id,
        num_samples=2,
        output_dir="samples"
    )
