In [1]:
# ## Final  Backbone

# +---------------------+      +-------------------+
# |     Input Image     |      |    Input Prompt   |
# +---------------------+      +-------------------+
#            |                         |
#            v                         v
# +---------------------+   +-------------------------+
# | Image Processor     |   | Tokenizer (Qwen)       |  [Pre-processing]
# | (SigLIP Processor)  |   | (model.tokenizer_...) |
# +---------------------+   +-------------------------+
#            |                         |
#            v                         v
#     pixel_values                prompt_input_ids
#            |                         |
# +----------+----------+              |
# | SiglipVisionModel   |              |
# | (Image Encoder)     |              |
# | [Frozen]            |              |
# +---------------------+              |
#            |                         |
#            v                         v
#  image_hidden_state        +--------------------------+
#  (Batch, ImgSeq, ImgDim)   | text_model.model.        |
#                            |   embed_tokens           | [Part of Frozen Gemma]
#                            | (Qwen Embeddings)        |
#                            +--------------------------+
#            |                         |
# +----------+----------+              |
# | image_projection    |              |
# | (Linear Layer)      |              |
# | [Trainable]         |              |
# +---------------------+              |
#            |                         |
#            v                         v
#  projected_image_embeds        prompt_embeds
#  (Batch, ImgSeq, TextDim)    (Batch, PromptSeq, TextDim)
#            |                         |
#            +-----------+-------------+
#                        |
#                        v
#       [Concatenate Embeddings & Masks]  <-- Creates combined_embeds & combined_mask
#                        |
# +----------------------V---------------------------------------------------+
# | text_model.generate()                                                    |
# |  - Uses Gemma Transformer Blocks                                         |
# |  - Autoregressive Generation Logic                                       |
# |  [Underlying text_model is Frozen]                                       |
# |  (Takes inputs_embeds=combined_embeds, attention_mask=combined_mask)   |
# +--------------------------------------------------------------------------+
#                        |
#                        v
#                Generated Token IDs
#                        |
# +----------------------V----------------------+
# | Tokenizer.decode (Qwen)                     |  [Post-processing]
# | (model.tokenizer_instance.decode)           |
# +---------------------------------------------+
#                        |
#                        v
# +---------------------------------------------+
# |               Final Output Text             |
# +---------------------------------------------+


# =========================================
# Component Status Summary (During Training):
# -----------------------------------------
# * [Frozen]: SiglipVisionModel (Image Encoder Backbone)
# * [Frozen]: AutoModelForCausalLM (Gemma Text Model - including transformer blocks and embeddings)
# * [Trainable]: image_projection (Linear layer mapping image features to text embedding dimension)
# * [Trainable]: logit_scale (Scalar parameter for SigLIP contrastive loss)
# * [Trainable]: logit_bias (Optional scalar parameter for SigLIP contrastive loss, if used)

# Note: During inference as shown above, only the forward paths of the frozen/trainable components are used. The `.generate()` method leverages the frozen Gemma transformer blocks. The trainable `logit_scale`/`logit_bias` are part of the model definition but aren't directly involved in the `.generate()` call itself.
# =========================================

SyntaxError: invalid syntax (<ipython-input-1-913967fd6b02>, line 1)

graph TD
    A[Input Image] --> B["Image Processor (SigLIP)"];
    C[Input Prompt] --> D["Tokenizer (Gemma)"];

    B -- pixel_values --> E["SiglipVisionModel [Frozen]<br/>(Image Encoder)"];
    D -- prompt_input_ids --> F["text_model.embed_tokens [Frozen]<br/>(Gemma Embedding Layer)"];

    E -- image_hidden_state --> G["image_projection [Trainable]<br/>(Linear Layer)"];
    F -- prompt_embeds --> H["Concatenate Embeddings & Masks"];
    G -- projected_image_embeds --> H;

    H -- combined_embeds, combined_mask --> I["text_model.generate() [Frozen]<br/>(Gemma Transformer Blocks &<br/>Generation Logic)"];

    I -- Generated Token IDs --> J["Tokenizer.decode (Gemma)"];
    J --> K[Final Output Text];

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
    AutoImageProcessor,
    AutoTokenizer,
    AutoModelForCausalLM,
    SiglipVisionModel, # Import the specific vision model class
    SiglipConfig,      # Helpful for getting config details
)


from typing import Optional, Dict, Tuple
import os # For securely handling tokens

import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

from google.colab import drive
import os

# Mount Google Drive
drive.mount('/content/drive')

Mounted at /content/drive


### For later


In [3]:
# --- Helper function for SigLIP Loss ---
# Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/siglip/modeling_siglip.py
def siglip_loss(
    image_features: torch.Tensor,
    text_features: torch.Tensor,
    logit_scale: nn.Parameter, # Learned temperature parameter
    logit_bias: Optional[nn.Parameter] = None, # Optional learned bias
) -> torch.Tensor:
    """
    Computes the SigLIP loss.

    Args:
        image_features: Tensor of shape (batch_size, embed_dim)
        text_features: Tensor of shape (batch_size, embed_dim)
        logit_scale: A learnable parameter (equivalent to 1/temperature)
        logit_bias: An optional learnable parameter

    Returns:
        The SigLIP loss.
    """
    # Normalize features
    image_features = F.normalize(image_features, dim=-1)
    text_features = F.normalize(text_features, dim=-1)

    # Calculate pairwise similarity logits
    logits_per_image = logit_scale.exp() * torch.matmul(image_features, text_features.t())
    if logit_bias is not None:
        logits_per_image += logit_bias

    logits_per_text = logits_per_image.t()

    # Create labels (positive pairs are on the diagonal)
    labels = torch.diag(torch.ones(image_features.size(0), device=image_features.device, dtype=torch.long))

    # Binary cross-entropy with logits loss
    # Convert labels to float for BCE loss
    # Positive labels: 1.0 on the diagonal, 0.0 elsewhere
    # Negative labels: 0.0 on the diagonal, 1.0 elsewhere (implicitly handled by BCE)
    positive_labels = labels.float()
    negative_labels = 1.0 - positive_labels # Not strictly needed for BCEWithLogitsLoss

    # Calculate loss for image-to-text and text-to-image
    # Apply sigmoid and compute binary cross entropy: -[p*log(sigmoid(z)) + (1-p)*log(1-sigmoid(z))]
    # where p is the label (0 or 1) and z is the logit.
    # BCEWithLogitsLoss combines sigmoid and BCE for numerical stability.
    loss_img = F.binary_cross_entropy_with_logits(
        logits_per_image, positive_labels, reduction="mean"
    )
    loss_txt = F.binary_cross_entropy_with_logits(
        logits_per_text, positive_labels, reduction="mean"
    )

    # Average the two losses
    loss = (loss_img + loss_txt) / 2.0
    return loss


In [4]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import os
from PIL import Image # Import PIL Image

# Assume VisionLanguageModel class from previous response is defined here or imported

