# Joint Slot Attention for Sound Source Localization

Implementation of "Improving Sound Source Localization with Joint Slot Attention on Image and Audio"

This notebook implements the nearly full paper architecture(with dummy dataset) including:
- Joint Slot Attention mechanism
- Cross-modal attention matching
- All loss functions (contrastive, matching, divergence, reconstruction)
- False negative mitigation
- IQR inference refinement

In [None]:
# Cell 1: Install required packages
!pip install torch torchvision torchaudio librosa soundfile scikit-learn matplotlib seaborn -q

In [None]:
# Cell 2: Import libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
import torchaudio
import numpy as np
import librosa
import soundfile as sf
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import average_precision_score, auc
from torch.utils.data import Dataset, DataLoader
import os
import random
from typing import Tuple, List, Dict
from collections import defaultdict

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Cell 3: Create Dummy Dataset for Testing

We'll create a synthetic dataset with paired images and audio

In [None]:
class DummySoundSourceDataset(Dataset):
    """Dummy dataset for testing sound source localization"""
    
    def __init__(self, num_samples=500, image_size=(224, 224), audio_length=5, sample_rate=16000):
        self.num_samples = num_samples
        self.image_size = image_size
        self.audio_length = audio_length
        self.sample_rate = sample_rate
        self.num_classes = 10  # 10 different sound source categories
        
        # Image transforms
        self.image_transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        print(f"Created dummy dataset with {num_samples} samples")
    
    def __len__(self):
        return self.num_samples
    
    def generate_synthetic_image(self, idx):
        """Generate synthetic image with a colored region (sound source)"""
        np.random.seed(idx)
        
        # Create base image
        image = np.random.randint(50, 100, (*self.image_size, 3), dtype=np.uint8)
        
        # Add a colored region representing sound source
        category = idx % self.num_classes
        
        # Define sound source position and size
        center_y = np.random.randint(50, self.image_size[0] - 50)
        center_x = np.random.randint(50, self.image_size[1] - 50)
        size = np.random.randint(30, 80)
        
        # Different colors for different categories
        colors = [
            [255, 0, 0],    # Red
            [0, 255, 0],    # Green
            [0, 0, 255],    # Blue
            [255, 255, 0],  # Yellow
            [255, 0, 255],  # Magenta
            [0, 255, 255],  # Cyan
            [255, 128, 0],  # Orange
            [128, 0, 255],  # Purple
            [0, 128, 255],  # Light Blue
            [255, 192, 203] # Pink
        ]
        
        color = colors[category]
        
        # Draw colored circle
        y, x = np.ogrid[:self.image_size[0], :self.image_size[1]]
        mask = (x - center_x)**2 + (y - center_y)**2 <= size**2
        image[mask] = color
        
        # Store ground truth mask
        gt_mask = mask.astype(np.float32)
        
        return Image.fromarray(image), gt_mask, category
    
    def generate_synthetic_audio(self, idx, category):
        """Generate synthetic audio signal based on category"""
        np.random.seed(idx + 1000)
        
        # Base audio
        samples = int(self.sample_rate * self.audio_length)
        audio = np.random.randn(samples) * 0.1
        
        # Add frequency components based on category
        base_freq = 200 + category * 100  # Different base frequency per category
        t = np.linspace(0, self.audio_length, samples)
        
        # Add harmonics
        for harmonic in range(1, 4):
            freq = base_freq * harmonic
            audio += 0.3 * np.sin(2 * np.pi * freq * t) / harmonic
        
        # Add some noise
        audio += np.random.randn(samples) * 0.05
        
        # Normalize
        audio = audio / (np.max(np.abs(audio)) + 1e-8)
        
        return audio
    
    def compute_spectrogram(self, audio):
        """Compute log spectrogram"""
        # Compute STFT
        stft = librosa.stft(audio, n_fft=512, hop_length=160, win_length=400)
        spectrogram = np.abs(stft)
        
        # Convert to log scale
        log_spec = librosa.amplitude_to_db(spectrogram, ref=np.max)
        
        # Normalize
        log_spec = (log_spec - log_spec.min()) / (log_spec.max() - log_spec.min() + 1e-8)
        
        # Resize to fixed size (257 freq bins)
        target_freq_bins = 257
        if log_spec.shape[0] < target_freq_bins:
            # Pad
            pad_size = target_freq_bins - log_spec.shape[0]
            log_spec = np.pad(log_spec, ((0, pad_size), (0, 0)), mode='constant')
        else:
            log_spec = log_spec[:target_freq_bins, :]
        
        return torch.FloatTensor(log_spec).unsqueeze(0)  # Add channel dim
    
    def __getitem__(self, idx):
        image, gt_mask, category = self.generate_synthetic_image(idx)
        audio = self.generate_synthetic_audio(idx, category)
        spectrogram = self.compute_spectrogram(audio)
        
        # Apply image transforms
        image_tensor = self.image_transform(image)
        
        return {
            'image': image_tensor,
            'audio': spectrogram,
            'gt_mask': torch.FloatTensor(gt_mask),
            'category': category,
            'idx': idx
        }

# Test dataset
test_dataset = DummySoundSourceDataset(num_samples=10)
sample = test_dataset[0]
print(f"Image shape: {sample['image'].shape}")
print(f"Audio shape: {sample['audio'].shape}")
print(f"GT mask shape: {sample['gt_mask'].shape}")
print(f"Category: {sample['category']}")

## Cell 4: Encoder Networks

Image and Audio encoders using ResNet-18

