# Attention in Multimodal Systems

This notebook explores how attention mechanisms are applied in multimodal systems that combine different types of data. We'll cover:

1. Cross-modal attention
2. Multimodal fusion
3. Visual question answering
4. Audio-visual learning

## Why Attention in Multimodal Learning?

Multimodal learning presents unique challenges:

1. **Different Modalities**: Each modality has its own characteristics
2. **Alignment**: Need to align information across modalities
3. **Fusion**: Need to combine information effectively

Attention helps by:

- Learning cross-modal relationships
- Focusing on relevant parts of each modality
- Enabling flexible fusion strategies

In [None]:
import torch
import torch.nn as nn
import torchvision
import torchaudio
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Tuple, Optional, List
from transformers import (
    ViTFeatureExtractor,
    ViTModel,
    Wav2Vec2Processor,
    Wav2Vec2Model
)

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## Cross-Modal Attention

Let's implement a cross-modal attention mechanism:

In [None]:
class CrossModalAttention(nn.Module):
    def __init__(
        self,
        query_dim: int,
        key_dim: int,
        value_dim: int,
        num_heads: int,
        dropout: float = 0.1
    ):
        super().__init__()
        
        self.num_heads = num_heads
        self.head_dim = query_dim // num_heads
        
        # Linear projections
        self.q_proj = nn.Linear(query_dim, query_dim)
        self.k_proj = nn.Linear(key_dim, query_dim)
        self.v_proj = nn.Linear(value_dim, query_dim)
        self.out_proj = nn.Linear(query_dim, query_dim)
        
        self.dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(self.head_dim)
        
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size = query.size(0)
        
        # Project Q, K, V
        q = self.q_proj(query)
        k = self.k_proj(key)
        v = self.v_proj(value)
        
        # Reshape for multi-head attention
        q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Compute attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale
        
        # Apply mask if provided
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        # Compute attention weights
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        # Apply attention weights to values
        context = torch.matmul(attention_weights, v)
        
        # Reshape back
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)
        
        # Project output
        output = self.out_proj(context)
        
        return output, attention_weights

## Multimodal Fusion Model

Now, let's implement a multimodal fusion model using attention:

In [None]:
class MultimodalFusionModel(nn.Module):
    def __init__(
        self,
        visual_dim: int,
        audio_dim: int,
        text_dim: int,
        hidden_dim: int,
        num_heads: int,
        num_classes: int,
        dropout: float = 0.1
    ):
        super().__init__()
        
        # Modality-specific encoders
        self.visual_encoder = nn.Linear(visual_dim, hidden_dim)
        self.audio_encoder = nn.Linear(audio_dim, hidden_dim)
        self.text_encoder = nn.Linear(text_dim, hidden_dim)
        
        # Cross-modal attention
        self.visual_audio_attention = CrossModalAttention(hidden_dim, hidden_dim, hidden_dim, num_heads)
        self.audio_visual_attention = CrossModalAttention(hidden_dim, hidden_dim, hidden_dim, num_heads)
        self.text_visual_attention = CrossModalAttention(hidden_dim, hidden_dim, hidden_dim, num_heads)
        
        # Layer normalization
        self.norm = nn.LayerNorm(hidden_dim)
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 3, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_classes)
        )
        
    def forward(
        self,
        visual: torch.Tensor,
        audio: torch.Tensor,
        text: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        # Encode modalities
        visual_features = self.visual_encoder(visual)
        audio_features = self.audio_encoder(audio)
        text_features = self.text_encoder(text)
        
        # Cross-modal attention
        visual_audio, va_weights = self.visual_audio_attention(visual_features, audio_features, audio_features, mask)
        audio_visual, av_weights = self.audio_visual_attention(audio_features, visual_features, visual_features, mask)
        text_visual, tv_weights = self.text_visual_attention(text_features, visual_features, visual_features, mask)
        
        # Layer normalization
        visual_audio = self.norm(visual_audio)
        audio_visual = self.norm(audio_visual)
        text_visual = self.norm(text_visual)
        
        # Global average pooling
        visual_audio = visual_audio.mean(dim=1)
        audio_visual = audio_visual.mean(dim=1)
        text_visual = text_visual.mean(dim=1)
        
        # Concatenate features
        combined = torch.cat([visual_audio, audio_visual, text_visual], dim=1)
        
        # Classification
        logits = self.classifier(combined)
        
        return logits, [va_weights, av_weights, tv_weights]

## Visual Question Answering Example

Let's implement a simple visual question answering system:

In [None]:
class VisualQuestionAnswering(nn.Module):
    def __init__(
        self,
        visual_dim: int,
        text_dim: int,
        hidden_dim: int,
        num_heads: int,
        vocab_size: int,
        dropout: float = 0.1
    ):
        super().__init__()
        
        # Visual encoder (ViT)
        self.visual_encoder = ViTModel.from_pretrained('google/vit-base-patch16-224')
        
        # Text encoder (BERT)
        self.text_encoder = nn.Linear(text_dim, hidden_dim)
        
        # Cross-modal attention
        self.visual_question_attention = CrossModalAttention(hidden_dim, hidden_dim, hidden_dim, num_heads)
        self.question_visual_attention = CrossModalAttention(hidden_dim, hidden_dim, hidden_dim, num_heads)
        
        # Answer prediction
        self.answer_predictor = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, vocab_size)
        )
        
    def forward(
        self,
        image: torch.Tensor,
        question: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        # Encode image
        visual_features = self.visual_encoder(image).last_hidden_state
        
        # Encode question
        question_features = self.text_encoder(question)
        
        # Cross-modal attention
        visual_question, vq_weights = self.visual_question_attention(
            visual_features, question_features, question_features, mask
        )
        question_visual, qv_weights = self.question_visual_attention(
            question_features, visual_features, visual_features, mask
        )
        
        # Combine features
        combined = torch.cat([visual_question, question_visual], dim=-1)
        
        # Predict answer
        logits = self.answer_predictor(combined)
        
        return logits, [vq_weights, qv_weights]

## Audio-Visual Learning Example

Let's implement an audio-visual learning model:

In [None]:
class AudioVisualLearning(nn.Module):
    def __init__(
        self,
        visual_dim: int,
        audio_dim: int,
        hidden_dim: int,
        num_heads: int,
        num_classes: int,
        dropout: float = 0.1
    ):
        super().__init__()
        
        # Visual encoder (ViT)
        self.visual_encoder = ViTModel.from_pretrained('google/vit-base-patch16-224')
        
        # Audio encoder (Wav2Vec2)
        self.audio_encoder = Wav2Vec2Model.from_pretrained('facebook/wav2vec2-base-960h')
        
        # Cross-modal attention
        self.visual_audio_attention = CrossModalAttention(hidden_dim, hidden_dim, hidden_dim, num_heads)
        self.audio_visual_attention = CrossModalAttention(hidden_dim, hidden_dim, hidden_dim, num_heads)
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_classes)
        )
        
    def forward(
        self,
        image: torch.Tensor,
        audio: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        # Encode image
        visual_features = self.visual_encoder(image).last_hidden_state
        
        # Encode audio
        audio_features = self.audio_encoder(audio).last_hidden_state
        
        # Cross-modal attention
        visual_audio, va_weights = self.visual_audio_attention(
            visual_features, audio_features, audio_features, mask
        )
        audio_visual, av_weights = self.audio_visual_attention(
            audio_features, visual_features, visual_features, mask
        )
        
        # Combine features
        combined = torch.cat([visual_audio, audio_visual], dim=-1)
        
        # Classification
        logits = self.classifier(combined)
        
        return logits, [va_weights, av_weights]

## Visualizing Cross-Modal Attention

Let's create functions to visualize cross-modal attention patterns:

In [None]:
def plot_cross_modal_attention(
    attention_weights: torch.Tensor,
    source_labels: List[str],
    target_labels: List[str],
    title: str = "Cross-Modal Attention"
) -> None:
    """Plot cross-modal attention weights."""
    plt.figure(figsize=(10, 8))
    sns.heatmap(
        attention_weights,
        xticklabels=target_labels,
        yticklabels=source_labels,
        cmap='viridis'
    )
    plt.title(title)
    plt.xlabel('Target Modality')
    plt.ylabel('Source Modality')
    plt.show()

def visualize_multimodal_attention(
    model: nn.Module,
    image: torch.Tensor,
    audio: torch.Tensor,
    text: Optional[torch.Tensor] = None
) -> None:
    """Visualize attention patterns in a multimodal model."""
    # Get model output
    with torch.no_grad():
        if text is not None:
            logits, attention_weights = model(image, audio, text)
        else:
            logits, attention_weights = model(image, audio)
    
    # Plot attention patterns
    for i, weights in enumerate(attention_weights):
        plot_cross_modal_attention(
            weights[0, 0],  # First batch, first head
            source_labels=[f"Source {j}" for j in range(weights.size(2))],
            target_labels=[f"Target {j}" for j in range(weights.size(3))],
            title=f"Attention Pattern {i+1}"
        )

## Conclusion

In this notebook, we've explored:

1. Cross-modal attention mechanisms
2. Multimodal fusion using attention
3. Visual question answering
4. Audio-visual learning

Key takeaways:

- Attention provides a powerful way to model relationships between modalities
- Cross-modal attention helps align information across different types of data
- Visualization helps understand how the model processes multimodal information

This concludes our exploration of attention mechanisms in various domains. The next notebook will focus on advanced topics and future directions.