# --- CIFAR Loading Function ---
def load_cifar10(root_dir: str, train: bool = False):
    """Downloads/Loads CIFAR-10 dataset."""
    print(f"Loading CIFAR-10 {'train' if train else 'test'} dataset...")
    # Basic transform to get PIL Images, processor will handle final transforms
    cifar_transform = transforms.Compose([
        # transforms.ToTensor() # Remove ToTensor, keep as PIL for processor
    ])
    try:
        dataset = torchvision.datasets.CIFAR10(
            root=root_dir,
            train=train,
            download=True,
            transform=cifar_transform # Keep as PIL
        )
        print("CIFAR-10 loaded successfully.")
        return dataset
    except Exception as e:
        print(f"Error loading CIFAR-10: {e}")
        raise

# --- Custom Dataset ---
class CustomImageTextDataset(Dataset):
    def __init__(self, metadata_file: str, cifar_dataset: torchvision.datasets.CIFAR10):
        """
        Args:
            metadata_file (str): Path to the CSV file with 'image_index' and 'generated_text'.
            cifar_dataset (Dataset): The loaded CIFAR-10 dataset (returning PIL images).
        """
        print(f"Loading custom metadata from: {metadata_file}")
        self.metadata_df = pd.read_csv(metadata_file)
        # Optional: Apply any filtering if needed, e.g., your previous iloc[17:]
        # self.metadata_df = self.metadata_df.iloc[17:].reset_index(drop=True)
        print(f"Loaded metadata for {len(self.metadata_df)} samples initially.")

        self.cifar_dataset = cifar_dataset
        self.valid_indices = [] # Store indices of successfully validated items

        # Validate entries during initialization
        for idx in range(len(self.metadata_df)):
            try:
                row = self.metadata_df.iloc[idx]
                image_index = int(row['image_index'])
                # Check if image index is valid for the loaded CIFAR set
                if 0 <= image_index < len(self.cifar_dataset):
                    _ = self.cifar_dataset[image_index] # Try accessing the image
                    if isinstance(row['generated_text'], str) and row['generated_text'].strip():
                         self.valid_indices.append(idx)
                    # else:
                    #     print(f"Warning: Skipping metadata row {idx} due to invalid text.")
                # else:
                #      print(f"Warning: Skipping metadata row {idx} due to invalid image_index: {image_index}")
            except Exception as e:
                print(f"Warning: Error validating metadata row {idx}: {e}. Skipping.")

        if not self.valid_indices:
             raise ValueError("No valid samples found after validation.")
        print(f"Found {len(self.valid_indices)} valid samples after validation.")


    def __len__(self):
        # Length is the number of *valid* samples
        return len(self.valid_indices)

    def __getitem__(self, idx):
        # Map the input index to the corresponding index in the *original* dataframe
        metadata_idx = self.valid_indices[idx]
        item_info = self.metadata_df.iloc[metadata_idx]

        image_index = int(item_info['image_index'])
        text = str(item_info['generated_text']) # Ensure text is string

        # Get the PIL image from CIFAR dataset
        image_pil, _ = self.cifar_dataset[image_index] # Assuming CIFAR returns (PIL Image, label)

        # Return raw data: PIL Image and text string
        return {'image': image_pil, 'text': text}

# --- Collate Function ---
def create_collate_fn(image_processor, tokenizer, max_length=512):
    """Creates a collate function for batching images and text."""
    def collate_fn(batch):
        # Filter out None items potentially returned by dataset (though validation should prevent this)
        batch = [item for item in batch if item is not None]
        if not batch:
            return None

        images = [item['image'] for item in batch]
        texts = [item['text'] for item in batch]

        try:
            # Process images: applies transforms, normalization, converts to Tensor
            image_inputs = image_processor(images=images, return_tensors="pt")

            # Process texts: tokenizes, pads/truncates, creates attention mask
            text_inputs = tokenizer(
                text=texts,
                return_tensors="pt",
                padding="max_length", # Pad to max_length
                truncation=True,
                max_length=max_length
            )

            return {
                "pixel_values": image_inputs['pixel_values'],
                "input_ids": text_inputs['input_ids'],
                "attention_mask": text_inputs['attention_mask']
            }
        except Exception as e:
            print(f"Error during collation: {e}")
            # Depending on severity, might want to return None or raise
            return None # Skip batch if collation fails

    return collate_fn

# --- DataLoader Creation Function ---
def create_dataloader(dataset, batch_size, image_processor, tokenizer, max_text_length=512, num_workers=4):
     collate_fn = create_collate_fn(image_processor, tokenizer, max_length=max_text_length)
     return DataLoader(
         dataset,
         batch_size=batch_size,
         shuffle=True,
         collate_fn=collate_fn,
         num_workers=num_workers, # Use multiple workers for faster loading
         pin_memory=True if torch.cuda.is_available() else False # Improves GPU transfer speed
     )

In [None]:
import torch
import torch.optim as optim
from transformers import AutoImageProcessor, AutoTokenizer, get_cosine_schedule_with_warmup
from tqdm.auto import tqdm # Use auto version for notebook/console compatibility
import os
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
    AutoImageProcessor,
    AutoTokenizer,
    AutoModelForCausalLM,
    SiglipVisionModel, # Import the specific vision model class
    SiglipConfig,      # Helpful for getting config details
)

from torch.cuda.amp import autocast, GradScaler # Import


from typing import Optional, Dict, Tuple
import os # For securely handling tokens



# Assume VisionLanguageModel class is defined above or imported
# Assume dataset/dataloader functions are defined above

# --- Configuration ---
CONFIG = {
    "model": {
        "image_encoder_name": "google/siglip-base-patch16-224",
        "text_model_name": "Qwen/Qwen2.5-0.5B", # ADJUST IF NEEDED
        "max_position_embeddings": 2048, # Match Gemma if possible, or use desired max sequence length
        "use_learned_siglip_params": True,
    },
    "data": {
        "cifar_root_dir": "./cifar_data", # Directory to download/store CIFAR
        "metadata_file": "/content/drive/MyDrive/EAG/a23-vlm/data/cifar10_smolvlm2_results.csv", # Path to your CSV
        "cifar_train_set": False, # Use False for CIFAR test set, True for train set
        "max_text_length": 64, # Max sequence length for tokenizer padding/truncation
    },
    "training": {
        "device": "cuda" if torch.cuda.is_available() else "cpu",
        "batch_size": 16, # Adjust based on GPU memory
        "epochs": 50, # Number of training epochs
        "learning_rate": 1e-5, # Starting learning rate
        "weight_decay": 0.01, # Weight decay for optimizer
        "warmup_steps": 100, # Steps for learning rate warmup
        "gradient_accumulation_steps": 32, # Use > 1 for larger effective batch size if memory constrained
        "max_grad_norm": 1.0, # Gradient clipping norm
        "output_dir": "/content/drive/MyDrive/EAG/a23-vlm/model/a23-vlm", # Directory to save checkpoints
        "log_steps": 50, # Log loss every N steps
        "save_steps": 500, # Save checkpoint every N steps (optional)
    },
    "hf_token": '', # Load token from environment variable
}

