In [None]:
# Import required libraries
import tensorflow as tf
import numpy as np
import PIL.Image
from PIL import ImageEnhance
import random
import os
import io
import zipfile
import matplotlib.pyplot as plt
from google.colab import files

# Install required packages
!pip install transformers

# Import the pipeline after installation
from transformers import pipeline

# Download CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

# Prepare dataset of the first 10 images
subset_x = x_train[:10]
subset_y = y_train[:10]

# Print shapes of the subset to verify
print("Subset of Images shape:", subset_x.shape)
print("Subset of Labels shape:", subset_y.shape)

# Initialize the image captioning pipeline with BLIP
image_captioning_pipeline = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")

# CIFAR-10 class names for better filenames
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck']

# Create a folder to store our images and descriptions
!mkdir -p image_descriptions

# Create a temporary directory for files that will go into the zip
!mkdir -p temp_for_zip

# Function to apply different augmentations to an image
def augment_image(image, augmentation_type):
    modified_image = image.copy()

    if augmentation_type == 0:
        # Original image
        return modified_image
    elif augmentation_type == 1:
        # Slightly rotated
        return modified_image.rotate(10)
    elif augmentation_type == 2:
        # Slightly cropped
        width, height = modified_image.size
        crop_size = int(min(width, height) * 0.9)
        left = (width - crop_size) // 2
        top = (height - crop_size) // 2
        modified_image = modified_image.crop((left, top, left + crop_size, top + crop_size))
        return modified_image.resize((width, height))
    elif augmentation_type == 3:
        # Adjust brightness
        enhancer = ImageEnhance.Brightness(modified_image)
        return enhancer.enhance(1.2)
    else:
        # Adjust contrast
        enhancer = ImageEnhance.Contrast(modified_image)
        return enhancer.enhance(1.2)

# Create a ZIP file
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:

    # Process each image
    for i in range(len(subset_x)):
        # Get class name from the label for better naming
        class_idx = subset_y[i][0]
        class_name = class_names[class_idx]

        # Create a meaningful filename
        image_filename = f"image_{i+1}_{class_name}.png"

        # Convert numpy array to PIL Image
        image = PIL.Image.fromarray(subset_x[i])

        # Save the original image
        image_path = f"temp_for_zip/{image_filename}"
        image.save(image_path)

        # Add image to zip file
        zip_file.write(image_path, arcname=image_filename)

        # Create a text file for descriptions
        descriptions_filename = f"image_{i+1}_{class_name}_descriptions.txt"
        descriptions_path = f"temp_for_zip/{descriptions_filename}"

        with open(descriptions_path, 'w') as desc_file:
            desc_file.write(f"Descriptions for {image_filename}:\n\n")

            # Generate 5 different descriptions using image augmentation
            for j in range(5):
                # Apply different augmentation for each description
                modified_image = augment_image(image, j)

                # Generate caption
                caption_output = image_captioning_pipeline(modified_image)
                caption = caption_output[0]["generated_text"]

                # Write to file
                desc_file.write(f"Description {j+1}: {caption}\n")

                # Also print to console
                print(f"Image {i+1} ({class_name}) - Description {j+1}: {caption}")

        # Add descriptions file to zip
        zip_file.write(descriptions_path, arcname=descriptions_filename)

        # Display the image in the notebook
        plt.figure(figsize=(3, 3))
        plt.imshow(subset_x[i])
        plt.title(f"Image {i+1}: {class_name}")
        plt.axis('off')
        plt.show()

# Download the zip file
zip_buffer.seek(0)
with open('image_descriptions.zip', 'wb') as f:
    f.write(zip_buffer.getvalue())

# Clean up temporary files
!rm -rf temp_for_zip

# Download the zip file to the user's computer
files.download('image_descriptions.zip')

print("\nProcess complete! The zip file has been downloaded to your computer.")
print("The zip contains 10 images and 10 text files with 5 descriptions for each image.")

In [None]:
# Training a SigLIP model on CIFAR-10 images with descriptions
# For Google Colab execution - Using the ZIP file from the previous cell

