Import and load every necessary library

In [None]:
!pip install -q transformers accelerate datasets torch pillow


In [19]:
# Cell 2: Imports and setup
import torch
import torch.nn as nn
import numpy as np
from transformers import AutoProcessor, LlavaForConditionalGeneration
from datasets import load_dataset
from PIL import Image
import requests
from tqdm.notebook import tqdm
import gc
import os

Load the model in

In [None]:
# Cell 3: Load model and processor
model_id = "llava-hf/llava-1.5-7b-hf"
model = LlavaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
).to(0)

processor = AutoProcessor.from_pretrained(model_id)

Run an example to test if we have enough GPU memory

In [None]:
# Cell 4: Load and process POPE example
# Load one example from POPE dataset
dataset = load_dataset("lmms-lab/POPE", split="test")
example = dataset[0]

# The image is already a PIL Image object
raw_image = example['image']  # No need to download, it's already loaded

# Create conversation format for both Yes/No responses
conversation_positive = [
    {
        "role": "user",
        "content": [
            {"type": "image"},
            {"type": "text", "text": example['question']},
        ],
    },
]

# Apply chat template
prompt = processor.apply_chat_template(conversation_positive, add_generation_prompt=True)

# Process inputs
inputs = processor(
    images=raw_image,
    text=prompt,
    return_tensors='pt'
).to(0, torch.float16)

# Generate output
output = model.generate(**inputs, max_new_tokens=200, do_sample=False)
response = processor.decode(output[0][2:], skip_special_tokens=True)

print("Question:", example['question'])
print("Model response:", response)
print("Ground truth:", example['answer'])

In [None]:
# Cell 5 (Optional): Display the image
from IPython.display import display
display(raw_image)

Extract the embedding of contrastive pairs from the POPE dataset

In [6]:
# Enable performance optimizations
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True

# Clear initial CUDA memory
torch.cuda.empty_cache()
gc.collect()

def process_batch(model, processor, batch_items, batch_size=8):
    pos_hidden_batch = []
    neg_hidden_batch = []

    # Process all positive examples in batch
    pos_conversations = []
    images = []
    for item in batch_items:
        # For positive examples, append Yes to the question
        pos_conversations.append([
            {
                "role": "user",
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": f"{item['question']} ASSISTANT: Yes"},
                ],
            }
        ])
        images.append(item['image'])

    pos_prompts = [processor.apply_chat_template(conv, add_generation_prompt=False) for conv in pos_conversations]
    pos_inputs = processor(
        images=images,
        text=pos_prompts,
        return_tensors='pt',
        padding=True,
        truncation=True
    ).to('cuda', torch.float16)

    with torch.no_grad():
        pos_outputs = model(**pos_inputs, output_hidden_states=True)
        pos_hidden = [layer[:, -1, :].cpu().numpy() for layer in pos_outputs.hidden_states]
        pos_hidden_batch.extend(pos_hidden)

    # Clear memory
    del pos_inputs, pos_outputs
    torch.cuda.empty_cache()
    gc.collect()

    # Process all negative examples in batch
    neg_conversations = []
    for item in batch_items:
        # For negative examples, append No to the question
        neg_conversations.append([
            {
                "role": "user",
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": f"{item['question']} ASSISTANT: No"},
                ],
            }
        ])

    neg_prompts = [processor.apply_chat_template(conv, add_generation_prompt=False) for conv in neg_conversations]
    neg_inputs = processor(
        images=images,  # Reuse same images
        text=neg_prompts,
        return_tensors='pt',
        padding=True,
        truncation=True
    ).to('cuda', torch.float16)

    with torch.no_grad():
        neg_outputs = model(**neg_inputs, output_hidden_states=True)
        neg_hidden = [layer[:, -1, :].cpu().numpy() for layer in neg_outputs.hidden_states]
        neg_hidden_batch.extend(neg_hidden)

    # Clear memory
    del neg_inputs, neg_outputs
    torch.cuda.empty_cache()
    gc.collect()

    return pos_hidden_batch, neg_hidden_batch

