# What's Up VLMs Dataset Browser

This notebook helps browse and understand the Kamatha "What's Up" dataset with multiple subsets:
- Visual Genome Q&A (one/two objects)
- COCO Q&A (one/two objects) 
- Controlled images
- Controlled CLEVR

## Dataset Format

**Important:** Each dataset item follows this structure:
```
[image_id, correct_caption, incorrect_caption, ...]
```

- **Index 0**: Image ID (or filename)
- **Index 1**: CORRECT caption (always first)
- **Index 2**: INCORRECT caption (always second)
- Additional fields may exist at indices 3+

The datasets test VLM understanding of spatial relations by providing pairs of captions where only one correctly describes the spatial relationship in the image.

In [None]:
import json
import os
from pathlib import Path
import matplotlib.pyplot as plt
from PIL import Image
import random
from collections import Counter
import ipywidgets as widgets
from IPython.display import display, clear_output

## Configuration

In [None]:
# Base path on cluster
BASE_PATH = Path("/leonardo_work/EUHPC_D27_102/compmech/whatsup_vlms_data/")

# Dataset configurations
DATASETS = {
    'vg_two_obj': {
        'json': BASE_PATH / 'vg_qa_two_obj.json',
        'images': BASE_PATH / 'vg_images',
        'description': 'Visual Genome - Two objects (left/right)'
    },
    'vg_one_obj': {
        'json': BASE_PATH / 'vg_qa_one_obj.json',
        'images': BASE_PATH / 'vg_images',
        'description': 'Visual Genome - One object (left/right)'
    },
    'controlled_images': {
        'json': BASE_PATH / 'controlled_images_dataset.json',
        'images': BASE_PATH / 'controlled_images',
        'description': 'Controlled images - Two objects (left/right/up/down)'
    },
    'coco_two_obj': {
        'json': BASE_PATH / 'coco_qa_two_obj.json',
        'images': BASE_PATH / 'val2017',
        'description': 'COCO - Two objects (up/down/left/right)'
    },
    'coco_one_obj': {
        'json': BASE_PATH / 'coco_qa_one_obj.json',
        'images': BASE_PATH / 'val2017',
        'description': 'COCO - One object (up/down/left/right)'
    },
    'controlled_clevr': {
        'json': BASE_PATH / 'controlled_clevr_dataset.json',
        'images': BASE_PATH / 'controlled_clevr',
        'description': 'Controlled CLEVR (front/behind/left/right)'
    }
}

## Load Dataset

In [None]:
def load_dataset(dataset_key):
    """Load a specific dataset

    Dataset format: List of [image_id, correct_caption, incorrect_caption(s), ...]
    - First caption option (index 1) is always CORRECT
    - Second caption option (index 2) is always INCORRECT
    """
    config = DATASETS[dataset_key]
    with open(config['json'], 'r') as f:
        data = json.load(f)
    return data, config

def get_image_path(config, image_id, dataset_key):
    """Get full path to image based on dataset conventions"""
    if 'coco' in dataset_key:
        # COCO uses zero-padded 12-digit IDs
        filename = f"{str(image_id).zfill(12)}.jpg"
    else:
        # Other datasets may use different conventions
        filename = f"{image_id}.jpg" if not str(image_id).endswith('.jpg') else image_id
    return config['images'] / filename

## Dataset Statistics

In [None]:
def analyze_dataset(data, dataset_key):
    """Analyze and print dataset statistics"""
    print(f"\n{'='*60}")
    print(f"Dataset: {DATASETS[dataset_key]['description']}")
    print(f"{'='*60}")
    
    # Total samples
    print(f"\nTotal samples: {len(data)}")
    
    # Data format: [image_id, correct_caption, incorrect_caption, ...]
    if data:
        print(f"\nData format: [image_id, correct_caption, incorrect_caption, ...]")
        print(f"Number of fields per item: {len(data[0])}")
        print(f"\nFirst example:")
        example = data[0]
        print(f"  Image ID: {example[0]}")
        print(f"  Correct caption: {example[1]}")
        print(f"  Incorrect caption: {example[2]}")
        if len(example) > 3:
            print(f"  Additional fields: {example[3:]}")
    
    # Extract correct captions to analyze prepositions
    correct_captions = [item[1] for item in data]
    
    # Extract spatial relations (prepositions)
    spatial_words = ['left', 'right', 'above', 'below', 'up', 'down', 'front', 'behind']
    caption_relations = []
    
    for caption in correct_captions:
        caption_lower = caption.lower()
        for word in spatial_words:
            if word in caption_lower:
                caption_relations.append(word)
                break
    
    if caption_relations:
        relation_dist = Counter(caption_relations)
        print(f"\nSpatial relation distribution:")
        for rel, count in relation_dist.most_common():
            print(f"  {rel}: {count} ({count/len(data)*100:.1f}%)")
    
    # Show some example captions
    print(f"\nExample correct captions:")
    for caption in correct_captions[:5]:
        print(f"  - {caption}")
    
    return data