import os
import zipfile
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import numpy as np
from transformers import AutoTokenizer, AutoModel
from torchvision.models import resnet18, ResNet18_Weights

# Install required packages
!pip install torch torchvision transformers tqdm matplotlib

# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Step 1: Use the zip file already generated from the previous cell
zip_path = "image_descriptions.zip"

# Check if the file exists
if not os.path.exists(zip_path):
    raise FileNotFoundError(f"Could not find {zip_path}. Make sure the previous cell completed successfully.")

print(f"Found existing zip file: {zip_path}")

# Extract the zip file
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall("dataset")
    print("Extracted files to 'dataset' folder")

# Step 2: Create Dataset class to load images and descriptions
class ImageTextDataset(Dataset):
    def __init__(self, folder_path, transform=None):
        self.folder_path = folder_path
        self.transform = transform
        self.samples = []

        # Find all image files
        image_files = [f for f in os.listdir(folder_path) if f.endswith('.png')]

        # For each image, find its descriptions
        for img_file in image_files:
            # Get base name without extension
            base_name = img_file.rsplit('.', 1)[0]
            desc_file = f"{base_name}_descriptions.txt"
            desc_path = os.path.join(folder_path, desc_file)

            if os.path.exists(desc_path):
                # Read descriptions
                with open(desc_path, 'r') as f:
                    lines = f.readlines()

                # Extract descriptions (skip header lines)
                descriptions = []
                for line in lines:
                    match = re.search(r'Description \d+: (.*)', line)
                    if match:
                        descriptions.append(match.group(1))

                # Add image-text pairs to samples
                for desc in descriptions:
                    self.samples.append((img_file, desc))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_file, text = self.samples[idx]
        img_path = os.path.join(self.folder_path, img_file)

        # Load and transform image
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)

        return image, text

# Step 3: Create the SigLIP model architecture
class TextEncoder(nn.Module):
    def __init__(self, output_dim=512):
        super(TextEncoder, self).__init__()
        # Use a pretrained BERT model
        self.bert = AutoModel.from_pretrained("bert-base-uncased")
        # Add a projection layer
        self.projection = nn.Linear(768, output_dim)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        # Use [CLS] token representation
        cls_output = outputs.last_hidden_state[:, 0, :]
        # Project to output dimension
        return self.projection(cls_output)

class ImageEncoder(nn.Module):
    def __init__(self, output_dim=512):
        super(ImageEncoder, self).__init__()
        # Use a pretrained ResNet model
        weights = ResNet18_Weights.DEFAULT
        self.resnet = resnet18(weights=weights)
        # Replace the final FC layer
        self.resnet.fc = nn.Linear(512, output_dim)

    def forward(self, x):
        return self.resnet(x)

class SigLIP(nn.Module):
    def __init__(self, output_dim=512):
        super(SigLIP, self).__init__()
        self.image_encoder = ImageEncoder(output_dim)
        self.text_encoder = TextEncoder(output_dim)
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

    def forward(self, images, input_ids, attention_mask):
        # Encode images and text
        image_features = self.image_encoder(images)
        text_features = self.text_encoder(input_ids, attention_mask)

        # Normalize features
        image_features = F.normalize(image_features, dim=1)
        text_features = F.normalize(text_features, dim=1)

        # Compute similarity matrix
        logit_scale = self.logit_scale.exp()
        logits = logit_scale * image_features @ text_features.t()

        return logits

# Step 4: Define the SigLIP loss function
class SigLIPLoss(nn.Module):
    def __init__(self):
        super(SigLIPLoss, self).__init__()

    def forward(self, logits):
        # Create labels (diagonal should be 1, rest 0)
        labels = torch.eye(logits.shape[0], device=logits.device)

        # Apply sigmoid loss
        loss = F.binary_cross_entropy_with_logits(logits, labels)

        return loss

