In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import sys

In [None]:
PROJECT_ROOT = os.path.join(os.getcwd(), '..')
DATA_DIR = os.path.join(PROJECT_ROOT, 'data')
RAW_DATA_SUBDIR = os.path.join(DATA_DIR, 'raw')
PROCESSED_DATA_SUBDIR = os.path.join(DATA_DIR, 'processed')
sys.path.append(PROJECT_ROOT)

In [None]:
"""
EXAMPLE ANNOTATION FILE FORMAT:

Selection	View	Channel	Begin Time (s)	End Time (s)	Low Freq (Hz)	High Freq (Hz)	Inband Power (dB FS)	Species	Call type	Rating	Reference
1	Spectrogram 1	1	1.789365731	1.892598370	5401.800	7438.600	-29.01	LW	AA	A	
2	Spectrogram 1	1	2.691422357	2.794654995	5401.800	7438.600	-60.31		noise		1
3	Spectrogram 1	1	20.557990213	20.631727812	6110.200	8589.800	-33.82	LW	AA	A	
4	Spectrogram 1	1	26.138457751	26.212195350	6110.200	8589.800	-58.35		noise		3

COLUMNS TO KEEP:
Begin Time (s), End Time (s), Low Freq (Hz), High Freq (Hz), Inband Power (dB FS), Species, Call type, Rating
"""
SELECTED_CALL_DIR = os.path.join(RAW_DATA_SUBDIR, "weddells_saddleBack_tamarin__LW")

# Get the list of recording files
recordings_files = [file for file in os.listdir(SELECTED_CALL_DIR) if file.lower().endswith('.wav')]
recordings_files.sort()
# Get the list of annotation files
annotations_files = [file for file in os.listdir(SELECTED_CALL_DIR) if file.lower().endswith('.txt')]
annotations_files.sort()

## Using the New Parsing Utility Functions

The code above has been refactored into reusable functions in the `parsing.py` module. Here's how to use them:

In [None]:
# Import the parsing utilities
from src.banana_net.utils.parsing import (
    load_and_process_annotations,
    print_dataset_summary,
    plot_call_type_distribution,
    save_processed_dataset
)

# Process the same data using the new functions
SELECTED_CALL_DIR = os.path.join(RAW_DATA_SUBDIR, "weddells_saddleBack_tamarin__LW")

# Load and process all annotations with the pipeline function
processed_dataset = load_and_process_annotations(
    directory=SELECTED_CALL_DIR,
    filter_top_n=5  # Keep only top 3 most common call types
)

# Print summary
print_dataset_summary(processed_dataset)

# Plot distribution
plot_save_path = os.path.join(PROCESSED_DATA_SUBDIR, 'call_type_distribution_new.png')
plot_call_type_distribution(processed_dataset, save_path=plot_save_path)

# Save the processed dataset
save_path = os.path.join(PROCESSED_DATA_SUBDIR, 'call_dataset_new.csv')
save_processed_dataset(processed_dataset, save_path)

## Transform to YOLO y-target Tensor

### Example: Process Weddell's Tamarin data

| index | begin_time | end_time	| low_freq	| high_freq	| inband_power	| species	| call_type |	recording_file      |
|-------|------------|----------|-----------|-----------|---------------|-----------|-----------|-----------------------|
| 0	    | 1.336330	 | 1.865889	| 7560.000	| 10080.000 |	-46.04	    | lw	    | cs        |	20240117_162607.wav |
| 1	    | 1.950137	 | 2.449607	| 7766.599	| 10801.822 |	-41.65	    | lw	    | cs        |	20240117_162607.wav |
| 2	    | 2.527837	 | 3.057396	| 8212.955	| 10444.737 |	-41.20	    | lw	    | cs        |	20240117_162607.wav |
| 3	    | 3.135626	 | 3.611026	| 8034.413	| 10266.194 |	-40.34	    | lw	    | cs        |	20240117_162607.wav |
| 4	    | 3.767486	 | 4.182708	| 7588.057	| 9552.024  |	-41.43	    | lw	    | cs        |	20240117_162607.wav |

YOLO-like tensor format:
Our system models detection as a regres-
sion problem. It divides the image into an S x S grid and for each
grid cell predicts B bounding boxes, confidence for those boxes,
and C class probabilities. These predictions are encoded as an
S x S x (B * 5 | C) tensor.

The 5 in the B * 5 term corresponds to the bounding box coordinates (x, y, width, height) and the confidence score.

