# Downloading the Data

In [None]:
!wget https://zenodo.org/api/records/4783391/files-archive

In [None]:
!unzip files-archive

In [None]:
!7z x /content/clotho_audio_development.7z
!7z x /content/clotho_audio_evaluation.7z
!7z x /content/clotho_audio_validation.7z

# Dataset

In [None]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import torchaudio
import torchaudio.transforms as T
import os
import numpy as np
import random
from typing import List, Dict, Tuple, Optional, Union
import librosa

In [None]:
class AudioAugmenter:
    """Audio augmentation class with simple transformation strategies"""

    def __init__(self,
                 sample_rate: int = 22050,
                 time_stretch_range: Tuple[float, float] = (0.9, 1.1),
                 pitch_shift_range: Tuple[int, int] = (-2, 2),
                 noise_factor_range: Tuple[float, float] = (0.001, 0.005),
                 time_mask_param: int = 40,
                 freq_mask_param: int = 8,
                 n_time_masks: int = 1,
                 n_freq_masks: int = 1):
        """
        Args:
            sample_rate: Audio sample rate
            time_stretch_range: Range for time stretching factor
            pitch_shift_range: Range for pitch shift in semitones
            noise_factor_range: Range for the factor of gaussian noise to add
            time_mask_param: Maximum time mask length
            freq_mask_param: Maximum frequency mask length
            n_time_masks: Number of time masks to apply
            n_freq_masks: Number of frequency masks to apply
        """
        self.sample_rate = sample_rate
        self.time_stretch_range = time_stretch_range
        self.pitch_shift_range = pitch_shift_range
        self.noise_factor_range = noise_factor_range
        self.time_mask_param = time_mask_param
        self.freq_mask_param = freq_mask_param
        self.n_time_masks = n_time_masks
        self.n_freq_masks = n_freq_masks

        # Initialize time and frequency masking transforms
        self.time_masking = T.TimeMasking(time_mask_param=time_mask_param)
        self.freq_masking = T.FrequencyMasking(freq_mask_param=freq_mask_param)

    def time_stretch(self, waveform: torch.Tensor) -> torch.Tensor:
        """Apply random time stretching using librosa"""
        stretch_factor = random.uniform(*self.time_stretch_range)

        # Convert to numpy for librosa processing
        waveform_np = waveform.numpy().squeeze()

        try:
            # Use librosa's time stretching which works on raw audio
            stretched = librosa.effects.time_stretch(waveform_np, rate=stretch_factor)
            return torch.tensor(stretched).unsqueeze(0)
        except Exception as e:
            print(f"Time stretch augmentation failed: {e}")
            return waveform

    def pitch_shift(self, waveform: torch.Tensor) -> torch.Tensor:
        """Apply random pitch shifting"""
        n_steps = random.randint(*self.pitch_shift_range)

        # Convert to numpy for pitch shift
        waveform_np = waveform.numpy().squeeze()

        try:
            # Pitch shift using librosa
            shifted = librosa.effects.pitch_shift(
                waveform_np,
                sr=self.sample_rate,
                n_steps=n_steps
            )

            # Convert back to torch tensor
            return torch.tensor(shifted).unsqueeze(0)
        except Exception as e:
            print(f"Pitch shift augmentation failed: {e}")
            return waveform

    def add_noise(self, waveform: torch.Tensor) -> torch.Tensor:
        """Add random gaussian noise"""
        noise_factor = random.uniform(*self.noise_factor_range)
        noise = torch.randn_like(waveform) * noise_factor
        return waveform + noise

    def apply_time_freq_mask(self, spectrogram: torch.Tensor) -> torch.Tensor:
        """Apply time and frequency masking to a spectrogram"""
        aug_spec = spectrogram.clone()

        # Apply time masking
        for _ in range(self.n_time_masks):
            aug_spec = self.time_masking(aug_spec)

        # Apply frequency masking
        for _ in range(self.n_freq_masks):
            aug_spec = self.freq_masking(aug_spec)

        return aug_spec

    def apply_waveform_augmentations(self, waveform: torch.Tensor, max_length: int,
                                    augment_list: List[str] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """Apply a list of waveform augmentations and ensure consistent length"""
        if augment_list is None:
            return waveform, torch.ones(max_length)

        aug_waveform = waveform.clone()

        # Apply augmentations with a probability
        aug_prob = 0.5

        if 'time_stretch' in augment_list and random.random() < aug_prob:
            aug_waveform = self.time_stretch(aug_waveform)

        if 'pitch_shift' in augment_list and random.random() < aug_prob:
            aug_waveform = self.pitch_shift(aug_waveform)

        if 'add_noise' in augment_list and random.random() < aug_prob:
            aug_waveform = self.add_noise(aug_waveform)

        # Ensure consistent length after augmentation
        current_length = aug_waveform.shape[1]
        mask = torch.ones(max_length)

        if current_length > max_length:
            # Truncate if longer than max_length
            aug_waveform = aug_waveform[:, :max_length]
        elif current_length < max_length:
            # Pad if shorter than max_length
            padding = torch.zeros(1, max_length - current_length)
            aug_waveform = torch.cat([aug_waveform, padding], dim=1)
            mask[current_length:] = 0

        return aug_waveform, mask

In [None]:
class ClothoBaseDataset(Dataset):
    """Base class for Clotho datasets with shared functionality"""

    def __init__(self,
                 base_dir: str,
                 split: str = 'train',
                 sample_rate: int = 22050,
                 max_length_seconds: int = 30,
                 augmentations: List[str] = None):
        """
        Args:
            base_dir: Root directory containing the Clotho dataset
            split: 'train' (combines dev+val) or 'eval'
            sample_rate: Target sample rate for audio
            max_length_seconds: Maximum length for audio in seconds
            augmentations: List of augmentation strategies to apply
        """
        self.base_dir = base_dir
        self.split = split
        self.sample_rate = sample_rate
        self.max_length = int(max_length_seconds * sample_rate)  # Convert seconds to samples
        self.augmentations = augmentations if augmentations else []

        # Initialize augmenter if needed
        if self.augmentations:
            self.augmenter = AudioAugmenter(sample_rate=sample_rate)

        # For training, merge development and validation
        if split == 'train':
            # Load and merge caption files
            dev_captions = pd.read_csv(os.path.join(base_dir, 'clotho_captions_development.csv'))
            val_captions = pd.read_csv(os.path.join(base_dir, 'clotho_captions_validation.csv'))
            self.captions_df = pd.concat([dev_captions, val_captions], ignore_index=True)

            # Create mapping of file names to their full paths
            self.audio_paths = {}
            for file in dev_captions['file_name'].unique():
                self.audio_paths[file] = os.path.join(base_dir, 'development', file)
            for file in val_captions['file_name'].unique():
                self.audio_paths[file] = os.path.join(base_dir, 'validation', file)

        else:  # Evaluation split for validation/testing
            self.captions_df = pd.read_csv(os.path.join(base_dir, 'clotho_captions_evaluation.csv'))
            self.audio_paths = {
                file: os.path.join(base_dir, 'evaluation', file)
                for file in self.captions_df['file_name'].unique()
            }

        # Prepare file list - get unique file names
        self.file_list = list(self.audio_paths.keys())
        print(f"Loaded {len(self.file_list)} unique audio files for {split}")
        print(f"Using max length of {max_length_seconds} seconds ({self.max_length} samples)")

        if self.augmentations:
            print(f"Using audio augmentations: {', '.join(self.augmentations)}")

    def __len__(self):
        return len(self.file_list)

    def _load_and_process_audio(self, file_name: str) -> Tuple[torch.Tensor, torch.Tensor]:
        """Load audio file and handle variable length"""
        audio_path = self.audio_paths[file_name]

        try:
            # Load audio
            waveform, sample_rate = torchaudio.load(audio_path)

            # Convert to mono if needed
            if waveform.shape[0] > 1:
                waveform = torch.mean(waveform, dim=0, keepdim=True)

            # Resample if needed
            if sample_rate != self.sample_rate:
                resampler = T.Resample(sample_rate, self.sample_rate)
                waveform = resampler(waveform)
        except Exception as e:
            print(f"Error loading audio file {file_name}: {e}")
            # Create empty audio as fallback
            waveform = torch.zeros(1, self.max_length)

        # Handle variable length
        original_length = waveform.shape[1]

        # Create attention mask (1 for real data, 0 for padding)
        mask = torch.ones(self.max_length)

        if original_length > self.max_length:
            # Truncate
            waveform = waveform[:, :self.max_length]
            mask[:] = 1
        else:
            # Pad
            padding = torch.zeros(1, self.max_length - original_length)
            waveform = torch.cat([waveform, padding], dim=1)
            mask[original_length:] = 0

        # Apply waveform augmentations if in training mode
        if self.split == 'train' and hasattr(self, 'augmenter') and self.augmentations:
            waveform_aug_list = [aug for aug in self.augmentations
                              if aug in ['time_stretch', 'pitch_shift', 'add_noise']]
            if waveform_aug_list:
                # Note: augmenter now returns both the waveform and updated mask
                waveform, mask = self.augmenter.apply_waveform_augmentations(
                    waveform, self.max_length, waveform_aug_list
                )

        return waveform, mask

    def _get_captions(self, file_name: str) -> List[str]:
        """Retrieve all captions for a given file"""
        file_captions = self.captions_df[self.captions_df['file_name'] == file_name]
        caption_list = []
        for i in range(1, 6):  # 5 captions per file
            col_name = f'caption_{i}'
            if col_name in file_captions.columns:
                if not file_captions.empty:
                    caption = file_captions[col_name].iloc[0]
                    caption_list.append(caption)

        # If no captions found, add placeholder
        if not caption_list:
            caption_list = ["No caption available"] * 5

        return caption_list

    def collate_fn(self, batch: List[Dict]) -> Dict:
        """Custom collate function for handling variable length sequences"""
        audio = torch.stack([item['audio'] for item in batch])
        features = [item['features'] for item in batch]
        masks = torch.stack([item['mask'] for item in batch])
        file_names = [item['file_name'] for item in batch]

        # Each item has a list of 5 captions
        captions_list = [item['captions'] for item in batch]

        captions = [random.choice(captions) for captions in captions_list]

        if isinstance(features[0], list):
            # For multiscale, we have a list of tensors for each item
            collated_features = []
            for scale_idx in range(len(features[0])):
                scale_features = torch.stack([item[scale_idx] for item in features])
                collated_features.append(scale_features)
        else:
            # Regular spectrogram - stack all features
            collated_features = torch.stack(features)

        return {
            'audio': audio,
            'features': collated_features,
            'masks': masks,
            'file_names': file_names,
            'captions': captions,
            'all_captions': captions_list
        }

In [None]:
class LogMelSpectrogramDataset(ClothoBaseDataset):
    """Clotho dataset with standard log-mel spectrogram features"""

    def __init__(self,
                 base_dir: str,
                 split: str = 'train',
                 sample_rate: int = 22050,
                 max_length_seconds: int = 30,
                 n_fft: int = 1024,
                 hop_length: int = 512,
                 n_mels: int = 64,
                 augmentations: List[str] = None):
        """
        Args:
            base_dir: Root directory containing the Clotho dataset
            split: 'train' (combines dev+val) or 'eval'
            sample_rate: Target sample rate for audio
            max_length_seconds: Maximum length for audio in seconds
            n_fft: FFT size
            hop_length: Hop length for STFT
            n_mels: Number of mel bands
            augmentations: List of augmentation strategies to apply
        """
        super().__init__(base_dir, split, sample_rate, max_length_seconds, augmentations)

        # Setup feature extraction
        self.feature_extractor = T.MelSpectrogram(
            sample_rate=self.sample_rate,
            n_fft=n_fft,
            hop_length=hop_length,
            n_mels=n_mels
        )
        self.to_db = T.AmplitudeToDB()

        print(f"Using log-mel spectrogram with n_fft={n_fft}, hop_length={hop_length}, n_mels={n_mels}")

    def __getitem__(self, idx: int) -> Dict:
        file_name = self.file_list[idx]

        # Load and process audio
        waveform, mask = self._load_and_process_audio(file_name)

        # Extract features
        features = self.feature_extractor(waveform)
        features = self.to_db(features)

        # Apply spectrogram augmentations if in training mode
        if self.split == 'train' and hasattr(self, 'augmenter') and self.augmentations:
            spec_aug_list = [aug for aug in self.augmentations
                            if aug in ['time_mask', 'freq_mask']]
            if spec_aug_list:
                features = self.augmenter.apply_time_freq_mask(features)

        # Get captions
        caption_list = self._get_captions(file_name)

        return {
            'audio': waveform,
            'features': features,
            'mask': mask,
            'file_name': file_name,
            'captions': caption_list
        }

In [None]:
class MultiscaleLogMelDataset(ClothoBaseDataset):
    """Clotho dataset with multiscale log-mel spectrogram features"""

    def __init__(self,
                 base_dir: str,
                 split: str = 'train',
                 sample_rate: int = 22050,
                 max_length_seconds: int = 30,
                 n_fft_scales: List[int] = [512, 1024, 2048],
                 hop_length_scales: List[int] = [256, 512, 1024],
                 n_mels: int = 64,
                 augmentations: List[str] = None):
        """
        Args:
            base_dir: Root directory containing the Clotho dataset
            split: 'train' (combines dev+val) or 'eval'
            sample_rate: Target sample rate for audio
            max_length_seconds: Maximum length for audio in seconds
            n_fft_scales: List of FFT sizes for different scales
            hop_length_scales: List of hop lengths for different scales
            n_mels: Number of mel bands
            augmentations: List of augmentation strategies to apply
        """
        super().__init__(base_dir, split, sample_rate, max_length_seconds, augmentations)

        assert len(n_fft_scales) == len(hop_length_scales), "Number of FFT scales must match hop length scales"

        # Setup feature extractors for each scale
        self.feature_extractors = []

        for n_fft, hop_length in zip(n_fft_scales, hop_length_scales):
            extractor = T.MelSpectrogram(
                sample_rate=self.sample_rate,
                n_fft=n_fft,
                hop_length=hop_length,
                n_mels=n_mels
            )
            self.feature_extractors.append(extractor)

        self.to_db = T.AmplitudeToDB()
        self.n_scales = len(n_fft_scales)

        print(f"Using {self.n_scales} spectral scales with FFT sizes: {n_fft_scales} and hop lengths: {hop_length_scales}")

    def __getitem__(self, idx: int) -> Dict:
        file_name = self.file_list[idx]

        # Load and process audio
        waveform, mask = self._load_and_process_audio(file_name)

        # Extract features at multiple scales
        multiscale_features = []
        for extractor in self.feature_extractors:
            features = extractor(waveform)
            features = self.to_db(features)

            # Apply spectrogram augmentations if in training mode
            if self.split == 'train' and hasattr(self, 'augmenter') and self.augmentations:
                spec_aug_list = [aug for aug in self.augmentations
                                if aug in ['time_mask', 'freq_mask']]
                if spec_aug_list:
                    features = self.augmenter.apply_time_freq_mask(features)

            multiscale_features.append(features)

        # Get captions
        caption_list = self._get_captions(file_name)

        return {
            'audio': waveform,
            'features': multiscale_features,  # List of features at different scales
            'mask': mask,
            'file_name': file_name,
            'captions': caption_list
        }

In [None]:
base_dir = "/content"
augmentations = ['time_stretch', 'pitch_shift', 'add_noise', 'time_mask', 'freq_mask']

# Test standard log-mel dataset
print("\nTesting Standard Log-Mel Dataset:")
standard_dataset = LogMelSpectrogramDataset(
    base_dir=base_dir,
    split='train',
    max_length_seconds=30,
    augmentations=augmentations
)

# Create dataloader
standard_loader = DataLoader(
    standard_dataset,
    batch_size=4,
    shuffle=True,
    collate_fn=standard_dataset.collate_fn
)

# Get a batch
batch = next(iter(standard_loader))
print(f"Audio batch shape: {batch['audio'].shape}")
print(f"Features shape: {batch['features'].shape}")
print(f"Mask shape: {batch['masks'].shape}")
print(f"First caption: {batch['captions'][0]}")
print()

# Test multiscale log-mel dataset
print("\nTesting Multiscale Log-Mel Dataset:")
multiscale_dataset = MultiscaleLogMelDataset(
    base_dir=base_dir,
    split='train',
    max_length_seconds=30,
    n_fft_scales=[512, 1024, 2048],
    hop_length_scales=[256, 512, 1024],
    augmentations=augmentations
)

# Create dataloader
multiscale_loader = DataLoader(
    multiscale_dataset,
    batch_size=4,
    shuffle=True,
    collate_fn=multiscale_dataset.collate_fn
)

# Get a batch
batch = next(iter(multiscale_loader))
print(f"Audio batch shape: {batch['audio'].shape}")
print(f"Number of feature scales: {len(batch['features'])}")

for i, scale_features in enumerate(batch['features']):
    print(f"Features scale {i} shape: {scale_features.shape}")

print(f"Mask shape: {batch['masks'].shape}")
print(f"First caption: {batch['captions'][0]}")

# Loss Functions

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class InfoNCELoss(nn.Module):
    """InfoNCE loss for contrastive learning between audio and text embeddings."""
    def __init__(self, temperature=0.07):
        super(InfoNCELoss, self).__init__()
        self.temperature = temperature
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, audio_embeddings, text_embeddings):
        """
        Calculate the InfoNCE loss.

        Args:
            audio_embeddings: Tensor of shape [batch_size, embedding_dim]
            text_embeddings: Tensor of shape [batch_size, embedding_dim]

        Returns:
            loss: The InfoNCE loss value
        """
        # Normalize embeddings for cosine similarity
        audio_embeddings = F.normalize(audio_embeddings, dim=1)
        text_embeddings = F.normalize(text_embeddings, dim=1)

        # Calculate similarity matrix
        similarity_matrix = torch.matmul(text_embeddings, audio_embeddings.T) / self.temperature

        # Labels are the diagonal indices (matching pairs)
        batch_size = audio_embeddings.shape[0]
        labels = torch.arange(batch_size).to(audio_embeddings.device)

        # Calculate loss in both directions (text-to-audio and audio-to-text)
        loss_t2a = self.criterion(similarity_matrix, labels)  # text as query, audio as target
        loss_a2t = self.criterion(similarity_matrix.T, labels)  # audio as query, text as target

        # Total loss is the average of both directions
        loss = (loss_t2a + loss_a2t) / 2.0

        return loss