# Step 5: Set up the training loop
def train_siglip(model, train_loader, optimizer, criterion, tokenizer, epochs=10):
    model.train()
    training_losses = []

    for epoch in range(epochs):
        running_loss = 0.0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")

        for batch_idx, (images, texts) in enumerate(progress_bar):
            # Tokenize texts
            encodings = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
            input_ids = encodings['input_ids'].to(device)
            attention_mask = encodings['attention_mask'].to(device)

            # Move images to device
            images = images.to(device)

            # Forward pass
            logits = model(images, input_ids, attention_mask)

            # Compute loss
            loss = criterion(logits)

            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update statistics
            running_loss += loss.item()
            progress_bar.set_postfix(loss=running_loss/(batch_idx+1))

        epoch_loss = running_loss/len(train_loader)
        training_losses.append(epoch_loss)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.6f}")

    # Plot the training loss
    plt.figure(figsize=(10, 5))
    plt.plot(range(1, epochs+1), training_losses, marker='o')
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(True)
    plt.show()

    return model

# Step 6: Function to test image retrieval
def test_image_retrieval(model, dataset, tokenizer):
    model.eval()

    # Create dataloaders for all images and texts
    all_images = []
    all_texts = []
    unique_images = {}
    image_files_list = []

    for img_file, text in dataset.samples:
        if img_file not in unique_images:
            img_path = os.path.join(dataset.folder_path, img_file)
            image = dataset.transform(Image.open(img_path).convert('RGB'))
            all_images.append(image)
            unique_images[img_file] = len(all_images) - 1
            image_files_list.append(img_file)

        all_texts.append(text)

    # Convert to tensors
    all_images = torch.stack(all_images).to(device)

    # Encode all images
    with torch.no_grad():
        image_features = model.image_encoder(all_images)
        image_features = F.normalize(image_features, dim=1)

    # Test retrieval for a few texts
    test_texts = all_texts[:5]  # Just test the first 5 texts

    print("\nTesting image retrieval:")
    for i, text in enumerate(test_texts):
        # Tokenize text
        encodings = tokenizer([text], padding=True, truncation=True, return_tensors="pt")
        input_ids = encodings['input_ids'].to(device)
        attention_mask = encodings['attention_mask'].to(device)

        # Encode text
        with torch.no_grad():
            text_feature = model.text_encoder(input_ids, attention_mask)
            text_feature = F.normalize(text_feature, dim=1)

        # Compute similarities
        logit_scale = model.logit_scale.exp()
        similarities = logit_scale * (text_feature @ image_features.t()).squeeze()

        # Get top 3 matches
        top_matches = torch.argsort(similarities, descending=True)[:3].tolist()

        print(f"\nFor text {i+1}: '{text}'")
        for rank, idx in enumerate(top_matches):
            img_file = image_files_list[idx]
            score = similarities[idx].item()
            print(f"  Match {rank+1}: {img_file} (score: {score:.4f})")

            # Display the image
            img_path = os.path.join(dataset.folder_path, img_file)
            img = Image.open(img_path).convert('RGB')
            plt.figure(figsize=(3, 3))
            plt.imshow(np.array(img))
            plt.title(f"Match {rank+1}: {img_file}")
            plt.axis('off')
            plt.show()