In [None]:
CALL_CSV = os.path.join(PROCESSED_DATA_SUBDIR, 'call_dataset_new.csv')

# Load the processed dataset for further analysis
call_data = pd.read_csv(CALL_CSV)

# Describe duratio and frequency range for each call type
def describe_call_types(call_data):
    call_types = call_data['call_type'].unique()
    descriptions = {}

    for call_type in call_types:
        type_data = call_data[call_data['call_type'] == call_type]
        # Calculate average duration, low frequency, and high frequency
        duration = (type_data['end_time'] - type_data['begin_time']).mean()
        low_freq = type_data['low_freq'].mean()
        high_freq = type_data['high_freq'].mean()
        # Add minimum and maximum duration, low frequency, and high frequency
        min_duration = (type_data['end_time'] - type_data['begin_time']).min()
        max_duration = (type_data['end_time'] - type_data['begin_time']).max()
        min_low_freq = type_data['low_freq'].min()
        max_low_freq = type_data['low_freq'].max()
        min_high_freq = type_data['high_freq'].min()
        max_high_freq = type_data['high_freq'].max()
        # Store the results in the descriptions dictionary
        descriptions[call_type] = {
            'average_duration': duration,
            'min_duration': min_duration,
            'max_duration': max_duration,
            'average_low_frequency': low_freq,
            'min_low_frequency': min_low_freq,
            'max_low_frequency': max_low_freq,
            'average_high_frequency': high_freq,
            'min_high_frequency': min_high_freq,
            'max_high_frequency': max_high_freq
        }

    return pd.DataFrame(descriptions).T
# Get the descriptions of call types
call_type_descriptions = describe_call_types(call_data)
# Print the call type descriptions
print("\nCall Type Descriptions:")
print(call_type_descriptions)

In [None]:
# drop cc call type
call_data = call_data[call_data['call_type'] != 'cc']
# Save the cleaned dataset without 'cc' call type
cleaned_save_path = os.path.join(PROCESSED_DATA_SUBDIR, 'call_dataset_cleaned.csv')
save_processed_dataset(call_data, cleaned_save_path)
# Print the cleaned dataset summary
print("\nCleaned Dataset Summary:")
print_dataset_summary(call_data)

In [None]:
from src.banana_net.utils.audio_clip_processing import (
    process_dataset_to_clips,
)

# Process the dataset into audio clips
clips_dir = os.path.join(PROCESSED_DATA_SUBDIR, 'audio_clips')

class_map = {
    ('lw', 'cs'): 0,
    ('lw', 'tr'): 2,
    ('lw', 'ta'): 2,
    ('lw', 'tj'): 2,
}

all_clips, all_tensors = process_dataset_to_clips(
    df = call_data,
    clip_duration = 5.0, 
    overlap = 1.0, 
    max_freq_hz = 24000.0, 
    S = 7, 
    B = 2, 
    class_map = class_map,
)

## Visualizing Spectrograms with Bounding Boxes

Ahora vamos a verificar si los tensores tienen la forma correcta visualizando el espectrograma de algunos clips y dibujando los bounding boxes de los eventos detectados.

In [None]:
import librosa
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import os
from typing import Dict, List
from src.banana_net.utils.audio_clip_processing import AudioClip

# Get a sample clip
sample_clip_name = list(all_tensors.keys())[0]
sample_tensor = all_tensors[sample_clip_name]
sample_clip = next(clip for clip in all_clips if clip.clip_name == sample_clip_name)

print(f"Sample clip: {sample_clip_name}")
print(f"Tensor shape: {sample_tensor.shape}")
print(f"Number of annotations: {sample_clip.num_annotations}")

In [None]:
def get_audio_path(clip_name: str) -> str:
    """Get the file path for an audio clip."""
    recording_file = clip_name.split('_clip_')[0]
    audio_path = os.path.join(RAW_DATA_SUBDIR, "weddells_saddleBack_tamarin__LW", f"{recording_file}")
    return audio_path

def load_and_extract_clip(clip: AudioClip) -> np.ndarray:
    """Load audio file and extract the clip segment."""
    audio_path = get_audio_path(clip.original_file)
    try:
        y, sr = librosa.load(audio_path, sr=None)
        start_sample = int(clip.start_time * sr)
        end_sample = int(clip.end_time * sr)
        clip_audio = y[start_sample:end_sample]
        return clip_audio, sr
    except Exception as e:
        print(f"Error loading audio: {e}")
        # Return a dummy signal
        return np.zeros(22050 * 5), 22050