In [None]:
class ImageEncoder(nn.Module):
    """Image encoder using ResNet-18 with projection layer"""
    
    def __init__(self, output_dim=512):
        super().__init__()
        
        # Load pretrained ResNet-18
        resnet = models.resnet18(pretrained=True)
        
        # Remove the final FC layer and avgpool
        self.features = nn.Sequential(*list(resnet.children())[:-2])
        
        # Add 1x1 conv for projection
        self.projection = nn.Conv2d(512, output_dim, kernel_size=1)
        
    def forward(self, x):
        # x: (B, 3, H, W)
        features = self.features(x)  # (B, 512, H/32, W/32)
        features = self.projection(features)  # (B, output_dim, H/32, W/32)
        return features


class AudioEncoder(nn.Module):
    """Audio encoder using ResNet-18 adapted for spectrograms"""
    
    def __init__(self, output_dim=512):
        super().__init__()
        
        # Create ResNet-18 for 1-channel input
        resnet = models.resnet18(pretrained=False)
        
        # Modify first conv to accept 1 channel
        resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        
        # Remove final FC layer
        self.features = nn.Sequential(*list(resnet.children())[:-2])
        
        # Add max pooling along frequency axis
        self.freq_pool = nn.AdaptiveMaxPool2d((16, None))
        
        # Add 1x1 conv for projection
        self.projection = nn.Conv2d(512, output_dim, kernel_size=1)
        
    def forward(self, x):
        # x: (B, 1, F, T) where F is frequency bins, T is time
        features = self.features(x)  # (B, 512, F/32, T/32)
        features = self.freq_pool(features)  # (B, 512, 16, T/32)
        features = self.projection(features)  # (B, output_dim, 16, T/32)
        return features

# Test encoders
image_encoder = ImageEncoder(output_dim=512).to(device)
audio_encoder = AudioEncoder(output_dim=512).to(device)

# Test with dummy data
dummy_image = torch.randn(2, 3, 224, 224).to(device)
dummy_audio = torch.randn(2, 1, 257, 160).to(device)

with torch.no_grad():
    image_feat = image_encoder(dummy_image)
    audio_feat = audio_encoder(dummy_audio)
    
print(f"Image features shape: {image_feat.shape}")  # Should be (2, 512, 7, 7)
print(f"Audio features shape: {audio_feat.shape}")  # Should be (2, 512, 16, 5)

## Cell 5: Joint Slot Attention Module

The core module that decomposes features into target and off-target slots

In [None]:
class JointSlotAttention(nn.Module):
    """Joint Slot Attention module for feature decomposition"""
    
    def __init__(
        self,
        dim=512,
        num_slots=2,
        num_iterations=5,
        hidden_dim=256,
        eps=1e-8
    ):
        super().__init__()
        
        self.dim = dim
        self.num_slots = num_slots
        self.num_iterations = num_iterations
        self.eps = eps
        
        # Shared initial slots for both modalities
        self.slots_init = nn.Parameter(torch.randn(1, num_slots, dim) * 0.05)
        
        # Layer normalization
        self.norm_input = nn.LayerNorm(dim)
        self.norm_slots = nn.LayerNorm(dim)
        self.norm_pre_ff = nn.LayerNorm(dim)
        
        # Linear projections for keys and values
        self.to_k = nn.Linear(dim, dim, bias=False)
        self.to_v = nn.Linear(dim, dim, bias=False)
        
        # Linear projection for queries (from slots)
        self.to_q = nn.Linear(dim, dim, bias=False)
        
        # GRU for slot updates
        self.gru = nn.GRUCell(dim, dim)
        
        # MLP for slot refinement
        self.mlp = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, dim)
        )
    
    def forward(self, features):
        """
        Args:
            features: (B, C, H, W) or (B, C, F, T) feature maps
        Returns:
            slots: (B, num_slots, C) slot representations
            queries: (B, num_slots, C) query representations
            attention: (B, num_slots, H*W) or (B, num_slots, F*T) attention maps
        """
        B = features.shape[0]
        
        # Reshape features to (B, N, C) where N = H*W or F*T
        if features.dim() == 4:
            B, C, H, W = features.shape
            features_flat = features.flatten(2).permute(0, 2, 1)  # (B, H*W, C)
        else:
            raise ValueError("Features must be 4D (B, C, H, W)")
        
        N = features_flat.shape[1]
        
        # Normalize and project to keys and values
        features_norm = self.norm_input(features_flat)  # (B, N, C)
        k = self.to_k(features_norm)  # (B, N, C)
        v = self.to_v(features_norm)  # (B, N, C)
        
        # Initialize slots
        slots = self.slots_init.expand(B, -1, -1).contiguous()  # (B, num_slots, C)
        
        # Iterative slot attention
        for _ in range(self.num_iterations):
            slots_prev = slots
            
            # Normalize slots
            slots_norm = self.norm_slots(slots)  # (B, num_slots, C)
            
            # Compute queries
            q = self.to_q(slots_norm)  # (B, num_slots, C)
            
            # Compute attention weights
            # (B, num_slots, C) @ (B, N, C).T -> (B, num_slots, N)
            attn_logits = torch.einsum('bnc,bmc->bnm', q, k) / np.sqrt(self.dim)
            attn = F.softmax(attn_logits, dim=-1)  # (B, num_slots, N)
            
            # Weighted mean
            attn_norm = attn / (attn.sum(dim=-1, keepdim=True) + self.eps)  # (B, num_slots, N)
            updates = torch.einsum('bnm,bmc->bnc', attn_norm, v)  # (B, num_slots, C)
            
            # GRU update
            slots = self.gru(
                updates.reshape(-1, self.dim),
                slots_prev.reshape(-1, self.dim)
            ).reshape(B, self.num_slots, self.dim)
            
            # MLP with residual
            slots = slots + self.mlp(self.norm_pre_ff(slots))
        
        # Compute final queries
        queries = self.to_q(self.norm_slots(slots))
        
        # Compute final attention
        attn_logits = torch.einsum('bnc,bmc->bnm', queries, k) / np.sqrt(self.dim)
        final_attn = F.softmax(attn_logits, dim=-1)
        
        return slots, queries, final_attn, k, v

