In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig, PreTrainedModel

In [21]:
class Siglip2VisionConfig(PretrainedConfig):
    model_type = "siglip2_vision"
    def __init__(self, hidden_size=768, intermediate_size=3072, num_hidden_layers=12, 
                 num_attention_heads=12, image_size=224, patch_size=16, num_channels=3, **kwargs):
        super().__init__(**kwargs)
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_channels = num_channels

In [22]:
class Siglip2TextConfig(PretrainedConfig):
    model_type = "siglip2_text"
    def __init__(self, vocab_size=32000, hidden_size=768, intermediate_size=3072, 
                 num_hidden_layers=12, num_attention_heads=12, max_length=64, **kwargs):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.max_length = max_length

In [23]:
class Siglip2Config(PretrainedConfig):
    model_type = "siglip2"
    def __init__(self, vision_config=None, text_config=None, projection_dim=768, **kwargs):
        super().__init__(**kwargs)
        if vision_config is None:
            vision_config = Siglip2VisionConfig()
        if text_config is None:
            text_config = Siglip2TextConfig()
        self.vision_config = vision_config
        self.text_config = text_config
        self.projection_dim = projection_dim

In [24]:
class AttentionPooling(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        # Define query with shape [1, 1, hidden_size] for proper broadcasting
        self.query = nn.Parameter(torch.zeros(1, 1, hidden_size))
        # MultiheadAttention with batch_first=True for [batch_size, seq_len, hidden_size]
        self.attn = nn.MultiheadAttention(hidden_size, num_heads=1, batch_first=True)

    def forward(self, x, key_padding_mask=None):
        # x: [batch_size, seq_len, hidden_size]
        # Expand query to [batch_size, 1, hidden_size]
        query = self.query.expand(x.size(0), 1, -1)
        # Apply attention, key_padding_mask is None for vision, used for text
        attn_output, _ = self.attn(query, x, x, key_padding_mask=key_padding_mask)
        # Output: [batch_size, 1, hidden_size] -> [batch_size, hidden_size]
        return attn_output.squeeze(1)

In [25]:
class Siglip2VisionModel(PreTrainedModel):
    config_class = Siglip2VisionConfig
    def __init__(self, config):
        super().__init__(config)
        # Patch embedding to convert image to patches
        self.patch_embedding = nn.Conv2d(
            in_channels=config.num_channels,
            out_channels=config.hidden_size,
            kernel_size=config.patch_size,
            stride=config.patch_size
        )
        # Number of patches: (image_size / patch_size)²
        num_patches = (config.image_size // config.patch_size) ** 2
        self.positional_embedding = nn.Parameter(torch.zeros(1, num_patches, config.hidden_size))
        # Transformer layers
        self.transformer = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=config.hidden_size,
                nhead=config.num_attention_heads,
                dim_feedforward=config.intermediate_size,
                batch_first=True
            ) for _ in range(config.num_hidden_layers)
        ])
        self.map_head = AttentionPooling(config.hidden_size)

    def forward(self, pixel_values):
        # pixel_values: [batch_size, num_channels, image_size, image_size]
        x = self.patch_embedding(pixel_values)  # [batch_size, hidden_size, H/p, W/p]
        x = x.flatten(2).transpose(1, 2)  # [batch_size, num_patches, hidden_size]
        x = x + self.positional_embedding
        # Apply transformer layers
        for layer in self.transformer:
            x = layer(x)
        # Pool using attention, no padding mask needed for vision
        pooled_output = self.map_head(x)
        return pooled_output  # [batch_size, hidden_size]

