# Notebook de pruebas

In [11]:
import torch
import torchaudio
import numpy as np
import torch.nn as nn
import wget
import os
import librosa
import csv
import sys
sys.path.append('./audioset_tagging_cnn/pytorch')

# ======= CONFIGURACIÓN =======
MODEL_URL = 'https://zenodo.org/record/3987831/files/Cnn14_16k_mAP%3D0.438.pth?download=1'
CHECKPOINT_PATH = 'Cnn14_16k.pth'  # Usa el nombre original
LABELS_CSV = 'class_labels_indices.csv'
AUDIO_PATH = 'About-the-Blues-Feel-the-Groove-_feat.-Priscilla-Zamborlini_-bass-E-minor-95bpm-440hz.wav'  # Cambia esto por tu archivo .wav
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'


# ======= DESCARGAR PESOS Y LABELS =======
if not os.path.exists(CHECKPOINT_PATH):
    raise FileNotFoundError(
        f"No se encontró el modelo {CHECKPOINT_PATH}. "
        "Por favor descárgalo manualmente desde:\n"
        "https://zenodo.org/record/3987831\n"
        "y colócalo en la misma carpeta con el nombre 'Cnn14_16k.pth'"
    )


if not os.path.exists(LABELS_CSV):
    print('Descargando etiquetas...')
    wget.download("https://raw.githubusercontent.com/qiuqiangkong/audioset_tagging_cnn/master/metadata/class_labels_indices.csv", LABELS_CSV)

# ======= FUNCIONES UTILES =======
def load_labels(csv_path):
    with open(csv_path) as f:
        reader = csv.DictReader(f)
        return [row['display_name'] for row in reader]

def preprocess_audio(wav_path, target_sample_rate=16000):
    waveform, sr = torchaudio.load(wav_path)
    waveform = waveform.mean(dim=0)  # convertir a mono
    waveform = torchaudio.functional.resample(waveform, sr, target_sample_rate)
    return waveform.unsqueeze(0)  # [1, T]

# ======= MODELO CNN14 =======
class Cnn14(torch.nn.Module):
    def __init__(self, sample_rate=16000, window_size=512, hop_size=160,
                 mel_bins=64, fmin=50, fmax=8000, classes_num=527):
        super().__init__()
        import sys
        sys.path.append('.')  # Asegura que podemos importar localmente
        from audioset_tagging_cnn.pytorch.models import Cnn14  # Asume que copiaste el modelo desde el repo oficial
        self.model = Cnn14(
            sample_rate=sample_rate,
            window_size=window_size,
            hop_size=hop_size,
            mel_bins=mel_bins,
            fmin=fmin,
            fmax=fmax,
            classes_num=classes_num
        )
        self.model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=DEVICE)['model'])
        self.model.to(DEVICE).eval()

    def forward(self, waveform):
        with torch.no_grad():
            output = self.model(waveform.to(DEVICE))
        return output['clipwise_output'][0], output['embedding'][0]
    
def softmax_on_instruments(scores, labels, instrument_keywords=["piano", "guitar", "bass"]):
    indices = []
    instrument_labels = []

    for i, label in enumerate(labels):
        label_lower = label.lower()
        if any(kw in label_lower for kw in instrument_keywords):
            indices.append(i)
            instrument_labels.append(label)

    if not indices:
        return []

    selected_scores = scores[indices]
    softmaxed = torch.softmax(selected_scores, dim=0)

    return list(zip(instrument_labels, softmaxed.tolist()))

def mostrar_resultados_filtrados(scores, csv_path):
    instrumentos_validos = ["piano", "guitar", "bass"]
    
    # Cargar etiquetas
    with open(csv_path, newline='') as f:
        reader = csv.DictReader(f)
        labels = [row["display_name"] for row in reader]
    
    resultados = []
    for i, score in enumerate(scores):
        nombre = labels[i].lower()
        if any(instr in nombre for instr in instrumentos_validos):
            resultados.append((labels[i], score.item()))

    # Ordenar por score descendente y mostrar top
    resultados.sort(key=lambda x: x[1], reverse=True)

    print("\n🎵 Instrumentos detectados (solo piano, guitarra y bajo):")
    if resultados:
        for label, score in resultados[:1]:
            print(f"  - {label}: {score:.3f}")
    else:
        print("  - Ninguno de los instrumentos especificados fue detectado.")


# ======= MAIN =======
if __name__ == '__main__':
    print('Preparando audio...')
    audio = preprocess_audio(AUDIO_PATH)
    print('Cargando modelo...')
    model = Cnn14()

    print('Realizando predicción...')
    scores, embedding = model(audio)
    mostrar_resultados_filtrados(scores, LABELS_CSV)
    
    labels = load_labels(LABELS_CSV)
    top_k = torch.topk(scores, k=5)
    print('\nInstrumentos detectados:')
    for i in range(5):
        print(f'  - {labels[top_k.indices[i]]}: {top_k.values[i].item():.3f}')

    print(f'\nEmbedding shape: {embedding.shape}')
    np.save("embedding.npy", embedding.cpu().numpy())
    print('Embedding guardado como embedding.npy')


Preparando audio...
Cargando modelo...
Realizando predicción...

🎵 Instrumentos detectados (solo piano, guitarra y bajo):
  - Bass guitar: 0.779

Instrumentos detectados:
  - Music: 0.896
  - Bass guitar: 0.779
  - Musical instrument: 0.658
  - Guitar: 0.645
  - Plucked string instrument: 0.507

Embedding shape: torch.Size([2048])
Embedding guardado como embedding.npy


In [8]:
audio_guitar = preprocess_audio('guitar_audio.wav')

scores, embedding = model(audio_guitar)
mostrar_resultados_filtrados(scores, LABELS_CSV)


🎵 Instrumentos detectados (solo piano, guitarra y bajo):
  - Guitar: 0.280


In [9]:
drum_audio = preprocess_audio('drum_audio.wav')

scores, embedding = model(drum_audio)
mostrar_resultados_filtrados(scores, LABELS_CSV)


🎵 Instrumentos detectados (solo piano, guitarra y bajo):
  - Bass drum: 0.225


In [10]:
piano_audio = preprocess_audio('piano_audio.wav')

scores, embedding = model(piano_audio)
mostrar_resultados_filtrados(scores, LABELS_CSV)


🎵 Instrumentos detectados (solo piano, guitarra y bajo):
  - Piano: 0.069


In [12]:
piano_audio2 = preprocess_audio('piano_audio2.wav')

scores, embedding = model(piano_audio2)
mostrar_resultados_filtrados(scores, LABELS_CSV)


🎵 Instrumentos detectados (solo piano, guitarra y bajo):
  - Electric piano: 0.149