# Test Joint Slot Attention
jsa = JointSlotAttention(dim=512, num_slots=2, num_iterations=5).to(device)

# Test with dummy features
dummy_features = torch.randn(2, 512, 7, 7).to(device)
slots, queries, attn, keys, values = jsa(dummy_features)

print(f"Slots shape: {slots.shape}")  # (2, 2, 512)
print(f"Queries shape: {queries.shape}")  # (2, 2, 512)
print(f"Attention shape: {attn.shape}")  # (2, 2, 49)
print(f"Keys shape: {keys.shape}")  # (2, 49, 512)
print(f"Values shape: {values.shape}")  # (2, 49, 512)


## Cell 6: Slot Decoders

Reconstruction decoders for slot reconstruction loss

In [None]:
class SlotDecoder(nn.Module):
    """MLP decoder for slot reconstruction"""
    
    def __init__(self, slot_dim=512, hidden_dim=1024, output_dim=512):
        super().__init__()
        
        self.mlp = nn.Sequential(
            nn.Linear(slot_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, slots):
        """
        Args:
            slots: (B, num_slots, slot_dim)
        Returns:
            reconstruction: (B, num_slots, output_dim)
        """
        return self.mlp(slots)

# Test decoder
decoder = SlotDecoder(slot_dim=512, output_dim=512).to(device)
reconstruction = decoder(slots)
print(f"Reconstruction shape: {reconstruction.shape}")  # (2, 2, 512)

## Cell 7: Complete Model Architecture

Putting everything together

In [None]:
class JointSlotAttentionSSL(nn.Module):
    """Complete Sound Source Localization model with Joint Slot Attention"""
    
    def __init__(
        self,
        feature_dim=512,
        num_slots=2,
        num_iterations=5,
        temperature=0.03
    ):
        super().__init__()
        
        self.feature_dim = feature_dim
        self.num_slots = num_slots
        self.temperature = temperature
        
        # Encoders
        self.image_encoder = ImageEncoder(output_dim=feature_dim)
        self.audio_encoder = AudioEncoder(output_dim=feature_dim)
        
        # Joint Slot Attention (shared initial slots)
        self.slot_attention = JointSlotAttention(
            dim=feature_dim,
            num_slots=num_slots,
            num_iterations=num_iterations
        )
        
        # Decoders for reconstruction loss
        self.image_decoder = SlotDecoder(slot_dim=feature_dim, output_dim=feature_dim)
        self.audio_decoder = SlotDecoder(slot_dim=feature_dim, output_dim=feature_dim)
    
    def encode(self, images, audios):
        """Encode images and audios to features"""
        image_features = self.image_encoder(images)  # (B, C, H, W)
        audio_features = self.audio_encoder(audios)  # (B, C, F, T)
        return image_features, audio_features
    
    def decompose(self, image_features, audio_features):
        """Decompose features into slots using joint slot attention"""
        # Image slot attention
        image_slots, image_queries, image_attn, image_keys, image_values = self.slot_attention(image_features)
        
        # Audio slot attention (shares the same initial slots)
        audio_slots, audio_queries, audio_attn, audio_keys, audio_values = self.slot_attention(audio_features)
        
        return {
            'image_slots': image_slots,  # (B, 2, C)
            'image_queries': image_queries,
            'image_keys': image_keys,
            'image_values': image_values,
            'audio_slots': audio_slots,
            'audio_queries': audio_queries,
            'audio_keys': audio_keys,
            'audio_values': audio_values
        }
    
    def forward(self, images, audios):
        """Forward pass"""
        # Encode
        image_features, audio_features = self.encode(images, audios)
        
        # Decompose
        decomposition = self.decompose(image_features, audio_features)
        
        # Extract target slots (first slot is target)
        image_target_slot = decomposition['image_slots'][:, 0]  # (B, C)
        image_off_target_slot = decomposition['image_slots'][:, 1]
        audio_target_slot = decomposition['audio_slots'][:, 0]
        audio_off_target_slot = decomposition['audio_slots'][:, 1]
        
        # Store feature map shapes for attention computation
        B, C, H, W = image_features.shape
        _, _, F, T = audio_features.shape
        
        return {
            **decomposition,
            'image_features': image_features,
            'audio_features': audio_features,
            'image_target_slot': image_target_slot,
            'image_off_target_slot': image_off_target_slot,
            'audio_target_slot': audio_target_slot,
            'audio_off_target_slot': audio_off_target_slot,
            'feature_shapes': {'B': B, 'C': C, 'H': H, 'W': W, 'F': F, 'T': T}
        }

# Test complete model
model = JointSlotAttentionSSL(feature_dim=512, num_slots=2, num_iterations=5).to(device)

dummy_images = torch.randn(2, 3, 224, 224).to(device)
dummy_audios = torch.randn(2, 1, 257, 160).to(device)

with torch.no_grad():
    output = model(dummy_images, dummy_audios)

print("Model output keys:", list(output.keys()))
print(f"Image target slot shape: {output['image_target_slot'].shape}")
print(f"Audio target slot shape: {output['audio_target_slot'].shape}")

## Cell 8: Loss Functions

Implementing all four losses: Contrastive, Attention Matching, Slot Divergence, and Reconstruction

In [None]:
class SSLoss(nn.Module):
    """Combined loss for Sound Source Localization"""
    
    def __init__(
        self,
        temperature=0.03,
        lambda_match=100.0,
        lambda_div=0.1,
        lambda_recon=0.1
    ):
        super().__init__()
        
        self.temperature = temperature
        self.lambda_match = lambda_match
        self.lambda_div = lambda_div
        self.lambda_recon = lambda_recon
    
    def cosine_similarity(self, x, y):
        """Compute cosine similarity with temperature"""
        x = F.normalize(x, dim=-1)
        y = F.normalize(y, dim=-1)
        return torch.matmul(x, y.T) / self.temperature
    
    def contrastive_loss(self, image_target_slots, audio_target_slots, false_negatives_mask=None):
        """
        InfoNCE contrastive loss between target slots
        Args:
            image_target_slots: (B, C)
            audio_target_slots: (B, C)
            false_negatives_mask: (B, B) boolean mask where True indicates samples to exclude
        """
        B = image_target_slots.shape[0]
        
        # Compute similarity matrix
        logits = self.cosine_similarity(image_target_slots, audio_target_slots)  # (B, B)
        
        # Labels: diagonal elements are positives
        labels = torch.arange(B).to(logits.device)
        
        # If false negatives mask provided, set those logits to very negative value
        if false_negatives_mask is not None:
            logits = logits.masked_fill(false_negatives_mask, -1e9)
        
        # InfoNCE loss
        loss_i2a = F.cross_entropy(logits, labels)
        loss_a2i = F.cross_entropy(logits.T, labels)
        
        return (loss_i2a + loss_a2i) / 2
    
    def attention_matching_loss(
        self,
        image_queries,
        audio_queries,
        image_keys,
        audio_keys
    ):
        """
        Cross-modal attention matching loss
        Encourages cross-modal attention to match intra-modal attention
        """
        # Target queries (first slot)
        image_target_query = image_queries[:, 0]  # (B, C)
        audio_target_query = audio_queries[:, 0]
        
        # Compute intra-modal attention (image -> image)
        # (B, C) @ (B, N, C).T -> (B, N)
        intra_image_attn = torch.einsum('bc,bnc->bn', image_target_query, image_keys)
        intra_image_attn = F.softmax(intra_image_attn / np.sqrt(image_keys.shape[-1]), dim=-1)
        
        # Compute cross-modal attention (audio -> image)
        cross_audio2image_attn = torch.einsum('bc,bnc->bn', audio_target_query, image_keys)
        cross_audio2image_attn = F.softmax(cross_audio2image_attn / np.sqrt(image_keys.shape[-1]), dim=-1)
        
        # Compute intra-modal attention (audio -> audio)
        intra_audio_attn = torch.einsum('bc,bnc->bn', audio_target_query, audio_keys)
        intra_audio_attn = F.softmax(intra_audio_attn / np.sqrt(audio_keys.shape[-1]), dim=-1)
        
        # Compute cross-modal attention (image -> audio)
        cross_image2audio_attn = torch.einsum('bc,bnc->bn', image_target_query, audio_keys)
        cross_image2audio_attn = F.softmax(cross_image2audio_attn / np.sqrt(audio_keys.shape[-1]), dim=-1)
        
        # Matching loss with stop-gradient on intra-modal attention
        loss_match_image = F.mse_loss(cross_audio2image_attn, intra_image_attn.detach())
        loss_match_audio = F.mse_loss(cross_image2audio_attn, intra_audio_attn.detach())
        
        return (loss_match_image + loss_match_audio) / 2
    
    def slot_divergence_loss(self, image_slots, audio_slots):
        """
        Slot divergence loss - encourages slots to be different
        """
        # Normalize slots
        image_slots_norm = F.normalize(image_slots, dim=-1)
        audio_slots_norm = F.normalize(audio_slots, dim=-1)
        
        # Compute cosine similarity between target and off-target slots
        image_sim = torch.einsum('bc,bc->b', image_slots_norm[:, 0], image_slots_norm[:, 1])
        audio_sim = torch.einsum('bc,bc->b', audio_slots_norm[:, 0], audio_slots_norm[:, 1])
        
        # We want similarity to be low (slots should be different)
        # So we minimize the similarity (maximize divergence)
        loss_div = (image_sim.mean() + audio_sim.mean()) / 2
        
        return loss_div
    
    def reconstruction_loss(
        self,
        image_slots,
        audio_slots,
        image_features,
        audio_features,
        image_decoder,
        audio_decoder
    ):
        """
        Slot reconstruction loss
        """
        # Decode slots
        image_recon = image_decoder(image_slots)  # (B, 2, C)
        audio_recon = audio_decoder(audio_slots)
        
        # Pool original features to match slot dimension
        image_features_pooled = image_features.mean(dim=(2, 3))  # (B, C)
        audio_features_pooled = audio_features.mean(dim=(2, 3))
        
        # Reconstruction loss (MSE)
        # Target slot should reconstruct pooled features
        loss_recon_image = F.mse_loss(image_recon[:, 0], image_features_pooled)
        loss_recon_audio = F.mse_loss(audio_recon[:, 0], audio_features_pooled)
        
        return (loss_recon_image + loss_recon_audio) / 2
    
    def forward(
        self,
        model_output,
        image_decoder,
        audio_decoder,
        false_negatives_mask=None
    ):
        """Compute all losses"""
        
        # Contrastive loss
        loss_cotr = self.contrastive_loss(
            model_output['image_target_slot'],
            model_output['audio_target_slot'],
            false_negatives_mask
        )
        
        # Attention matching loss
        loss_match = self.attention_matching_loss(
            model_output['image_queries'],
            model_output['audio_queries'],
            model_output['image_keys'],
            model_output['audio_keys']
        )
        
        # Slot divergence loss
        loss_div = self.slot_divergence_loss(
            model_output['image_slots'],
            model_output['audio_slots']
        )
        
        # Reconstruction loss
        loss_recon = self.reconstruction_loss(
            model_output['image_slots'],
            model_output['audio_slots'],
            model_output['image_features'],
            model_output['audio_features'],
            image_decoder,
            audio_decoder
        )
        
        # Total loss
        total_loss = loss_cotr + self.lambda_match * loss_match + self.lambda_div * loss_div + self.lambda_recon * loss_recon

        
        return {
            'total': total_loss,
            'contrastive': loss_cotr,
            'matching': loss_match,
            'divergence': loss_div,
            'reconstruction': loss_recon
        }

# Test loss computation
criterion = SSLoss(temperature=0.03, lambda_match=100.0, lambda_div=0.1, lambda_recon=0.1)
losses = criterion(output, model.image_decoder, model.audio_decoder)

print("Loss components:")
for key, value in losses.items():
    print(f"  {key}: {value.item():.4f}")

## Cell 9: False Negative Mitigation

Implementation of k-reciprocal nearest neighbors

In [None]:
def compute_false_negatives(image_target_slots, audio_target_slots, k=20):
    """
    Compute false negatives using k-reciprocal nearest neighbors
    
    Args:
        image_target_slots: (B, C)
        audio_target_slots: (B, C)
        k: number of nearest neighbors
    
    Returns:
        false_negatives_mask: (B, B) boolean mask where True indicates false negatives
    """
    B = image_target_slots.shape[0]
    device = image_target_slots.device
    
    # Clamp k to be at most B-1 (can't have more neighbors than other samples)
    k = min(k, B - 1)
    
    if k <= 0:
        # If batch size is 1, no false negatives possible
        return torch.zeros(B, B, dtype=torch.bool, device=device)
    
    with torch.no_grad():
        # Normalize features
        image_norm = F.normalize(image_target_slots, dim=-1)
        audio_norm = F.normalize(audio_target_slots, dim=-1)
        
        # Compute similarity matrices
        image_sim = torch.matmul(image_norm, image_norm.T)  # (B, B)
        audio_sim = torch.matmul(audio_norm, audio_norm.T)  # (B, B)
        
        # Get k-nearest neighbors for each sample
        # k+1 because sample itself is included, but we clamp to B
        k_actual = min(k + 1, B)
        _, image_nn = torch.topk(image_sim, k_actual, dim=-1)
        _, audio_nn = torch.topk(audio_sim, k_actual, dim=-1)
        
        # Remove self from nearest neighbors
        image_nn = image_nn[:, 1:]  # (B, k)
        audio_nn = audio_nn[:, 1:]
        
        # Check reciprocal relationship
        false_negatives_mask = torch.zeros(B, B, dtype=torch.bool, device=device)
        
        for i in range(B):
            for j in image_nn[i]:
                j = j.item()
                if i in image_nn[j]:
                    # Check if also reciprocal in audio
                    if j in audio_nn[i] and i in audio_nn[j]:
                        false_negatives_mask[i, j] = True
                        false_negatives_mask[j, i] = True
    
    # Don't mask diagonal (positives should remain)
    false_negatives_mask.fill_diagonal_(False)
    
    return false_negatives_mask


## Cell 10: Inference and Localization

Generate sound source localization maps with IQR refinement

In [None]:
class InferenceModule:
    """Inference module for sound source localization"""
    
    def __init__(self, model, threshold=0.5, alpha=0.6):
        self.model = model
        self.threshold = threshold
        self.alpha = alpha  # Balance parameter for IQR
        self.model.eval()
    
    @torch.no_grad()
    def localize(self, images, audios, return_raw=False):
        """
        Perform sound source localization
        
        Args:
            images: (B, 3, H, W)
            audios: (B, 1, F, T)
            return_raw: if True, return raw attention maps
        
        Returns:
            localization_maps: (B, H, W) binary or soft localization maps
        """
        # Forward pass
        output = self.model(images, audios)
        
        # Get feature shapes
        B, C, H, W = output['image_features'].shape
        
        # Get target query and image keys
        audio_target_query = output['audio_queries'][:, 0]  # (B, C)
        image_target_query = output['image_queries'][:, 0]
        image_keys = output['image_keys']  # (B, H*W, C)
        
        # Compute cross-modal attention (audio target -> image)
        cross_attn = torch.einsum('bc,bnc->bn', audio_target_query, image_keys)
        cross_attn = cross_attn / np.sqrt(C)
        cross_attn = F.softmax(cross_attn, dim=-1)  # (B, H*W)
        
        # Reshape to image dimensions
        cross_attn_map = cross_attn.view(B, H, W)  # (B, H, W)
        
        # Compute intra-modal attention for IQR (image target -> image)
        intra_attn = torch.einsum('bc,bnc->bn', image_target_query, image_keys)
        intra_attn = intra_attn / np.sqrt(C)
        intra_attn = F.softmax(intra_attn, dim=-1)
        intra_attn_map = intra_attn.view(B, H, W)
        
        # Image-Query based Refinement (IQR)
        refined_map = self.alpha * cross_attn_map + (1 - self.alpha) * intra_attn_map
        
        if return_raw:
            return {
                'cross_modal_attention': cross_attn_map,
                'intra_modal_attention': intra_attn_map,
                'refined_attention': refined_map,
                'binary_mask': (refined_map > self.threshold).float()
            }
        
        # Upsample to original image size (224x224)
        refined_map_upsampled = F.interpolate(
            refined_map.unsqueeze(1),
            size=(224, 224),
            mode='bilinear',
            align_corners=False
        ).squeeze(1)
        
        # Threshold to get binary mask
        binary_mask = (refined_map_upsampled > self.threshold).float()
        
        return refined_map_upsampled, binary_mask

# Test inference
inference = InferenceModule(model, threshold=0.5, alpha=0.6)
loc_map, binary_mask = inference.localize(dummy_images, dummy_audios)

print(f"Localization map shape: {loc_map.shape}")
print(f"Binary mask shape: {binary_mask.shape}")

## Cell 11: Training Loop

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device, epoch, use_false_negative_mitigation=True, k_neighbors=20):
    """Train for one epoch"""
    model.train()
    total_losses = defaultdict(float)
    num_batches = 0
    
    for batch_idx, batch in enumerate(dataloader):
        images = batch['image'].to(device)
        audios = batch['audio'].to(device)
        
        # Forward pass
        output = model(images, audios)
        
        # Compute false negatives if enabled
        false_negatives_mask = None
        if use_false_negative_mitigation and batch_idx % 5 == 0:  # Compute every 5 batches for efficiency
            false_negatives_mask = compute_false_negatives(
                output['image_target_slot'],
                output['audio_target_slot'],
                k=k_neighbors
            )
        
        # Compute losses
        losses = criterion(output, model.image_decoder, model.audio_decoder, false_negatives_mask)
        
        # Backward pass
        optimizer.zero_grad()
        losses['total'].backward()
        optimizer.step()
        
        # Accumulate losses
        for key, value in losses.items():
            total_losses[key] += value.item()
        
        num_batches += 1
        
        if batch_idx % 10 == 0:
            print(f"Epoch {epoch}, Batch {batch_idx}/{len(dataloader)}, "
                  f"Total Loss: {losses['total'].item():.4f}")
    
    # Average losses
    avg_losses = {k: v / num_batches for k, v in total_losses.items()}
    return avg_losses