# --- Main Training Function ---
def train():
    print("Starting Training Process...")
    print(f"Using device: {CONFIG['training']['device']}")
    os.makedirs(CONFIG['training']['output_dir'], exist_ok=True)

    # 1. Load Data Components
    print("\n--- Loading Data ---")
    cifar_dataset = load_cifar10(CONFIG['data']['cifar_root_dir'], train=CONFIG['data']['cifar_train_set'])
    custom_dataset = CustomImageTextDataset(
        metadata_file=CONFIG['data']['metadata_file'],
        cifar_dataset=cifar_dataset
    )

    # 2. Load Processor and Tokenizer
    print("\n--- Loading Processor & Tokenizer ---")
    image_processor = AutoImageProcessor.from_pretrained(
        CONFIG['model']['image_encoder_name'], token=CONFIG['hf_token']
    )
    tokenizer = AutoTokenizer.from_pretrained(
        CONFIG['model']['text_model_name'], token=CONFIG['hf_token']
    )
    # Add padding token if missing (common for Gemma)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        # Important: Resize model embeddings if adding tokens AFTER model init
        # Do this *before* initializing the model that uses the tokenizer if needed

    # 3. Create DataLoader
    print("\n--- Creating DataLoader ---")
    train_dataloader = create_dataloader(
        custom_dataset,
        batch_size=CONFIG['training']['batch_size'],
        image_processor=image_processor,
        tokenizer=tokenizer,
        max_text_length=CONFIG['data']['max_text_length']
    )

    # 4. Initialize Model
    print("\n--- Initializing Model ---")
    model = VisionLanguageModel(
        image_encoder_name=CONFIG['model']['image_encoder_name'],
        text_model_name=CONFIG['model']['text_model_name'],
        max_position_embeddings=CONFIG['model']['max_position_embeddings'],
        use_learned_siglip_params=CONFIG['model']['use_learned_siglip_params']
    )
    model.to(CONFIG['training']['device'])

    # Resize embeddings *if* pad token was added *after* text model was loaded inside VisionLanguageModel
    # It's generally safer to ensure tokenizer has pad token *before* model init
    # text_model_vocab_size = model.text_model.get_input_embeddings().weight.size(0)
    # if len(tokenizer) > text_model_vocab_size:
    #    model.text_model.resize_token_embeddings(len(tokenizer))
    #    print(f"Resized text model token embeddings to {len(tokenizer)}")

    # 5. Initialize Optimizer and Scheduler
    print("\n--- Setting up Optimizer & Scheduler ---")
    # Filter parameters that require gradients
    trainable_params = [p for p in model.parameters() if p.requires_grad]
    print(f"Optimizing {len(trainable_params)} trainable parameters.")
    optimizer = optim.AdamW(
        trainable_params,
        lr=CONFIG['training']['learning_rate'],
        weight_decay=CONFIG['training']['weight_decay']
    )

    num_training_steps = len(train_dataloader) // CONFIG['training']['gradient_accumulation_steps'] * CONFIG['training']['epochs']
    lr_scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=CONFIG['training']['warmup_steps'],
        num_training_steps=num_training_steps
    )

    # 6. Training Loop
    print("\n--- Starting Training ---")
    total_steps = 0
    model.train() # Set model to training mode

    from torch.cuda.amp import autocast, GradScaler # Import

    # ... inside the train() function ...

    # Initialize GradScaler
    scaler = GradScaler(enabled=(CONFIG['training']['device'] == 'cuda'))

    for epoch in range(CONFIG['training']['epochs']):
        print(f"\nEpoch {epoch+1}/{CONFIG['training']['epochs']}")
        epoch_loss = 0.0
        progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}", leave=False)

        for step, batch in enumerate(progress_bar):
            if batch is None: # Skip batch if collation failed
                print(f"Warning: Skipping None batch at step {step}")
                continue

            # Move batch to device
            try:
                 batch = {k: v.to(CONFIG['training']['device']) for k, v in batch.items()}
            except AttributeError:
                 print(f"Warning: Could not move batch to device at step {step}. Skipping.")
                 continue # Skip if batch items aren't tensors

            # Forward pass
            # Use torch.cuda.amp for mixed precision (optional, requires accelerate)
            # with torch.cuda.amp.autocast(enabled=CONFIG['training']['device'] == 'cuda'):
            with autocast(enabled=(CONFIG['training']['device'] == 'cuda')):

                outputs = model(**batch, return_loss=True)
                if outputs is None or 'loss' not in outputs or outputs['loss'] is None:
                    print(f"Warning: Invalid output or loss not found at step {step}. Skipping.")
                    continue

                loss = outputs['loss']

                # Scale loss for gradient accumulation
                loss = loss / CONFIG['training']['gradient_accumulation_steps']

            # Backward pass
            scaler.scale(loss).backward() # If using mixed precision scaler
            # loss.backward()

            # Accumulate gradients
            if (step + 1) % CONFIG['training']['gradient_accumulation_steps'] == 0:
                # Gradient Clipping
                if CONFIG['training']['max_grad_norm'] > 0:
                    scaler.unscale_(optimizer) # If using mixed precision scaler
                    torch.nn.utils.clip_grad_norm_(
                        trainable_params, CONFIG['training']['max_grad_norm']
                    )

                # Optimizer step
                scaler.step(optimizer) # If using mixed precision scaler
                # optimizer.step()

                scaler.update() # If using mixed precision scaler

                # Scheduler step
                lr_scheduler.step()

                # Zero gradients
                optimizer.zero_grad()

                total_steps += 1

                # Logging
                if total_steps % CONFIG['training']['log_steps'] == 0:
                    current_loss = loss.item() * CONFIG['training']['gradient_accumulation_steps'] # Unscale loss for logging
                    current_lr = lr_scheduler.get_last_lr()[0]
                    print(f"  Step: {total_steps}, Loss: {current_loss:.4f}, LR: {current_lr:.2e}")

                # Save checkpoint (optional)
                if CONFIG['training']['save_steps'] > 0 and total_steps % CONFIG['training']['save_steps'] == 0:
                     ckpt_path = os.path.join(CONFIG['training']['output_dir'], f"checkpoint-{total_steps}")
                     # Save only trainable parts or the whole model state_dict
                     # model.save_pretrained(ckpt_path) # Preferred if using HF Trainer compatible structure
                     torch.save(model.state_dict(), os.path.join(ckpt_path, "pytorch_model.bin"))
                     tokenizer.save_pretrained(ckpt_path)
                     image_processor.save_pretrained(ckpt_path)
                     print(f"  Checkpoint saved to {ckpt_path}")


            # Update progress bar description with current loss
            progress_bar.set_postfix(loss=f"{loss.item()*CONFIG['training']['gradient_accumulation_steps']:.4f}") # Show unscaled loss
            epoch_loss += loss.item() * CONFIG['training']['gradient_accumulation_steps']

        avg_epoch_loss = epoch_loss / len(train_dataloader)
        print(f"End of Epoch {epoch+1}, Average Loss: {avg_epoch_loss:.4f}")

    # 7. Save Final Model
    print("\n--- Training Finished ---")
    final_path = os.path.join(CONFIG['training']['output_dir'], "final_model")
    # model.save_pretrained(final_path) # Preferred HF method
    os.makedirs(final_path, exist_ok=True)
    torch.save(model.state_dict(), os.path.join(final_path, "pytorch_model.bin"))
    tokenizer.save_pretrained(final_path)
    image_processor.save_pretrained(final_path)
    print(f"Final model saved to {final_path}")