In [None]:
class VICRegLoss(nn.Module):
    """VICReg loss (Variance-Invariance-Covariance Regularization)"""
    def __init__(self, sim_weight=25.0, var_weight=25.0, cov_weight=1.0, epsilon=1e-4):
        super(VICRegLoss, self).__init__()
        self.sim_weight = sim_weight
        self.var_weight = var_weight
        self.cov_weight = cov_weight
        self.epsilon = epsilon

    def forward(self, audio_embeddings, text_embeddings):
        """
        Calculate the VICReg loss between audio and text embeddings.

        Args:
            audio_embeddings: Tensor of shape [batch_size, embedding_dim]
            text_embeddings: Tensor of shape [batch_size, embedding_dim]

        Returns:
            loss: The VICReg loss value
        """
        # Invariance/similarity loss (MSE between paired embeddings)
        sim_loss = F.mse_loss(audio_embeddings, text_embeddings)

        # Center the embeddings
        audio_embeddings_centered = audio_embeddings - audio_embeddings.mean(dim=0)
        text_embeddings_centered = text_embeddings - text_embeddings.mean(dim=0)

        # Variance loss (ensures representations have variance above threshold)
        audio_std = torch.sqrt(audio_embeddings_centered.var(dim=0) + self.epsilon)
        text_std = torch.sqrt(text_embeddings_centered.var(dim=0) + self.epsilon)

        audio_var_loss = torch.mean(F.relu(1.0 - audio_std))
        text_var_loss = torch.mean(F.relu(1.0 - text_std))
        var_loss = audio_var_loss + text_var_loss

        # Covariance loss (decorrelates dimensions)
        batch_size = audio_embeddings.shape[0]
        embedding_dim = audio_embeddings.shape[1]

        audio_cov = (audio_embeddings_centered.T @ audio_embeddings_centered) / (batch_size - 1)
        text_cov = (text_embeddings_centered.T @ text_embeddings_centered) / (batch_size - 1)

        # Zero out the diagonal elements (self-correlation)
        audio_cov_off_diag = audio_cov - torch.diag(torch.diag(audio_cov))
        text_cov_off_diag = text_cov - torch.diag(torch.diag(text_cov))

        # Square and sum off-diagonal elements
        audio_cov_loss = torch.sum(audio_cov_off_diag ** 2) / embedding_dim
        text_cov_loss = torch.sum(text_cov_off_diag ** 2) / embedding_dim

        cov_loss = audio_cov_loss + text_cov_loss

        # Combine the three loss components
        loss = self.sim_weight * sim_loss + self.var_weight * var_loss + self.cov_weight * cov_loss

        return loss