def extract_hidden_states(model, processor, dataset, checkpoint_dir='checkpoints', batch_size=16):
    os.makedirs(checkpoint_dir, exist_ok=True)

    positive_hidden_states = []
    negative_hidden_states = []

    total_examples = len(dataset)
    num_batches = (total_examples + batch_size - 1) // batch_size
    print(f"Total examples: {total_examples}, Number of batches: {num_batches}")

    # Load last checkpoint if exists
    last_processed_idx = -1
    checkpoint_files = sorted([f for f in os.listdir(checkpoint_dir) if f.startswith('pope_checkpoint_')])
    if checkpoint_files:
        latest_checkpoint = torch.load(os.path.join(checkpoint_dir, checkpoint_files[-1]))
        positive_hidden_states = latest_checkpoint['positive_hidden_states']
        negative_hidden_states = latest_checkpoint['negative_hidden_states']
        last_processed_idx = latest_checkpoint['last_processed_idx']
        print(f"Resuming from example {last_processed_idx + 1}")

    try:
        for batch_idx in tqdm(range((last_processed_idx + 1) // batch_size, num_batches)):
            start_idx = batch_idx * batch_size
            end_idx = min(start_idx + batch_size, total_examples)
            batch_items = dataset.select(range(start_idx, end_idx))

            pos_hidden, neg_hidden = process_batch(model, processor, batch_items, batch_size)

            positive_hidden_states.extend(pos_hidden)
            negative_hidden_states.extend(neg_hidden)

            # Save checkpoint every N batches
            if (batch_idx + 1) % 10 == 0:
                checkpoint_path = os.path.join(checkpoint_dir, f'pope_checkpoint_{end_idx-1}.pt')
                print(f"\nSaving checkpoint at example {end_idx}/{total_examples}")
                torch.save({
                    'positive_hidden_states': positive_hidden_states,
                    'negative_hidden_states': negative_hidden_states,
                    'last_processed_idx': end_idx - 1
                }, checkpoint_path)

    except Exception as e:
        print(f"Error occurred at batch {batch_idx}: {str(e)}")
        # Save checkpoint on error
        checkpoint_path = os.path.join(checkpoint_dir, f'pope_checkpoint_error_{batch_idx}.pt')
        torch.save({
            'positive_hidden_states': positive_hidden_states,
            'negative_hidden_states': negative_hidden_states,
            'last_processed_idx': start_idx - 1
        }, checkpoint_path)
        raise e

    return positive_hidden_states, negative_hidden_states

In [None]:
# Load dataset
print("Loading dataset...")
dataset = load_dataset("lmms-lab/POPE", split="test")

# Extract hidden states for all examples
print("Starting extraction...")
pos_hidden_states, neg_hidden_states = extract_hidden_states(model, processor, dataset, batch_size=36)

# Save final results
print("Saving final results...")
torch.save({
    'positive_hidden_states': pos_hidden_states,
    'negative_hidden_states': neg_hidden_states
}, 'pope_contrast_hidden_states_final.pt')

print("Done!")

Apply PGD and extract contrastive pair embedding from pertubed prompt

In [9]:
from torchvision import transforms

def pgd_attack(images, batch_items, model, processor, epsilon=8/255, alpha=2/255, num_iter=10):
    """
    Perform PGD attack on images with proper normalization handling
    """
    try:
        # Save original PIL images
        original_pil_images = [img.copy() for img in images]

        # First get the processor's normalization parameters
        mean = processor.image_processor.image_mean
        std = processor.image_processor.image_std
        target_size = processor.image_processor.size["shortest_edge"]

        # Convert PIL to tensors manually (without normalization)
        processed_images = []
        for img in images:
            # Resize
            if img.size[0] != target_size or img.size[1] != target_size:
                img = img.resize((target_size, target_size))

            # Convert to numpy and normalize to [0, 1]
            img_array = np.array(img).astype(np.float32) / 255.0
            # Convert to channel-first format
            img_array = img_array.transpose(2, 0, 1)
            processed_images.append(img_array)

        # Stack into tensor
        x = torch.tensor(np.stack(processed_images)).cpu()
        print(f"Initial tensor shape: {x.shape}")
        print(f"Initial tensor range: [{x.min()}, {x.max()}]")

        # Initialize perturbation
        delta = torch.zeros_like(x, requires_grad=True)

        # Create text inputs once
        pos_conversations = []
        for item in batch_items:
            pos_conversations.append([{
                "role": "user",
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": f"{item['question']} ASSISTANT: Yes"},
                ],
            }])

        text_tokens = processor.tokenizer(
            [processor.apply_chat_template(conv, add_generation_prompt=False)
             for conv in pos_conversations],
            return_tensors="pt",
            padding=True
        ).to('cuda')

        # PGD attack loop
        for i in range(num_iter):
            # Add perturbation and clamp to valid image range [0, 1]
            perturbed = torch.clamp(x + delta, 0, 1)

            # Convert to PIL image
            perturbed_array = (perturbed[0].detach().numpy() * 255).astype(np.uint8)
            perturbed_array = np.transpose(perturbed_array, (1, 2, 0))
            perturbed_pil = Image.fromarray(perturbed_array)

            # Process through processor
            inputs = processor(
                images=[perturbed_pil],
                text=[pos_conversations[0]],
                return_tensors="pt",
                padding=True
            ).to('cuda')

            # Forward pass
            outputs = model(**inputs, output_hidden_states=True)
            hidden_states = outputs.hidden_states[-1]
            loss = -hidden_states.mean()

            if i == 0:
                print(f"First iteration loss: {loss.item()}")

            # Backward pass
            loss.backward()

            # Update perturbation
            with torch.no_grad():
                grad = delta.grad.detach()
                delta.data = delta.data + alpha * grad.sign()
                delta.data = torch.clamp(delta.data, -epsilon, epsilon)
                delta.data = torch.clamp(x + delta.data, 0, 1) - x

            # Clear memory
            del outputs, loss, grad, hidden_states, inputs
            torch.cuda.empty_cache()
            delta.grad.zero_()

        # Get final perturbed images
        perturbed_final = torch.clamp(x + delta.detach(), 0, 1)

        # Convert to PIL images
        perturbed_images = []
        for img_tensor in perturbed_final:
            img_array = (img_tensor.numpy() * 255).astype(np.uint8)
            img_array = np.transpose(img_array, (1, 2, 0))
            perturbed_images.append(Image.fromarray(img_array))

        return perturbed_images

    except Exception as e:
        print(f"Error in PGD attack: {str(e)}")
        print(f"Error type: {type(e)}")
        import traceback
        print(f"Traceback: {traceback.format_exc()}")
        raise e
def get_forced_embedding_prompt(question, answer):
    """Create prompt for getting embeddings with forced answer"""
    return [{
        "role": "user",
        "content": [
            {"type": "image"},
            {"type": "text", "text": f"{question} ASSISTANT: {answer}"},
        ],
    }]

def get_generation_prompt(question):
    """Create prompt for free generation"""
    return [{
        "role": "user",
        "content": [
            {"type": "image"},
            {"type": "text", "text": f"{question}"},
        ],
    }]

def process_pope_batch_with_pgd(model, processor, batch_items, batch_size=8):
    print(f"Batch size: {len(batch_items)}")
    print(f"Sample image size: {batch_items[0]['image'].size}")

    # Get original images
    images = [item['image'] for item in batch_items]

    try:
        perturbed_images = pgd_attack(images, batch_items, model, processor)
        print(f"Number of perturbed images: {len(perturbed_images)}")
    except Exception as e:
        print(f"Error in image processing: {str(e)}")
        print(f"Image shapes: {[img.size for img in images]}")
        raise e

    pos_hidden_batch = []
    neg_hidden_batch = []
    generations = []

    # Process with perturbed images
    pos_conversations = []
    for item in batch_items:
        pos_conversations.append([{
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": f"{item['question']} ASSISTANT: Yes"},
            ],
        }])

    # Get the hidden states for Yes answers
    pos_prompts = [processor.apply_chat_template(conv, add_generation_prompt=False)
                  for conv in pos_conversations]
    pos_inputs = processor(
        images=perturbed_images,
        text=pos_prompts,
        return_tensors='pt',
        padding=True
    ).to('cuda', torch.float16)

    with torch.no_grad():
        pos_outputs = model(**pos_inputs, output_hidden_states=True)
        pos_hidden = [layer[:, -1, :].cpu().numpy() for layer in pos_outputs.hidden_states]
        pos_hidden_batch.extend(pos_hidden)

    del pos_inputs, pos_outputs
    torch.cuda.empty_cache()

    # Get the hidden states for No answers
    neg_conversations = []
    for item in batch_items:
        neg_conversations.append([{
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": f"{item['question']} ASSISTANT: No"},
            ],
        }])

    neg_prompts = [processor.apply_chat_template(conv, add_generation_prompt=False)
                  for conv in neg_conversations]
    neg_inputs = processor(
        images=perturbed_images,
        text=neg_prompts,
        return_tensors='pt',
        padding=True
    ).to('cuda', torch.float16)

    with torch.no_grad():
        neg_outputs = model(**neg_inputs, output_hidden_states=True)
        neg_hidden = [layer[:, -1, :].cpu().numpy() for layer in neg_outputs.hidden_states]
        neg_hidden_batch.extend(neg_hidden)

    del neg_inputs, neg_outputs
    torch.cuda.empty_cache()

    return pos_hidden_batch, neg_hidden_batch, generations