In [None]:
# Analyze all datasets
for dataset_key in DATASETS.keys():
    try:
        data, config = load_dataset(dataset_key)
        analyze_dataset(data, dataset_key)
    except Exception as e:
        print(f"\nError loading {dataset_key}: {e}")

## Browse Examples

In [None]:
def display_example(data, config, index, dataset_key):
    """Display a single example with image and captions
    
    Data format: [image_id, correct_caption, incorrect_caption, ...]
    """
    example = data[index]
    
    image_id = example[0]
    correct_caption = example[1]
    incorrect_caption = example[2]
    
    # Create figure with image and text
    fig, ax = plt.subplots(1, 1, figsize=(12, 10))
    
    # Load and display image
    img_path = get_image_path(config, image_id, dataset_key)
    
    if img_path.exists():
        img = Image.open(img_path)
        ax.imshow(img)
        ax.axis('off')
    else:
        ax.text(0.5, 0.5, f"Image not found:\n{img_path}", 
                ha='center', va='center', fontsize=10)
        ax.axis('off')
    
    # Display captions as title with color coding
    title = f"✓ CORRECT: {correct_caption}\n✗ INCORRECT: {incorrect_caption}"
    plt.title(title, fontsize=12, pad=20, loc='left')
    
    # Print all metadata
    print(f"\nExample {index + 1}/{len(data)}")
    print("-" * 70)
    print(f"Image ID: {image_id}")
    print(f"Image path: {img_path}")
    print(f"\n✓ CORRECT caption:   {correct_caption}")
    print(f"✗ INCORRECT caption: {incorrect_caption}")
    
    if len(example) > 3:
        print(f"\nAdditional fields: {example[3:]}")
    
    plt.tight_layout()
    plt.show()

In [None]:
# Interactive browser
def create_browser(dataset_key):
    """Create an interactive browser for a dataset"""
    data, config = load_dataset(dataset_key)
    
    # Widgets
    index_slider = widgets.IntSlider(
        value=0,
        min=0,
        max=len(data)-1,
        step=1,
        description='Index:',
        continuous_update=False,
        style={'description_width': 'initial'}
    )
    
    random_button = widgets.Button(description="Random Example")
    
    # Filter by spatial relation
    correct_captions = [item[1] for item in data]
    spatial_words = ['left', 'right', 'above', 'below', 'up', 'down', 'front', 'behind']
    relations_found = set()
    for caption in correct_captions:
        for word in spatial_words:
            if word in caption.lower():
                relations_found.add(word)
    
    filter_dropdown = widgets.Dropdown(
        options=['all'] + sorted(list(relations_found)),
        value='all',
        description='Filter by:',
        style={'description_width': 'initial'}
    )
    
    output = widgets.Output()
    
    # Store filtered indices
    filtered_indices = list(range(len(data)))
    
    def update_filtered_indices():
        nonlocal filtered_indices
        if filter_dropdown.value == 'all':
            filtered_indices = list(range(len(data)))
        else:
            filtered_indices = [i for i, item in enumerate(data) 
                              if filter_dropdown.value in item[1].lower()]
        index_slider.max = max(0, len(filtered_indices) - 1)
        index_slider.value = 0
    
    def on_index_change(change):
        with output:
            clear_output(wait=True)
            if filtered_indices:
                actual_index = filtered_indices[change['new']]
                display_example(data, config, actual_index, dataset_key)
    
    def on_random_click(b):
        if filtered_indices:
            index_slider.value = random.randint(0, len(filtered_indices)-1)
    
    def on_filter_change(change):
        update_filtered_indices()
        with output:
            clear_output(wait=True)
            if filtered_indices:
                display_example(data, config, filtered_indices[0], dataset_key)
    
    index_slider.observe(on_index_change, names='value')
    random_button.on_click(on_random_click)
    filter_dropdown.observe(on_filter_change, names='value')
    
    # Initial display
    with output:
        display_example(data, config, 0, dataset_key)
    
    controls = widgets.HBox([index_slider, filter_dropdown, random_button])
    display(widgets.VBox([controls, output]))

