# Module 11.6: Multimodal Understanding

**Goal**: Work with vision-language models and cross-modal understanding

**Time**: 90 minutes

**Concepts Covered**:
- Vision-language model architecture
- CLIP encoder integration
- Cross-attention implementation
- Image-text training pipeline
- Edge deployment for VLM (MiniCPM-V, MobileVLM)

## Setup

In [None]:
!pip install torch transformers accelerate matplotlib seaborn numpy -q

In [None]:
# Vision-Language Model Architecture
import torch
import torch.nn as nn

class VisionLanguageModel(nn.Module):
    """Simple VLM architecture"""
    def __init__(self, vision_dim=768, text_dim=768, hidden_dim=1024):
        super().__init__()
        
        # Vision encoder (simulated - in practice use CLIP/ViT)
        self.vision_encoder = nn.Sequential(
            nn.Linear(224*224*3, vision_dim),  # Simplified
            nn.LayerNorm(vision_dim),
            nn.GELU(),
        )
        
        # Text encoder
        self.text_encoder = nn.Embedding(50000, text_dim)
        
        # Cross-attention
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=8,
            batch_first=True
        )
        
        # Fusion layer
        self.fusion = nn.Sequential(
            nn.Linear(vision_dim + text_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
        )
        
        # Language model head
        self.lm_head = nn.Linear(hidden_dim, 50000)
    
    def forward(self, image, text_ids):
        """Forward pass with image and text"""
        # Encode image
        image_features = self.vision_encoder(image)
        
        # Encode text
        text_features = self.text_encoder(text_ids)
        
        # Fuse features
        fused = self.fusion(torch.cat([image_features, text_features], dim=-1))
        
        # Cross-attention
        attn_out, _ = self.cross_attention(fused, fused, fused)
        
        # Language modeling
        logits = self.lm_head(attn_out)
        
        return logits

print("Vision-Language Model:")
print("- Vision encoder: processes images")
print("- Text encoder: processes text")
print("- Cross-attention: aligns modalities")
print("- LM head: generates text")

In [None]:
# CLIP-Style Contrastive Learning
def clip_loss(image_features, text_features, temperature=0.07):
    """CLIP contrastive loss"""
    # Normalize features
    image_features = image_features / image_features.norm(dim=-1, keepdim=True)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)
    
    # Compute similarity matrix
    logits = torch.matmul(image_features, text_features.t()) / temperature
    
    # Labels: diagonal (matched pairs)
    batch_size = image_features.shape[0]
    labels = torch.arange(batch_size, device=image_features.device)
    
    # Cross-entropy loss (symmetric)
    loss_img = nn.functional.cross_entropy(logits, labels)
    loss_txt = nn.functional.cross_entropy(logits.t(), labels)
    
    return (loss_img + loss_txt) / 2

print("CLIP Contrastive Learning:")
print("- Learn aligned image-text embeddings")
print("- Maximize similarity for matched pairs")
print("- Minimize similarity for unmatched pairs")

## Key Takeaways

✅ **Module Complete**

## Next Steps

Continue to the next module in the course.