extract contrastive pair for IMDB dataset

In [10]:
def load_imdb_dataset():
    print("Loading IMDB dataset...")
    dataset = load_dataset("imdb", split="test")
    # Subsample to match size if needed
    return dataset.shuffle(seed=42).select(range(min(len(dataset), 3000)))

def get_imdb_forced_prompt(review, sentiment):
    """Create prompt for getting embeddings with forced sentiment"""
    return [{
        "role": "user",
        "content": [
            {"type": "text", "text": f"Is this movie review positive or negative? Review: {review} ASSISTANT: {sentiment}"}
        ]
    }]

def get_imdb_generation_prompt(review):
    """Create prompt for free generation"""
    return [{
        "role": "user",
        "content": [
            {"type": "text", "text": f"Is this movie review positive or negative? Review: {review}"}
        ]
    }]

def process_imdb_batch(model, processor, batch_items, batch_size=8):
    pos_hidden_batch = []
    neg_hidden_batch = []
    generations = []

    # First, get generations with free responses
    gen_conversations = []
    for item in batch_items:
        gen_conversations.append(get_imdb_generation_prompt(item['text']))

    gen_prompts = [processor.apply_chat_template(conv, add_generation_prompt=True)
                   for conv in gen_conversations]
    gen_inputs = processor(
        text=gen_prompts,
        return_tensors='pt',
        padding=True,
        truncation=True
    ).to('cuda', torch.float16)

    # Get generations
    with torch.no_grad():
        gen_outputs = model.generate(
            **gen_inputs,
            max_new_tokens=20,
            pad_token_id=processor.tokenizer.pad_token_id,
            return_dict_in_generate=True
        )

        # Decode generations
        for output_ids in gen_outputs.sequences:
            decoded = processor.decode(output_ids, skip_special_tokens=True)
            generations.append(decoded)

    del gen_inputs, gen_outputs
    torch.cuda.empty_cache()

    # Now get embeddings with forced answers
    # Process positive examples
    pos_conversations = []
    for item in batch_items:
        pos_conversations.append(get_imdb_forced_prompt(item['text'], "Positive"))

    pos_prompts = [processor.apply_chat_template(conv, add_generation_prompt=False)
                   for conv in pos_conversations]
    pos_inputs = processor(
        text=pos_prompts,
        return_tensors='pt',
        padding=True,
        truncation=True
    ).to('cuda', torch.float16)

    with torch.no_grad():
        pos_outputs = model(**pos_inputs, output_hidden_states=True)
        pos_hidden = [layer[:, -1, :].cpu().numpy() for layer in pos_outputs.hidden_states]
        pos_hidden_batch.extend(pos_hidden)

    del pos_inputs, pos_outputs
    torch.cuda.empty_cache()

    # Process negative examples
    neg_conversations = []
    for item in batch_items:
        neg_conversations.append(get_imdb_forced_prompt(item['text'], "Negative"))

    neg_prompts = [processor.apply_chat_template(conv, add_generation_prompt=False)
                   for conv in neg_conversations]
    neg_inputs = processor(
        text=neg_prompts,
        return_tensors='pt',
        padding=True,
        truncation=True
    ).to('cuda', torch.float16)

    with torch.no_grad():
        neg_outputs = model(**neg_inputs, output_hidden_states=True)
        neg_hidden = [layer[:, -1, :].cpu().numpy() for layer in neg_outputs.hidden_states]
        neg_hidden_batch.extend(neg_hidden)

    del neg_inputs, neg_outputs
    torch.cuda.empty_cache()

    return pos_hidden_batch, neg_hidden_batch, generations