# Step 7: Test text retrieval given an image
def test_text_retrieval(model, dataset, tokenizer):
    model.eval()

    # Encode all texts
    all_texts = []
    all_image_files = []
    unique_texts = {}

    for img_file, text in dataset.samples:
        all_texts.append(text)
        all_image_files.append(img_file)

    # Tokenize all texts
    encodings = tokenizer(all_texts, padding=True, truncation=True, return_tensors="pt")
    input_ids = encodings['input_ids'].to(device)
    attention_mask = encodings['attention_mask'].to(device)

    # Encode all texts
    with torch.no_grad():
        text_features = model.text_encoder(input_ids, attention_mask)
        text_features = F.normalize(text_features, dim=1)

    # Test retrieval for a few images
    test_image_indices = [0, 10, 20, 30, 40]  # Test a few images

    print("\nTesting text retrieval:")
    for idx in test_image_indices:
        if idx >= len(dataset):
            continue

        image, _ = dataset[idx]
        img_file = all_image_files[idx]

        # Display the query image
        img_path = os.path.join(dataset.folder_path, img_file)
        img = Image.open(img_path).convert('RGB')
        plt.figure(figsize=(3, 3))
        plt.imshow(np.array(img))
        plt.title(f"Query Image: {img_file}")
        plt.axis('off')
        plt.show()

        # Encode image
        image = image.unsqueeze(0).to(device)  # Add batch dimension
        with torch.no_grad():
            image_feature = model.image_encoder(image)
            image_feature = F.normalize(image_feature, dim=1)

        # Compute similarities
        logit_scale = model.logit_scale.exp()
        similarities = logit_scale * (image_feature @ text_features.t()).squeeze()

        # Get top 3 matches
        top_matches = torch.argsort(similarities, descending=True)[:3].tolist()

        print(f"\nFor image: {img_file}")
        for rank, text_idx in enumerate(top_matches):
            matched_text = all_texts[text_idx]
            score = similarities[text_idx].item()
            print(f"  Match {rank+1}: '{matched_text}' (score: {score:.4f})")

# Step 8: Main function to run the training
def main():
    # Set up transforms
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Create dataset and dataloader
    dataset = ImageTextDataset(folder_path="dataset", transform=transform)
    print(f"Dataset size: {len(dataset)} image-text pairs")

    # Display a few examples
    print("\nSample image-text pairs:")
    for i in range(min(5, len(dataset))):
        print(f"Example {i+1}: {dataset.samples[i][0]} - {dataset.samples[i][1]}")

    # Create dataloader
    batch_size = 1  # Small batch size for this tiny dataset
    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Initialize tokenizer
    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

    # Initialize model
    model = SigLIP(output_dim=512).to(device)

    # Initialize optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

    # Initialize loss function
    criterion = SigLIPLoss()

    # Train model
    print("\nStarting SigLIP training...")
    model = train_siglip(model, train_loader, optimizer, criterion, tokenizer, epochs=10)

    # Save model
    torch.save(model.state_dict(), "siglip_cifar10.pth")
    print("Model saved to siglip_cifar10.pth")

    # Test the model on a few examples
    test_image_retrieval(model, dataset, tokenizer)
    test_text_retrieval(model, dataset, tokenizer)

    print("\nTraining and evaluation complete!")

# Run the main function
if __name__ == "__main__":
    main()

In [None]:
!pip install torch torchvision transformers tqdm matplotlib peft bitsandbytes accelerate wandb


In [None]:
!pip install -U bitsandbytes # Ensure bitsandbytes is installed/updated
# Training a VLM (Vision + Phi3) with SigLIP vision encoder and frozen Phi-3 text decoder
# Optimized for smaller model size with QLoRA
# For Google Colab execution

import os
import zipfile
import re
import json
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoModel
import bitsandbytes as bnb # Import the bitsandbytes library
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel
import logging
import wandb
from datetime import datetime

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Install required packages
!pip install torch torchvision transformers tqdm matplotlib peft bitsandbytes accelerate wandb

# ... (Rest of the code)

In [None]:
# Training a VLM (Vision + Phi3) with SigLIP vision encoder and frozen Phi-3 text decoder
# Optimized for smaller model size with QLoRA
# For Google Colab execution

import os
import zipfile
import re
import json
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoModel
import bitsandbytes as bnb
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel
import logging
import wandb
from datetime import datetime

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Install required packages
!pip install torch torchvision transformers tqdm matplotlib peft bitsandbytes accelerate wandb

# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

# Check if the SigLIP model file exists
siglip_path = "siglip_cifar10.pth"
if not os.path.exists(siglip_path):
    raise FileNotFoundError(f"Could not find {siglip_path}. Make sure the model file is available.")

logger.info(f"Found SigLIP model file: {siglip_path}")

# Check if the zip file exists (in case we need to use the dataset again)
zip_path = "image_descriptions.zip"
if not os.path.exists(zip_path):
    raise FileNotFoundError(f"Could not find {zip_path}. Make sure the dataset file is available.")