In [None]:
class CosineLoss(nn.Module):
    """Cosine similarity loss for matching audio and text embeddings."""
    def __init__(self):
        super(CosineLoss, self).__init__()

    def forward(self, audio_embeddings, text_embeddings):
        """
        Calculate the cosine loss.

        Args:
            audio_embeddings: Tensor of shape [batch_size, embedding_dim]
            text_embeddings: Tensor of shape [batch_size, embedding_dim]

        Returns:
            loss: The cosine loss value
        """
        # Normalize embeddings for cosine similarity
        audio_embeddings = F.normalize(audio_embeddings, dim=1)
        text_embeddings = F.normalize(text_embeddings, dim=1)

        # Calculate cosine similarity for the positive pairs
        batch_size = audio_embeddings.shape[0]
        similarity_matrix = torch.matmul(text_embeddings, audio_embeddings.T)
        pos_indices = torch.arange(batch_size).to(audio_embeddings.device)
        positive_similarities = similarity_matrix[pos_indices, pos_indices]

        # Negative because we want to maximize similarity (minimize negative similarity)
        loss = -torch.mean(positive_similarities)

        return loss

# Text Encoder

In [None]:
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer


class BertTextEncoder(nn.Module):
    """
    Text encoder using a frozen pre-trained BERT model.
    """
    def __init__(self,
                 bert_model_name="bert-base-uncased",
                 output_dim=512,
                 pooling_strategy="cls"):
        """
        Initialize the BERT text encoder.

        Args:
            bert_model_name: Pre-trained BERT model name
            output_dim: Dimension of the output embedding
            pooling_strategy: Strategy to pool token embeddings ("cls", "mean")
        """
        super(BertTextEncoder, self).__init__()

        # Initialize BERT model and tokenizer
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.tokenizer = BertTokenizer.from_pretrained(bert_model_name)
        self.pooling_strategy = pooling_strategy

        # Freeze BERT parameters
        for param in self.bert.parameters():
            param.requires_grad = False

        # Get BERT hidden dimension
        self.bert_dim = self.bert.config.hidden_size

        # Projection layer if output_dim is different from BERT hidden size
        self.use_projection = (output_dim != self.bert_dim)
        if self.use_projection:
            self.projection = nn.Linear(self.bert_dim, output_dim)

    def forward(self, captions):
        """
        Process text captions through BERT.

        Args:
            captions: List of caption strings

        Returns:
            embeddings: Tensor of shape [batch_size, output_dim]
        """
        # Tokenize captions
        encoding = self.tokenizer(
            captions,
            padding=True,
            truncation=True,
            max_length=77,  # Standard max length for many models
            return_tensors="pt"
        ).to(next(self.bert.parameters()).device)

        # Pass through BERT
        with torch.no_grad():  # No gradients since BERT is frozen
            outputs = self.bert(**encoding)

        # Get embeddings based on pooling strategy
        if self.pooling_strategy == "cls":
            # Use CLS token embedding
            embeddings = outputs.last_hidden_state[:, 0, :]
        elif self.pooling_strategy == "mean":
            # Mean pooling over tokens
            # Create attention mask (1 for tokens, 0 for padding)
            attention_mask = encoding['attention_mask']
            # Mean pooling over non-padding tokens
            token_embeddings = outputs.last_hidden_state
            input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
            sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
            sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
            embeddings = sum_embeddings / sum_mask

        # Project to desired output dimension if needed
        if self.use_projection:
            embeddings = self.projection(embeddings)

        return embeddings

    def encode_text(self, captions):
        """
        Public method to encode text for inference.

        Args:
            captions: List of caption strings

        Returns:
            embeddings: Tensor of shape [batch_size, output_dim]
        """
        return self.forward(captions)