In [11]:
def extract_hidden_states(model, processor, dataset, process_batch_fn, checkpoint_dir='checkpoints', batch_size=16):
    os.makedirs(checkpoint_dir, exist_ok=True)

    positive_hidden_states = []
    negative_hidden_states = []
    generations = []

    total_examples = len(dataset)
    num_batches = (total_examples + batch_size - 1) // batch_size
    print(f"Total examples: {total_examples}, Number of batches: {num_batches}")

    # Load last checkpoint if exists
    last_processed_idx = -1
    checkpoint_files = sorted([f for f in os.listdir(checkpoint_dir) if f.startswith('checkpoint_')])
    if checkpoint_files:
        latest_checkpoint = torch.load(os.path.join(checkpoint_dir, checkpoint_files[-1]))
        positive_hidden_states = latest_checkpoint['positive_hidden_states']
        negative_hidden_states = latest_checkpoint['negative_hidden_states']
        generations = latest_checkpoint.get('generations', [])
        last_processed_idx = latest_checkpoint['last_processed_idx']
        print(f"Resuming from example {last_processed_idx + 1}")

    try:
        for batch_idx in tqdm(range((last_processed_idx + 1) // batch_size, num_batches)):
            start_idx = batch_idx * batch_size
            end_idx = min(start_idx + batch_size, total_examples)
            batch_items = dataset.select(range(start_idx, end_idx))

            pos_hidden, neg_hidden, gens = process_batch_fn(model, processor, batch_items, batch_size)

            positive_hidden_states.extend(pos_hidden)
            negative_hidden_states.extend(neg_hidden)
            generations.extend(gens)

            # Save checkpoint every N batches
            if (batch_idx + 1) % 10 == 0:
                checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_{end_idx-1}.pt')
                print(f"\nSaving checkpoint at example {end_idx}/{total_examples}")
                torch.save({
                    'positive_hidden_states': positive_hidden_states,
                    'negative_hidden_states': negative_hidden_states,
                    'generations': generations,
                    'last_processed_idx': end_idx - 1
                }, checkpoint_path)

    except Exception as e:
        print(f"Error occurred at batch {batch_idx}: {str(e)}")
        checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_error_{batch_idx}.pt')
        torch.save({
            'positive_hidden_states': positive_hidden_states,
            'negative_hidden_states': negative_hidden_states,
            'generations': generations,
            'last_processed_idx': start_idx - 1
        }, checkpoint_path)
        raise e

    return positive_hidden_states, negative_hidden_states, generations

In [None]:
# For POPE with PGD:
pope_pos_hidden, pope_neg_hidden, pope_pos_gens, pope_neg_gens = extract_hidden_states(
    model, processor, dataset,
    process_batch_fn=process_pope_batch_with_pgd,
    checkpoint_dir='pope_pgd_checkpoints',
    batch_size=1
)

torch.save({
    'positive_hidden_states': pope_pos_hidden,
    'negative_hidden_states': pope_neg_hidden,
    'positive_generations': pope_pos_gens,
    'negative_generations': pope_neg_gens
}, 'pope_contrast_hidden_states_pgd_final.pt')

In [75]:
print(processor.image_processor.size)
print(processor.image_processor.do_resize)

{'shortest_edge': 336}
True


In [None]:
imdb_dataset = load_imdb_dataset()

# For IMDB:
imdb_pos_hidden, imdb_neg_hidden, imdb_pos_gens, imdb_neg_gens = extract_hidden_states(
    model, processor, imdb_dataset,
    process_batch_fn=process_imdb_batch,
    checkpoint_dir='imdb_checkpoints',
    batch_size=36
)

torch.save({
    'positive_hidden_states': imdb_pos_hidden,
    'negative_hidden_states': imdb_neg_hidden,
    'positive_generations': imdb_pos_gens,
    'negative_generations': imdb_neg_gens
}, 'imdb_contrast_hidden_states_final.pt')

Train and evaluate CCS on the POPE dataset

In [29]:
import copy

class CCS(object):
    def __init__(self, x0, x1, nepochs=1000, ntries=10, lr=1e-3, batch_size=-1,
                 verbose=False, device="cuda", linear=True, weight_decay=0.01, var_normalize=False):
        # data
        self.var_normalize = var_normalize
        self.x0 = self.normalize(x0)
        self.x1 = self.normalize(x1)
        self.d = self.x0.shape[-1]

        # training
        self.nepochs = nepochs
        self.ntries = ntries
        self.lr = lr
        self.verbose = verbose
        self.device = device
        self.batch_size = batch_size
        self.weight_decay = weight_decay

        # probe
        self.linear = linear
        self.probe = self.initialize_probe()
        self.best_probe = copy.deepcopy(self.probe)

        self.scaler = torch.cuda.amp.GradScaler()

        # Add progress tracking
        self.train_losses = []
        self.best_loss = float('inf')

    def initialize_probe(self):
        if self.linear:
            probe = nn.Sequential(nn.Linear(self.d, 1), nn.Sigmoid())
        else:
            probe = MLPProbe(self.d)
        probe.to(self.device)
        return probe

    def normalize(self, x):
        normalized_x = x - x.mean(axis=0, keepdims=True)
        if self.var_normalize:
            normalized_x /= normalized_x.std(axis=0, keepdims=True)
        return normalized_x

    def get_tensor_data(self):
        x0 = torch.tensor(self.x0, dtype=torch.float, requires_grad=False, device=self.device)
        x1 = torch.tensor(self.x1, dtype=torch.float, requires_grad=False, device=self.device)
        return x0, x1

    def get_loss(self, p0, p1):
        informative_loss = (torch.min(p0, p1)**2).mean(0)
        consistent_loss = ((p0 - (1-p1))**2).mean(0)
        return informative_loss + consistent_loss

    def get_acc(self, x0_test, x1_test, y_test):
        x0 = torch.tensor(self.normalize(x0_test), dtype=torch.float, requires_grad=False, device=self.device)
        x1 = torch.tensor(self.normalize(x1_test), dtype=torch.float, requires_grad=False, device=self.device)
        with torch.no_grad():
            p0, p1 = self.best_probe(x0), self.best_probe(x1)
        avg_confidence = 0.5*(p0 + (1-p1))
        predictions = (avg_confidence.detach().cpu().numpy() < 0.5).astype(int)[:, 0]
        acc = (predictions == y_test).mean()
        return max(acc, 1 - acc)

    def train(self):
        """Optimized training with progress tracking and mixed precision"""
        x0, x1 = self.get_tensor_data()

        no_improve_epochs = 0
        self.best_loss = float('inf')
        # Increase batch size for A100
        batch_size = len(x0) if self.batch_size == -1 else self.batch_size
        # Use larger batches on A100
        if torch.cuda.get_device_properties(0).total_memory >= 40e9:  # Check if GPU has >= 40GB memory
            batch_size = min(batch_size * 4, len(x0))

        nbatches = len(x0) // batch_size

        # Initialize optimizer with larger learning rate for bigger batches
        base_lr = self.lr * (batch_size / 32)  # Scale learning rate with batch size
        optimizer = torch.optim.AdamW(self.probe.parameters(), lr=base_lr, weight_decay=self.weight_decay)

        # Add learning rate scheduler
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

        pbar = tqdm(range(self.nepochs), desc="Training CCS", disable=not self.verbose)
        for epoch in pbar:
            epoch_losses = []
            # Shuffle data each epoch
            permutation = torch.randperm(len(x0), device=self.device)
            x0_shuffled = x0[permutation]
            x1_shuffled = x1[permutation]

            for j in range(nbatches):
                start_idx = j * batch_size
                end_idx = start_idx + batch_size
                x0_batch = x0_shuffled[start_idx:end_idx]
                x1_batch = x1_shuffled[start_idx:end_idx]

                # Use automatic mixed precision
                with torch.cuda.amp.autocast():
                    p0, p1 = self.probe(x0_batch), self.probe(x1_batch)
                    loss = self.get_loss(p0, p1)

                # Scaled backprop
                optimizer.zero_grad()
                self.scaler.scale(loss).backward()
                self.scaler.step(optimizer)
                self.scaler.update()

                epoch_losses.append(loss.detach().cpu().item())

            avg_epoch_loss = np.mean(epoch_losses)
            self.train_losses.append(avg_epoch_loss)

            # Update learning rate
            scheduler.step(avg_epoch_loss)

            # Update progress bar
            pbar.set_postfix({
                'loss': f'{avg_epoch_loss:.4f}',
                'lr': f'{optimizer.param_groups[0]["lr"]:.2e}'
            })

            # Early stopping
            if avg_epoch_loss < self.best_loss:
                self.best_loss = avg_epoch_loss
                no_improve_epochs = 0
            else:
                no_improve_epochs += 1

            if no_improve_epochs >= 20:  # Early stopping after 20 epochs without improvement
                if self.verbose:
                    print(f"Early stopping at epoch {epoch}")
                break

        return self.best_loss

    def repeated_train(self):
        """Run multiple training attempts with progress tracking"""
        best_overall_loss = float('inf')

        for trial in tqdm(range(self.ntries), desc="Training attempts"):
            self.initialize_probe()
            loss = self.train()

            if loss < best_overall_loss:
                best_overall_loss = loss
                self.best_probe = copy.deepcopy(self.probe)

            if self.verbose:
                print(f"Trial {trial + 1}/{self.ntries}: Loss = {loss:.4f}")

        return best_overall_loss

In [None]:
# First, let's look at what's in the POPE dataset
dataset = load_dataset("lmms-lab/POPE", split="test")
print("POPE Dataset example:")
print(dataset[0])

# Count actual Yes/No distribution in POPE
pope_answers = [example['answer'] for example in dataset]
print(pope_answers)
yes_count = sum(1 for a in pope_answers if a == 'Yes')
no_count = sum(1 for a in pope_answers if a == 'No')
print(f"\nPOPE Dataset distribution:")
print(f"Yes answers: {yes_count}")
print(f"No answers: {no_count}")

In [None]:
def load_vision_language_embeddings(checkpoint_path, dataset, layer_idx=-1):
    """
    Load vision-language embeddings from checkpoint and match with ground truth labels
    Args:
        checkpoint_path: path to the checkpoint file
        dataset: POPE dataset for ground truth labels
        layer_idx: which layer to use (default: last layer)
    """
    checkpoint = torch.load(checkpoint_path)

    pos_hidden_states = checkpoint['positive_hidden_states']
    neg_hidden_states = checkpoint['negative_hidden_states']

    pos_hidden = np.array([states[layer_idx] for states in pos_hidden_states])
    neg_hidden = np.array([states[layer_idx] for states in neg_hidden_states])

    assert pos_hidden.shape == neg_hidden.shape
    print(f"Loaded embeddings with shape: {pos_hidden.shape}")

    ground_truth = np.array([1 if example['answer'] == 'yes' else 0 for example in dataset])

    # Add debug prints
    print(f"Number of total examples: {len(ground_truth)}")
    print(f"Distribution of labels: {np.bincount(ground_truth)}")

    return pos_hidden, neg_hidden, ground_truth

# Load data
dataset = load_dataset("lmms-lab/POPE", split="test")
checkpoint_path = "pope_contrast_hidden_states_final.pt"
pos_hs, neg_hs, ground_truth = load_vision_language_embeddings(checkpoint_path, dataset)

# Split data with explicit train/test sets
np.random.seed(42)  # For reproducibility
all_indices = np.arange(len(pos_hs))
train_size = len(pos_hs) // 5
train_indices = np.random.choice(all_indices, train_size, replace=False)
test_indices = np.array(list(set(all_indices) - set(train_indices)))

# Training data
pos_hs_train = pos_hs[train_indices]
neg_hs_train = neg_hs[train_indices]

# Test data
pos_hs_test = pos_hs[test_indices]
neg_hs_test = neg_hs[test_indices]
y_test = ground_truth[test_indices]

print(f"\nTraining set size: {len(train_indices)}")
print(f"Test set size: {len(test_indices)}")
print(f"Test label distribution: {np.bincount(y_test)}")

In [None]:
# Initialize with verbose mode to see progress
ccs = CCS(
    neg_hs_train,
    pos_hs_train,
    nepochs=10,
    ntries=10,
    lr=0.0001,
    batch_size=16,
    device="cuda",
    var_normalize=True,
    verbose=True
)

# Train with progress tracking
best_loss = ccs.repeated_train()

# Plot training loss history if desired
import matplotlib.pyplot as plt
plt.plot(ccs.train_losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('CCS Training Loss')
plt.show()

# Save the trained model
torch.save({
    'probe_state_dict': ccs.best_probe.state_dict(),
    'training_loss': best_loss
}, 'ccs_model.pt')

In [None]:
# Evaluate on test set
test_acc = ccs.get_acc(neg_hs_test, pos_hs_test, y_test)
print(f"CCS Test Accuracy: {test_acc:.4f}")

# For comparison, train and evaluate logistic regression
from sklearn.linear_model import LogisticRegression

x_train = neg_hs_train - pos_hs_train
x_test = neg_hs_test - pos_hs_test
y_train = ground_truth[train_indices]  # Supervised baseline needs training labels
y_test = ground_truth[test_indices]

lr = LogisticRegression(class_weight="balanced")
lr.fit(x_train, y_train)
lr_acc = lr.score(x_test, y_test)
print(f"Logistic Regression Test Accuracy: {lr_acc:.4f}")

Train and evaluate CCS in transfer case and adversarial case

In [None]:
def load_and_prepare_imdb_data():
    """Load and prepare IMDB embeddings"""
    checkpoint = torch.load('imdb_contrast_hidden_states_final.pt')
    pos_hidden = checkpoint['positive_hidden_states']
    neg_hidden = checkpoint['negative_hidden_states']
    
    # Create balanced labels 
    labels = np.concatenate([np.ones(len(pos_hidden)//2), np.zeros(len(pos_hidden)//2)])
    
    return pos_hidden, neg_hidden, labels

def load_and_prepare_pope_data(filename='pope_contrast_hidden_states_final.pt'):
    """Load and prepare POPE embeddings"""
    dataset = load_dataset("lmms-lab/POPE", split="test")
    checkpoint = torch.load(filename)
    
    pos_hidden = checkpoint['positive_hidden_states']
    neg_hidden = checkpoint['negative_hidden_states']
    
    # Create labels from dataset
    labels = np.array([1 if example['answer'].lower() == 'yes' else 0 
                      for example in dataset])
    
    return pos_hidden, neg_hidden, labels

def cross_domain_experiment():
    """Run cross-domain experiments between IMDB and POPE"""
    # Load datasets
    imdb_pos, imdb_neg, imdb_labels = load_and_prepare_imdb_data()
    pope_pos, pope_neg, pope_labels = load_and_prepare_pope_data()
    pope_adv_pos, pope_adv_neg, pope_adv_labels = load_and_prepare_pope_data(
        'pope_contrast_hidden_states_pgd_final.pt'
    )
    
    results = {}
    
    # 1. Train on IMDB, test on POPE
    print("\nTraining on IMDB, testing on POPE...")
    ccs_imdb = CCS(imdb_pos, imdb_neg, verbose=True)
    ccs_imdb.repeated_train()
    acc_imdb_pope = ccs_imdb.get_acc(pope_pos, pope_neg, pope_labels)
    results['IMDB->POPE'] = acc_imdb_pope
    print(f"IMDB->POPE Accuracy: {acc_imdb_pope:.4f}")
    
    # 2. Train on POPE, test on Adversarial POPE
    print("\nTraining on POPE, testing on Adversarial POPE...")
    ccs_pope = CCS(pope_pos, pope_neg, verbose=True)
    ccs_pope.repeated_train()
    acc_pope_adv = ccs_pope.get_acc(pope_adv_pos, pope_adv_neg, pope_adv_labels)
    results['POPE->ADV_POPE'] = acc_pope_adv
    print(f"POPE->Adversarial POPE Accuracy: {acc_pope_adv:.4f}")
    
    # Save results
    torch.save(results, 'cross_domain_results.pt')
    return results

In [None]:
# Run cross-domain experiments
print("\nRunning cross-domain experiments...")
cross_domain_results = cross_domain_experiment()


plt.figure(figsize=(10, 6))
plt.bar(cross_domain_results.keys(), cross_domain_results.values())
plt.title('Cross-Domain Transfer Results')
plt.ylabel('Accuracy')
plt.ylim(0, 1)
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

Code for clearing out GPU memory in case of error

mport gc
import torch
import subprocess
import numpy as np

# # Delete all variables in global scope
# for obj in dir():
#     if not obj.startswith('_'):
#         del globals()[obj]

# Clear PyTorch cache
if torch.cuda.is_available():
    # Empty cache
    torch.cuda.empty_cache()
    # Reset peak stats
    torch.cuda.reset_peak_memory_stats()
    # Force synchronization
    torch.cuda.synchronize()

# Delete all existing tensors
for obj in gc.get_objects():
    try:
        if torch.is_tensor(obj):
            del obj
        elif hasattr(obj, 'data') and torch.is_tensor(obj.data):
            del obj.data
    except:
        pass

# Clear TensorFlow memory if present
try:
    import tensorflow as tf
    tf.keras.backend.clear_session()
    # Reset TF GPU configuration
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
except:
    pass

# Aggressive garbage collection
gc.collect()
torch.cuda.empty_cache()


# Force GPU reset (might require restart)
try:
    subprocess.run(['nvidia-smi', '--gpu-reset'], check=True)
except:
    pass

# Print current GPU memory usage
print("\nCurrent GPU Memory Status:")
!nvidia-smi