# Extract the zip file if the dataset folder doesn't exist
if not os.path.exists("dataset"):
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall("dataset")
        logger.info("Extracted files to 'dataset' folder")

# Define model classes from the previous code
class TextEncoder(nn.Module):
    def __init__(self, output_dim=512):
        super(TextEncoder, self).__init__()
        self.bert = AutoModel.from_pretrained("bert-base-uncased")
        self.projection = nn.Linear(768, output_dim)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.last_hidden_state[:, 0, :]
        return self.projection(cls_output)

class ImageEncoder(nn.Module):
    def __init__(self, output_dim=512):
        super(ImageEncoder, self).__init__()
        from torchvision.models import resnet18, ResNet18_Weights
        weights = ResNet18_Weights.DEFAULT
        self.resnet = resnet18(weights=weights)
        self.resnet.fc = nn.Linear(512, output_dim)

    def forward(self, x):
        return self.resnet(x)

class SigLIP(nn.Module):
    def __init__(self, output_dim=512):
        super(SigLIP, self).__init__()
        self.image_encoder = ImageEncoder(output_dim)
        self.text_encoder = TextEncoder(output_dim)
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

    def forward(self, images, input_ids=None, attention_mask=None):
        # If text inputs are provided, compute full SigLIP forward pass
        if input_ids is not None and attention_mask is not None:
            image_features = self.image_encoder(images)
            text_features = self.text_encoder(input_ids, attention_mask)

            image_features = F.normalize(image_features, dim=1)
            text_features = F.normalize(text_features, dim=1)

            logit_scale = self.logit_scale.exp()
            logits = logit_scale * image_features @ text_features.t()

            return logits
        # Otherwise, just return image features (for VLM)
        else:
            return self.image_encoder(images)

# Create VLM dataset class
class VLMDataset(Dataset):
    def __init__(self, folder_path, transform=None, max_length=512):
        self.folder_path = folder_path
        self.transform = transform
        self.max_length = max_length
        self.samples = []

        # Find all image files
        image_files = [f for f in os.listdir(folder_path) if f.endswith('.png')]

        # For each image, find its descriptions
        for img_file in image_files:
            # Get base name without extension
            base_name = img_file.rsplit('.', 1)[0]
            desc_file = f"{base_name}_descriptions.txt"
            desc_path = os.path.join(folder_path, desc_file)

            if os.path.exists(desc_path):
                # Read descriptions
                with open(desc_path, 'r') as f:
                    lines = f.readlines()

                # Extract descriptions (skip header lines)
                descriptions = []
                for line in lines:
                    match = re.search(r'Description \d+: (.*)', line)
                    if match:
                        descriptions.append(match.group(1))

                # Add image-text pairs to samples
                for desc in descriptions:
                    self.samples.append((img_file, desc))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_file, text = self.samples[idx]
        img_path = os.path.join(self.folder_path, img_file)

        # Load and transform image
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)

        return image, text

# Define the Vision-Language Model
class VLM(nn.Module):
    def __init__(self, vision_encoder, text_decoder, projection_dim=4096):
        super(VLM, self).__init__()
        self.vision_encoder = vision_encoder
        self.text_decoder = text_decoder

        # Projection layer to map vision embeddings to text model dimension
        vision_dim = 512  # SigLIP output dimension
        text_dim = self.text_decoder.config.hidden_size
        self.projection = nn.Linear(vision_dim, text_dim)

    def forward(self, images, input_ids, attention_mask):
        # Encode images
        with torch.no_grad():  # Freeze vision encoder
            image_features = self.vision_encoder(images)

        # Project image features to text model dimension
        projected_features = self.projection(image_features)

        # Prepare for text decoder (reshape to match hidden states)
        batch_size = images.shape[0]

        # Get outputs from text decoder
        outputs = self.text_decoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=None,  # We'll use the standard token embeddings
            output_hidden_states=True
        )

        # For training, we use the language modeling head to predict the next tokens
        logits = outputs.logits

        return logits