In [None]:
train()


Starting Training Process...
Using device: cuda

--- Loading Data ---
Loading CIFAR-10 test dataset...


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


CIFAR-10 loaded successfully.
Loading custom metadata from: /content/cifar_data/cifar10_smolvlm2_results.csv
Loaded metadata for 65 samples initially.
Found 65 valid samples after validation.

--- Loading Processor & Tokenizer ---

--- Creating DataLoader ---

--- Initializing Model ---


Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.



--- Setting up Optimizer & Scheduler ---
Optimizing 3 trainable parameters.

--- Starting Training ---

Epoch 1/50


  scaler = GradScaler(enabled=(CONFIG['training']['device'] == 'cuda'))


Epoch 1:   0%|          | 0/5 [00:00<?, ?it/s]

  with autocast(enabled=(CONFIG['training']['device'] == 'cuda')):


End of Epoch 1, Average Loss: 0.6083

Epoch 2/50


Epoch 2:   0%|          | 0/5 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e6614392200>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
     Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e6614392200>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
     self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive(): 
        ^ ^^^^^ ^^^ ^^^^^
^^  File "

End of Epoch 2, Average Loss: 0.6907

Epoch 3/50


Epoch 3:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 3, Average Loss: 0.5978

Epoch 4/50


Epoch 4:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 4, Average Loss: 0.6520

Epoch 5/50


Epoch 5:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 5, Average Loss: 0.7615

Epoch 6/50


Epoch 6:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 6, Average Loss: 0.6641

Epoch 7/50


Epoch 7:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 7, Average Loss: 0.6336

Epoch 8/50


Epoch 8:   0%|          | 0/5 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e6614392200>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e6614392200>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 16

End of Epoch 8, Average Loss: 0.6168

Epoch 9/50


Epoch 9:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 9, Average Loss: 0.6827

Epoch 10/50


Epoch 10:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 10, Average Loss: 0.6936

Epoch 11/50


Epoch 11:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 11, Average Loss: 0.6366

Epoch 12/50


Epoch 12:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 12, Average Loss: 0.7450

Epoch 13/50


Epoch 13:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 13, Average Loss: 0.7672

Epoch 14/50


Epoch 14:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 14, Average Loss: 0.6526

Epoch 15/50


Epoch 15:   0%|          | 0/5 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e6614392200>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e6614392200>    if w.is_alive():

Traceback (most recent call last):
   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
       self._shutdown_workers()  
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
     ^if w.is_alive():^
^^^^ ^^ ^ ^^ ^
   ^  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
^^^^^^^    
^^^^assert self._parent_pid == os.getpid(), 'can only test a child process'  File "/usr/lib/python3

End of Epoch 15, Average Loss: 0.5858

Epoch 16/50


Epoch 16:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 16, Average Loss: 0.6745

Epoch 17/50


Epoch 17:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 17, Average Loss: 0.6986

Epoch 18/50


Epoch 18:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 18, Average Loss: 0.6870

Epoch 19/50


Epoch 19:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 19, Average Loss: 0.6539

Epoch 20/50


Epoch 20:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 20, Average Loss: 0.6985

Epoch 21/50


Epoch 21:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 21, Average Loss: 0.6411

Epoch 22/50


Epoch 22:   0%|          | 0/5 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e6614392200>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e6614392200>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 16

End of Epoch 22, Average Loss: 0.6667

Epoch 23/50


Epoch 23:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 23, Average Loss: 0.6698

Epoch 24/50


Epoch 24:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 24, Average Loss: 0.6521

Epoch 25/50


Epoch 25:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 25, Average Loss: 0.6835

Epoch 26/50


Epoch 26:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 26, Average Loss: 0.6005

Epoch 27/50


Epoch 27:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 27, Average Loss: 0.6568

Epoch 28/50


Epoch 28:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 28, Average Loss: 0.6523

Epoch 29/50


Epoch 29:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 29, Average Loss: 0.6357

Epoch 30/50


Epoch 30:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 30, Average Loss: 0.7004

Epoch 31/50


Epoch 31:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 31, Average Loss: 0.6522

Epoch 32/50


Epoch 32:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 32, Average Loss: 0.5980

Epoch 33/50


Epoch 33:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 33, Average Loss: 0.6664

Epoch 34/50


Epoch 34:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 34, Average Loss: 0.6244

Epoch 35/50


Epoch 35:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 35, Average Loss: 0.6239

Epoch 36/50


Epoch 36:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 36, Average Loss: 0.6874

Epoch 37/50


Epoch 37:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 37, Average Loss: 0.6010

Epoch 38/50


Epoch 38:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 38, Average Loss: 0.6868

Epoch 39/50


Epoch 39:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 39, Average Loss: 0.7513

Epoch 40/50


Epoch 40:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 40, Average Loss: 0.6983

Epoch 41/50


Epoch 41:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 41, Average Loss: 0.6680

Epoch 42/50


Epoch 42:   0%|          | 0/5 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e6614392200>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive

    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e6614392200>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 16

End of Epoch 42, Average Loss: 0.6478

Epoch 43/50


Epoch 43:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 43, Average Loss: 0.6244

Epoch 44/50


Epoch 44:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 44, Average Loss: 0.7125

Epoch 45/50


Epoch 45:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 45, Average Loss: 0.7692

Epoch 46/50


Epoch 46:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 46, Average Loss: 0.6088

Epoch 47/50


Epoch 47:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 47, Average Loss: 0.6705

Epoch 48/50


Epoch 48:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 48, Average Loss: 0.6853

Epoch 49/50


Epoch 49:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 49, Average Loss: 0.6173

Epoch 50/50


Epoch 50:   0%|          | 0/5 [00:00<?, ?it/s]

End of Epoch 50, Average Loss: 0.6809

--- Training Finished ---
Final model saved to ./training_output/final_model


In [None]:
from numba import cuda
device = cuda.get_current_device()
device.reset()

import gc
gc.collect()

70

In [None]:

# --- Updated Model ---
class VisionLanguageModel(nn.Module):
    def __init__(
        self,
        image_encoder_name: str = "google/siglip-base-patch16-224",
        text_model_name: str = "Qwen/Qwen2.5-0.5B", # NOTE: Use a valid Gemma model name if 'gemma-3-1b-it' doesn't exist
        max_position_embeddings: int = 2048, # Match text model's max positions if possible
        dropout: float = 0.1, # Currently unused, consider adding Dropout layers if needed
        use_learned_siglip_params: bool = False # Option to load SigLIP's learned temp/bias
    ):
        super().__init__()

        # --- Securely get token ---
        # Best practice: Use environment variables or notebook secrets
        # access_token = os.getenv("HF_TOKEN")
        # Or login via CLI: huggingface-cli login
        access_token =  ''# Set to your token if needed, but avoid hardcoding

        # --- Image Encoder (SigLIP Vision Backbone) ---
        self.image_processor = AutoImageProcessor.from_pretrained(image_encoder_name, token=access_token)
        # Load ONLY the vision model
        self.image_encoder = SiglipVisionModel.from_pretrained(image_encoder_name, token=access_token)
        image_dim = self.image_encoder.config.hidden_size

        # --- Text Model (Gemma) ---
        self.tokenizer = AutoTokenizer.from_pretrained(text_model_name, token=access_token)
        self.text_model = AutoModelForCausalLM.from_pretrained(text_model_name, token=access_token)
        text_dim = self.text_model.config.hidden_size

        # --- Model Configuration ---
        # Use text model's max sequence length if not specified otherwise
        self.max_position_embeddings = max_position_embeddings # Or self.text_model.config.max_position_embeddings

        # --- Freeze Parameters (Optional but common) ---
        # Freeze the text model's main body
        for param in self.text_model.parameters():
            param.requires_grad = False
        # Freeze the image encoder (optional, depends on training strategy)
        for param in self.image_encoder.parameters():
            param.requires_grad = False

        # --- Projection Layer ---
        # Project image features to match text model's embedding dimension
        self.image_projection = nn.Linear(image_dim, text_dim)
        # Important: Unfreeze the projection layer so it can be trained
        for param in self.image_projection.parameters():
             param.requires_grad = True

        # --- SigLIP Loss Parameters ---
        # Initialize logit_scale (learned temperature) and optional logit_bias
        # Option 1: Initialize from scratch
        # self.logit_scale = nn.Parameter(torch.tensor(1.0 / 0.07)) # Common starting point (inverse temp)
        # self.logit_bias = nn.Parameter(torch.tensor(0.0)) # Optional bias
        # Option 2: Try loading from the original SigLIP model if available (recommended)
        # if use_learned_siglip_params:
        #      try:
        #          # Temporarily load the full SigLIP model to get its parameters
        #          from transformers import SiglipModel
        #          full_siglip_model = SiglipModel.from_pretrained(image_encoder_name, token=access_token)
        #          self.logit_scale = nn.Parameter(full_siglip_model.logit_scale.detach().clone())
        #          # SigLIP base doesn't always have bias, handle potential absence
        #          if hasattr(full_siglip_model, 'logit_bias') and full_siglip_model.logit_bias is not None:
        #               self.logit_bias = nn.Parameter(full_siglip_model.logit_bias.detach().clone())
        #          else:
        #               self.logit_bias = None # Or nn.Parameter(torch.zeros(())) if you want a trainable zero bias
        #          del full_siglip_model # Free memory
        #      except Exception as e:
        #          print(f"Warning: Could not load learned SigLIP params: {e}. Initializing from scratch.")
        #          self.logit_scale = nn.Parameter(torch.log(torch.tensor(1 / 0.07))) # Often stored as log(1/T)
        #          self.logit_bias = None # Or nn.Parameter(torch.zeros(()))
        # else:
        self.logit_scale = nn.Parameter(torch.log(torch.tensor(1 / 0.07)))
        self.logit_bias = None # Or nn.Parameter(torch.zeros(()))


        # --- Output Projection (Optional) ---
        # You might not need this if using pre-decoder features for loss
        # self.output_projection = nn.Linear(text_dim, text_dim) # Example projection dim

        # --- REMOVED: Custom Position Embeddings ---
        # Removed self.position_embeddings and self.register_buffer("position_ids")
        # We will rely on the text model's internal position handling

    def encode_image(self, pixel_values: torch.FloatTensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Encodes images using the SigLIP vision backbone and projects features.
        Returns both sequence output and pooled output.
        """
        # Pass through image encoder
        image_outputs = self.image_encoder(pixel_values=pixel_values, output_hidden_states=False) # No need for hidden states here

        # Option 1: Use the CLS token embedding (if available and desired)
        # image_features_seq = image_outputs.last_hidden_state
        # pooled_output = image_outputs.pooler_output # Usually the CLS token processed by a layer norm + optional dense

        # Option 2: Use last_hidden_state (sequence of patch embeddings + CLS token)
        image_features_seq = image_outputs.last_hidden_state

        # It's often better to use the features *before* the final pooler for projection
        # if contrastive loss relies on image features before text decoder interaction.
        # Let's use the CLS token from the sequence output as the pooled representation for loss
        image_cls_feature = image_features_seq[:, 0, :] # Assumes CLS token is first

        # Project the sequence of image features for input to the text model
        projected_image_features_seq = self.image_projection(image_features_seq)

        # Project the CLS feature separately for the contrastive loss
        # Note: We might want a separate projection head for the contrastive loss
        # vs. the one used to feed into the LLM. For simplicity, using the same one here.
        projected_image_cls = self.image_projection(image_cls_feature)

        return projected_image_features_seq, projected_image_cls

    def encode_text(self, input_ids: torch.LongTensor) -> torch.Tensor:
        """ Gets text embeddings directly from the text model's embedding layer. """
        text_embeddings = self.text_model.model.embed_tokens(input_ids)
        return text_embeddings

    def get_text_features_for_loss(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor) -> torch.Tensor:
        """
        Gets text features for contrastive loss.
        Option 1: Pass text through the text model independently (like original SigLIP/CLIP).
        Option 2: Use output from the combined forward pass (more complex).
        Let's go with Option 1 for standard contrastive setup.
        """
        # We need the *pooled* output of the text model for comparison
        # Getting this from AutoModelForCausalLM requires careful handling.
        # A simpler approach for contrastive loss is often to have a separate
        # text tower or use the embeddings directly if the original model did.
        # SigLIP uses a separate text transformer. Since we're using Gemma,
        # let's use the embeddings of the [EOS] token or mean pool the embeddings
        # before they go into the main transformer blocks.

        # Simplest: Use the embedding of the first token (less common for causal LMs)
        # text_embeddings = self.encode_text(input_ids)
        # text_cls = text_embeddings[:, 0, :]

        # Better for Causal LMs: Use the embedding of the *last* non-padding token
        text_embeddings = self.encode_text(input_ids)
        sequence_lengths = attention_mask.sum(dim=1) - 1 # Get index of last token
        batch_indices = torch.arange(input_ids.size(0), device=input_ids.device)
        last_token_embeddings = text_embeddings[batch_indices, sequence_lengths]

        # Apply a text projection if needed (e.g., if image_projection maps to a specific space)
        # If image_projection's target dim is text_dim, no extra projection needed here.
        text_features_for_loss = last_token_embeddings

        # Alternative: Mean pooling over input embeddings
        # input_mask_expanded = attention_mask.unsqueeze(-1).expand(text_embeddings.size()).float()
        # sum_embeddings = torch.sum(text_embeddings * input_mask_expanded, 1)
        # sum_mask = input_mask_expanded.sum(1)
        # sum_mask = torch.clamp(sum_mask, min=1e-9)
        # mean_pooled_embeddings = sum_embeddings / sum_mask
        # text_features_for_loss = mean_pooled_embeddings


        return text_features_for_loss


    def forward(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.LongTensor] = None,
        return_loss: bool = True,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Dict[str, torch.Tensor]:

        return_dict = return_dict if return_dict is not None else self.text_model.config.use_return_dict
        output_attentions = output_attentions if output_attentions is not None else self.text_model.config.output_attentions
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.text_model.config.output_hidden_states

        if pixel_values is None or input_ids is None:
            raise ValueError("Both pixel_values and input_ids must be provided")

        batch_size = pixel_values.size(0)
        device = pixel_values.device

        # 1. Encode Image (get sequence and pooled/CLS features)
        # image_features_seq: For input into Gemma | image_cls_for_loss: For contrastive loss
        image_features_seq, image_cls_for_loss = self.encode_image(pixel_values)
        image_seq_len = image_features_seq.size(1)

        # 2. Encode Text (get input embeddings)
        text_embeddings = self.encode_text(input_ids)
        text_seq_len = text_embeddings.size(1)

        # 3. Combine Embeddings
        combined_embeds = torch.cat([image_features_seq, text_embeddings], dim=1)
        combined_seq_len = combined_embeds.size(1)

        # 4. Create Combined Attention Mask
        image_attention_mask = torch.ones(
            (batch_size, image_seq_len),
            dtype=torch.long, # Match typical mask dtype
            device=device
        )
        # Ensure text attention mask is also LongTensor if it isn't
        text_attention_mask = attention_mask.to(torch.long)
        combined_attention_mask = torch.cat([image_attention_mask, text_attention_mask], dim=1)


        # 5. Prepare inputs for Gemma's Transformer blocks
        # The Gemma model's forward pass will internally handle position embeddings
        # based on the shape of inputs_embeds and the attention_mask.
        # We pass the combined embeddings directly as inputs_embeds.

        # No need to manually add position embeddings here.

        # --- Pass through Gemma's Transformer Blocks ---
        # Note: We pass combined_attention_mask. The model's _prepare_decoder_attention_mask
        # (or similar internal mechanism) will convert this into the appropriate causal mask.
        outputs = self.text_model.model(
            inputs_embeds=combined_embeds,
            attention_mask=combined_attention_mask, # Provide the combined mask
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            # position_ids=None # Let the model handle position_ids generation
        )

        # last_hidden_state contains the output features for the *entire* sequence (image + text)
        last_hidden_state = outputs.last_hidden_state

        # Initialize loss to None
        loss = None

        if return_loss:
            # --- Calculate SigLIP Loss ---
            # Use the *projected CLS image feature* obtained *before* the text decoder
            image_embeds_for_loss = image_cls_for_loss

            # Get text features suitable for contrastive loss (e.g., last token embedding)
            text_embeds_for_loss = self.get_text_features_for_loss(input_ids, attention_mask)

            # Calculate the loss using the dedicated function
            loss = siglip_loss(
                image_features=image_embeds_for_loss,
                text_features=text_embeds_for_loss,
                logit_scale=self.logit_scale,
                logit_bias=self.logit_bias,
            )

        # --- Prepare Output Dictionary ---
        # Standard Hugging Face model output format is often preferred
        output_dict = {
            "loss": loss,
            "last_hidden_state": last_hidden_state,
            # Add other outputs if needed (e.g., logits for generation)
            "hidden_states": outputs.hidden_states,
            "attentions": outputs.attentions,
            # Include embeddings used for loss calculation for potential analysis
            "image_embeds_for_loss": image_cls_for_loss if return_loss else None,
            "text_embeds_for_loss": text_embeds_for_loss if return_loss else None,
        }

        # If generation is the goal, calculate logits using the text model's LM head
        # Apply LM head only to the text part of the sequence
        text_output_states = last_hidden_state[:, image_seq_len:]
        if hasattr(self.text_model, 'lm_head'):
             logits = self.text_model.lm_head(text_output_states)
             output_dict["logits"] = logits


        # Filter out None values if not returning a dict object
        if not return_dict:
            return tuple(v for v in output_dict.values() if v is not None)

        # Remove keys with None values before returning
        return {k: v for k, v in output_dict.items() if v is not None}

In [None]:
from google.colab import drive
import os

# Mount Google Drive
drive.mount('/content/drive')

# Define the path where you want to save the model in your Drive
model_save_path = '/content/drive/MyDrive/my_model'

# Save the model (assuming you have a Keras model named 'model')
model.save(model_save_path)

print(f'Model saved to {model_save_path}')


In [12]:
# prompt: using the above code, generate the code for inference, assume the model is still loaded into the memory, the code should take in the image and generate the text related to that image, modify the code to check if model is loaded into memory, else load the model

import torch
import gc
from google.colab import drive
import shutil

# Check if model is loaded
try:
  model
  print("Model already loaded.")
except NameError:
  print("Model not loaded. Loading now...")
  # Mount Google Drive
  drive.mount('/content/drive')

  # Define the path where you saved the model in your Drive
  model_load_path = '/content/drive/MyDrive/EAG/a23-vlm/final_model'

  # Load the model
  model = VisionLanguageModel()
  model.load_state_dict(torch.load(os.path.join(model_load_path, 'model_weights.pth'), map_location=torch.device('cuda')))

  model.to("cuda")
  print("Model loaded successfully.")


foo = ""

def infer(image_path):
    image = Image.open(image_path)
    inputs = model.image_processor(images=image, return_tensors="pt").to("cuda")

    with torch.no_grad():
      generated_ids = model.generate(**inputs)
    generated_text = model.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    print(generated_text)
    return generated_text




Model already loaded.


In [14]:
load_dataset()

NameError: name 'load_dataset' is not defined

In [17]:
datasets[1000]

(<PIL.Image.Image image mode=RGB size=32x32>, 5)

In [5]:
import torch
import os
import argparse
from PIL import Image
from transformers import AutoImageProcessor, AutoTokenizer
from tqdm import tqdm # Optional, if processing multiple images

# Assume VisionLanguageModel class is defined here or imported from your model file
# Make sure the VisionLanguageModel class definition is available in the scope
# from model import VisionLanguageModel # Example if saved in model.py

# --- Paste or Import the VisionLanguageModel Class Definition Here ---
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
    AutoImageProcessor,
    AutoTokenizer,
    AutoModelForCausalLM,
    SiglipVisionModel, # Import the specific vision model class
    SiglipConfig,      # Helpful for getting config details
)
from typing import Optional, Dict, Tuple
import os # For securely handling tokens

