In [1]:
import os
os.environ['HF_HOME'] = "/mimer/NOBACKUP/groups/drl_mps_planner/agents/models"

In [2]:
from unsloth import FastModel
import torch

torch._dynamo.reset()
torch._dynamo.config.cache_size_limit = 64

model_id = "unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit"

model, tokenizer = FastModel.from_pretrained(
    model_name = model_id,
    dtype = None, # None for auto detection
    max_seq_length = 1024, # Choose any for long context!
    load_in_4bit = True,  # 4 bit quantization to reduce memory
    full_finetuning = False, # [NEW!] We have full finetuning now!
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.8.1: Fast Gemma3N patching. Transformers: 4.54.1.
   \\   /|    NVIDIA H100 NVL. Num GPUs = 1. Max memory: 93.086 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.1+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.3.1
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.31.post1. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Gemma3N does not support SDPA - switching to eager!


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
from typing import Optional, Union
from transformers.cache_utils import Cache
from transformers.utils import is_torchdynamo_compiling
import torch

def embed_image_and_text(
    base_model,
    input_ids: Optional[torch.LongTensor] = None,  # text inputs
    pixel_values: Optional[torch.FloatTensor] = None,  # vision inputs
    input_features: Optional[torch.FloatTensor] = None,  # audio inputs
    attention_mask: Optional[torch.Tensor] = None,
    input_features_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
    token_type_ids: Optional[torch.LongTensor] = None,
    cache_position: Optional[torch.LongTensor] = None,
    inputs_embeds: Optional[torch.FloatTensor] = None,
    labels: Optional[torch.LongTensor] = None,
    use_cache: Optional[bool] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    image_features: Optional[torch.Tensor] = None,
    **lm_kwargs,
) -> torch.Tensor:
    self = base_model
    if (input_ids is None) ^ (inputs_embeds is not None):
        raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
    output_hidden_states = (
        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
    )

    if input_ids is not None:
        inputs_embeds = self.get_input_embeddings()(input_ids)

        # Prepare per-layer inputs from inputs_ids
        per_layer_inputs_mask = torch.logical_and(input_ids >= 0, input_ids < self.model.vocab_size_per_layer_input)
        per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids))
        per_layer_inputs = self.language_model.get_per_layer_inputs(per_layer_inputs_tokens)

        # Handle vision tokens (>= embed_vision.vocab_offset and < embed_audio.vocab_offset)
        vision_mask = torch.logical_and(
            input_ids >= self.model.embed_vision.vocab_offset, input_ids < self.model.embed_audio.vocab_offset
        )
        dummy_vision_token_id = self.model.embed_vision.vocab_offset + self.model.embed_vision.vocab_size - 1
        vision_input_ids = torch.where(vision_mask, input_ids, dummy_vision_token_id).to(inputs_embeds.device)
        vision_embeds = self.model.embed_vision(input_ids=vision_input_ids)
        expanded_vision_mask = vision_mask.unsqueeze(-1).expand_as(inputs_embeds)
        inputs_embeds = torch.where(expanded_vision_mask, vision_embeds, inputs_embeds)

        # Handle audio tokens (>= embed_audio.vocab_offset)
        audio_mask = input_ids >= self.model.embed_audio.vocab_offset
        dummy_audio_token_id = self.model.embed_audio.vocab_offset + self.model.embed_audio.vocab_size - 1
        audio_input_ids = torch.where(audio_mask, input_ids, dummy_audio_token_id).to(inputs_embeds.device)
        audio_embeds = self.model.embed_audio(input_ids=audio_input_ids)
        expanded_audio_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds)
        inputs_embeds = torch.where(expanded_audio_mask, audio_embeds, inputs_embeds)
    else:
        per_layer_inputs = None

    # Merge text and images
    if pixel_values is not None:
        if image_features is None:
            image_features = self.model.get_image_features(pixel_values)

        if input_ids is None:
            special_image_mask = inputs_embeds == self.get_input_embeddings()(
                torch.tensor(self.model.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
            )
        else:
            special_image_mask = (input_ids == self.model.config.image_token_id).unsqueeze(-1)
            special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)

        if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
            image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
            raise ValueError(
                f"Number of images does not match number of special image tokens in the input text. "
                f"Got {image_tokens_in_text} image tokens in the text and "
                f"{image_features.shape[0] * image_features.shape[1]} tokens from image embeddings."
            )
        image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
        inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
    outputs = self.language_model(
        input_ids=None,
        per_layer_inputs=per_layer_inputs,
        attention_mask=attention_mask,
        position_ids=position_ids,
        past_key_values=past_key_values,
        inputs_embeds=inputs_embeds,
        use_cache=use_cache,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=True,
        cache_position=cache_position,
        **lm_kwargs,
    )
    # returns embeddings before language model and after language model
    return inputs_embeds, outputs.last_hidden_state