def decode_yolo_tensor(tensor: np.ndarray, S: int, B: int, class_map: Dict) -> List[dict]:
    """Decode YOLO tensor to get bounding box parameters."""
    boxes = []
    classes_map_reverse = {v: k for k, v in class_map.items()}
    
    for row in range(S):
        for col in range(S):
            for b in range(B):
                box_offset = b * 5
                confidence = tensor[row, col, box_offset + 4]
                
                if confidence > 0.5:  # Confidence threshold
                    x_cell, y_cell, w, h = tensor[row, col, box_offset:box_offset + 4]
                    
                    # Calculate center coordinates relative to the entire grid
                    center_x = (col + x_cell) / S
                    center_y = (row + y_cell) / S
                    
                    # Class probabilities start after all bounding boxes
                    class_offset = B * 5
                    class_probabilities = tensor[row, col, class_offset:]
                    class_id = np.argmax(class_probabilities)
                    
                    if classes_map_reverse.get(class_id) is not None:
                        species, call_type = classes_map_reverse[class_id]
                    else:
                        species, call_type = "unknown", "unknown"
                    
                    boxes.append({
                        "center_x": center_x,
                        "center_y": center_y,
                        "width": w,
                        "height": h,
                        "confidence": confidence,
                        "class_id": class_id,
                        "species": species,
                        "call_type": call_type
                    })
    
    return boxes

In [None]:
def visualize_spectrogram_with_boxes(clip: AudioClip, tensor: np.ndarray, max_freq_hz: float = 24000.0, 
                                  S: int = 7, B: int = 2, class_map: Dict = None):
    """Visualize spectrogram with bounding boxes from YOLO tensor."""
    # Load audio clip
    audio_data, sr = load_and_extract_clip(clip)
    
    # Create spectrogram
    plt.figure(figsize=(12, 8))
    
    # Plot spectrogram
    D = librosa.amplitude_to_db(np.abs(librosa.stft(audio_data)), ref=np.max)
    
    plt.subplot(2, 1, 1)
    librosa.display.specshow(D, sr=sr, x_axis='time', y_axis='hz', cmap='viridis')
    plt.colorbar(format='%+2.0f dB')
    plt.title(f'Spectrogram: {clip.clip_name}')
    
    # Plot spectrogram with bounding boxes
    plt.subplot(2, 1, 2)
    librosa.display.specshow(D, sr=sr, x_axis='time', y_axis='hz', cmap='viridis', vmax=0)
    plt.colorbar(format='%+2.0f dB')
    
    # Get bounding boxes from tensor
    boxes = decode_yolo_tensor(tensor, S, B, class_map)
    print(f"Found {len(boxes)} boxes in tensor")
    
    # Create plot axes for adding rectangles
    ax = plt.gca()
    
    # Draw boxes
    for box in boxes:
        # Convert normalized coordinates to time and frequency
        center_time = box["center_x"] * clip.duration
        center_freq = box["center_y"] * max_freq_hz
        width_time = box["width"] * clip.duration
        height_freq = box["height"] * max_freq_hz
        
        # Calculate rectangle parameters (lower left corner, width, height)
        rect_time = center_time - (width_time / 2)
        rect_freq = center_freq - (height_freq / 2)
        
        # Create rectangle
        rect = patches.Rectangle(
            (rect_time, rect_freq), width_time, height_freq,
            linewidth=2, edgecolor='r', facecolor='none'
        )
        
        # Add rectangle to plot
        ax.add_patch(rect)
        
        # Add label above the box
        label = f"{box['species']}-{box['call_type']} ({box['confidence']:.2f})"
        plt.text(center_time, center_freq + height_freq/2 + 500, label, 
                 color='white', fontsize=10, bbox=dict(facecolor='red', alpha=0.5))
    
    plt.title(f'Spectrogram with Annotations: {len(boxes)} detections')
    plt.tight_layout()
    
    return boxes

In [None]:
# Visualize a sample clip with bounding boxes
boxes = visualize_spectrogram_with_boxes(
    sample_clip, 
    sample_tensor, 
    max_freq_hz=24000.0, 
    S=7, 
    B=2,
    class_map=class_map
)