# Example: Browse a specific dataset
print("Available datasets:")
for key, config in DATASETS.items():
    print(f"  - {key}: {config['description']}")

print("\nUse: create_browser('dataset_key') to browse")

In [None]:
# Browse Visual Genome two objects
create_browser('vg_two_obj')

In [None]:
# Browse COCO two objects
create_browser('coco_two_obj')

In [None]:
# Browse Controlled images
create_browser('controlled_images')

In [None]:
# Browse Controlled CLEVR
create_browser('controlled_clevr')

## Compare Across Datasets

In [None]:
def compare_datasets():
    """Compare statistics across all datasets"""
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    axes = axes.flatten()
    
    for idx, (dataset_key, config) in enumerate(DATASETS.items()):
        try:
            data, _ = load_dataset(dataset_key)
            
            # Extract spatial relations from correct captions
            correct_captions = [item[1] for item in data]
            spatial_words = ['left', 'right', 'above', 'below', 'up', 'down', 'front', 'behind']
            caption_relations = []
            
            for caption in correct_captions:
                caption_lower = caption.lower()
                for word in spatial_words:
                    if word in caption_lower:
                        caption_relations.append(word)
                        break
            
            if caption_relations:
                relation_dist = Counter(caption_relations)
                labels = list(relation_dist.keys())
                values = list(relation_dist.values())
                
                axes[idx].bar(labels, values, color='steelblue')
                axes[idx].set_title(f"{dataset_key}\n({len(data)} samples)", fontsize=11, fontweight='bold')
                axes[idx].set_ylabel('Count', fontsize=10)
                axes[idx].tick_params(axis='x', rotation=45, labelsize=9)
                axes[idx].grid(axis='y', alpha=0.3)
        except Exception as e:
            axes[idx].text(0.5, 0.5, f"Error: {str(e)[:50]}", 
                          ha='center', va='center', transform=axes[idx].transAxes,
                          fontsize=9, wrap=True)
            axes[idx].set_title(f"{dataset_key}\n(Error)", fontsize=11)
    
    plt.suptitle('Spatial Relation Distribution Across Datasets', fontsize=14, fontweight='bold', y=1.00)
    plt.tight_layout()
    plt.show()

compare_datasets()

## Filter and Search

In [None]:
def search_examples(dataset_key, spatial_relation=None, caption_contains=None):
    """Search for specific examples
    
    Args:
        dataset_key: Which dataset to search
        spatial_relation: Filter by spatial relation (e.g., 'left', 'right', 'above')
        caption_contains: Filter by text in correct caption
    """
    data, config = load_dataset(dataset_key)
    
    filtered = data
    
    if spatial_relation:
        filtered = [item for item in filtered 
                   if spatial_relation.lower() in item[1].lower()]
    
    if caption_contains:
        filtered = [item for item in filtered 
                   if caption_contains.lower() in item[1].lower()]
    
    print(f"Found {len(filtered)} examples matching criteria")
    
    # Show first few examples
    if filtered:
        print("\nFirst 5 matches:")
        for i, item in enumerate(filtered[:5]):
            print(f"\n{i+1}. Image ID: {item[0]}")
            print(f"   Correct: {item[1]}")
            print(f"   Incorrect: {item[2]}")
    
    return filtered, config

# Example usage (uncomment to use):
# filtered_data, config = search_examples('coco_two_obj', spatial_relation='left')
# if filtered_data:
#     display_example(filtered_data, config, 0, 'coco_two_obj')