In [4]:
from PIL import Image
import requests

In [5]:
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
image = Image.open(requests.get(url, stream=True).raw)

messages = [
    {
        "role": "user",
        "content": [
            {"type": "image"},
            {"type": "text", "text": "Where is the cat standing?"},
        ]
    }
]

prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(images=image, text=prompt, return_tensors="pt").to(model.device)

In [6]:
with torch.inference_mode():
    early_embeddings, late_embeddings = embed_image_and_text(base_model=model, **inputs)#, image_features=image_features)

In [7]:
output.shape

torch.Size([1, 275, 2048])

# Prepare dataset

In [8]:
def image_match_prompt(question: str):
    return f"""# General Instructions
You are given an image and a question and should indicate whether question about the image is true, wrong or irrelevant.

# Labels:
- "true": The image clearly supports the question—visual evidence is unambiguous.
- "false": The image clearly contradicts the question—visual evidence directly disproves it.
- "null": The image does not provide enough visual information to answer definitively.

# Instructions:
Use only the visual content of the image. Do not make assumptions. Choose "null" when there's not enough evidence, even if a guess seems likely.

# Question:
{question}

# Output format:
{{"answer": "<true|false|null>"}}
"""

In [9]:
import requests
from PIL import Image
from io import BytesIO
import os
import hashlib
import json
from datasets import Dataset

CACHE_DIR = "image_cache"
if not os.path.exists(CACHE_DIR):
    os.makedirs(CACHE_DIR)

def download_image(url):
    # Create a unique filename based on the URL hash
    url_hash = hashlib.sha256(url.encode()).hexdigest()
    file_extension = os.path.splitext(url.split('?')[0])[-1] or '.jpg' # Default to .jpg if no extension
    cache_path = os.path.join(CACHE_DIR, f"{url_hash}{file_extension}")

    if os.path.exists(cache_path):
        try:
            return Image.open(cache_path)
        except IOError as e:
            print(f"Warning: Could not open cached image {cache_path}: {e}")
            # If cache is corrupt, proceed to download

    try:
        response = requests.get(url, timeout=10)
        response.raise_for_status()
        image = Image.open(BytesIO(response.content))
        # Resize the image to max 512x512, preserving aspect ratio
        # image.thumbnail((512, 512))
        # Save the resized image to the cache
        image.save(cache_path)
        return image
    except (requests.exceptions.RequestException, IOError) as e:
        print(f"Warning: Could not download image {url}: {e}")
        return None

def data_generator(path: str):
    """Yields records from a JSONL file for memory-efficient loading."""
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            
            obj = json.loads(line)
            question = obj.get("query") or obj.get("question")
            image_url = obj.get("imageUrl")
            reply = obj.get("reply")

            image = download_image(image_url)

            if not (question and image_url and reply and image):
                continue

            yield {
                "messages": [
                    {
                        "role": "user",
                        "content": [
                            {"type": "text", "prompt": image_match_prompt(question)},
                            {"type": "text", "text": question},
                            {"type": "image", "image": image},
                            {"type": "url", "text": image_url},
                        ]
                    },
                    {
                        "role": "assistant",
                        "content": [
                            {"type": "text", "text": reply},
                        ]
                    }
                ],
            }

def create_dataset(data_path):
    """Create a HuggingFace dataset by converting to a simpler format first."""
    # First, collect all the data in the simple format
    if isinstance(data_path, list):
        # If data_path is a list, process all files and combine results
        data_items = []
        for path in data_path:
            data_items.extend(list(data_generator(path)))
    else:
        # If data_path is a single path, process it normally
        data_items = list(data_generator(data_path))    
    # Convert to a flat structure that HuggingFace can handle
    flat_data = []
    for item in data_items:
        # Store the complex structure as a single field and handle images separately
        user_message = item["messages"][0]
        assistant_message = item["messages"][1]
        
        # Extract text and image from user message
        text_content = None
        image_content = None
        
        for content in user_message["content"]:
            if content["type"] == "text":
                if "prompt" in content:
                    prompt_content
                else:
                    text_content = content["text"]
            elif content["type"] == "image":
                image_content = content["image"]
            elif content["type"] == "url":
                image_url = content["text"]
        
        raw_reply = assistant_message["content"][0]["text"]
        if raw_reply == "doesn't match":
            reply = '{"response": "false"}'
            label = 0
        elif raw_reply == "matches":
            reply = '{"response": "true"}'
            label = 1
        elif raw_reply == "irrelevant":
            reply = '{"response": "null"}'
            label = 2
        else:
            print(f"{raw_reply} not found")

        flat_data.append({
            "user_text": text_content,
            "user_image": image_content,
            "assistant_text": reply,
            "image_url": image_url,
            "label": label,
            # "original_messages": item["messages"]  # Keep original structure
        })
    
    # Create dataset from the flat structure
    dataset = Dataset.from_list(flat_data)
    
    # Add a method to get back the original format
    def get_original_format(example):
        return example["original_messages"]
    
    dataset.get_original_format = get_original_format
    
    return dataset