In [26]:
class Siglip2TextModel(PreTrainedModel):
    config_class = Siglip2TextConfig
    def __init__(self, config):
        super().__init__(config)
        # Token embeddings
        self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
        self.positional_embeddings = nn.Parameter(
            torch.zeros(1, config.max_length, config.hidden_size)
        )
        # Transformer layers
        self.transformer = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=config.hidden_size,
                nhead=config.num_attention_heads,
                dim_feedforward=config.intermediate_size,
                batch_first=True
            ) for _ in range(config.num_hidden_layers)
        ])
        self.map_head = AttentionPooling(config.hidden_size)

    def forward(self, input_ids, attention_mask=None):
        # input_ids: [batch_size, seq_len]
        x = self.embeddings(input_ids)  # [batch_size, seq_len, hidden_size]
        x = x + self.positional_embeddings[:, :x.size(1), :]
        # Handle padding mask
        if attention_mask is not None:
            src_key_padding_mask = (attention_mask == 0)  # True for padding tokens
        else:
            src_key_padding_mask = None
        # Apply transformer layers
        for layer in self.transformer:
            x = layer(x, src_key_padding_mask=src_key_padding_mask)
        # Pool using attention, pass padding mask
        pooled_output = self.map_head(x, key_padding_mask=src_key_padding_mask)
        return pooled_output  # [batch_size, hidden_size]

In [27]:
class Siglip2Model(PreTrainedModel):
    config_class = Siglip2Config
    def __init__(self, config):
        super().__init__(config)
        self.vision_model = Siglip2VisionModel(config.vision_config)
        self.text_model = Siglip2TextModel(config.text_config)
        self.vision_projection = nn.Linear(
            config.vision_config.hidden_size, config.projection_dim
        )
        self.text_projection = nn.Linear(
            config.text_config.hidden_size, config.projection_dim
        )

    def forward(self, pixel_values, input_ids, attention_mask=None):
        # Get embeddings from vision and text models
        vision_embeddings = self.vision_model(pixel_values)  # [batch_size, hidden_size]
        text_embeddings = self.text_model(input_ids, attention_mask)  # [batch_size, hidden_size]
        # Project to shared space
        vision_proj = self.vision_projection(vision_embeddings)  # [batch_size, projection_dim]
        text_proj = self.text_projection(text_embeddings)  # [batch_size, projection_dim]
        return vision_proj, text_proj

In [28]:
def siglip_loss(vision_proj, text_proj, temperature=1.0):
    # Compute similarity logits
    logits = torch.matmul(vision_proj, text_proj.t()) / temperature  # [batch_size, batch_size]
    # Labels: diagonal = matching pairs
    labels = torch.eye(logits.size(0), device=logits.device)
    # Symmetric sigmoid loss
    loss_i2t = F.binary_cross_entropy_with_logits(logits, labels)
    loss_t2i = F.binary_cross_entropy_with_logits(logits.t(), labels)
    return (loss_i2t + loss_t2i) / 2

In [29]:
# Initialize configurations
vision_config = Siglip2VisionConfig()
text_config = Siglip2TextConfig()
config = Siglip2Config(vision_config=vision_config, text_config=text_config)

# Initialize model
model = Siglip2Model(config)

In [30]:
# Example inputs (batch_size=8)
pixel_values = torch.rand(8, 3, 224, 224)  # [batch_size, channels, height, width]
input_ids = torch.randint(0, 32000, (8, 64))  # [batch_size, seq_len]
attention_mask = torch.ones(8, 64)  # [batch_size, seq_len]

In [None]:
# Forward pass
vision_proj, text_proj = model(pixel_values, input_ids, attention_mask)

# Compute loss
loss = siglip_loss(vision_proj, text_proj)

print(f"Loss: {loss.item()}")

In [None]:
import torch
print(torch.cuda.is_available())
if torch.cuda.is_available():
    print(f"CUDA Device: {torch.cuda.get_device_name(0)}")
else:
    print("CUDA not detected")

In [None]:
import torch
print(torch.version.cuda)

In [None]:
!nvcc --version

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig, PreTrainedModel, BertTokenizer
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from datasets import load_dataset
from PIL import Image
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import os
import random
from tqdm import tqdm

# CUDA kontrolü
print("Checking CUDA availability...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"CUDA Device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"PyTorch Version: {torch.__version__}")
else:
    print("CUDA not available, falling back to CPU")
    exit()