# --- Helper function for SigLIP Loss (Not needed for inference, but part of class) ---
def siglip_loss(image_features: torch.Tensor, text_features: torch.Tensor, logit_scale: nn.Parameter, logit_bias: Optional[nn.Parameter] = None) -> torch.Tensor:
    # ... (loss implementation - can be kept or removed for pure inference code) ...
    # For brevity in inference, we can potentially remove the loss part if class allows it
    # Or just keep it as it doesn't harm inference.
    image_features = F.normalize(image_features, dim=-1)
    text_features = F.normalize(text_features, dim=-1)
    logits_per_image = logit_scale.exp() * torch.matmul(image_features, text_features.t())
    if logit_bias is not None:
        logits_per_image += logit_bias
    logits_per_text = logits_per_image.t()
    labels = torch.diag(torch.ones(image_features.size(0), device=image_features.device, dtype=torch.long))
    positive_labels = labels.float()
    loss_img = F.binary_cross_entropy_with_logits(logits_per_image, positive_labels, reduction="mean")
    loss_txt = F.binary_cross_entropy_with_logits(logits_per_text, positive_labels, reduction="mean")
    loss = (loss_img + loss_txt) / 2.0
    return loss

# --- Updated Model ---
class VisionLanguageModel(nn.Module):
    def __init__(
        self,
        image_encoder_name: str = "google/siglip-base-patch16-224",
        text_model_name: str = "Qwen/Qwen2.5-0.5B", # Ensure consistency
        max_position_embeddings: int = 2048,
        dropout: float = 0.1,
        use_learned_siglip_params: bool = True
    ):
        super().__init__()
        access_token = os.getenv("HF_TOKEN") # Use environment variable

        access_token =  ''

        self.image_processor_name = image_encoder_name # Store names for reference
        self.text_model_name = text_model_name

        # --- Image Encoder (SigLIP Vision Backbone) ---
        self.image_processor_instance = AutoImageProcessor.from_pretrained(image_encoder_name, token=access_token) # Store processor instance
        self.image_encoder = SiglipVisionModel.from_pretrained(image_encoder_name, token=access_token)
        image_dim = self.image_encoder.config.hidden_size

        # --- Text Model (Gemma) ---
        self.tokenizer_instance = AutoTokenizer.from_pretrained(text_model_name, token=access_token) # Store tokenizer instance
        self.text_model = AutoModelForCausalLM.from_pretrained(text_model_name, token=access_token)
        text_dim = self.text_model.config.hidden_size

        # Add pad token if missing (important for tokenizer instance)
        if self.tokenizer_instance.pad_token is None:
            self.tokenizer_instance.pad_token = self.tokenizer_instance.eos_token
            # Adjust model embedding size if needed (should be done *before* loading state_dict if vocab changed)
            # self.text_model.resize_token_embeddings(len(self.tokenizer_instance))


        # --- Model Configuration ---
        self.max_position_embeddings = max_position_embeddings

        # --- Projection Layer ---
        self.image_projection = nn.Linear(image_dim, text_dim)

        # --- SigLIP Loss Parameters (Loaded from checkpoint during inference) ---
        # Initialize them here, but they will be overwritten by load_state_dict
        if use_learned_siglip_params:
             try:
                 from transformers import SiglipModel
                 temp_siglip = SiglipModel.from_pretrained(self.image_processor_name, token=access_token)
                 # Ensure loaded value is treated as scalar parameter
                 self.logit_scale = nn.Parameter(temp_siglip.logit_scale.detach().clone().squeeze())
                 if hasattr(temp_siglip, 'logit_bias') and temp_siglip.logit_bias is not None:
                      # Assuming bias is also scalar if it exists
                      self.logit_bias = nn.Parameter(temp_siglip.logit_bias.detach().clone().squeeze())
                 else:
                      self.logit_bias = None
                 del temp_siglip
             except Exception as e:
                 print("Warning: Could not pre-load default SigLIP params. Initializing as scalar.")
                 # Initialize as SCALAR
                 self.logit_scale = nn.Parameter(torch.log(torch.tensor(1 / 0.07))) # No []
                 self.logit_bias = None # Or nn.Parameter(torch.zeros(())) if bias was trained as scalar
        else:
            # Initialize as SCALAR
            self.logit_scale = nn.Parameter(torch.log(torch.tensor(1 / 0.07))) # No []
            self.logit_bias = None


    def encode_image(self, pixel_values: torch.FloatTensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encodes images and returns projected sequence and CLS features."""
        image_outputs = self.image_encoder(pixel_values=pixel_values, output_hidden_states=False)
        image_features_seq = image_outputs.last_hidden_state
        image_cls_feature = image_features_seq[:, 0, :] # Assumes CLS token is first
        projected_image_features_seq = self.image_projection(image_features_seq)
        projected_image_cls = self.image_projection(image_cls_feature) # Not typically used in generation
        return projected_image_features_seq, projected_image_cls

    # encode_text, get_text_features_for_loss, forward can be kept or removed
    # for pure inference, as they are mainly for training.
    # Keeping them doesn't hurt unless they have large dependencies.
    # For clean inference, only encode_image and necessary components are needed.
    # We will call text_model.generate() directly.

# --- End of VisionLanguageModel Class Definition ---



def generate_caption(image_path: str, prompt: str, checkpoint_path: str, config: dict):
    """
    Generates a caption for a given image using the trained model,
    guided by an optional text prompt.

    Args:
        image_path (str): Path to the input image file.
        prompt (str): Text prompt to guide generation.
        checkpoint_path (str): Path to the directory containing the trained model weights.
        config (dict): Configuration dictionary.

    Returns:
        str: The generated caption (only the newly generated part).
    """
    device = config.get("device", "cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # --- Load Model Components ---
    print("Loading model components...")
    # (Error handling for loading components omitted for brevity, refer to previous version)
    model = VisionLanguageModel(
        image_encoder_name=config["model"]["image_encoder_name"],
        text_model_name=config["model"]["text_model_name"],
        max_position_embeddings=config["model"]["max_position_embeddings"],
        use_learned_siglip_params=config["model"]["use_learned_siglip_params"]
    )
    image_processor = model.image_processor_instance
    tokenizer = model.tokenizer_instance
    tokenizer.padding_side = 'left'
    if tokenizer.pad_token is None:
         tokenizer.pad_token = tokenizer.eos_token

    # --- Load Trained Weights ---
    print(f"Loading trained weights from: {checkpoint_path}")
    state_dict_path = os.path.join(checkpoint_path, "pytorch_model.bin")
    if not os.path.exists(state_dict_path):
        raise FileNotFoundError(f"Checkpoint file not found at {state_dict_path}")
    state_dict = torch.load(state_dict_path, map_location="cpu")
    # Handle 'module.' prefix if model was saved with DataParallel/DDP
    new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
    load_result = model.load_state_dict(new_state_dict, strict=False)
    # Print warnings for missing/unexpected keys
    if load_result.missing_keys: print("Warning: Missing keys:", load_result.missing_keys)
    if load_result.unexpected_keys: print("Warning: Unexpected keys:", load_result.unexpected_keys)

    model.eval()
    model.to(device)
    print("Model loaded successfully.")

    # --- Load and Process Image ---
    print(f"Loading image: {image_path}")
    try:
        # raw_image = Image.open(image_path).convert('RGB')
        inputs = image_processor(images=image_path, return_tensors="pt").to(device)
        pixel_values = inputs['pixel_values']
    except Exception as e:
        print(f"Error processing image {image_path}: {e}")
        return None

    # --- Encode Image Features ---
    print("Encoding image features...")
    with torch.no_grad():
        image_embeds, _ = model.encode_image(pixel_values)
        # image_embeds shape: (1, image_seq_len, text_hidden_dim) - assuming batch size 1

    # --- Process Text Prompt ---
    print(f"Processing prompt: '{prompt}'")
    # Tokenize prompt - IMPORTANT: add_special_tokens=False prevents BOS/EOS here
    # We rely on generate() to handle starting/ending tokens appropriately
    # Add a space if prompt is not empty, often helps models start generation
    if prompt and not prompt.endswith(" "):
         prompt += " "
    prompt_inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(device)
    with torch.no_grad():
        prompt_embeds = model.text_model.model.embed_tokens(prompt_inputs.input_ids)
        # prompt_embeds shape: (1, prompt_seq_len, text_hidden_dim)

    # --- Combine Embeddings and Masks ---
    combined_embeds = torch.cat([image_embeds, prompt_embeds], dim=1)
    # Attention mask: 1s for image tokens, use tokenizer's mask for prompt tokens
    image_mask = torch.ones(image_embeds.shape[:-1], dtype=torch.long, device=device)
    prompt_mask = prompt_inputs.attention_mask
    combined_mask = torch.cat([image_mask, prompt_mask], dim=1)

    # --- Define Generation Parameters ---
    generation_kwargs = config.get("generation", {})
    generation_kwargs.setdefault("pad_token_id", tokenizer.pad_token_id)
    generation_kwargs.setdefault("eos_token_id", tokenizer.eos_token_id)
    # If using sampling, ensure temperature/top_k/top_p are set
    if generation_kwargs.get("do_sample", False):
        generation_kwargs.setdefault("temperature", 0.7)
        generation_kwargs.setdefault("top_p", 0.9)
        generation_kwargs.setdefault("top_k", 50)


    # --- Generate Text ---
    print("Generating text...")
    with torch.no_grad():
        outputs = model.text_model.generate(
            inputs_embeds=combined_embeds,        # Pass combined image+prompt embeddings
            attention_mask=combined_mask,       # Pass combined mask
            **generation_kwargs
        )

    # --- Decode Generated Text (Excluding Input) ---
    print("Decoding generated text...")
    # outputs tensor contains token IDs for the full sequence (image+prompt+generated)
    # Input length is the sequence length of the combined embeddings
    input_length = combined_embeds.shape[1]
    # Slice the outputs tensor to get only the generated token IDs
    generated_ids = outputs[0, input_length:]
    # Decode the generated IDs
    generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)



In [7]:
# --- Build Configuration from Args and Defaults ---
INFERENCE_CONFIG = {
    "model": {
        "image_encoder_name": "google/siglip-base-patch16-224",
        "text_model_name": "Qwen/Qwen2.5-0.5B",
        "max_position_embeddings": 2048, # Should match training if possible
        "use_learned_siglip_params": True, # Assumes checkpoint saved them
    },

    # image_encoder_name: str = "google/siglip-base-patch16-224",
    #     text_model_name: str = "Qwen/Qwen2.5-0.5B", # Ensure consistency
    # "generation": {
    #     "max_new_tokens": args.max_new_tokens,
    #     "do_sample": args.do_sample,
    #     "temperature": args.temperature,
    #     "top_k": args.top_k,
    #     "top_p": args.top_p,
    #     # Add other generate() args as needed (e.g., num_beams)
    # },
    "device": "cuda" if torch.cuda.is_available() else "cpu",
}

prompt = 'Describe this image'

datasets = load_cifar10(root_dir='./cifar_data/')

image, _ = datasets[0]
# generated_text = infer(image)

checkpoint_path = "/content/drive/MyDrive/EAG/a23-vlm/model/final_model"

# --- Run Generation ---
try:
    generated_caption = generate_caption(
        image_path=image,
        prompt=prompt,
        checkpoint_path=checkpoint_path,
        config=INFERENCE_CONFIG
    )

    if generated_caption:
        print("\n--- Generated Caption ---")
        print(generated_caption)

except FileNotFoundError as e:
    print(f"\nError: {e}")
except Exception as e:
    print(f"\nAn unexpected error occurred: {e}")
    import traceback
    traceback.print_exc() # Print detailed traceback for debugging

Loading CIFAR-10 test dataset...


100%|██████████| 170M/170M [00:02<00:00, 79.6MB/s]


CIFAR-10 loaded successfully.
Using device: cuda
Loading model components...


preprocessor_config.json:   0%|          | 0.00/368 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


config.json:   0%|          | 0.00/432 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/813M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/7.23k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/2.78M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/1.67M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/7.03M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/681 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/988M [00:00<?, ?B/s]

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


generation_config.json:   0%|          | 0.00/138 [00:00<?, ?B/s]

Loading trained weights from: /content/drive/MyDrive/EAG/a23-vlm/model/final_model
Model loaded successfully.
Loading image: <PIL.Image.Image image mode=RGB size=32x32 at 0x79BD6A6B2810>
Encoding image features...
Processing prompt: 'Describe this image'
Generating text...
Decoding generated text...


In [10]:
# Assuming 'model' is your instantiated and loaded VisionLanguageModel object
model = VisionLanguageModel(
)


print(model)

# You can also iterate through named parameters to see their names and requires_grad status
print("\n--- Trainable Parameters ---")
for name, param in model.named_parameters():
    if param.requires_grad:
        print(f"{name}: {param.shape}")

print("\n--- Frozen Parameters (Example) ---")
count = 0
for name, param in model.named_parameters():
    if not param.requires_grad and count < 10: # Print first few frozen params
        print(f"{name}: {param.shape}")
        count += 1

VisionLanguageModel(
  (image_encoder): SiglipVisionModel(
    (vision_model): SiglipVisionTransformer(
      (embeddings): SiglipVisionEmbeddings(
        (patch_embedding): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16), padding=valid)
        (position_embedding): Embedding(196, 768)
      )
      (encoder): SiglipEncoder(
        (layers): ModuleList(
          (0-11): 12 x SiglipEncoderLayer(
            (self_attn): SiglipSdpaAttention(
              (k_proj): Linear(in_features=768, out_features=768, bias=True)
              (v_proj): Linear(in_features=768, out_features=768, bias=True)
              (q_proj): Linear(in_features=768, out_features=768, bias=True)
              (out_proj): Linear(in_features=768, out_features=768, bias=True)
            )
            (layer_norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
            (mlp): SiglipMLP(
              (activation_fn): PytorchGELUTanh()
              (fc1): Linear(in_features=768, out_features=3072,

In [6]:

# Example usage (replace with your image path)
datasets = load_cifar10(root_dir='./cifar_data/')

image, _ = datasets[0]
generated_text = infer(image)


Loading CIFAR-10 test dataset...
CIFAR-10 loaded successfully.


NameError: name 'infer' is not defined

In [None]:
from google.colab import drive
import shutil

# Mount Google Drive
drive.mount('/content/drive')

# Define source and destination paths
src = '/content/final_model'
dst = '/content/drive/MyDrive/final_model'

# Copy the entire folder
shutil.copytree(src, dst)

print(f'Model copied to {dst}')


In [10]:
# Example usage (replace with actual image and text)
from PIL import Image
image = Image.open("path/to/your/image.jpg")
text = "A photo of a cat"
outputs = infer(image, text)
print(outputs)


NameError: name 'model' is not defined