In [10]:
data_path = "./image-analysis-results.jsonl"
dataset = create_dataset(data_path)

KeyError: 'text'

In [None]:
dataset

In [None]:
dataset[0]

# Prompt eval

In [120]:
# Split dataset into train and validation sets
train_test_split = dataset.train_test_split(test_size=0.2, seed=42)
train_dataset = train_test_split['train']
val_dataset = train_test_split['test']

print(f"Total dataset size: {len(dataset)}")
print(f"Training set size: {len(train_dataset)}")
print(f"Validation set size: {len(val_dataset)}")

# instruction = "You are an expert at understanding aparment images. Answer accurately."

def convert_to_conversation(sample):
    conversation = [
        { "role": "user",
          "content" : [
            {"type" : "text",  "text"  : sample["user_text"]},
            {"type" : "image", "image" : sample["user_image"]} ]
        },
        { "role" : "assistant",
          "content" : [
            {"type" : "text",  "text"  : sample["assistant_text"]} ]
        },
    ]
    return { "messages" : conversation }


train_converted_dataset = [convert_to_conversation(sample) for sample in train_dataset]
val_converted_dataset = [convert_to_conversation(sample) for sample in val_dataset]
pass

Total dataset size: 6444
Training set size: 5155
Validation set size: 1289


In [125]:
import re
import json
import torch
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

def extract_classification(text):
    """
    Extract classification from model output.
    Returns 'true', 'false', 'null', or None if not found.
    """
    text = text.lower().strip()
    
    # Try to extract from JSON format first
    try:
        # Look for JSON-like patterns
        json_match = re.search(r'\{[^}]*"response"[^}]*:[^}]*"([^"]+)"[^}]*\}', text)
        if json_match:
            response = json_match.group(1).lower().strip()
            if response in ['true', 'false', 'null']:
                return response
    except:
        pass
    
    # Direct string matching as fallback
    if 'null' in text:
        return 'null'
    elif 'true' in text:
        return 'true'
    elif 'false' in text:
        return 'false'
    
    return None

