# Attention in Audio Processing

This notebook explores how attention mechanisms are applied in audio processing tasks. We'll cover:

1. Audio feature extraction
2. Attention in speech recognition
3. Attention in audio classification
4. Visualizing audio attention patterns

## Why Attention in Audio?

Audio processing presents unique challenges:

1. **Long Sequences**: Audio signals can be very long
2. **Temporal Dependencies**: Important information can be spread across time
3. **Variable Length**: Different audio samples have different durations

Attention helps by:

- Focusing on relevant parts of the audio signal
- Handling variable-length sequences
- Capturing long-range dependencies

In [None]:
import torch
import torch.nn as nn
import torchaudio
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Tuple, Optional
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## Audio Feature Extraction

First, let's implement a simple audio feature extractor:

In [None]:
class AudioFeatureExtractor(nn.Module):
    def __init__(
        self,
        n_mels: int = 80,
        n_fft: int = 400,
        hop_length: int = 160,
        win_length: int = 400
    ):
        super().__init__()
        
        # Mel spectrogram parameters
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.win_length = win_length
        
        # Mel spectrogram transform
        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=16000,
            n_fft=n_fft,
            hop_length=hop_length,
            win_length=win_length,
            n_mels=n_mels
        )
        
        # Log-mel transform
        self.log_transform = torchaudio.transforms.AmplitudeToDB()
        
    def forward(self, waveform: torch.Tensor) -> torch.Tensor:
        """Extract mel spectrogram features from waveform."""
        # Compute mel spectrogram
        mel_spec = self.mel_transform(waveform)
        
        # Convert to log scale
        log_mel_spec = self.log_transform(mel_spec)
        
        return log_mel_spec

## Audio Attention Model

Now, let's implement an attention-based audio processing model:

In [None]:
class AudioAttentionModel(nn.Module):
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        num_heads: int,
        num_classes: int,
        dropout: float = 0.1
    ):
        super().__init__()
        
        # Feature projection
        self.projection = nn.Linear(input_dim, hidden_dim)
        
        # Multi-head attention
        self.attention = nn.MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=num_heads,
            dropout=dropout
        )
        
        # Layer normalization
        self.norm = nn.LayerNorm(hidden_dim)
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_classes)
        )
        
    def forward(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Project features
        x = self.projection(x)
        
        # Transpose for attention (sequence first)
        x = x.transpose(0, 1)
        
        # Apply attention
        attn_output, attn_weights = self.attention(x, x, x, key_padding_mask=mask)
        
        # Layer normalization
        x = self.norm(attn_output)
        
        # Transpose back (batch first)
        x = x.transpose(0, 1)
        
        # Global average pooling
        x = x.mean(dim=1)
        
        # Classification
        logits = self.classifier(x)
        
        return logits, attn_weights

## Visualizing Audio Attention

Let's create functions to visualize attention patterns in audio:

In [None]:
def plot_audio_attention(
    waveform: torch.Tensor,
    attention_weights: torch.Tensor,
    sample_rate: int = 16000,
    title: str = "Audio Attention Pattern"
) -> None:
    """Plot audio waveform with attention weights."""
    # Create figure
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))
    
    # Plot waveform
    time = np.arange(len(waveform)) / sample_rate
    ax1.plot(time, waveform.numpy())
    ax1.set_title("Audio Waveform")
    ax1.set_xlabel("Time (s)")
    ax1.set_ylabel("Amplitude")
    
    # Plot attention weights
    sns.heatmap(
        attention_weights,
        ax=ax2,
        cmap='viridis',
        xticklabels=100,
        yticklabels=100
    )
    ax2.set_title("Attention Weights")
    ax2.set_xlabel("Time Step")
    ax2.set_ylabel("Time Step")
    
    plt.tight_layout()
    plt.show()

def plot_mel_attention(
    mel_spec: torch.Tensor,
    attention_weights: torch.Tensor,
    title: str = "Mel Spectrogram with Attention"
) -> None:
    """Plot mel spectrogram with attention weights."""
    # Create figure
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))
    
    # Plot mel spectrogram
    im1 = ax1.imshow(
        mel_spec.numpy(),
        aspect='auto',
        origin='lower',
        cmap='viridis'
    )
    ax1.set_title("Mel Spectrogram")
    ax1.set_xlabel("Time Step")
    ax1.set_ylabel("Mel Bin")
    plt.colorbar(im1, ax=ax1)
    
    # Plot attention weights
    im2 = ax2.imshow(
        attention_weights.numpy(),
        aspect='auto',
        origin='lower',
        cmap='viridis'
    )
    ax2.set_title("Attention Weights")
    ax2.set_xlabel("Time Step")
    ax2.set_ylabel("Time Step")
    plt.colorbar(im2, ax=ax2)
    
    plt.tight_layout()
    plt.show()

## Real-World Example: Speech Recognition

Let's analyze attention patterns in a pre-trained speech recognition model:

In [None]:
# Load pre-trained model and processor
model_name = 'facebook/wav2vec2-base-960h'
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name)

def analyze_speech_attention(audio_path: str) -> None:
    """Analyze attention patterns in speech recognition."""
    # Load audio
    waveform, sample_rate = torchaudio.load(audio_path)
    
    # Resample if necessary
    if sample_rate != 16000:
        resampler = torchaudio.transforms.Resample(sample_rate, 16000)
        waveform = resampler(waveform)
    
    # Process audio
    inputs = processor(
        waveform.squeeze().numpy(),
        sampling_rate=16000,
        return_tensors="pt"
    )
    
    # Get model output with attention
    with torch.no_grad():
        outputs = model(**inputs, output_attentions=True)
    
    # Get attention weights from the last layer
    attention_weights = outputs.attentions[-1][0, 0]  # First batch, first head
    
    # Plot attention patterns
    plot_audio_attention(
        waveform.squeeze(),
        attention_weights,
        title="Speech Recognition Attention Pattern"
    )
    
    # Decode transcription
    predicted_ids = torch.argmax(outputs.logits, dim=-1)
    transcription = processor.batch_decode(predicted_ids)
    print(f"Transcription: {transcription[0]}")

# Example usage (replace with your audio file)
# analyze_speech_attention('path_to_audio.wav')

## Audio Classification Example

Let's implement a simple audio classification task using attention:

In [None]:
def train_audio_classifier(
    model: AudioAttentionModel,
    train_loader: torch.utils.data.DataLoader,
    num_epochs: int,
    learning_rate: float = 0.001
) -> List[float]:
    """Train the audio classification model."""
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()
    
    losses = []
    
    for epoch in range(num_epochs):
        epoch_loss = 0
        
        for batch in train_loader:
            # Get batch data
            waveforms, labels = batch
            
            # Forward pass
            logits, _ = model(waveforms)
            
            # Compute loss
            loss = criterion(logits, labels)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
        
        # Record average loss
        avg_loss = epoch_loss / len(train_loader)
        losses.append(avg_loss)
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")
    
    return losses

## Conclusion

In this notebook, we've explored:

1. Audio feature extraction
2. Attention-based audio processing
3. Visualization of audio attention patterns
4. Real-world applications in speech recognition

Key takeaways:

- Attention is particularly useful for audio due to its sequential nature
- Different attention patterns can capture different aspects of audio
- Visualization helps understand how the model processes audio

In the next notebook, we'll explore attention in multimodal systems that combine audio with other modalities.