def get_qlora_config():
    return LoraConfig(
        r=4,  # Reduced rank for smaller size
        lora_alpha=16,
        # Updated module names for Phi-3 model
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM"
    )

# Initialize the models and training components
def initialize_models():
    # 1. Load the trained SigLIP vision encoder
    vision_encoder = SigLIP(output_dim=512)
    vision_encoder.load_state_dict(torch.load(siglip_path, map_location=device))
    vision_encoder = vision_encoder.to(device)
    vision_encoder.eval()  # Freeze the vision encoder
    logger.info("SigLIP vision encoder loaded successfully")

    # 2. Initialize Phi-3 with BitsAndBytes for quantization
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )

    # Use the Microsoft Phi-3 mini model
    model_id = "microsoft/phi-3-mini-4k-instruct"

    # Load the tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.pad_token = tokenizer.eos_token

    text_decoder = AutoModelForCausalLM.from_pretrained(
        model_id,
        quantization_config=bnb_config,
        device_map="auto"
    )

    # Prepare model for k-bit training
    text_decoder = prepare_model_for_kbit_training(text_decoder)

    # Apply optimized LoRA
    peft_config = get_qlora_config()
    text_decoder = get_peft_model(text_decoder, peft_config)

    # Freeze the base model parameters
    for param in text_decoder.base_model.parameters():
        param.requires_grad = False

    # 3. Create the VLM model
    vlm_model = VLM(vision_encoder, text_decoder)

    # Move to device
    vlm_model = vlm_model.to(device)

    # Count trainable parameters
    trainable_params = sum(p.numel() for p in vlm_model.parameters() if p.requires_grad)
    logger.info(f"Trainable parameters: {trainable_params:,}")
    logger.info(f"Total parameters: {sum(p.numel() for p in vlm_model.parameters()):,}")

    return vlm_model, tokenizer

# Training function with logging
def train_vlm(model, train_loader, optimizer, tokenizer, epochs=5):
    # Initialize wandb for logging
    run = wandb.init(project="vlm-training", name=f"vlm-siglip-phi3-{datetime.now().strftime('%Y%m%d-%H%M%S')}")

    model.train()

    # Total number of training steps
    total_steps = epochs * len(train_loader)

    for epoch in range(epochs):
        running_loss = 0.0
        correct_predictions = 0
        total_predictions = 0

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")

        for batch_idx, (images, texts) in enumerate(progress_bar):
            # Prepare text inputs with special format for instruction-based models
            formatted_texts = []
            for text in texts:
                formatted_text = f"<|user|>\nDescribe this image.\n<|assistant|>\n{text}</s>"
                formatted_texts.append(formatted_text)

            # Tokenize texts - we'll use this both for inputs and targets
            encodings = tokenizer(formatted_texts, padding="max_length", truncation=True,
                                 max_length=512, return_tensors="pt")
            input_ids = encodings['input_ids'].to(device)
            attention_mask = encodings['attention_mask'].to(device)

            # Prepare target labels (shifted input_ids for causal language modeling)
            labels = input_ids.clone()

            # Set labels to -100 for non-target tokens (padding and prompt tokens)
            # This identifies which tokens are part of the response we want to train on
            for i, text in enumerate(formatted_texts):
                prompt_len = len(tokenizer(text.split("<|assistant|>\n")[0] + "<|assistant|>\n",
                                          add_special_tokens=False)['input_ids'])
                labels[i, :prompt_len] = -100  # Ignore tokens up to the assistant's response

            # Move images to device
            images = images.to(device)

            # Forward pass
            logits = model(images, input_ids, attention_mask)

            # Compute loss
            # We need to reshape as our logits are [batch_size, seq_len, vocab_size]
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()

            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

            # Compute accuracy
            # Only consider positions where labels != -100
            active_logits = shift_logits.view(-1, shift_logits.shape[-1])
            active_labels = shift_labels.view(-1)
            active_mask = active_labels != -100

            if active_mask.sum() > 0:  # Only compute if we have valid positions
                active_preds = torch.argmax(active_logits[active_mask], dim=-1)
                active_labels_filtered = active_labels[active_mask]
                correct = (active_preds == active_labels_filtered).sum().item()
                total = active_mask.sum().item()

                correct_predictions += correct
                total_predictions += total

            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update statistics
            running_loss += loss.item()

            # Update progress bar
            progress_bar.set_postfix(
                loss=running_loss/(batch_idx+1),
                acc=correct_predictions/max(1, total_predictions)
            )

            # Log to wandb
            step = epoch * len(train_loader) + batch_idx
            wandb.log({
                "train/loss": loss.item(),
                "train/accuracy": correct/max(1, total) if active_mask.sum() > 0 else 0,
                "train/step": step,
                "train/progress": step/total_steps
            })

        # End of epoch
        epoch_loss = running_loss / len(train_loader)
        epoch_accuracy = correct_predictions / max(1, total_predictions)

        logger.info(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.6f}, Accuracy: {epoch_accuracy:.4f}")

        # Log epoch metrics
        wandb.log({
            "train/epoch": epoch+1,
            "train/epoch_loss": epoch_loss,
            "train/epoch_accuracy": epoch_accuracy
        })

        # Save model checkpoint efficiently (only the necessary parts)
        checkpoint_path = f"efficient_vlm_checkpoint_epoch_{epoch+1}"
        save_efficient_model(model, tokenizer, path=checkpoint_path)
        logger.info(f"Efficient checkpoint saved to {checkpoint_path}")

    wandb.finish()
    return model