# Dataset yükleme
print("Loading Flickr30k dataset...")
dataset = load_dataset("nlphuji/flickr30k")
print("Dataset loaded successfully!")
full_dataset = dataset["test"]
print(f"Full dataset size: {len(full_dataset)} examples")

# Subset oluşturma
print("Creating training and validation subsets...")
random.seed(42)
all_indices = list(range(len(full_dataset)))
random.shuffle(all_indices)
train_indices = all_indices[:100]
val_indices = all_indices[100:150]
train_subset = full_dataset.select(train_indices)
val_subset = full_dataset.select(val_indices)
print(f"Training subset size: {len(train_subset)}")
print(f"Validation subset size: {len(val_subset)}")

# Tokenizer
print("Initializing BertTokenizer...")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
print("Tokenizer initialized!")

# Image transformations
print("Defining image transformations...")
image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
print("Image transformations defined!")

# Custom Dataset (Düzeltildi)
print("Defining custom Flickr30kMiniDataset class...")
class Flickr30kMiniDataset(Dataset):
    def __init__(self, dataset, transform, tokenizer):
        self.dataset = dataset
        self.transform = transform
        self.tokenizer = tokenizer
        print(f"Dataset initialized with {len(dataset)} examples")

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        example = self.dataset[idx]
        print(f"Processing example at index {idx}")
        # Görüntü 'image' anahtarında
        image = example["image"]
        print(f"Image retrieved from dataset, type: {type(image)}")
        if image.mode != "RGB":
            image = image.convert("RGB")
            print("Image converted to RGB")
        pixel_values = self.transform(image)
        print(f"Image transformed, pixel_values shape: {pixel_values.shape}")
        # İlk caption’ı al (liste içinden)
        caption = example["caption"][0]  # Düzeltme burada!
        print(f"Tokenizing caption: {caption}")
        tokenized = self.tokenizer(caption, padding="max_length", max_length=64, truncation=True, return_tensors="pt")
        print(f"Caption tokenized, input_ids shape: {tokenized['input_ids'].shape}")
        return {
            "pixel_values": pixel_values,
            "input_ids": tokenized["input_ids"].squeeze(0),
            "attention_mask": tokenized["attention_mask"].squeeze(0),
        }

# Datasets oluşturma
print("Creating training dataset...")
train_data = Flickr30kMiniDataset(train_subset, image_transform, tokenizer)
print("Creating validation dataset...")
val_data = Flickr30kMiniDataset(val_subset, image_transform, tokenizer)

# Data Loaders
print("Creating training data loader...")
train_loader = DataLoader(train_data, batch_size=16, shuffle=True, num_workers=0, pin_memory=True)
print(f"Training loader created with {len(train_loader)} batches")
print("Creating validation data loader...")
val_loader = DataLoader(val_data, batch_size=16, shuffle=False, num_workers=0, pin_memory=True)
print(f"Validation loader created with {len(val_loader)} batches")

# [Paste your Siglip2VisionConfig, Siglip2TextConfig, Siglip2Config, AttentionPooling,
# Siglip2VisionModel, Siglip2TextModel, Siglip2Model, and siglip_loss definitions here]

# Model başlatma
print("Initializing SIGLIP2 model...")
vision_config = Siglip2VisionConfig()
text_config = Siglip2TextConfig()
config = Siglip2Config(vision_config=vision_config, text_config=text_config)
model = Siglip2Model(config)
model.to(device)
print("Model initialized and moved to CUDA!")

# Training setup
print("Setting up training environment...")
optimizer = AdamW(model.parameters(), lr=5e-5, weight_decay=0.01)
print("Optimizer initialized!")
scheduler = CosineAnnealingLR(optimizer, T_max=len(train_loader))
print("Scheduler initialized!")