@torch.no_grad()
def evaluate(model, dataloader, device):
    """Evaluate the model"""
    model.eval()
    inference = InferenceModule(model, threshold=0.5, alpha=0.6)
    
    all_preds = []
    all_gts = []
    
    for batch in dataloader:
        images = batch['image'].to(device)
        audios = batch['audio'].to(device)
        gt_masks = batch['gt_mask'].numpy()
        
        # Get localization maps
        loc_maps, binary_masks = inference.localize(images, audios)
        
        # Store predictions and ground truths
        loc_maps_np = loc_maps.cpu().numpy()
        
        for pred, gt in zip(loc_maps_np, gt_masks):
            all_preds.append(pred.flatten())
            all_gts.append(gt.flatten())
    
    # Compute metrics
    all_preds = np.concatenate(all_preds)
    all_gts = np.concatenate(all_gts)
    
    # Average Precision
    ap = average_precision_score(all_gts, all_preds)
    
    # cIoU at threshold 0.5
    binary_preds = (all_preds > 0.5).astype(float)
    intersection = np.sum((binary_preds == 1) & (all_gts == 1))
    union = np.sum((binary_preds == 1) | (all_gts == 1))
    ciou = intersection / (union + 1e-8)
    
    return {'AP': ap, 'cIoU': ciou}