# Evaluation function
def evaluate_vlm(model, tokenizer, eval_dataset, num_samples=5):
    model.eval()

    # Create a dataloader with batch size 1 for evaluation
    eval_loader = DataLoader(eval_dataset, batch_size=1, shuffle=True)

    logger.info("Generating image descriptions...")

    results = []

    # Get the original dataset from the subset
    original_dataset = eval_dataset.dataset

    for i, (image, true_text) in enumerate(eval_loader):
        if i >= num_samples:
            break

        image = image.to(device)

        # Input prompt
        prompt = "<|user|>\nDescribe this image.\n<|assistant|>\n"
        inputs = tokenizer(prompt, return_tensors="pt", padding=True)
        input_ids = inputs.input_ids.to(device)
        attention_mask = inputs.attention_mask.to(device)

        # Generate text
        with torch.no_grad():
            # Get image features
            image_features = model.vision_encoder(image)

            # Project features
            projected_features = model.projection(image_features)

            # Generate text with the decoder
            outputs = model.text_decoder.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_new_tokens=100,
                do_sample=True,
                temperature=0.7,
                top_p=0.9,
                num_return_sequences=1
            )

            generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

            # Extract just the assistant's response
            if "<|assistant|>" in generated_text:
                generated_text = generated_text.split("<|assistant|>")[1].strip()

        # Display results
        print(f"\nSample {i+1}:")
        print(f"True description: {true_text[0]}")
        print(f"Generated description: {generated_text}")

        # Display the image - fix access to the original dataset
        # Get the index in the original dataset
        original_idx = eval_dataset.indices[i]
        img_path = os.path.join(original_dataset.folder_path, original_dataset.samples[original_idx][0])
        img = Image.open(img_path).convert('RGB')
        plt.figure(figsize=(5, 5))
        plt.imshow(np.array(img))
        plt.title("Input Image")
        plt.axis('off')
        plt.show()

        results.append({
            "image_path": img_path,
            "true_text": true_text[0],
            "generated_text": generated_text
        })

    # Save results to file
    with open("vlm_evaluation_results.json", "w") as f:
        json.dump(results, f, indent=2)

    return results