# Audio Encoder

In [None]:
class AudioEncoderCRNN(nn.Module):
    """
    CRNN-based audio encoder for single or multi-scale mel spectrograms.
    """
    def __init__(
        self,
        output_dim=512,
        cnn_channels=[64, 128, 256, 512],
        kernel_size=3,
        stride=2,
        gru_hidden_size=512,
        gru_num_layers=2,
        dropout_rate=0.3,
        use_multiscale=False
    ):
        """
        Initialize the CRNN audio encoder.

        Args:
            output_dim: Final embedding dimension
            cnn_channels: List of channel dimensions for CNN layers
            kernel_size: Kernel size for CNN layers
            stride: Stride for CNN layers
            gru_hidden_size: Hidden size for GRU
            gru_num_layers: Number of GRU layers
            dropout_rate: Dropout rate
            use_multiscale: Whether to use multiscale input processing
        """
        super(AudioEncoderCRNN, self).__init__()
        self.output_dim = output_dim
        self.use_multiscale = use_multiscale

        # CNN for feature extraction (same architecture for all scales)
        cnn_layers = []
        in_channels = 1  # Mel spectrograms have 1 channel

        for i, out_channels in enumerate(cnn_channels):
            # Add convolutional block with batch norm and dropout
            cnn_layers.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
                              stride=stride, padding=kernel_size//2),
                    nn.BatchNorm2d(out_channels),
                    nn.ReLU(),
                    nn.Dropout2d(dropout_rate)
                )
            )
            in_channels = out_channels

        self.cnn = nn.Sequential(*cnn_layers)

        # Add a projection layer to handle the CNN output to GRU input conversion
        # This is flexible and will adapt to whatever dimensions come out of the CNN
        self.projection = nn.Linear(cnn_channels[-1], cnn_channels[-1])

        # GRU for sequence modeling
        self.gru_input_size = cnn_channels[-1]  # This stays the same
        self.gru = nn.GRU(
            input_size=self.gru_input_size,
            hidden_size=gru_hidden_size,
            num_layers=gru_num_layers,
            batch_first=True,
            dropout=dropout_rate if gru_num_layers > 1 else 0,
            bidirectional=True
        )

        # Final FC layers with dropout
        self.fc = nn.Sequential(
            nn.Linear(gru_hidden_size * 2, 768),  # *2 for bidirectional
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(768, output_dim)
        )

    def _process_single_scale(self, x):
        """Process a single-scale spectrogram."""
        batch_size = x.size(0)

        # CNN feature extraction
        x = self.cnn(x)  # (B, C, H', W')

        # Use adaptive pooling to handle variable height dimension
        # This collapses the height dimension to 1
        x = F.adaptive_avg_pool2d(x, (1, x.size(3)))  # (B, C, 1, W')

        # Remove height dimension and transpose to (B, W', C)
        x = x.squeeze(2).permute(0, 2, 1)

        # Now x is of shape (B, W', C) where C = cnn_channels[-1] = 512
        # This matches what the GRU expects

        # RNN sequence modeling
        self.gru.flatten_parameters()
        x, _ = self.gru(x)  # (B, W', hidden_size*2)

        # Take the last time step output
        x = x[:, -1, :]

        # FC layers
        x = self.fc(x)

        return x

    def forward(self, x):
        """
        Process audio features through the CRNN.

        Args:
            x: For standard dataset: Tensor of shape [batch_size, 1, height, width]
               For multiscale dataset: List of tensors, each of shape [batch_size, 1, height, width]

        Returns:
            embeddings: Tensor of shape [batch_size, output_dim]
        """
        if self.use_multiscale:
            # Process each scale separately
            if not isinstance(x, list):
                raise ValueError("Expected a list of tensors for multiscale input")

            scale_embeddings = []
            for scale_x in x:
                scale_embedding = self._process_single_scale(scale_x)
                scale_embeddings.append(scale_embedding)

            # Average the embeddings from different scales
            embeddings = torch.stack(scale_embeddings, dim=0).mean(dim=0)
        else:
            # Process single-scale input
            embeddings = self._process_single_scale(x)

        # Normalize output embeddings
        embeddings = F.normalize(embeddings, p=2, dim=1)

        return embeddings