In [None]:
def visualize_multiple_clips(num_clips: int = 3):
    """Visualize multiple clips to verify tensor correctness."""
    # Get clips with annotations
    clips_with_annotations = [clip for clip in all_clips if clip.num_annotations > 0]
    
    # Select a few random clips
    import random
    if clips_with_annotations:
        selected_clips = random.sample(clips_with_annotations, min(num_clips, len(clips_with_annotations)))
        
        for clip in selected_clips:
            tensor = all_tensors[clip.clip_name]
            boxes = visualize_spectrogram_with_boxes(
                clip, 
                tensor, 
                max_freq_hz=24000.0, 
                S=7, 
                B=2,
                class_map=class_map
            )
            plt.show()
    else:
        print("No clips with annotations found.")

# Visualize multiple clips
visualize_multiple_clips(5)

## Comparación de Anotaciones Originales vs Representación Tensorial

Para verificar si los tensores tienen la forma correcta, vamos a comparar las anotaciones originales con la representación tensorial para un clip específico.

In [None]:
def compare_original_annotations_with_tensor(clip, tensor, max_freq_hz=24000.0, S=7, B=2, class_map=None):
    """Compare original annotations with their tensor representation."""
    # Get the original annotations for this clip
    original_annotations = clip.annotations.copy()
    
    # Decode the boxes from the tensor
    decoded_boxes = decode_yolo_tensor(tensor, S, B, class_map)
    
    # Create a figure with two subplots
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 10))
    
    # Load audio data for the spectrogram background
    audio_data, sr = load_and_extract_clip(clip)
    D = librosa.amplitude_to_db(np.abs(librosa.stft(audio_data)), ref=np.max)
    
    # Plot spectrogram with original annotations
    librosa.display.specshow(D, sr=sr, x_axis='time', y_axis='hz', cmap='viridis', ax=ax1)
    ax1.set_title(f'Original Annotations: {clip.clip_name}')
    
    # Draw original annotation boxes
    for _, ann in original_annotations.iterrows():
        begin_time_clip = ann['begin_time_clip']
        end_time_clip = ann['end_time_clip']
        low_freq = ann['low_freq']
        high_freq = ann['high_freq']
        species = ann['species']
        call_type = ann['call_type']
        
        width = end_time_clip - begin_time_clip
        height = high_freq - low_freq
        
        # Create rectangle for original annotation
        rect = patches.Rectangle(
            (begin_time_clip, low_freq), width, height,
            linewidth=2, edgecolor='blue', facecolor='none'
        )
        ax1.add_patch(rect)
        
        # Add label
        ax1.text(begin_time_clip + width/2, high_freq + 500, 
                f"{species}-{call_type}",
                color='white', fontsize=9, ha='center',
                bbox=dict(facecolor='blue', alpha=0.5))
    
    # Plot spectrogram with tensor-derived annotations
    librosa.display.specshow(D, sr=sr, x_axis='time', y_axis='hz', cmap='viridis', ax=ax2)
    ax2.set_title(f'Tensor-derived Annotations: {clip.clip_name}')
    
    # Draw tensor-derived annotation boxes
    for box in decoded_boxes:
        # Convert normalized coordinates to time and frequency
        center_time = box["center_x"] * clip.duration
        center_freq = box["center_y"] * max_freq_hz
        width_time = box["width"] * clip.duration
        height_freq = box["height"] * max_freq_hz
        
        # Calculate rectangle parameters
        rect_time = center_time - (width_time / 2)
        rect_freq = center_freq - (height_freq / 2)
        
        # Create rectangle for tensor-derived annotation
        rect = patches.Rectangle(
            (rect_time, rect_freq), width_time, height_freq,
            linewidth=2, edgecolor='red', facecolor='none'
        )
        ax2.add_patch(rect)
        
        # Add label
        ax2.text(center_time, center_freq + height_freq/2 + 500, 
                f"{box['species']}-{box['call_type']} ({box['confidence']:.2f})", 
                color='white', fontsize=9, ha='center',
                bbox=dict(facecolor='red', alpha=0.5))
    
    # Add grid lines to visualize the YOLO grid cells
    for i in range(1, S):
        ax2.axhline(y=(i/S) * max_freq_hz, color='gray', linestyle=':', alpha=0.5)
        ax2.axvline(x=(i/S) * clip.duration, color='gray', linestyle=':', alpha=0.5)
    
    plt.tight_layout()
    return original_annotations, decoded_boxes

