In [None]:
"""
Vision-Language Model (VLM) Pretraining Script
===============================================
Modern implementation for pretraining vision-language models from scratch:
- Vision Transformer (ViT) for image encoding
- Cross-attention mechanism for vision-language fusion
- Flash Attention 2 for efficient attention
- Mixed precision training (bfloat16)
- Contrastive learning objectives (CLIP-style)
- Next-token prediction for language generation
- Q-Former architecture for efficient vision-language alignment

Dataset: Will use a popular VLM dataset from HuggingFace
"""

import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
from datasets import load_dataset
from PIL import Image
import torchvision.transforms as T
from transformers import PreTrainedTokenizerFast, AutoTokenizer
from transformers.optimization import get_cosine_schedule_with_warmup
from dataclasses import dataclass
from typing import Optional, Dict, Any, Tuple
import wandb
from tqdm.auto import tqdm
import logging
from io import BytesIO
import requests

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


In [None]:

@dataclass
class VLMConfig:
    """Configuration for the Vision-Language Model."""

    # Vision encoder (ViT)
    image_size: int = 224
    patch_size: int = 14
    num_channels: int = 3
    vision_hidden_size: int = 1024
    vision_num_layers: int = 24
    vision_num_heads: int = 16
    vision_intermediate_size: int = 4096

    # Q-Former (Query Transformer for vision-language alignment)
    num_query_tokens: int = 32
    qformer_hidden_size: int = 768
    qformer_num_layers: int = 12
    qformer_num_heads: int = 12

    # Language model
    vocab_size: int = 50304
    max_seq_length: int = 512
    hidden_size: int = 2048
    num_hidden_layers: int = 24
    num_attention_heads: int = 16
    num_key_value_heads: int = 8
    intermediate_size: int = 5632

    # General
    hidden_act: str = "silu"
    rms_norm_eps: float = 1e-6
    rope_theta: float = 10000.0
    attention_dropout: float = 0.0
    use_flash_attention: bool = True
    gradient_checkpointing: bool = True

    # Vision-Language fusion
    projection_dim: int = 512  # For contrastive learning


In [None]:

@dataclass
class VLMTrainingConfig:
    """Configuration for VLM training."""

    # Data
    dataset_name: str = "HuggingFaceM4/COCO"  # Multi-modal COCO dataset
    dataset_split: str = "train"
    image_size: int = 224
    max_seq_length: int = 512

    # Training
    batch_size: int = 8
    gradient_accumulation_steps: int = 16  # Effective batch size = 128
    num_epochs: int = 3
    max_steps: Optional[int] = None

    # Optimization
    learning_rate: float = 1e-4
    vision_lr: float = 5e-5  # Lower LR for vision encoder
    weight_decay: float = 0.05
    adam_beta1: float = 0.9
    adam_beta2: float = 0.999
    adam_epsilon: float = 1e-8
    max_grad_norm: float = 1.0
    warmup_steps: int = 1000

    # Loss weights
    contrastive_loss_weight: float = 1.0
    language_loss_weight: float = 1.0

    # Mixed precision
    use_fp16: bool = False
    use_bf16: bool = True

    # Checkpointing
    save_steps: int = 5000
    output_dir: str = "./checkpoints/vlm"

    # Logging
    logging_steps: int = 10
    use_wandb: bool = False
    wandb_project: str = "vlm-pretraining"

    # Tokenizer
    tokenizer_path: str = "./tokenizer"  # From LLM pretraining

    # Distributed
    local_rank: int = -1
    world_size: int = 1


In [None]:

class PatchEmbedding(nn.Module):
    """
    Convert images to patch embeddings.
    Splits image into non-overlapping patches and projects them.
    """

    def __init__(self, config: VLMConfig):
        """
        Initialize patch embedding layer.

        Args:
            config: VLM configuration
        """
        super().__init__()
        self.image_size = config.image_size
        self.patch_size = config.patch_size
        self.num_patches = (config.image_size // config.patch_size) ** 2

        # Convolutional projection
        self.projection = nn.Conv2d(
            config.num_channels,
            config.vision_hidden_size,
            kernel_size=config.patch_size,
            stride=config.patch_size,
        )

    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
        """
        Convert images to patch embeddings.

        Args:
            pixel_values: Image tensor of shape (batch, channels, height, width)

        Returns:
            Patch embeddings of shape (batch, num_patches, hidden_size)
        """
        # Project and flatten
        embeddings = self.projection(pixel_values)  # (batch, hidden_size, H, W)
        embeddings = embeddings.flatten(2).transpose(1, 2)  # (batch, num_patches, hidden_size)
        return embeddings


In [None]:

class VisionTransformerEncoder(nn.Module):
    """
    Vision Transformer (ViT) encoder for processing images.
    Based on "An Image is Worth 16x16 Words" paper.
    """

    def __init__(self, config: VLMConfig):
        """
        Initialize ViT encoder.

        Args:
            config: VLM configuration
        """
        super().__init__()
        self.config = config

        # Patch embedding
        self.patch_embedding = PatchEmbedding(config)

        # CLS token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, config.vision_hidden_size))

        # Position embeddings
        num_positions = self.patch_embedding.num_patches + 1  # +1 for CLS token
        self.position_embeddings = nn.Parameter(torch.zeros(1, num_positions, config.vision_hidden_size))

        # Transformer layers
        self.layers = nn.ModuleList([
            VisionTransformerLayer(config) for _ in range(config.vision_num_layers)
        ])

        self.layernorm = nn.LayerNorm(config.vision_hidden_size, eps=1e-6)

    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
        """
        Encode images.

        Args:
            pixel_values: Image tensor of shape (batch, channels, height, width)

        Returns:
            Encoded features of shape (batch, num_patches+1, hidden_size)
        """
        batch_size = pixel_values.shape[0]

        # Get patch embeddings
        embeddings = self.patch_embedding(pixel_values)

        # Add CLS token
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        embeddings = torch.cat([cls_tokens, embeddings], dim=1)

        # Add position embeddings
        embeddings = embeddings + self.position_embeddings

        # Pass through transformer layers
        for layer in self.layers:
            embeddings = layer(embeddings)

        embeddings = self.layernorm(embeddings)
        return embeddings


In [None]:

class VisionTransformerLayer(nn.Module):
    """Single ViT transformer layer."""

    def __init__(self, config: VLMConfig):
        """
        Initialize ViT layer.

        Args:
            config: VLM configuration
        """
        super().__init__()
        self.attention = nn.MultiheadAttention(
            config.vision_hidden_size,
            config.vision_num_heads,
            dropout=config.attention_dropout,
            batch_first=True,
        )
        self.mlp = nn.Sequential(
            nn.Linear(config.vision_hidden_size, config.vision_intermediate_size),
            nn.GELU(),
            nn.Linear(config.vision_intermediate_size, config.vision_hidden_size),
        )
        self.layernorm1 = nn.LayerNorm(config.vision_hidden_size, eps=1e-6)
        self.layernorm2 = nn.LayerNorm(config.vision_hidden_size, eps=1e-6)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """
        Forward pass with residual connections.

        Args:
            hidden_states: Input tensor

        Returns:
            Output tensor
        """
        # Self-attention
        residual = hidden_states
        hidden_states = self.layernorm1(hidden_states)
        hidden_states, _ = self.attention(hidden_states, hidden_states, hidden_states)
        hidden_states = residual + hidden_states

        # MLP
        residual = hidden_states
        hidden_states = self.layernorm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


In [None]:

class QFormerLayer(nn.Module):
    """
    Q-Former layer for vision-language alignment.
    Uses learnable query tokens to extract relevant visual features.
    """

    def __init__(self, config: VLMConfig):
        """
        Initialize Q-Former layer.

        Args:
            config: VLM configuration
        """
        super().__init__()
        self.hidden_size = config.qformer_hidden_size

        # Self-attention on queries
        self.self_attention = nn.MultiheadAttention(
            config.qformer_hidden_size,
            config.qformer_num_heads,
            dropout=config.attention_dropout,
            batch_first=True,
        )

        # Cross-attention with vision features
        self.cross_attention = nn.MultiheadAttention(
            config.qformer_hidden_size,
            config.qformer_num_heads,
            dropout=config.attention_dropout,
            batch_first=True,
        )

        # MLP
        self.mlp = nn.Sequential(
            nn.Linear(config.qformer_hidden_size, config.qformer_hidden_size * 4),
            nn.GELU(),
            nn.Linear(config.qformer_hidden_size * 4, config.qformer_hidden_size),
        )

        self.layernorm1 = nn.LayerNorm(config.qformer_hidden_size, eps=1e-6)
        self.layernorm2 = nn.LayerNorm(config.qformer_hidden_size, eps=1e-6)
        self.layernorm3 = nn.LayerNorm(config.qformer_hidden_size, eps=1e-6)

    def forward(
        self,
        query_embeds: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        """
        Forward pass of Q-Former layer.

        Args:
            query_embeds: Query embeddings (learnable)
            encoder_hidden_states: Vision encoder outputs

        Returns:
            Updated query embeddings
        """
        # Self-attention
        residual = query_embeds
        query_embeds = self.layernorm1(query_embeds)
        query_embeds, _ = self.self_attention(query_embeds, query_embeds, query_embeds)
        query_embeds = residual + query_embeds

        # Cross-attention with vision features
        residual = query_embeds
        query_embeds = self.layernorm2(query_embeds)
        query_embeds, _ = self.cross_attention(
            query_embeds,
            encoder_hidden_states,
            encoder_hidden_states,
        )
        query_embeds = residual + query_embeds

        # MLP
        residual = query_embeds
        query_embeds = self.layernorm3(query_embeds)
        query_embeds = self.mlp(query_embeds)
        query_embeds = residual + query_embeds

        return query_embeds


In [None]:

class QFormer(nn.Module):
    """
    Q-Former module for bridging vision and language.
    Introduced in BLIP-2 paper: https://arxiv.org/abs/2301.12597
    """

    def __init__(self, config: VLMConfig):
        """
        Initialize Q-Former.

        Args:
            config: VLM configuration
        """
        super().__init__()
        self.config = config

        # Learnable query tokens
        self.query_tokens = nn.Parameter(
            torch.zeros(1, config.num_query_tokens, config.qformer_hidden_size)
        )
        nn.init.normal_(self.query_tokens, std=0.02)

        # Vision to Q-Former projection
        self.vision_proj = nn.Linear(config.vision_hidden_size, config.qformer_hidden_size)

        # Q-Former layers
        self.layers = nn.ModuleList([
            QFormerLayer(config) for _ in range(config.qformer_num_layers)
        ])

        # Output projection to language model dimension
        self.proj_to_language = nn.Linear(config.qformer_hidden_size, config.hidden_size)

    def forward(self, vision_outputs: torch.Tensor) -> torch.Tensor:
        """
        Extract visual features using learnable queries.

        Args:
            vision_outputs: Vision encoder outputs of shape (batch, seq_len, vision_hidden_size)

        Returns:
            Query outputs of shape (batch, num_queries, hidden_size)
        """
        batch_size = vision_outputs.shape[0]

        # Project vision features
        vision_outputs = self.vision_proj(vision_outputs)

        # Expand query tokens
        query_embeds = self.query_tokens.expand(batch_size, -1, -1)

        # Pass through Q-Former layers
        for layer in self.layers:
            query_embeds = layer(query_embeds, vision_outputs)

        # Project to language model dimension
        query_outputs = self.proj_to_language(query_embeds)

        return query_outputs


In [None]:

class VisionLanguageModel(nn.Module):
    """
    Complete Vision-Language Model combining:
    - Vision Transformer for image encoding
    - Q-Former for vision-language alignment
    - Language model for text generation
    """

    def __init__(self, config: VLMConfig):
        """
        Initialize VLM.

        Args:
            config: VLM configuration
        """
        super().__init__()
        self.config = config

        # Vision encoder
        self.vision_encoder = VisionTransformerEncoder(config)

        # Q-Former for vision-language alignment
        self.qformer = QFormer(config)

        # Language model components (simplified decoder)
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
        self.layers = nn.ModuleList([
            TransformerDecoderLayer(config) for _ in range(config.num_hidden_layers)
        ])
        self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Contrastive learning heads
        self.vision_projection = nn.Linear(config.vision_hidden_size, config.projection_dim)
        self.text_projection = nn.Linear(config.hidden_size, config.projection_dim)

        # Temperature parameter for contrastive loss
        self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1 / 0.07))

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        """Initialize weights."""
        std = 0.02
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)

    def encode_image(self, pixel_values: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Encode images through vision encoder and Q-Former.

        Args:
            pixel_values: Image tensor

        Returns:
            Tuple of (vision_embeds for contrastive, query_outputs for generation)
        """
        # Encode image
        vision_outputs = self.vision_encoder(pixel_values)

        # Get CLS token for contrastive learning
        vision_embeds = vision_outputs[:, 0, :]  # CLS token

        # Get query outputs for language generation
        query_outputs = self.qformer(vision_outputs)

        return vision_embeds, query_outputs

    def encode_text(self, input_ids: torch.Tensor) -> torch.Tensor:
        """
        Encode text for contrastive learning.

        Args:
            input_ids: Token IDs

        Returns:
            Text embeddings
        """
        # Embed tokens
        hidden_states = self.embed_tokens(input_ids)

        # Pass through transformer (simplified - just use first layer)
        for layer in self.layers[:4]:  # Use first few layers for efficiency
            hidden_states = layer(hidden_states, encoder_hidden_states=None)

        # Pool (mean pooling)
        text_embeds = hidden_states.mean(dim=1)

        return text_embeds

    def forward(
        self,
        pixel_values: torch.Tensor,
        input_ids: torch.Tensor,
        labels: Optional[torch.Tensor] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        Forward pass computing both contrastive and generation losses.

        Args:
            pixel_values: Image tensor
            input_ids: Token IDs
            labels: Optional labels for language modeling

        Returns:
            Dictionary containing losses and logits
        """
        batch_size = pixel_values.shape[0]

        # Encode images
        vision_embeds, query_outputs = self.encode_image(pixel_values)

        # Encode text
        text_embeds = self.encode_text(input_ids)

        # Project for contrastive learning
        vision_features = F.normalize(self.vision_projection(vision_embeds), dim=-1)
        text_features = F.normalize(self.text_projection(text_embeds), dim=-1)

        # Compute contrastive loss (CLIP-style)
        logit_scale = self.logit_scale.exp()
        logits_per_image = logit_scale * vision_features @ text_features.t()
        logits_per_text = logits_per_image.t()

        # Contrastive labels (diagonal)
        contrastive_labels = torch.arange(batch_size, device=pixel_values.device)

        contrastive_loss = (
            F.cross_entropy(logits_per_image, contrastive_labels) +
            F.cross_entropy(logits_per_text, contrastive_labels)
        ) / 2

        # Language modeling
        # Combine visual queries with text embeddings
        text_embeddings = self.embed_tokens(input_ids)
        combined_embeds = torch.cat([query_outputs, text_embeddings], dim=1)

        # Pass through language model
        hidden_states = combined_embeds
        for layer in self.layers:
            hidden_states = layer(hidden_states, encoder_hidden_states=None)

        hidden_states = self.norm(hidden_states)
        logits = self.lm_head(hidden_states)

        # Compute language modeling loss
        language_loss = None
        if labels is not None:
            # Shift for next-token prediction (skip visual tokens)
            num_visual_tokens = query_outputs.shape[1]
            shift_logits = logits[:, num_visual_tokens:-1, :].contiguous()
            shift_labels = labels[:, 1:].contiguous()

            language_loss = F.cross_entropy(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1),
                ignore_index=-100,
            )

        return {
            "contrastive_loss": contrastive_loss,
            "language_loss": language_loss,
            "logits": logits,
        }