print("Training functions defined successfully")

## Cell 12: Main Training Execution

In [None]:
# Hyperparameters (from paper)
CONFIG = {
    'feature_dim': 512,
    'num_slots': 2,
    'num_iterations': 5,
    'temperature': 0.03,
    'lambda_match': 100.0,
    'lambda_div': 0.1,
    'lambda_recon': 0.1,
    'learning_rate': 5e-5,
    'weight_decay': 1e-2,
    'batch_size': 16,  # Reduced for Kaggle GPU
    'num_epochs': 5,   # Reduced for quick testing
    'k_neighbors': 20,
    'threshold': 0.5,
    'alpha': 0.6
}

print("Configuration:")
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

# Create datasets
train_dataset = DummySoundSourceDataset(num_samples=500)
val_dataset = DummySoundSourceDataset(num_samples=100)

# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    num_workers=0,  # Set to 0 for Kaggle
    pin_memory=True if torch.cuda.is_available() else False
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=0
)

print(f"\nTrain dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")

In [None]:
# Initialize model
model = JointSlotAttentionSSL(
    feature_dim=CONFIG['feature_dim'],
    num_slots=CONFIG['num_slots'],
    num_iterations=CONFIG['num_iterations'],
    temperature=CONFIG['temperature']
).to(device)