def calculate_accuracy_batched(model, tokenizer, val_converted_dataset, batch_size=8):
    """Calculate accuracy using batched inference with robust classification extraction."""
    total = len(val_converted_dataset)
    
    # Track different types of errors for debugging
    stats = {
        'correct': 0,
        'wrong_classification': 0,
        'unparseable': 0,
        'by_class': {'true': {'correct': 0, 'total': 0}, 
                     'false': {'correct': 0, 'total': 0}, 
                     'null': {'correct': 0, 'total': 0}}
    }
    
    y_true = []
    y_pred = []
    
    model.eval()  # Set model to evaluation mode
    
    processed_count = 0
    
    # Process in batches
    for i in range(0, total, batch_size):
        batch_end = min(i + batch_size, total)
        batch_samples = val_converted_dataset[i:batch_end]
        
        batch_inputs = []
        expected_responses = []
        valid_indices = []  # Track which samples were successfully tokenized
        
        # Tokenize each sample individually
        for idx, sample in enumerate(batch_samples):
            try:
                messages = sample['messages']
                
                # Extract user message (first message, excluding assistant)
                user_message = messages[0]  # This contains the user text and image
                
                # Extract expected response (assistant message content)
                assistant_message = messages[1]  # This is the assistant response
                expected_response = assistant_message['content'][0]['text']  # Extract the text
                
                # Tokenize individual sample
                inputs = tokenizer.apply_chat_template(
                    [user_message],
                    tokenize=True,
                    add_generation_prompt=True,
                    return_tensors="pt",
                )
                
                batch_inputs.append(inputs[0])  # Remove batch dimension
                expected_responses.append(expected_response)
                valid_indices.append(i + idx)  # Track original index
                
            except Exception as e:
                print(f"Error tokenizing sample {i + idx}: {e}")
                continue
        
        if not batch_inputs:
            continue
            
        # Manually pad to the same length (left padding)
        max_length = max(len(inp) for inp in batch_inputs)
        padded_inputs = []
        original_lengths = []
        
        for inp in batch_inputs:
            original_lengths.append(len(inp))
            if len(inp) < max_length:
                # Left pad with pad_token_id
                padding_length = max_length - len(inp)
                padded = torch.cat([
                    torch.full((padding_length,), tokenizer.pad_token_id, dtype=inp.dtype),
                    inp
                ])
                padded_inputs.append(padded)
            else:
                padded_inputs.append(inp)
        
        # Stack into batch tensor
        input_ids = torch.stack(padded_inputs).to(model.device)
        
        # Create attention mask for left-padded sequences
        attention_mask = (input_ids != tokenizer.pad_token_id).to(model.device)
        
        with torch.no_grad():
            outputs = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_new_tokens=128,
                temperature=0.0,
                do_sample=False,
                pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
            )
        
        # Process each sample in the batch
        for j in range(len(batch_inputs)):
            # Use the original length we stored
            original_length = original_lengths[j]
            
            # For left-padded sequences, we need to account for the padding
            padding_length = max_length - original_length
            actual_input_end = padding_length + original_length
            
            # The generated part starts after the actual input
            generated_tokens = outputs[j][actual_input_end:]
            
            # Decode the generated response
            generated_text = tokenizer.decode(
                generated_tokens, 
                skip_special_tokens=True
            ).strip()
            
            # Extract classifications
            expected_class = extract_classification(expected_responses[j])
            predicted_class = extract_classification(generated_text)
            
            if expected_class and predicted_class:
                y_true.append(expected_class)
                y_pred.append(predicted_class)
            
            # Update statistics
            if expected_class:
                stats['by_class'][expected_class]['total'] += 1
            
            if predicted_class is None:
                stats['unparseable'] += 1
                if processed_count < 10:  # Show first few for debugging
                    print(f"Unparseable output: '{generated_text}' (expected: '{expected_responses[j]}')")
            elif predicted_class == expected_class:
                stats['correct'] += 1
                if expected_class:
                    stats['by_class'][expected_class]['correct'] += 1
            else:
                stats['wrong_classification'] += 1
                if processed_count < 10:  # Show first few misclassifications for debugging
                    print(f"Misclassification: predicted '{predicted_class}', expected '{expected_class}'")
            
            processed_count += 1
        
        # Print progress
        print(f"Processed {processed_count}/{total} samples")
    
    # Calculate final metrics
    accuracy = stats['correct'] / processed_count if processed_count > 0 else 0
    
    print(f"\n=== FINAL RESULTS ===")
    print(f"Overall Accuracy: {stats['correct']}/{processed_count} = {accuracy:.4f} ({accuracy*100:.2f}%)")
    print(f"Correct predictions: {stats['correct']}")
    print(f"Wrong classifications: {stats['wrong_classification']}")
    print(f"Unparseable outputs: {stats['unparseable']}")
    
    print(f"\n=== BY CLASS ===")
    for class_name, class_stats in stats['by_class'].items():
        if class_stats['total'] > 0:
            class_acc = class_stats['correct'] / class_stats['total']
            print(f"{class_name.upper()}: {class_stats['correct']}/{class_stats['total']} = {class_acc:.4f} ({class_acc*100:.2f}%)")
    
    # Generate and plot confusion matrix
    if y_true and y_pred:
        labels = sorted(list(set(y_true)))
        cm = confusion_matrix(y_true, y_pred, labels=labels)
        
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels)
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.title('Confusion Matrix')
        plt.savefig('confusion_matrix.png')
        print("\nConfusion matrix plot saved to 'confusion_matrix.png'")

    return accuracy, stats

In [126]:
# Ensure pad token is properly set and configure padding side
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    print("Set pad_token to eos_token")

# Set left padding for decoder-only models
tokenizer.padding_side = 'left'
print("Set padding_side to 'left' for decoder-only model")

# Calculate accuracy on validation set
accuracy, detailed_stats = calculate_accuracy_batched(model, tokenizer, val_converted_dataset, batch_size=64)

Set padding_side to 'left' for decoder-only model
Misclassification: predicted 'null', expected 'false'
Misclassification: predicted 'null', expected 'false'
Misclassification: predicted 'null', expected 'false'
Misclassification: predicted 'null', expected 'false'
Misclassification: predicted 'false', expected 'null'
Processed 64/1289 samples
Processed 128/1289 samples
Processed 192/1289 samples
Processed 256/1289 samples
Processed 320/1289 samples
Processed 384/1289 samples
Processed 448/1289 samples
Processed 512/1289 samples
Processed 576/1289 samples
Processed 640/1289 samples
Processed 704/1289 samples
Processed 768/1289 samples
Processed 832/1289 samples
Processed 896/1289 samples
Processed 960/1289 samples
Processed 1024/1289 samples
Processed 1088/1289 samples
Processed 1152/1289 samples
Processed 1216/1289 samples
Processed 1280/1289 samples
Processed 1289/1289 samples

=== FINAL RESULTS ===
Overall Accuracy: 606/1289 = 0.4701 (47.01%)
Correct predictions: 606
Wrong classific