# Training Function

In [None]:
import torch
import torch.optim as optim
import numpy as np
import time
from torch.utils.data import DataLoader
import os


def train_model(
    audio_encoder,
    text_encoder,
    loss_function,
    train_dataset,
    val_dataset,
    batch_size=8,
    num_epochs=30,
    learning_rate=1e-4,
    weight_decay=1e-5,
    patience=5,
    device='cuda' if torch.cuda.is_available() else 'cpu'
):
    """
    Universal training function for any audio encoder and dataset.

    Args:
        audio_encoder: Any audio encoder model
        text_encoder: Text encoder model (frozen)
        loss_function: Loss function to use
        train_dataset: Training dataset (any type)
        val_dataset: Validation dataset (any type)
        batch_size: Batch size
        num_epochs: Maximum number of epochs
        learning_rate: Learning rate
        weight_decay: Weight decay for optimizer
        patience: Early stopping patience
        device: Device to train on

    Returns:
        best_model: The audio encoder with the best validation performance
        history: Dictionary containing training and validation losses for each epoch
    """
    # Move models to device
    audio_encoder = audio_encoder.to(device)
    text_encoder = text_encoder.to(device)

    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        collate_fn=train_dataset.collate_fn if hasattr(train_dataset, 'collate_fn') else None
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        collate_fn=val_dataset.collate_fn if hasattr(val_dataset, 'collate_fn') else None
    )

    # Set up optimizer (only for audio encoder as text encoder is frozen)
    optimizer = optim.AdamW(
        audio_encoder.parameters(),
        lr=learning_rate,
        weight_decay=weight_decay
    )

    # Early stopping variables
    best_val_loss = float('inf')
    best_model_state = None
    patience_counter = 0

    # History dictionary to track losses
    history = {
        'train_loss': [],
        'val_loss': [],
        'epochs': []
    }

    # Training loop
    for epoch in range(1, num_epochs + 1):
        # Training phase
        audio_encoder.train()
        text_encoder.eval()  # Text encoder is always in eval mode since it's frozen

        train_losses = []

        for batch in train_loader:
            # Extract batch data
            audio_inputs = batch['features'].to(device)
            captions = batch['captions']

            # Zero gradients
            optimizer.zero_grad()

            # Forward pass - automatically handle different input formats
            audio_embeddings = audio_encoder(audio_inputs)

            with torch.no_grad():
                text_embeddings = text_encoder(captions)

            # Calculate loss
            loss = loss_function(audio_embeddings, text_embeddings)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            train_losses.append(loss.item())

        # Calculate average training loss
        avg_train_loss = np.mean(train_losses)

        # Validation phase
        audio_encoder.eval()
        val_losses = []

        with torch.no_grad():
            for batch in val_loader:
                # Extract batch data
                audio_inputs = batch['features'].to(device)
                captions = batch['captions']

                # Forward pass - automatically handle different input formats
                audio_embeddings = audio_encoder(audio_inputs)
                text_embeddings = text_encoder(captions)

                # Calculate loss
                loss = loss_function(audio_embeddings, text_embeddings)

                val_losses.append(loss.item())

        # Calculate average validation loss
        avg_val_loss = np.mean(val_losses)

        # Add to history
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['epochs'].append(epoch)

        # Print formatted output for this epoch
        print(f"Epoch {epoch}: Train Loss: {avg_train_loss:.4f} Val Loss: {avg_val_loss:.4f}")

        # Check for improvement
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model_state = audio_encoder.state_dict().copy()
            patience_counter = 0
        else:
            patience_counter += 1

        # Early stopping
        if patience_counter >= patience:
            print(f"Early stopping triggered after {epoch} epochs")
            break

    # Load best model state
    audio_encoder.load_state_dict(best_model_state)

    return audio_encoder, history