# Function to save efficient model (only the necessary components)
def save_efficient_model(model, tokenizer, path="efficient_vlm"):
    # Create directory if it doesn't exist
    os.makedirs(path, exist_ok=True)

    # 1. Save the projection layer
    torch.save(model.projection.state_dict(), f"{path}/projection_layer.pt")

    # 2. Save LoRA weights only (not the full model)
    model.text_decoder.save_pretrained(f"{path}/lora_weights")

    # 3. Save tokenizer
    tokenizer.save_pretrained(f"{path}/tokenizer")

    # 4. Save a config file with architecture details
    config = {
        "vision_model": "SigLIP",
        "text_model": "microsoft/phi-3-mini-4k-instruct",
        "projection_dim": model.projection.out_features,
        "input_dim": model.projection.in_features,
        "quantization": "4bit",
        "lora_config": {
            "r": 4,
            "alpha": 16,
            "target_modules": ["q_proj", "v_proj"]
        }
    }

    with open(f"{path}/config.json", "w") as f:
        json.dump(config, f, indent=2)

    logger.info(f"Efficient model saved to {path}")

    # 5. Log the model size
    total_size = 0
    for dirpath, dirnames, filenames in os.walk(path):
        for f in filenames:
            fp = os.path.join(dirpath, f)
            total_size += os.path.getsize(fp)

    logger.info(f"Total model size: {total_size / (1024 * 1024):.2f} MB")

# Function to load efficient model
def load_efficient_model(path="efficient_vlm"):
    # Load config
    with open(f"{path}/config.json", "r") as f:
        config = json.load(f)

    # Initialize vision encoder
    vision_encoder = SigLIP(output_dim=config["input_dim"])
    vision_encoder.load_state_dict(torch.load("siglip_cifar10.pth", map_location=device))
    vision_encoder.eval()

    # Initialize text decoder with quantization
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(f"{path}/tokenizer")

    # Load text model
    text_decoder = AutoModelForCausalLM.from_pretrained(
        config["text_model"],
        quantization_config=bnb_config,
        device_map="auto"
    )

    # Load LoRA weights
    text_decoder = PeftModel.from_pretrained(text_decoder, f"{path}/lora_weights")

    # Create VLM
    vlm_model = VLM(vision_encoder, text_decoder)

    # Load projection layer
    vlm_model.projection.load_state_dict(torch.load(f"{path}/projection_layer.pt"))

    return vlm_model, tokenizer

# Main function
def main():
    # Login to wandb
    try:
        wandb.login()
    except:
        logger.warning("Could not log in to Weights & Biases. Continuing without logging.")

    # Set up transforms
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Create dataset
    dataset = VLMDataset(folder_path="dataset", transform=transform)
    logger.info(f"Dataset size: {len(dataset)} image-text pairs")

    # Split dataset: 80% train, 20% eval
    train_size = int(0.8 * len(dataset))
    eval_size = len(dataset) - train_size
    train_dataset, eval_dataset = torch.utils.data.random_split(dataset, [train_size, eval_size])

    # Create dataloaders
    batch_size = 8  # Increased batch size since we're using less memory with QLoRA
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    # Initialize models
    vlm_model, tokenizer = initialize_models()

    # Initialize optimizer - we only train the projection layer and LoRA weights
    optimizer = torch.optim.AdamW([
        {'params': vlm_model.projection.parameters(), 'lr': 1e-4},
        {'params': vlm_model.text_decoder.parameters(), 'lr': 5e-5}
    ])

    # Train the model
    logger.info("Starting VLM training...")
    vlm_model = train_vlm(vlm_model, train_loader, optimizer, tokenizer, epochs=5)

    # Save final model efficiently
    save_efficient_model(vlm_model, tokenizer, path="efficient_vlm_final")
    logger.info("Efficient model saved to efficient_vlm_final")

    # Evaluate the model
    logger.info("Evaluating VLM...")
    evaluate_vlm(vlm_model, tokenizer, eval_dataset, num_samples=5)

    # Demonstrate loading the efficient model
    loaded_model, loaded_tokenizer = load_efficient_model(path="efficient_vlm_final")
    logger.info("Successfully loaded the efficient model")

    logger.info("Training and evaluation complete!")

if __name__ == "__main__":
    main()