In [None]:

class TransformerDecoderLayer(nn.Module):
    """Simplified transformer decoder layer for language modeling."""

    def __init__(self, config: VLMConfig):
        """Initialize decoder layer."""
        super().__init__()
        self.self_attention = nn.MultiheadAttention(
            config.hidden_size,
            config.num_attention_heads,
            dropout=config.attention_dropout,
            batch_first=True,
        )
        self.mlp = nn.Sequential(
            nn.Linear(config.hidden_size, config.intermediate_size),
            nn.SiLU(),
            nn.Linear(config.intermediate_size, config.hidden_size),
        )
        self.layernorm1 = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.layernorm2 = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Forward pass."""
        # Self-attention
        residual = hidden_states
        hidden_states = self.layernorm1(hidden_states)
        hidden_states, _ = self.self_attention(hidden_states, hidden_states, hidden_states)
        hidden_states = residual + hidden_states

        # MLP
        residual = hidden_states
        hidden_states = self.layernorm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states

In [None]:


class VisionLanguageDataset(Dataset):
    """Dataset for vision-language pretraining."""

    def __init__(
        self,
        dataset_name: str,
        tokenizer: PreTrainedTokenizerFast,
        image_size: int = 224,
        max_length: int = 512,
        split: str = "train",
    ):
        """
        Initialize VL dataset.

        Args:
            dataset_name: HuggingFace dataset name
            tokenizer: Tokenizer instance
            image_size: Size to resize images
            max_length: Maximum text sequence length
            split: Dataset split
        """
        self.tokenizer = tokenizer
        self.max_length = max_length

        # Image transforms
        self.transform = T.Compose([
            T.Resize((image_size, image_size)),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

        logger.info(f"Loading dataset {dataset_name}...")
        self.dataset = load_dataset(dataset_name, split=split, streaming=False)
        logger.info(f"Dataset loaded with {len(self.dataset)} samples")

    def __len__(self) -> int:
        return len(self.dataset)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        """
        Get a single image-text pair.

        Args:
            idx: Sample index

        Returns:
            Dictionary with pixel_values, input_ids, and labels
        """
        sample = self.dataset[idx]

        # Process image
        image = sample["image"]
        if isinstance(image, str):
            # Load image from URL or path
            try:
                if image.startswith("http"):
                    image = Image.open(BytesIO(requests.get(image).content)).convert("RGB")
                else:
                    image = Image.open(image).convert("RGB")
            except:
                # Fallback to blank image
                image = Image.new("RGB", (224, 224))

        pixel_values = self.transform(image)

        # Process text (use sentences/captions field based on dataset structure)
        text = sample.get("sentences", sample.get("caption", ""))
        if isinstance(text, list):
            text = " ".join([s["raw"] if isinstance(s, dict) else s for s in text])

        # Tokenize
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )

        input_ids = encoding["input_ids"].squeeze(0)

        return {
            "pixel_values": pixel_values,
            "input_ids": input_ids,
            "labels": input_ids.clone(),
        }


In [None]:

def train_vlm(config: VLMTrainingConfig):
    """
    Main VLM training function.

    Args:
        config: Training configuration
    """
    # Setup distributed
    rank, world_size, local_rank = 0, 1, -1
    if "RANK" in os.environ:
        rank = int(os.environ["RANK"])
        world_size = int(os.environ["WORLD_SIZE"])
        local_rank = int(os.environ["LOCAL_RANK"])
        dist.init_process_group("nccl")
        torch.cuda.set_device(local_rank)

    is_main_process = rank == 0

    if is_main_process:
        logger.info("Starting VLM pretraining...")
        if config.use_wandb:
            wandb.init(project=config.wandb_project, config=config.__dict__)

    # Device
    device = torch.device(f"cuda:{local_rank}" if local_rank >= 0 else "cuda" if torch.cuda.is_available() else "cpu")

    # Load tokenizer
    tokenizer = PreTrainedTokenizerFast.from_pretrained(config.tokenizer_path)

    # Create model
    model_config = VLMConfig(vocab_size=len(tokenizer))
    model = VisionLanguageModel(model_config).to(device)

    if is_main_process:
        total_params = sum(p.numel() for p in model.parameters())
        logger.info(f"Model parameters: {total_params:,} ({total_params / 1e9:.2f}B)")

    # DDP
    if world_size > 1:
        model = DDP(model, device_ids=[local_rank])

    # Dataset
    dataset = VisionLanguageDataset(
        config.dataset_name,
        tokenizer,
        config.image_size,
        config.max_seq_length,
        config.dataset_split,
    )
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank) if world_size > 1 else None
    dataloader = DataLoader(
        dataset,
        batch_size=config.batch_size,
        sampler=sampler,
        shuffle=(sampler is None),
        num_workers=4,
        pin_memory=True,
    )

    # Optimizer with different learning rates for vision and language
    vision_params = list(model.vision_encoder.parameters()) if not hasattr(model, "module") else list(model.module.vision_encoder.parameters())
    other_params = [p for n, p in model.named_parameters() if "vision_encoder" not in n]

    optimizer = torch.optim.AdamW([
        {"params": vision_params, "lr": config.vision_lr},
        {"params": other_params, "lr": config.learning_rate},
    ], weight_decay=config.weight_decay, betas=(config.adam_beta1, config.adam_beta2))

    # Scheduler
    total_steps = config.max_steps if config.max_steps else len(dataloader) * config.num_epochs // config.gradient_accumulation_steps
    scheduler = get_cosine_schedule_with_warmup(optimizer, config.warmup_steps, total_steps)

    # Mixed precision
    scaler = torch.cuda.amp.GradScaler() if config.use_fp16 else None
    dtype = torch.bfloat16 if config.use_bf16 else torch.float16 if config.use_fp16 else torch.float32

    # Training loop
    global_step = 0
    model.zero_grad()

    for epoch in range(config.num_epochs):
        if sampler:
            sampler.set_epoch(epoch)

        progress_bar = tqdm(dataloader, disable=not is_main_process)

        for step, batch in enumerate(progress_bar):
            pixel_values = batch["pixel_values"].to(device)
            input_ids = batch["input_ids"].to(device)
            labels = batch["labels"].to(device)

            # Forward pass
            with torch.cuda.amp.autocast(dtype=dtype):
                outputs = model(pixel_values, input_ids, labels)

                # Combined loss
                loss = (
                    config.contrastive_loss_weight * outputs["contrastive_loss"] +
                    config.language_loss_weight * outputs["language_loss"]
                ) / config.gradient_accumulation_steps

            # Backward
            if scaler:
                scaler.scale(loss).backward()
            else:
                loss.backward()

            # Update
            if (step + 1) % config.gradient_accumulation_steps == 0:
                if scaler:
                    scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)

                if scaler:
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    optimizer.step()

                scheduler.step()
                optimizer.zero_grad()
                global_step += 1

                # Logging
                if global_step % config.logging_steps == 0 and is_main_process:
                    loss_val = loss.item() * config.gradient_accumulation_steps
                    progress_bar.set_postfix({"loss": f"{loss_val:.4f}"})

                    if config.use_wandb:
                        wandb.log({
                            "train/total_loss": loss_val,
                            "train/contrastive_loss": outputs["contrastive_loss"].item(),
                            "train/language_loss": outputs["language_loss"].item(),
                            "train/learning_rate": scheduler.get_last_lr()[0],
                            "train/step": global_step,
                        })

                # Save checkpoint
                if global_step % config.save_steps == 0 and is_main_process:
                    checkpoint_path = os.path.join(config.output_dir, f"checkpoint-{global_step}")
                    os.makedirs(checkpoint_path, exist_ok=True)

                    model_to_save = model.module if hasattr(model, "module") else model
                    torch.save({
                        "model": model_to_save.state_dict(),
                        "optimizer": optimizer.state_dict(),
                        "scheduler": scheduler.state_dict(),
                        "step": global_step,
                        "config": model_config,
                    }, os.path.join(checkpoint_path, "pytorch_model.bin"))

                    logger.info(f"Checkpoint saved at step {global_step}")

                if config.max_steps and global_step >= config.max_steps:
                    break

        if config.max_steps and global_step >= config.max_steps:
            break

    # Save final model
    if is_main_process:
        final_path = os.path.join(config.output_dir, "final")
        os.makedirs(final_path, exist_ok=True)

        model_to_save = model.module if hasattr(model, "module") else model
        torch.save({
            "model": model_to_save.state_dict(),
            "config": model_config,
        }, os.path.join(final_path, "pytorch_model.bin"))

        logger.info(f"Final model saved to {final_path}")

        if config.use_wandb:
            wandb.finish()

    if world_size > 1:
        dist.destroy_process_group()


In [None]:

config = VLMTrainingConfig()
train_vlm(config)