# Initialize loss
criterion = SSLoss(
    temperature=CONFIG['temperature'],
    lambda_match=CONFIG['lambda_match'],
    lambda_div=CONFIG['lambda_div'],
    lambda_recon=CONFIG['lambda_recon']
)

# Initialize optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=CONFIG['learning_rate'],
    weight_decay=CONFIG['weight_decay']
)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

In [None]:
# Training loop
print("\n" + "="*50)
print("Starting Training")
print("="*50 + "\n")

history = {
    'train_loss': [],
    'val_metrics': []
}

for epoch in range(1, CONFIG['num_epochs'] + 1):
    print(f"\nEpoch {epoch}/{CONFIG['num_epochs']}")
    print("-" * 30)
    
    # Train
    train_losses = train_epoch(
        model,
        train_loader,
        criterion,
        optimizer,
        device,
        epoch,
        use_false_negative_mitigation=True,
        k_neighbors=CONFIG['k_neighbors']
    )
    
    print(f"\nTrain Losses - Total: {train_losses['total']:.4f}, "
          f"Contr: {train_losses['contrastive']:.4f}, "
          f"Match: {train_losses['matching']:.4f}")
    
    history['train_loss'].append(train_losses)
    
    # Evaluate
    val_metrics = evaluate(model, val_loader, device)
    print(f"Validation - AP: {val_metrics['AP']:.4f}, cIoU: {val_metrics['cIoU']:.4f}")
    
    history['val_metrics'].append(val_metrics)