# Training loop
num_epochs = 1
print(f"Starting training for {num_epochs} epoch on {device}...")
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs} started")
    model.train()
    total_train_loss = 0
    print("Entering training loop...")
    for batch_idx, batch in enumerate(tqdm(train_loader, desc=f"Training Epoch {epoch+1}")):
        print(f"Batch {batch_idx+1}/{len(train_loader)} loaded")
        pixel_values = batch["pixel_values"].to(device, non_blocking=True)
        input_ids = batch["input_ids"].to(device, non_blocking=True)
        attention_mask = batch["attention_mask"].to(device, non_blocking=True)
        print(f"Batch moved to {device} - pixel_values: {pixel_values.shape}, input_ids: {input_ids.shape}")
        vision_proj, text_proj = model(pixel_values, input_ids, attention_mask)
        print(f"Forward pass completed on {device}, vision_proj shape: {vision_proj.shape}")
        loss = siglip_loss(vision_proj, text_proj)
        print(f"Loss computed on {device}: {loss.item()}")
        optimizer.zero_grad()
        loss.backward()
        print("Backward pass completed")
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        print("Optimizer step completed")
        total_train_loss += loss.item()
    scheduler.step()
    avg_train_loss = total_train_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs} completed, Average Training Loss: {avg_train_loss}")

    # Validation
    print(f"Starting validation for epoch {epoch+1}")
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(val_loader, desc=f"Validation Epoch {epoch+1}")):
            print(f"Validation batch {batch_idx+1}/{len(val_loader)} loaded")
            pixel_values = batch["pixel_values"].to(device, non_blocking=True)
            input_ids = batch["input_ids"].to(device, non_blocking=True)
            attention_mask = batch["attention_mask"].to(device, non_blocking=True)
            print(f"Validation batch moved to {device} - pixel_values: {pixel_values.shape}")
            vision_proj, text_proj = model(pixel_values, input_ids, attention_mask)
            loss = siglip_loss(vision_proj, text_proj)
            print(f"Validation loss on {device}: {loss.item()}")
            total_val_loss += loss.item()
    avg_val_loss = total_val_loss / len(val_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Validation Loss: {avg_val_loss}")

# Evaluation function
def evaluate_model(model, data_loader, device):
    print("Starting model evaluation...")
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(data_loader, desc="Evaluation")):
            print(f"Evaluation batch {batch_idx+1}/{len(data_loader)} loaded")
            pixel_values = batch["pixel_values"].to(device, non_blocking=True)
            input_ids = batch["input_ids"].to(device, non_blocking=True)
            attention_mask = batch["attention_mask"].to(device, non_blocking=True)
            vision_proj, text_proj = model(pixel_values, input_ids, attention_mask)
            loss = siglip_loss(vision_proj, text_proj)
            print(f"Evaluation batch loss on {device}: {loss.item()}")
            total_loss += loss.item()
    avg_loss = total_loss / len(data_loader)
    print(f"Evaluation completed, Average Loss: {avg_loss}")
    return avg_loss

# Run evaluation
print("Running final evaluation...")
evaluate_model(model, val_loader, device)
print("Training and evaluation process finished!")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig, PreTrainedModel, BertTokenizer
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from datasets import load_dataset
from PIL import Image
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import os
import random
from tqdm import tqdm

# CUDA kontrolü
print("Checking CUDA availability...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"CUDA Device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"PyTorch Version: {torch.__version__}")

# Dataset yükleme
dataset = load_dataset("nlphuji/flickr30k")
full_dataset = dataset["test"]
random.seed(42)
all_indices = list(range(len(full_dataset)))
random.shuffle(all_indices)
train_indices = all_indices[:100]
val_indices = all_indices[100:150]
train_subset = full_dataset.select(train_indices)
val_subset = full_dataset.select(val_indices)

# Tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# Image transformations
image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

# Custom Dataset
class Flickr30kMiniDataset(Dataset):
    def __init__(self, dataset, transform, tokenizer):
        self.dataset = dataset
        self.transform = transform
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        example = self.dataset[idx]
        image = example["image"]
        if image.mode != "RGB":
            image = image.convert("RGB")
        pixel_values = self.transform(image)
        caption = example["caption"][0]  # İlk caption
        tokenized = self.tokenizer(caption, padding="max_length", max_length=64, truncation=True, return_tensors="pt")
        return {
            "pixel_values": pixel_values,
            "input_ids": tokenized["input_ids"].squeeze(0),
            "attention_mask": tokenized["attention_mask"].squeeze(0),
        }