In [None]:
# Select a clip with multiple annotations for comparison
clips_with_multiple_annotations = [clip for clip in all_clips if clip.num_annotations >= 3]
if clips_with_multiple_annotations:
    comparison_clip = clips_with_multiple_annotations[0]  # Take first clip with multiple annotations
    comparison_tensor = all_tensors[comparison_clip.clip_name]
    
    print(f"Comparing annotations for clip: {comparison_clip.clip_name}")
    print(f"Number of original annotations: {comparison_clip.num_annotations}")
    
    original_anns, decoded_boxes = compare_original_annotations_with_tensor(
        comparison_clip,
        comparison_tensor,
        max_freq_hz=24000.0,
        S=7,
        B=2,
        class_map=class_map
    )
    
    # Print detailed comparison
    print("\nOriginal annotations:")
    for i, ann in original_anns.iterrows():
        print(f"  {i+1}. {ann['species']}-{ann['call_type']}: Time [{ann['begin_time_clip']:.2f}-{ann['end_time_clip']:.2f}], Freq [{ann['low_freq']:.1f}-{ann['high_freq']:.1f}] Hz")
    
    print("\nTensor-decoded boxes:")
    for i, box in enumerate(decoded_boxes):
        center_time = box["center_x"] * comparison_clip.duration
        center_freq = box["center_y"] * 24000.0
        width_time = box["width"] * comparison_clip.duration
        height_freq = box["height"] * 24000.0
        t_min = center_time - width_time/2
        t_max = center_time + width_time/2
        f_min = center_freq - height_freq/2
        f_max = center_freq + height_freq/2
        print(f"  {i+1}. {box['species']}-{box['call_type']}: Time [{t_min:.2f}-{t_max:.2f}], Freq [{f_min:.1f}-{f_max:.1f}] Hz (conf: {box['confidence']:.2f})")
else:
    print("No clips with multiple annotations found.")

## Visualización del Grid YOLO

Para entender mejor cómo funciona la representación YOLO, visualizamos el grid S×S y cómo las anotaciones se asignan a las celdas del grid.

In [None]:
def visualize_yolo_grid(clip, tensor, max_freq_hz=24000.0, S=7, B=2, class_map=None):
    """Visualize the YOLO grid structure and how annotations are assigned to grid cells."""
    # Load audio data
    audio_data, sr = load_and_extract_clip(clip)
    D = librosa.amplitude_to_db(np.abs(librosa.stft(audio_data)), ref=np.max)
    
    # Create figure
    plt.figure(figsize=(12, 8))
    
    # Plot spectrogram
    librosa.display.specshow(D, sr=sr, x_axis='time', y_axis='hz', cmap='viridis')
    plt.colorbar(format='%+2.0f dB')
    plt.title(f'YOLO Grid Visualization ({S}×{S} grid): {clip.clip_name}')
    
    # Get axis for drawing
    ax = plt.gca()
    
    # Draw grid lines
    grid_colors = ['white', 'yellow', 'cyan', 'magenta']
    
    # Draw horizontal and vertical grid lines
    for i in range(S+1):
        # Horizontal lines (frequency divisions)
        y_pos = (i/S) * max_freq_hz
        ax.axhline(y=y_pos, color='white', linestyle='-', alpha=0.5)
        
        # Vertical lines (time divisions)
        x_pos = (i/S) * clip.duration
        ax.axvline(x=x_pos, color='white', linestyle='-', alpha=0.5)
        
        # Add grid cell coordinates
        if i < S:
            for j in range(S):
                cell_center_x = (j + 0.5) * (clip.duration / S)
                cell_center_y = (i + 0.5) * (max_freq_hz / S)
                ax.text(cell_center_x, cell_center_y, f"({j},{i})", 
                        color='white', fontsize=9, ha='center', va='center',
                        bbox=dict(facecolor='black', alpha=0.5))
    
    # Decode and visualize the boxes from tensor
    boxes = decode_yolo_tensor(tensor, S, B, class_map)
    
    # Draw boxes with different colors based on which grid cell they belong to
    for box in boxes:
        # Convert normalized coordinates to time and frequency
        center_time = box["center_x"] * clip.duration
        center_freq = box["center_y"] * max_freq_hz
        width_time = box["width"] * clip.duration
        height_freq = box["height"] * max_freq_hz
        
        # Calculate rectangle parameters
        rect_time = center_time - (width_time / 2)
        rect_freq = center_freq - (height_freq / 2)
        
        # Determine which grid cell this belongs to
        grid_col = int(box["center_x"] * S)
        grid_row = int(box["center_y"] * S)
        color_idx = (grid_row + grid_col) % len(grid_colors)
        
        # Create rectangle
        rect = patches.Rectangle(
            (rect_time, rect_freq), width_time, height_freq,
            linewidth=2, edgecolor=grid_colors[color_idx], facecolor='none'
        )
        ax.add_patch(rect)
        
        # Add label indicating grid cell assignment
        label = f"{box['species']}-{box['call_type']} (cell: {grid_col},{grid_row})"
        plt.text(center_time, center_freq + height_freq/2 + 500, label,
                 color='white', fontsize=9, ha='center',
                 bbox=dict(facecolor=grid_colors[color_idx], alpha=0.5))
    
    plt.tight_layout()
    return boxes