In [None]:
def evaluate_retrieval(
    audio_encoder,
    text_encoder,
    test_dataset,
    device,
    batch_size=8,
    top_k=10
):
    """
    Evaluation function for retrieval metrics.

    Args:
        audio_encoder: Trained audio encoder model
        text_encoder: Text encoder model
        test_dataset: Test/evaluation dataset
        device: Device to run evaluation on
        batch_size: Batch size for encoding
        top_k: Number of top results to consider

    Returns:
        metrics: Dictionary with R@1, R@5, R@10, and mAP@10 scores
    """
    audio_encoder.eval()
    text_encoder.eval()

    # First pass: encode all audio files
    audio_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        collate_fn=test_dataset.collate_fn if hasattr(test_dataset, 'collate_fn') else None
    )

    # Collect all audio embeddings and file names
    all_audio_embeds = []
    all_file_names = []

    with torch.no_grad():
        for batch in audio_loader:
            audio_inputs = batch['features'].to(device)
            file_names = batch['file_names']

            # Forward pass - automatically handle different input formats
            audio_embed = audio_encoder(audio_inputs)

            all_audio_embeds.append(audio_embed.cpu())
            all_file_names.extend(file_names)

    # Stack all audio embeddings
    all_audio_embeds = torch.cat(all_audio_embeds, dim=0).to(device)

    # Create a dataloader with batch size 1 to handle individual captions
    query_loader = DataLoader(
        test_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=4,
        collate_fn=test_dataset.collate_fn if hasattr(test_dataset, 'collate_fn') else None
    )

    # Initialize metrics
    recalls = {1: [], 5: [], 10: []}
    aps = []

    # For each query, get the top-k closest audio files
    with torch.no_grad():
        for i, batch in enumerate(query_loader):

            query_caption = batch['captions'][0]
            ground_truth_file = batch['file_names'][0]

            # Encode text query
            text_embed = text_encoder([query_caption]).to(device)

            # Calculate similarities with all audio files
            similarities = torch.matmul(text_embed, all_audio_embeds.T)[0]

            # Get indices of top-k most similar audios
            _, top_indices = torch.topk(similarities, k=min(top_k, len(all_audio_embeds)))

            # Convert to CPU for processing
            top_indices = top_indices.cpu().numpy()

            # Calculate recall metrics
            for k in [1, 5, 10]:
                if k <= len(top_indices):
                    # Check if ground truth is in top-k
                    top_k_files = [all_file_names[idx] for idx in top_indices[:k]]
                    recall_k = 1.0 if ground_truth_file in top_k_files else 0.0
                    recalls[k].append(recall_k)

            # Calculate Average Precision for mAP
            ap = 0.0
            relevant_count = 0

            for j, idx in enumerate(top_indices[:top_k]):
                if all_file_names[idx] == ground_truth_file:
                    relevant_count += 1
                    # Precision at position j+1
                    precision_at_j = relevant_count / (j + 1)
                    ap += precision_at_j

            # If the relevant item is in top_k, divide by 1, otherwise AP is 0
            if relevant_count > 0:
                ap /= relevant_count

            aps.append(ap)

    # Calculate final metrics
    r1 = sum(recalls[1]) / len(recalls[1]) if recalls[1] else 0
    r5 = sum(recalls[5]) / len(recalls[5]) if recalls[5] else 0
    r10 = sum(recalls[10]) / len(recalls[10]) if recalls[10] else 0
    map10 = sum(aps) / len(aps) if aps else 0

    metrics = {
        'R@1': r1,
        'R@5': r5,
        'R@10': r10,
        'mAP@10': map10
    }

    return metrics