# Embed data

In [16]:
from tqdm import tqdm

In [18]:
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
image = Image.open(requests.get(url, stream=True).raw)

messages = [
    {
        "role": "user",
        "content": [
            {"type": "image"},
            {"type": "text", "text": "Where is the cat standing?"},
        ]
    }
]

prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(images=image, text=prompt, return_tensors="pt").to(model.device)

image_features = model.get_image_features(inputs["pixel_values"])

with torch.inference_mode():
    embeddings, output = embed_image_and_text(base_model=model, **inputs, image_features=image_features)

AttributeError: 'Gemma3nForConditionalGeneration' object has no attribute 'embed_image_and_text'

In [20]:
embedding_cache = {}
embeddings = []
embeddings_1 = []
with torch.inference_mode():
    for sample in tqdm(dataset):
        url = sample["image_url"]
        image = sample["user_image"]
        #image = Image.open(requests.get(url, stream=True).raw)
    
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": sample["user_text"]},
                ]
            }
        ]
        prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = tokenizer(images=image, text=prompt, return_tensors="pt").to(model.device)
    
        if url in embedding_cache:
            image_features = embedding_cache[url]
        else:
            image_features = model.get_image_features(inputs["pixel_values"])
            embedding_cache[url] = image_features
        embedding_1, embedding = embed_image_and_text(base_model=model, **inputs, image_features=image_features)
        if embedding.isnan().any():
            print("nan")
            break
        embeddings.append(embedding.detach().clone().cpu())
        embeddings_1.append(embedding_1.detach().clone().cpu())

100%|██████████| 6444/6444 [09:51<00:00, 10.90it/s]


# Train classifier

In [21]:
import pickle

In [22]:
with open("embeddings.pkl", "wb") as f:
    pickle.dump(embeddings, f)

In [11]:
with open("embeddings.pkl", "rb") as f:
    embeddings = pickle.load(f)

In [None]:
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
import lightning as L
import torchmetrics

class AttentionPooling(nn.Module):
    """
    Implements an attention mechanism that can handle padded sequences.
    """
    def __init__(self, embedding_dim):
        super().__init__()
        self.attention_net = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim // 2),
            nn.Tanh(),
            nn.Linear(embedding_dim // 2, 1, bias=False)
        )

    def forward(self, x, attention_mask):
        """
        Args:
            x (torch.Tensor): Padded input tensor of shape [batch_size, seq_len, embedding_dim]
            attention_mask (torch.Tensor): Boolean mask of shape [batch_size, seq_len, 1]
        Returns:
            torch.Tensor: A pooled representation of shape [batch_size, embedding_dim]
        """
        attention_scores = self.attention_net(x)

        # --- MASKING STEP ---
        # Where the mask is False (i.e., for padded elements), set scores to a very
        # large negative number so they become zero after softmax.
        attention_scores.masked_fill_(~attention_mask, -float('inf'))

        # Convert scores to weights
        attention_weights = torch.softmax(attention_scores, dim=1)

        # --- FIX for NaN issue ---
        # If a sequence is entirely padded, softmax(-inf) results in NaN.
        # We replace any NaNs with 0.0 to ensure numerical stability.
        attention_weights = torch.nan_to_num(attention_weights, nan=0.0)

        # The weights for padded elements will be 0, so they don't contribute
        # to the weighted average.
        weighted_average = torch.sum(x * attention_weights, dim=1)

        return weighted_average

class EmbeddingClassifier(L.LightningModule):
    """
    A classifier that uses an attention mechanism to pool variable-length sequences.
    """
    def __init__(self, embedding_dim=2048, hidden_dim=512, num_classes=3, learning_rate=1e-3, dropout_p=0.5, class_weights=None):
        super().__init__()
        self.save_hyperparameters()
        self.attention_pooling = AttentionPooling(embedding_dim)
        self.classifier = nn.Sequential(
            nn.Linear(self.hparams.embedding_dim, self.hparams.hidden_dim),
            nn.ReLU(),
            nn.Dropout(p=self.hparams.dropout_p),
            nn.Linear(self.hparams.hidden_dim, self.hparams.num_classes)
        )
        self.loss_fn = nn.CrossEntropyLoss(weight=self.hparams.class_weights)
        self.train_accuracy = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes)
        self.val_accuracy = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes)

    def setup(self, stage=None):
        if self.hparams.class_weights is not None:
            self.loss_fn.weight = self.hparams.class_weights.to(self.device)

    def forward(self, x, attention_mask):
        """
        The forward pass now accepts the mask and passes it to the pooling layer.
        """
        pooled_output = self.attention_pooling(x, attention_mask)
        logits = self.classifier(pooled_output)
        return logits

    def training_step(self, batch, batch_idx):
        # The batch now contains embeddings, labels, and the attention mask
        embeddings, labels, mask = batch
        logits = self(embeddings, mask)
        loss = self.loss_fn(logits, labels)
        self.train_accuracy(logits, labels)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_acc', self.train_accuracy, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        embeddings, labels, mask = batch
        logits = self(embeddings, mask)
        loss = self.loss_fn(logits, labels)
        self.val_accuracy(logits, labels)
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        self.log('val_acc', self.val_accuracy, on_epoch=True, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',      # Reduce on minimum validation loss
            factor=0.2,      # new_lr = lr * factor
            patience=2,      # Number of epochs with no improvement after which lr is reduced
            min_lr=1e-6,     # Don't let the learning rate go too low
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss", # The metric to watch
            },
        }