In [None]:
# Visualize the YOLO grid for the sample clip
if clips_with_multiple_annotations:
    grid_boxes = visualize_yolo_grid(
        comparison_clip,
        comparison_tensor,
        max_freq_hz=24000.0,
        S=7,
        B=2,
        class_map=class_map
    )
    
    # Print information about tensor shape and format
    print(f"\nYOLO tensor format explanation:")
    print(f"- Grid size (S): 7x7")
    print(f"- Bounding boxes per cell (B): 2")
    print(f"- Number of classes (C): {len(class_map)}")
    print(f"- Tensor shape: {comparison_tensor.shape} = (S, S, B*5 + C) = (7, 7, {2*5 + len(class_map)})")
    print("- Each bounding box has 5 values: (x, y, width, height, confidence)")
    
    # Print how many objects are detected in each grid cell
    grid_cell_counts = np.zeros((7, 7), dtype=int)
    for box in grid_boxes:
        grid_col = int(box["center_x"] * 7)
        grid_row = int(box["center_y"] * 7)
        grid_cell_counts[grid_row, grid_col] += 1
    
    print("\nObjects detected per grid cell:")
    print(grid_cell_counts)

## Arquitectura del Modelo YOLO-like para Detección de Audio

Ahora vamos a definir la arquitectura del modelo basada en YOLO (You Only Look Once) para la detección de eventos de audio en espectrogramas. La arquitectura consistirá en bloques convolucionales seguidos de capas fully-connected para generar las predicciones en formato de tensor YOLO.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvBlock(nn.Module):
    """Bloque convolucional básico: Convolution + BatchNorm + LeakyReLU"""
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn = nn.BatchNorm2d(out_channels)
        self.activation = nn.LeakyReLU(0.1)
    
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.activation(x)
        return x