# Data Loaders
train_data = Flickr30kMiniDataset(train_subset, image_transform, tokenizer)
val_data = Flickr30kMiniDataset(val_subset, image_transform, tokenizer)
train_loader = DataLoader(train_data, batch_size=16, shuffle=True, num_workers=0, pin_memory=True)
val_loader = DataLoader(val_data, batch_size=16, shuffle=False, num_workers=0, pin_memory=True)

# Model Configs unchanged (Siglip2VisionConfig, Siglip2TextConfig, AttentionPooling as before)
# [Paste your previous Siglip2VisionConfig, Siglip2TextConfig, AttentionPooling here]

# Text Decoder (Image-to-Text için)
class Siglip2Decoder(nn.Module):
    def __init__(self, hidden_size=768, num_layers=6, num_heads=12, vocab_size=30522):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.decoder = nn.ModuleList([
            nn.TransformerDecoderLayer(d_model=hidden_size, nhead=num_heads, dim_feedforward=hidden_size*4, batch_first=True)
            for _ in range(num_layers)
        ])
        self.fc_out = nn.Linear(hidden_size, vocab_size)
        self.positional_embeddings = nn.Parameter(torch.zeros(1, 64, hidden_size))

    def forward(self, vision_embedding, target_ids=None, attention_mask=None):
        if target_ids is None:  # Inference
            seq_len = 64
            output_ids = torch.zeros((vision_embedding.size(0), seq_len), dtype=torch.long, device=vision_embedding.device)
            for t in range(seq_len):
                x = self.embedding(output_ids[:, :t+1]) + self.positional_embeddings[:, :t+1]
                for layer in self.decoder:
                    x = layer(x, vision_embedding, tgt_mask=nn.Transformer.generate_square_subsequent_mask(t+1).to(device))
                logits = self.fc_out(x[:, -1, :])
                output_ids[:, t] = logits.argmax(-1)
            return output_ids
        else:  # Training
            x = self.embedding(target_ids) + self.positional_embeddings[:, :target_ids.size(1)]
            for layer in self.decoder:
                x = layer(x, vision_embedding, tgt_mask=nn.Transformer.generate_square_subsequent_mask(target_ids.size(1)).to(device))
            return self.fc_out(x)

# Simple Diffusion Model (Text-to-Image için temel)
class SimpleDiffusion(nn.Module):
    def __init__(self, hidden_size=768, img_size=224):
        super().__init__()
        self.noise_scheduler = torch.linspace(0, 1, 1000)  # Basit bir noise schedule
        self.unet = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1),
            nn.Tanh()
        )
        self.text_proj = nn.Linear(hidden_size, 64*img_size*img_size)

    def forward(self, text_embedding, steps=50):
        noise = torch.randn(text_embedding.size(0), 3, 224, 224, device=device)
        for t in range(steps):
            t_tensor = torch.full((text_embedding.size(0),), self.noise_scheduler[t], device=device)
            pred_noise = self.unet(noise + self.text_proj(text_embedding).view(-1, 64, 224, 224))
            noise = noise - 0.1 * pred_noise  # Basit denoising
        return noise.clamp(-1, 1)

# Extended SIGLIP2 Model
class Siglip2ModelExtended(PreTrainedModel):
    config_class = Siglip2Config
    def __init__(self, config):
        super().__init__(config)
        self.vision_model = Siglip2VisionModel(config.vision_config)
        self.text_model = Siglip2TextModel(config.text_config)
        self.vision_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim)
        self.text_projection = nn.Linear(config.text_config.hidden_size, config.projection_dim)
        # Yeni eklenenler
        self.decoder = Siglip2Decoder(config.text_config.hidden_size, vocab_size=tokenizer.vocab_size)
        self.diffusion = SimpleDiffusion(config.projection_dim)

    def forward(self, pixel_values, input_ids=None, attention_mask=None, mode="contrastive"):
        vision_embeddings = self.vision_model(pixel_values)
        vision_proj = self.vision_projection(vision_embeddings)
        
        if mode == "contrastive":
            text_embeddings = self.text_model(input_ids, attention_mask)
            text_proj = self.text_projection(text_embeddings)
            return vision_proj, text_proj
        elif mode == "image_to_text":
            return self.decoder(vision_proj, input_ids, attention_mask)
        elif mode == "text_to_image":
            text_embeddings = self.text_model(input_ids, attention_mask)
            text_proj = self.text_projection(text_embeddings)
            return self.diffusion(text_proj)