In [80]:
import torch
from torch.utils.data import Dataset

class FeatureDataset(Dataset):
    """
    Custom PyTorch Dataset for variable-length sequences.

    Args:
        features (list): A list of feature tensors. Each tensor in the list
                         can have a variable sequence length,
                         e.g., shape [seq_len, embedding_dim].
        labels (list or torch.Tensor): A list or tensor of corresponding labels.
    """
    def __init__(self, features, labels):
        # Ensure features and labels have the same number of samples
        assert len(features) == len(labels), "Features and labels must have the same length"
        self.features = features
        self.labels = labels

    def __len__(self):
        """Returns the total number of samples in the dataset."""
        return len(self.features)

    def __getitem__(self, idx):
        """
        Retrieves the feature and label for a given index.

        Args:
            idx (int): The index of the sample to retrieve.

        Returns:
            tuple: (feature_tensor, label_tensor)
        """
        # Get the feature tensor at the given index
        feature = self.features[idx].to(torch.float32)
        # Get the label at the given index
        label = self.labels[idx]

        # It's good practice to ensure the label is a tensor
        if not isinstance(label, torch.Tensor):
            label = torch.tensor(label, dtype=torch.long)

        return feature, label


def collate_fn_pad(batch, target_embedding_dim=2048):
    """
    This function is passed to the DataLoader to handle padding.
    It takes a list of samples (tuples) and returns a single batch.
    It also normalizes the embedding dimension to a fixed size to prevent errors.
    """
    # 1. Separate embeddings and labels
    embeddings_list = [item[0] for item in batch]
    labels_list = [item[1] for item in batch]

    # --- NEW: Normalize embedding dimension to a fixed size ---
    # This loop fixes the error by ensuring all tensors have the same embedding dim.
    normalized_embeddings_list = []
    for emb in embeddings_list:
        # --- FIX: Handle potential extra batch dimension ---
        if emb.dim() == 3 and emb.shape[0] == 1:
            emb = emb.squeeze(0)

        # Use the last dimension as the embedding dimension for robustness
        current_dim = emb.shape[-1]

        if current_dim > target_embedding_dim:
            # Truncate the embedding dimension if it's too large
            normalized_emb = emb[:, :target_embedding_dim]
        elif current_dim < target_embedding_dim:
            # Pad the embedding dimension with zeros if it's too small
            padding_size = target_embedding_dim - current_dim
            padding = torch.zeros((emb.shape[0], padding_size), dtype=emb.dtype, device=emb.device)
            normalized_emb = torch.cat([emb, padding], dim=1)
        else:
            # No change needed if the dimension is correct
            normalized_emb = emb
        normalized_embeddings_list.append(normalized_emb)

    # 2. Pad the sequences of embeddings using the normalized list
    # `pad_sequence` stacks them and pads with 0 to the length of the longest sequence.
    # batch_first=True makes the output shape [batch_size, seq_len, embedding_dim]
    padded_embeddings = pad_sequence(normalized_embeddings_list, batch_first=True, padding_value=0.0)

    # 3. Create the attention mask
    # The mask is True where there is real data, False for padding (0.0)
    # We check the sum across the embedding dimension. If it's 0, it's padding.
    attention_mask = (padded_embeddings.sum(dim=2) != 0).unsqueeze(2)

    # 4. Stack the labels into a tensor
    labels = torch.stack(labels_list)

    return padded_embeddings, labels, attention_mask

In [25]:
import copy
vision_text_embeddings = copy.deepcopy(embeddings)

In [26]:
(vision_text_embeddings[110]==vision_text_embeddings[6000]).all()

tensor(False)

In [58]:
vision_text_embeddings[1].max(dim=1).values

tensor([[15.5625, 13.1875, 11.1875,  ...,  7.2188,  8.6250,  7.7812]],
       dtype=torch.bfloat16)

In [82]:
from sklearn.model_selection import train_test_split