class YOLOAudioModel(nn.Module):
    """Modelo YOLO para detección de eventos en espectrogramas de audio"""
    def __init__(self, in_channels=1, S=7, B=2, C=3):
        """
        Inicializa el modelo YOLO para audio
        
        Args:
            in_channels: Número de canales de entrada (1 para espectrogramas de audio)
            S: Tamaño del grid (S x S)
            B: Número de bounding boxes por celda
            C: Número de clases
        """
        super(YOLOAudioModel, self).__init__()
        
        # Número de valores a predecir por celda: B*(x,y,w,h,conf) + C clases
        self.S = S
        self.B = B
        self.C = C
        self.output_size = B*5 + C
        
        # Bloques convolucionales inspirados en YOLOv1 pero más pequeños para espectrogramas
        self.layer1 = ConvBlock(in_channels, 16, kernel_size=3, stride=1, padding=1)
        self.pool1 = nn.MaxPool2d(2, 2)  # Reduce tamaño a la mitad (128x128)
        
        self.layer2 = ConvBlock(16, 32, kernel_size=3, stride=1, padding=1)
        self.pool2 = nn.MaxPool2d(2, 2)  # Reduce a 64x64
        
        self.layer3 = ConvBlock(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool3 = nn.MaxPool2d(2, 2)  # Reduce a 32x32
        
        self.layer4 = ConvBlock(64, 128, kernel_size=3, stride=1, padding=1)
        self.pool4 = nn.MaxPool2d(2, 2)  # Reduce a 16x16
        
        self.layer5 = ConvBlock(128, 256, kernel_size=3, stride=1, padding=1)
        self.pool5 = nn.MaxPool2d(2, 2)  # Reduce a 8x8, cercano a nuestro S=7
        
        # Ajuste final para obtener SxS
        self.adaptive_pool = nn.AdaptiveAvgPool2d((S, S))
        
        # Capa de predicción final
        self.conv_final = nn.Conv2d(256, self.output_size, kernel_size=1, stride=1, padding=0)
    
    def forward(self, x):
        x = self.layer1(x)
        x = self.pool1(x)
        
        x = self.layer2(x)
        x = self.pool2(x)
        
        x = self.layer3(x)
        x = self.pool3(x)
        
        x = self.layer4(x)
        x = self.pool4(x)
        
        x = self.layer5(x)
        x = self.pool5(x)
        
        # Ajustar al tamaño final del grid
        x = self.adaptive_pool(x)
        
        # Predicciones finales
        x = self.conv_final(x)
        
        return x

# Instanciar el modelo para nuestro caso
model = YOLOAudioModel(
    in_channels=1,  # 1 canal para espectrogramas
    S=7,            # Grid de 7x7
    B=2,            # 2 bounding boxes por celda
    C=len(class_map) # Número de clases según nuestro class_map
)

print(model)

# Resumen del modelo
print(f"\nResumen del modelo:")
print(f"- Grid size (S): 7x7")
print(f"- Bounding boxes per cell (B): 2")
print(f"- Number of classes (C): {len(class_map)}")
print(f"- Output tensor shape: (S, S, B*5 + C) = (7, 7, {2*5 + len(class_map)})")

In [None]:
class YOLOLoss(nn.Module):
    """Función de pérdida personalizada para modelo YOLO de audio.
    
    Basada en la función de pérdida original de YOLO que penaliza:
    - Error de coordenadas (x, y, w, h)
    - Error de confianza (objectness)
    - Error de clasificación
    """
    def __init__(self, S=7, B=2, C=3, lambda_coord=5.0, lambda_noobj=0.5):
        super(YOLOLoss, self).__init__()
        self.S = S
        self.B = B
        self.C = C
        self.lambda_coord = lambda_coord
        self.lambda_noobj = lambda_noobj
        self.mse = nn.MSELoss(reduction='sum')
    
    def forward(self, predictions, targets):
        """Calcula la pérdida YOLO entre predicciones y targets.
        
        Args:
            predictions: Tensor de forma (batch_size, S, S, B*5+C)
            targets: Tensor de forma (batch_size, S, S, B*5+C)
            
        Returns:
            Pérdida total
        """
        # Reorganizar predicciones para facilitar el acceso
        predictions = predictions.reshape(-1, self.S, self.S, self.B * 5 + self.C)
        
        # Extraer componentes
        # Para cada bounding box: [x, y, w, h, conf]
        pred_boxes = predictions[..., :self.B*5].reshape(-1, self.S, self.S, self.B, 5)
        pred_classes = predictions[..., self.B*5:]
        
        # Igual para los targets
        target_boxes = targets[..., :self.B*5].reshape(-1, self.S, self.S, self.B, 5)
        target_classes = targets[..., self.B*5:]
        
        # Máscara para celdas con objetos (confianza > 0)
        obj_mask = torch.zeros_like(target_boxes[..., 4])
        for b in range(self.B):
            obj_mask[..., b] = target_boxes[..., b, 4] > 0
        
        # 1. Pérdida de coordenadas (solo para celdas con objetos)
        xy_loss = self.mse(torch.flatten(pred_boxes[..., :2][obj_mask.bool()]), 
                           torch.flatten(target_boxes[..., :2][obj_mask.bool()]))
        
        # Para w, h usamos raíz cuadrada para penalizar menos errores en cajas grandes
        wh_pred = torch.sign(pred_boxes[..., 2:4]) * torch.sqrt(torch.abs(pred_boxes[..., 2:4]) + 1e-6)
        wh_target = torch.sqrt(target_boxes[..., 2:4] + 1e-6)
        wh_loss = self.mse(torch.flatten(wh_pred[obj_mask.bool()]), 
                          torch.flatten(wh_target[obj_mask.bool()]))
        
        # 2. Pérdida de confianza
        # Para cajas con objetos
        conf_obj_loss = self.mse(torch.flatten(pred_boxes[..., 4][obj_mask.bool()]), 
                               torch.flatten(target_boxes[..., 4][obj_mask.bool()]))
        
        # Para cajas sin objetos
        noobj_mask = ~obj_mask.bool()
        conf_noobj_loss = self.mse(torch.flatten(pred_boxes[..., 4][noobj_mask]), 
                                 torch.flatten(target_boxes[..., 4][noobj_mask]))
        
        # 3. Pérdida de clasificación (solo para celdas con objetos)
        # Crear máscara para celdas con objetos (cualquier caja)
        cell_has_obj = obj_mask.sum(dim=3) > 0
        class_loss = self.mse(torch.flatten(pred_classes[cell_has_obj]), 
                             torch.flatten(target_classes[cell_has_obj]))
        
        # Pérdida total
        loss = (
            self.lambda_coord * xy_loss + 
            self.lambda_coord * wh_loss + 
            conf_obj_loss + 
            self.lambda_noobj * conf_noobj_loss + 
            class_loss
        )
        
        return loss

# Crear instancia de la función de pérdida
yolo_loss = YOLOLoss(S=7, B=2, C=len(class_map))
print("Función de pérdida personalizada para YOLO creada correctamente.")

## Preparación de Datos para Entrenamiento

Para entrenar nuestro modelo YOLO para audio, necesitamos preparar los datos de entrenamiento. Convertiremos nuestros espectrogramas en tensores PyTorch y dividiremos el conjunto de datos en entrenamiento y prueba.

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import librosa
import numpy as np
import random
from sklearn.model_selection import train_test_split

class AudioSpectrogramDataset(Dataset):
    """Dataset para espectrogramas de audio y sus etiquetas YOLO"""
    def __init__(self, clips, tensors, max_freq_hz=24000.0, transform=None):
        self.clips = clips
        self.tensors = tensors
        self.max_freq_hz = max_freq_hz
        self.transform = transform
        # Lista de nombres de clips para acceso por índice
        self.clip_names = list(self.tensors.keys())
    
    def __len__(self):
        return len(self.clip_names)
    
    def __getitem__(self, idx):
        clip_name = self.clip_names[idx]
        # Obtener el tensor YOLO para este clip
        yolo_tensor = self.tensors[clip_name]
        
        # Buscar el objeto AudioClip correspondiente
        clip = next((c for c in self.clips if c.clip_name == clip_name), None)
        
        if clip is None:
            raise ValueError(f"No se encontró el objeto AudioClip para {clip_name}")
        
        # Cargar el audio y convertirlo a espectrograma
        audio_data, sr = load_and_extract_clip(clip)
        
        # Calcular el espectrograma usando librosa
        D = librosa.stft(audio_data)
        # Convertir a magnitud en dB
        D_db = librosa.amplitude_to_db(np.abs(D), ref=np.max)
        
        # Asegurarnos de tener un tamaño consistente para el modelo
        # Usar tamaño 256x256 para la entrada de la red
        D_resized = librosa.util.fix_length(D_db, size=256, axis=1)
        D_resized = librosa.util.fix_length(D_resized, size=256, axis=0)
        
        # Normalizar los valores al rango [0, 1]
        D_norm = (D_resized - D_resized.min()) / (D_resized.max() - D_resized.min() + 1e-8)
        
        # Convertir a tensor de PyTorch y agregar dimensión de canal
        spectrogram = torch.tensor(D_norm, dtype=torch.float32).unsqueeze(0)  # [1, 256, 256]
        
        # Aplicar transformaciones adicionales si existen
        if self.transform:
            spectrogram = self.transform(spectrogram)
        
        # Convertir el tensor YOLO a tensor de PyTorch
        yolo_target = torch.tensor(yolo_tensor, dtype=torch.float32)
        
        return spectrogram, yolo_target

# Dividir los clips en conjuntos de entrenamiento y prueba
clip_names = list(all_tensors.keys())
random.seed(42)  # Para reproducibilidad

train_names, test_names = train_test_split(clip_names, test_size=0.2, random_state=42)

print(f"Total de clips: {len(clip_names)}")
print(f"Clips para entrenamiento: {len(train_names)}")
print(f"Clips para prueba: {len(test_names)}")

# Crear diccionarios separados para entrenamiento y prueba
train_tensors = {name: all_tensors[name] for name in train_names}
test_tensors = {name: all_tensors[name] for name in test_names}

# Crear los datasets
train_dataset = AudioSpectrogramDataset(
    clips=all_clips,
    tensors=train_tensors,
    max_freq_hz=24000.0
)

test_dataset = AudioSpectrogramDataset(
    clips=all_clips,
    tensors=test_tensors,
    max_freq_hz=24000.0
)

# Crear dataloaders
batch_size = 8  # Tamaño de batch pequeño debido a la complejidad del modelo

train_dataloader = DataLoader(
    train_dataset, 
    batch_size=batch_size, 
    shuffle=True,
    num_workers=4  # Ajusta según tu CPU
)

test_dataloader = DataLoader(
    test_dataset, 
    batch_size=batch_size, 
    shuffle=False,
    num_workers=4
)

# Visualizar un batch de datos
print("\nVisualización de un batch de datos:")
for images, targets in train_dataloader:
    print(f"Batch de espectrogramas: {images.shape}")
    print(f"Batch de tensores YOLO: {targets.shape}")
    break  # Solo el primer batch