# Training Setup
config = Siglip2Config()
model = Siglip2ModelExtended(config)
model.to(device)
optimizer = AdamW(model.parameters(), lr=5e-5, weight_decay=0.01)
scheduler = CosineAnnealingLR(optimizer, T_max=len(train_loader))

# Loss Functions
def contrastive_loss(vision_proj, text_proj):
    return siglip_loss(vision_proj, text_proj)  # Previous siglip_loss

def captioning_loss(logits, target_ids):
    return F.cross_entropy(logits.view(-1, logits.size(-1)), target_ids.view(-1), ignore_index=tokenizer.pad_token_id)

def diffusion_loss(pred_img, target_img):
    return F.mse_loss(pred_img, target_img)

# Training Loop
num_epochs = 1
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch in tqdm(train_loader, desc=f"Training Epoch {epoch+1}"):
        pixel_values = batch["pixel_values"].to(device, non_blocking=True)
        input_ids = batch["input_ids"].to(device, non_blocking=True)
        attention_mask = batch["attention_mask"].to(device, non_blocking=True)

        # Contrastive Loss
        vision_proj, text_proj = model(pixel_values, input_ids, attention_mask, mode="contrastive")
        loss_contrastive = contrastive_loss(vision_proj, text_proj)

        # Captioning Loss
        caption_logits = model(pixel_values, input_ids, attention_mask, mode="image_to_text")
        loss_caption = captioning_loss(caption_logits, input_ids)

        # Diffusion Loss (Text-to-Image)
        generated_img = model(input_ids=input_ids, attention_mask=attention_mask, mode="text_to_image")
        loss_diffusion = diffusion_loss(generated_img, pixel_values)

        # Total Loss
        loss = loss_contrastive + loss_caption + loss_diffusion
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Training Loss: {avg_loss}")

    # Validation (contrastive only for simplicity)
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Validation Epoch {epoch+1}"):
            pixel_values = batch["pixel_values"].to(device, non_blocking=True)
            input_ids = batch["input_ids"].to(device, non_blocking=True)
            attention_mask = batch["attention_mask"].to(device, non_blocking=True)
            vision_proj, text_proj = model(pixel_values, input_ids, attention_mask, mode="contrastive")
            loss = contrastive_loss(vision_proj, text_proj)
            total_val_loss += loss.item()
    print(f"Validation Loss: {total_val_loss / len(val_loader)}")

# Inference Functions
def image_to_text(model, image, tokenizer, max_length=64):
    model.eval()
    with torch.no_grad():
        pixel_values = image_transform(image).unsqueeze(0).to(device)
        caption_ids = model(pixel_values, mode="image_to_text")
        return tokenizer.decode(caption_ids[0], skip_special_tokens=True)

def text_to_image(model, text, tokenizer):
    model.eval()
    with torch.no_grad():
        tokenized = tokenizer(text, padding="max_length", max_length=64, truncation=True, return_tensors="pt")
        input_ids = tokenized["input_ids"].to(device)
        attention_mask = tokenized["attention_mask"].to(device)
        generated_img = model(input_ids=input_ids, attention_mask=attention_mask, mode="text_to_image")
        return generated_img.squeeze(0).cpu()

# Test Inference
sample_image = val_subset[0]["image"]
print("Generated Caption:", image_to_text(model, sample_image, tokenizer))
print("Generated Image from Text:", text_to_image(model, "A dog running in the park", tokenizer).shape)