# Runs

In [None]:
import torch
import torch.nn as nn
import os
from torch.utils.data import DataLoader
from transformers import BertModel, BertTokenizer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
base_dir = "/content"

# Training parameters
batch_size = 8
learning_rate = 1e-4
weight_decay = 1e-5
num_epochs = 100
patience = 5
embedding_dim = 512

augmentations = ['time_stretch', 'pitch_shift', 'add_noise', 'time_mask', 'freq_mask']

all_histories = {}

text_encoder = BertTextEncoder(
    bert_model_name="bert-base-uncased",
    output_dim=embedding_dim,
    pooling_strategy="cls"
).to(device)

standard_train_dataset = LogMelSpectrogramDataset(
    base_dir=base_dir,
    split='train',
    max_length_seconds=30,
    augmentations=augmentations
)

standard_val_dataset = LogMelSpectrogramDataset(
    base_dir=base_dir,
    split='eval',
    max_length_seconds=30,
    augmentations=None
)

multiscale_train_dataset = MultiscaleLogMelDataset(
    base_dir=base_dir,
    split='train',
    max_length_seconds=30,
    n_fft_scales=[512, 1024, 2048],
    hop_length_scales=[256, 512, 1024],
    augmentations=augmentations
)

multiscale_val_dataset = MultiscaleLogMelDataset(
    base_dir=base_dir,
    split='eval',
    max_length_seconds=30,
    n_fft_scales=[512, 1024, 2048],
    hop_length_scales=[256, 512, 1024],
    augmentations=None
)

In [None]:
print("\n=== Training with InfoNCE Loss and Standard Mel Spectrograms ===\n")

# Initialize audio encoder
audio_encoder = AudioEncoderCRNN(
    output_dim=embedding_dim,
    cnn_channels=[64, 128, 256, 512],
    kernel_size=3,
    stride=2,
    gru_hidden_size=512,
    gru_num_layers=2,
    dropout_rate=0.3,
    use_multiscale=False
).to(device)

# Initialize loss function
loss_function = InfoNCELoss(temperature=0.07)

# Train model
best_model, history = train_model(
    audio_encoder=audio_encoder,
    text_encoder=text_encoder,
    loss_function=loss_function,
    train_dataset=standard_train_dataset,
    val_dataset=standard_val_dataset,
    batch_size=batch_size,
    num_epochs=num_epochs,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    patience=patience,
    device=device
)

# Store history
all_histories['infonce_standard'] = history

# Evaluate retrieval performance
metrics = evaluate_retrieval(best_model, text_encoder, standard_val_dataset, device)
print("\nRetrieval Metrics:")
print(f"R@1: {metrics['R@1']:.4f}, R@5: {metrics['R@5']:.4f}, R@10: {metrics['R@10']:.4f}, mAP@10: {metrics['mAP@10']:.4f}")

# Add metrics to history
all_histories['infonce_standard']['metrics'] = metrics

In [None]:
print("\n=== Training with InfoNCE Loss and Multiscale Mel Spectrograms ===\n")

# Initialize audio encoder
audio_encoder = AudioEncoderCRNN(
    output_dim=embedding_dim,
    cnn_channels=[64, 128, 256, 512],
    kernel_size=3,
    stride=2,
    gru_hidden_size=512,
    gru_num_layers=2,
    dropout_rate=0.3,
    use_multiscale=True
).to(device)

# Initialize loss function
loss_function = InfoNCELoss(temperature=0.07)

# Train model
best_model, history = train_model(
    audio_encoder=audio_encoder,
    text_encoder=text_encoder,
    loss_function=loss_function,
    train_dataset=multiscale_train_dataset,
    val_dataset=multiscale_val_dataset,
    batch_size=batch_size,
    num_epochs=num_epochs,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    patience=patience,
    device=device
)

# Store history
all_histories['infonce_multiscale'] = history

# Evaluate retrieval performance
metrics = evaluate_retrieval(best_model, text_encoder, multiscale_val_dataset, device)
print("\nRetrieval Metrics:")
print(f"R@1: {metrics['R@1']:.4f}, R@5: {metrics['R@5']:.4f}, R@10: {metrics['R@10']:.4f}, mAP@10: {metrics['mAP@10']:.4f}")

# Add metrics to history
all_histories['infonce_multiscale']['metrics'] = metrics

In [None]:
print("\n=== Training with VICReg Loss and Standard Mel Spectrograms ===\n")

# Initialize audio encoder
audio_encoder = AudioEncoderCRNN(
    output_dim=embedding_dim,
    cnn_channels=[64, 128, 256, 512],
    kernel_size=3,
    stride=2,
    gru_hidden_size=512,
    gru_num_layers=2,
    dropout_rate=0.3,
    use_multiscale=False
).to(device)

# Initialize loss function
loss_function = VICRegLoss(sim_weight=25.0, var_weight=25.0, cov_weight=1.0)

# Train model
best_model, history = train_model(
    audio_encoder=audio_encoder,
    text_encoder=text_encoder,
    loss_function=loss_function,
    train_dataset=standard_train_dataset,
    val_dataset=standard_val_dataset,
    batch_size=batch_size,
    num_epochs=num_epochs,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    patience=patience,
    device=device
)

# Store history
all_histories['vicreg_standard'] = history

# Evaluate retrieval performance
metrics = evaluate_retrieval(best_model, text_encoder, standard_val_dataset, device)
print("\nRetrieval Metrics:")
print(f"R@1: {metrics['R@1']:.4f}, R@5: {metrics['R@5']:.4f}, R@10: {metrics['R@10']:.4f}, mAP@10: {metrics['mAP@10']:.4f}")

# Add metrics to history
all_histories['vicreg_standard']['metrics'] = metrics

In [None]:
print("\n=== Training with VICReg Loss and Multiscale Mel Spectrograms ===\n")

# Initialize audio encoder
audio_encoder = AudioEncoderCRNN(
    output_dim=embedding_dim,
    cnn_channels=[64, 128, 256, 512],
    kernel_size=3,
    stride=2,
    gru_hidden_size=512,
    gru_num_layers=2,
    dropout_rate=0.3,
    use_multiscale=True
).to(device)