#mean_pool_embeddings = [e.nanmean(dim=1) for e in vision_text_embeddings]
all_labels = torch.tensor(dataset["label"], dtype=torch.long)
mean_embeddings = torch.cat([e.mean(dim=1) for e in vision_text_embeddings]).to(torch.float32)
max_embeddings = torch.cat([e.max(dim=1).values for e in vision_text_embeddings]).to(torch.float32)
min_embeddings = torch.cat([e.min(dim=1).values for e in vision_text_embeddings]).to(torch.float32)
all_embeddings = torch.cat([mean_embeddings, max_embeddings, min_embeddings], dim=1)

In [83]:
all_embeddings.shape

torch.Size([6444, 6144])

In [84]:
train_embeddings, val_embeddings, train_labels, val_labels = train_test_split(
    vision_text_embeddings,
    all_labels,
    test_size=0.2,
    random_state=42,
    stratify=all_labels
)

In [62]:
import numpy as np

In [86]:
print("--- Before Oversampling ---")
# Get the unique classes and their counts in the training set
unique_classes, class_counts = np.unique(train_labels.numpy(), return_counts=True)
for cls, count in zip(unique_classes, class_counts):
    print(f"Class {cls}: {count} samples")

# Find the majority class count
majority_count = class_counts.max()
print(f"Majority class count: {majority_count}\n")

# Convert features to a list if they are not already
# (This makes appending easier)
train_labels_list = train_labels.tolist()
train_embeddings_list = train_embeddings


# Loop through each class to perform oversampling
for cls, count in zip(unique_classes, class_counts):
    if count < majority_count:
        # Calculate how many samples we need to add
        num_to_add = int((majority_count - count) // 1.2)
        print(f"Oversampling Class {cls}: Adding {num_to_add} samples...")

        # Get the indices of all samples belonging to the current minority class
        minority_indices = np.where(train_labels.numpy() == cls)[0]

        # Randomly choose indices from the minority class to duplicate
        # `replace=True` allows us to pick the same sample multiple times
        indices_to_add = np.random.choice(minority_indices, size=num_to_add, replace=True)

        # Add the chosen samples to our training lists
        for idx in indices_to_add:
            train_embeddings_list.append(train_embeddings[idx])
            train_labels_list.append(train_labels[idx])

print("\n--- After Oversampling ---")
# Verify the new class distribution
new_labels_tensor = torch.tensor(train_labels_list)
unique_classes, class_counts = np.unique(new_labels_tensor.numpy(), return_counts=True)
for cls, count in zip(unique_classes, class_counts):
    print(f"Class {cls}: {count} samples")

--- Before Oversampling ---
Class 0: 2185 samples
Class 1: 193 samples
Class 2: 2777 samples
Majority class count: 2777

Oversampling Class 0: Adding 493 samples...
Oversampling Class 1: Adding 2153 samples...

--- After Oversampling ---
Class 0: 2678 samples
Class 1: 2346 samples
Class 2: 2777 samples


In [88]:
BATCH_SIZE = 32

train_dataset = FeatureDataset(features=train_embeddings_list, labels=train_labels_list)
val_dataset = FeatureDataset(features=val_embeddings, labels=val_labels)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn_pad
)
#
val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    collate_fn=collate_fn_pad
)

In [89]:
from sklearn.utils.class_weight import compute_class_weight

class_weights = compute_class_weight(
    'balanced',
    classes=np.unique(train_labels_list),
    y=np.array(train_labels_list)
)
class_weights = torch.tensor(class_weights, dtype=torch.float)

print(f"Calculated class weights: {class_weights}")

Calculated class weights: tensor([0.9710, 1.1084, 0.9364])


In [90]:
os.environ['TOKENIZERS_PARALLELISM'] = "false"

In [110]:
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint

In [111]:
L.seed_everything(42)

EMBEDDING_DIM = 1*2048
NUM_CLASSES = 3

classifier = EmbeddingClassifier(
    embedding_dim=EMBEDDING_DIM,
    num_classes=NUM_CLASSES,
    learning_rate=1e-4,
    dropout_p=0.4,
    class_weights=class_weights
)