print("\n" + "="*50)
print("Training Complete!")
print("="*50)

## Cell 13: Visualization

In [None]:
def visualize_results(model, dataset, num_samples=4):
    """Visualize sound source localization results"""
    model.eval()
    inference = InferenceModule(model, threshold=0.5, alpha=0.6)
    
    fig, axes = plt.subplots(num_samples, 4, figsize=(16, 4*num_samples))
    
    for i in range(num_samples):
        # Get sample
        sample = dataset[i]
        image = sample['image'].unsqueeze(0).to(device)
        audio = sample['audio'].unsqueeze(0).to(device)
        gt_mask = sample['gt_mask'].numpy()
        
        # Denormalize image for visualization
        img_vis = sample['image'].permute(1, 2, 0).cpu().numpy()
        img_vis = img_vis * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        img_vis = np.clip(img_vis, 0, 1)
        
        # Get localization
        with torch.no_grad():
            results = inference.localize(image, audio, return_raw=True)
            refined_map = results['refined_attention'].squeeze().cpu().numpy()
            binary_mask = results['binary_mask'].squeeze().cpu().numpy()
        
        # Upsample to match image size
        refined_map_up = F.interpolate(
            torch.FloatTensor(refined_map).unsqueeze(0).unsqueeze(0),
            size=(224, 224),
            mode='bilinear',
            align_corners=False
        ).squeeze().numpy()
        
        # Plot
        axes[i, 0].imshow(img_vis)
        axes[i, 0].set_title('Original Image')
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(gt_mask, cmap='jet')
        axes[i, 1].set_title('Ground Truth')
        axes[i, 1].axis('off')
        
        axes[i, 2].imshow(refined_map_up, cmap='jet')
        axes[i, 2].set_title('Predicted Heatmap')
        axes[i, 2].axis('off')
        
        # Overlay
        axes[i, 3].imshow(img_vis)
        axes[i, 3].imshow(refined_map_up, cmap='jet', alpha=0.5)
        axes[i, 3].set_title('Overlay')
        axes[i, 3].axis('off')
    
    plt.tight_layout()
    plt.savefig('ssl_results.png', dpi=150, bbox_inches='tight')
    plt.show()
    print("Results saved to ssl_results.png")

# Visualize results
visualize_results(model, val_dataset, num_samples=4)

In [None]:
# Plot training history
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss curves
epochs = range(1, len(history['train_loss']) + 1)
axes[0].plot(epochs, [l['total'] for l in history['train_loss']], label='Total', marker='o')
axes[0].plot(epochs, [l['contrastive'] for l in history['train_loss']], label='Contrastive', marker='s')
axes[0].plot(epochs, [l['matching'] for l in history['train_loss']], label='Matching', marker='^')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Losses')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Validation metrics
axes[1].plot(epochs, [m['AP'] for m in history['val_metrics']], label='AP', marker='o')
axes[1].plot(epochs, [m['cIoU'] for m in history['val_metrics']], label='cIoU', marker='s')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Metric Value')
axes[1].set_title('Validation Metrics')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('training_history.png', dpi=150, bbox_inches='tight')
plt.show()
print("Training history saved to training_history.png")

## Cell 14: Cross-Modal Retrieval Demo

Demonstrate the model's ability to retrieve matching audio given an image and vice versa