# Initialize loss function
loss_function = VICRegLoss(sim_weight=25.0, var_weight=25.0, cov_weight=1.0)

# Train model
best_model, history = train_model(
    audio_encoder=audio_encoder,
    text_encoder=text_encoder,
    loss_function=loss_function,
    train_dataset=multiscale_train_dataset,
    val_dataset=multiscale_val_dataset,
    batch_size=batch_size,
    num_epochs=num_epochs,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    patience=patience,
    device=device
)

# Store history
all_histories['vicreg_multiscale'] = history

# Evaluate retrieval performance
metrics = evaluate_retrieval(best_model, text_encoder, multiscale_val_dataset, device)
print("\nRetrieval Metrics:")
print(f"R@1: {metrics['R@1']:.4f}, R@5: {metrics['R@5']:.4f}, R@10: {metrics['R@10']:.4f}, mAP@10: {metrics['mAP@10']:.4f}")

# Add metrics to history
all_histories['vicreg_multiscale']['metrics'] = metrics

In [None]:
print("\n=== Training with Cosine Loss and Standard Mel Spectrograms ===\n")

# Initialize audio encoder
audio_encoder = AudioEncoderCRNN(
    output_dim=embedding_dim,
    cnn_channels=[64, 128, 256, 512],
    kernel_size=3,
    stride=2,
    gru_hidden_size=512,
    gru_num_layers=2,
    dropout_rate=0.3,
    use_multiscale=False
).to(device)

# Initialize loss function
loss_function = CosineLoss()

# Train model
best_model, history = train_model(
    audio_encoder=audio_encoder,
    text_encoder=text_encoder,
    loss_function=loss_function,
    train_dataset=standard_train_dataset,
    val_dataset=standard_val_dataset,
    batch_size=batch_size,
    num_epochs=num_epochs,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    patience=patience,
    device=device
)

# Store history
all_histories['cosine_standard'] = history

# Evaluate retrieval performance
metrics = evaluate_retrieval(best_model, text_encoder, standard_val_dataset, device)
print("\nRetrieval Metrics:")
print(f"R@1: {metrics['R@1']:.4f}, R@5: {metrics['R@5']:.4f}, R@10: {metrics['R@10']:.4f}, mAP@10: {metrics['mAP@10']:.4f}")

# Add metrics to history
all_histories['cosine_standard']['metrics'] = metrics

In [None]:
print("\n=== Training with Cosine Loss and Multiscale Mel Spectrograms ===\n")

# Initialize audio encoder
audio_encoder = AudioEncoderCRNN(
    output_dim=embedding_dim,
    cnn_channels=[64, 128, 256, 512],
    kernel_size=3,
    stride=2,
    gru_hidden_size=512,
    gru_num_layers=2,
    dropout_rate=0.3,
    use_multiscale=True
).to(device)

# Initialize loss function
loss_function = CosineLoss()

# Train model
best_model, history = train_model(
    audio_encoder=audio_encoder,
    text_encoder=text_encoder,
    loss_function=loss_function,
    train_dataset=multiscale_train_dataset,
    val_dataset=multiscale_val_dataset,
    batch_size=batch_size,
    num_epochs=num_epochs,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    patience=patience,
    device=device
)

# Store history
all_histories['cosine_multiscale'] = history

# Evaluate retrieval performance
metrics = evaluate_retrieval(best_model, text_encoder, multiscale_val_dataset, device)
print("\nRetrieval Metrics:")
print(f"R@1: {metrics['R@1']:.4f}, R@5: {metrics['R@5']:.4f}, R@10: {metrics['R@10']:.4f}, mAP@10: {metrics['mAP@10']:.4f}")

# Add metrics to history
all_histories['cosine_multiscale']['metrics'] = metrics

In [None]:
# Plotting function
def plot_training_histories(histories, save_path='training_results.png'):
    """
    Plot training results in a 3x2 grid (3 loss functions × 2 feature types)
    """
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    fig.suptitle('Training Results for Different Loss Functions and Feature Types', fontsize=16)

    # Define positions and titles
    positions = {
        'infonce_standard': (0, 0),
        'infonce_multiscale': (1, 0),
        'vicreg_standard': (0, 1),
        'vicreg_multiscale': (1, 1),
        'cosine_standard': (0, 2),
        'cosine_multiscale': (1, 2)
    }

    titles = {
        'infonce_standard': 'InfoNCE - Standard Mel',
        'infonce_multiscale': 'InfoNCE - Multiscale Mel',
        'vicreg_standard': 'VICReg - Standard Mel',
        'vicreg_multiscale': 'VICReg - Multiscale Mel',
        'cosine_standard': 'Cosine - Standard Mel',
        'cosine_multiscale': 'Cosine - Multiscale Mel'
    }

    # Plot each history
    for key, history in histories.items():
        if key in positions:
            row, col = positions[key]
            ax = axes[row, col]

            # Plot loss curves
            ax.plot(history['epochs'], history['train_loss'], label='Training Loss', marker='o')
            ax.plot(history['epochs'], history['val_loss'], label='Validation Loss', marker='x')

            # Add title and labels
            ax.set_title(titles[key])
            ax.set_xlabel('Epoch')
            ax.set_ylabel('Loss')
            ax.legend()
            ax.grid(True, linestyle='--', alpha=0.7)

            # Add metrics as text if available
            if 'metrics' in history:
                metrics_text = (f"R@1: {history['metrics']['R@1']:.3f}\n"
                               f"R@5: {history['metrics']['R@5']:.3f}\n"
                               f"R@10: {history['metrics']['R@10']:.3f}\n"
                               f"mAP@10: {history['metrics']['mAP@10']:.3f}")
                ax.text(0.05, 0.95, metrics_text, transform=ax.transAxes, fontsize=9,
                        verticalalignment='top', bbox=dict(boxstyle='round', alpha=0.1))

    plt.tight_layout(rect=[0, 0, 1, 0.96])  # Adjust for the suptitle
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

# Plot all results
plot_training_histories(all_histories)