model_checkpoint_callback = ModelCheckpoint(
    monitor='val_acc',
    dirpath='checkpoints',
    filename='best-model-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)

logger = WandbLogger(log_model="none")

trainer = L.Trainer(
    max_epochs=50,
    accelerator='auto',
    logger=logger,
    callbacks=[model_checkpoint_callback]
)

[rank: 0] Seed set to 42
/mimer/NOBACKUP/groups/drl_mps_planner/agents/agents_venv/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /apps/Arch/software/jupyter-server/2.7.2-GCCcore-12. ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [112]:
trainer.fit(classifier, train_dataloaders=train_loader, val_dataloaders=val_loader)

/mimer/NOBACKUP/groups/drl_mps_planner/agents/agents_venv/lib/python3.11/site-packages/lightning/pytorch/loggers/wandb.py:397: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
/mimer/NOBACKUP/groups/drl_mps_planner/agents/agents_venv/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:658: Checkpoint directory /mimer/NOBACKUP/groups/drl_mps_planner/agents/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type               | Params | Mode 
-----------------------------------------------------------------
0 | attention_pooling | AttentionPooling   | 2.1 M  | train
1 | classifier        | Sequential         | 1.1 M  | train
2 | loss_fn           | CrossEntropyLoss   | 0      | train
3 | train_accuracy    | MulticlassAccuracy | 0      | train
4 | val_accuracy      | Multicl

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/mimer/NOBACKUP/groups/drl_mps_planner/agents/agents_venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
/mimer/NOBACKUP/groups/drl_mps_planner/agents/agents_venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=50` reached.


In [113]:
best_model_path = model_checkpoint_callback.best_model_path
print(f"Path to the best model: {best_model_path}")

best_model = EmbeddingClassifier.load_from_checkpoint(best_model_path)
best_model.eval()

Path to the best model: /mimer/NOBACKUP/groups/drl_mps_planner/agents/checkpoints/best-model-epoch=22-val_acc=0.85.ckpt


EmbeddingClassifier(
  (attention_pooling): AttentionPooling(
    (attention_net): Sequential(
      (0): Linear(in_features=2048, out_features=1024, bias=True)
      (1): Tanh()
      (2): Linear(in_features=1024, out_features=1, bias=False)
    )
  )
  (classifier): Sequential(
    (0): Linear(in_features=2048, out_features=512, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.4, inplace=False)
    (3): Linear(in_features=512, out_features=3, bias=True)
  )
  (loss_fn): CrossEntropyLoss()
  (train_accuracy): MulticlassAccuracy()
  (val_accuracy): MulticlassAccuracy()
)

In [114]:
print("\nRunning test on validation data (as a proxy for a test set)...")
trainer.validate(dataloaders=val_loader)

Restoring states from the checkpoint path at /mimer/NOBACKUP/groups/drl_mps_planner/agents/checkpoints/best-model-epoch=22-val_acc=0.85.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /mimer/NOBACKUP/groups/drl_mps_planner/agents/checkpoints/best-model-epoch=22-val_acc=0.85.ckpt



Running test on validation data (as a proxy for a test set)...


Validation: |          | 0/? [00:00<?, ?it/s]

[{'val_loss': 0.5024335980415344, 'val_acc': 0.8487199544906616}]

In [118]:
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
import numpy as np

# --- Generate Predictions for the Entire Validation Set ---
print("Generating predictions for the validation set...")
model.eval() # Set the model to evaluation mode

all_preds = []
all_labels = []

# Use the val_loader for efficient batch processing
for batch in val_loader:
    embeddings, labels, mask = batch
    # Move data to the same device as the model if using a GPU
    embeddings = embeddings.to(best_model.device)
    labels = labels.to(best_model.device)
    mask = mask.to(best_model.device)

    with torch.no_grad():
        logits = best_model(embeddings, mask)
        predictions = torch.argmax(logits, dim=1)
        
        # Append batch predictions and labels to our lists
        all_preds.extend(predictions.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

Generating predictions for the validation set...


In [119]:
print("Plotting the confusion matrix...")

# Define the class names for the plot labels
class_names = ['False', 'True', 'Null']
cm = confusion_matrix(all_labels, all_preds)

# Plotting the confusion matrix using seaborn
plt.figure(figsize=(8, 6))
sns.heatmap(
    cm, 
    annot=True,     # Show the numbers in the cells
    fmt='d',        # Format as integers
    cmap='Blues',   # Color scheme
    xticklabels=class_names, 
    yticklabels=class_names
)

plt.title('Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.savefig("./conf_matrix.png")

Plotting the confusion matrix...


In [127]:
import matplotlib.pyplot as plt
import numpy as np

# Data for the plot
labels = ['Prompt Based', 'Classification']
accuracies = [47, 85]

# Create a figure and a set of subplots
fig, ax = plt.subplots(figsize=(6, 5))

# Create the bar chart
bars = ax.bar(labels, accuracies, color=['skyblue', 'steelblue'])

# Add titles and labels for clarity
ax.set_ylabel('Accuracy (%)')
ax.set_title('Model Performance Comparison')
ax.set_ylim(0, 100) # Set y-axis to go from 0 to 100 for percentage context

# Add the accuracy value on top of each bar
for bar in bars:
    yval = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2.0, yval + 2, f'{yval}%', ha='center', va='bottom')

# Display the plot
plt.tight_layout()
plt.savefig("./acc_bar.png")