In [None]:
def cross_modal_retrieval(model, dataset, num_queries=3, top_k=3):
    """Demonstrate cross-modal retrieval"""
    model.eval()
    
    # Extract all features
    all_image_slots = []
    all_audio_slots = []
    
    dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
    
    with torch.no_grad():
        for batch in dataloader:
            images = batch['image'].to(device)
            audios = batch['audio'].to(device)
            
            output = model(images, audios)
            all_image_slots.append(output['image_target_slot'].cpu())
            all_audio_slots.append(output['audio_target_slot'].cpu())
    
    all_image_slots = torch.cat(all_image_slots, dim=0)  # (N, C)
    all_audio_slots = torch.cat(all_audio_slots, dim=0)  # (N, C)
    
    # Normalize
    image_norm = F.normalize(all_image_slots, dim=-1)
    audio_norm = F.normalize(all_audio_slots, dim=-1)
    
    # Compute similarity matrix
    sim_matrix = torch.matmul(image_norm, audio_norm.T)  # (N, N)
    
    # Visualize retrieval results
    fig, axes = plt.subplots(num_queries, top_k + 1, figsize=(3*(top_k+1), 3*num_queries))
    
    for q in range(num_queries):
        query_idx = q * (len(dataset) // num_queries)
        query_sample = dataset[query_idx]
        
        # Show query image
        img_vis = query_sample['image'].permute(1, 2, 0).cpu().numpy()
        img_vis = img_vis * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        img_vis = np.clip(img_vis, 0, 1)
        
        axes[q, 0].imshow(img_vis)
        axes[q, 0].set_title(f'Query (Cat: {query_sample["category"]})')
        axes[q, 0].axis('off')
        
        # Get top-k retrieved audio indices
        similarities = sim_matrix[query_idx]
        top_indices = torch.topk(similarities, top_k).indices
        
        # Show retrieved images (corresponding to top audio matches)
        for k, idx in enumerate(top_indices):
            retrieved_sample = dataset[idx.item()]
            img_vis_k = retrieved_sample['image'].permute(1, 2, 0).cpu().numpy()
            img_vis_k = img_vis_k * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
            img_vis_k = np.clip(img_vis_k, 0, 1)
            
            match_status = "✓" if retrieved_sample['category'] == query_sample['category'] else "✗"
            axes[q, k+1].imshow(img_vis_k)
            axes[q, k+1].set_title(f'Rank {k+1} {match_status}\n(Cat: {retrieved_sample["category"]})')
            axes[q, k+1].axis('off')
    
    plt.tight_layout()
    plt.savefig('cross_modal_retrieval.png', dpi=150, bbox_inches='tight')
    plt.show()
    print("Cross-modal retrieval results saved to cross_modal_retrieval.png")

# Run retrieval demo
cross_modal_retrieval(model, val_dataset, num_queries=3, top_k=3)


*The cross-modal retrieval visualization gives an intuitive view of how the model connects what it sees with what it hears. Given an input image, the model focuses on the main object that is likely to produce sound and uses this representation to search for similar audio examples.

When the retrieved sounds match the visual object, it suggests that the model has learned a meaningful audio–visual association at the object level. Mismatches are also informative, as they often show that the model is responding to shared acoustic characteristics rather than true semantic correspondence. Overall, the visualization supports the idea that separating target and off-target slots helps isolate object-specific information while reducing the influence of background content.*

## Cell 15: Save Model

In [None]:
# Save model checkpoint
checkpoint = {
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'config': CONFIG,
    'history': history
}

torch.save(checkpoint, 'joint_slot_attention_ssl.pth')
print("Model saved to joint_slot_attention_ssl.pth")

# Print final summary
print("\n" + "="*50)
print("FINAL SUMMARY")
print("="*50)
print(f"Final Validation AP: {history['val_metrics'][-1]['AP']:.4f}")
print(f"Final Validation cIoU: {history['val_metrics'][-1]['cIoU']:.4f}")
print(f"Total Training Epochs: {CONFIG['num_epochs']}")
print("="*50)

## Cell 16: Ablation Study 

Test the impact of each loss component

In [None]:
def run_ablation_study(base_config, ablation_configs, train_dataset, val_dataset, device):
    """Run ablation study by training with different loss configurations"""
    results = {}
    
    for config_name, config_changes in ablation_configs.items():
        print(f"\n{'='*50}")
        print(f"Running: {config_name}")
        print(f"{'='*50}")
        
        # Merge configs
        config = {**base_config, **config_changes}
        
        # Create model
        model = JointSlotAttentionSSL(
            feature_dim=config['feature_dim'],
            num_slots=config['num_slots'],
            num_iterations=config['num_iterations']
        ).to(device)
        
        # Create loss with ablation
        criterion = SSLoss(
            temperature=config['temperature'],
            lambda_match=config.get('lambda_match', 100.0),
            lambda_div=config.get('lambda_div', 0.1),
            lambda_recon=config.get('lambda_recon', 0.1)
        )
        
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=config['learning_rate'],
            weight_decay=config['weight_decay']
        )
        
        # Train for fewer epochs for quick ablation
        num_epochs = 2
        train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False)
        
        for epoch in range(1, num_epochs + 1):
            train_losses = train_epoch(model, train_loader, criterion, optimizer, device, epoch)
            print(f"Epoch {epoch}: Loss = {train_losses['total']:.4f}")
        
        # Evaluate
        val_metrics = evaluate(model, val_loader, device)
        results[config_name] = val_metrics
        print(f"Final - AP: {val_metrics['AP']:.4f}, cIoU: {val_metrics['cIoU']:.4f}")
    
    return results

# Define ablation configurations
ablation_configs = {
    'Full Model': {},
    'No Matching Loss': {'lambda_match': 0.0},
    'No Divergence Loss': {'lambda_div': 0.0},
    'No Reconstruction Loss': {'lambda_recon': 0.1},
    'Contrastive Only': {'lambda_match': 0.0, 'lambda_div': 0.0, 'lambda_recon': 0.0}
}




In [None]:
# run ablation study 
ablation_results = run_ablation_study(CONFIG, ablation_configs, train_dataset, val_dataset, device)

# Print results
print("\n" + "="*50)
print("ABLATION STUDY RESULTS")
print("="*50)
for name, metrics in ablation_results.items():
     print(f"{name:25s} - AP: {metrics['AP']:.4f}, cIoU: {metrics['cIoU']:.4f}")

Even with the simplified dataset and limited training, the ablation study shows consistent patterns. All model variants report zero cIoU, suggesting that the model has not yet learned precise spatial localization. This outcome is more likely due to coarse heatmaps, strict evaluation thresholds, and insufficient training, rather than a complete lack of learning.

At the same time, the variation in AP across configurations indicates that the model does capture meaningful audio–visual relationships at a coarse level. The full model performs best, supporting the idea that the losses work better in combination. The most noticeable drop occurs when the matching loss is removed, underscoring its role in cross-modal alignment. The divergence loss provides additional improvement by encouraging distinct slot representations, while the reconstruction loss appears to act mainly as a mild regularizer.> 