# What's Up VLMs Dataset Browser

This notebook helps browse and understand the Kamatha "What's Up" dataset with multiple subsets:
- Visual Genome Q&A (one/two objects)
- COCO Q&A (one/two objects) 
- Controlled images
- Controlled CLEVR

In [None]:
import json
import os
from pathlib import Path
import matplotlib.pyplot as plt
from PIL import Image
import random
from collections import Counter
import ipywidgets as widgets
from IPython.display import display, clear_output

## Configuration

In [None]:
# Base path on cluster
BASE_PATH = Path("/leonardo_work/EUHPC_D27_102/compmech/whatsup_vlms_data/")

# Dataset configurations
DATASETS = {
    'vg_two_obj': {
        'json': BASE_PATH / 'vg_qa_two_obj.json',
        'images': BASE_PATH / 'vg_images',
        'description': 'Visual Genome - Two objects (left/right)'
    },
    'vg_one_obj': {
        'json': BASE_PATH / 'vg_qa_one_obj.json',
        'images': BASE_PATH / 'vg_images',
        'description': 'Visual Genome - One object (left/right)'
    },
    'controlled_images': {
        'json': BASE_PATH / 'controlled_images_dataset.json',
        'images': BASE_PATH / 'controlled_images',
        'description': 'Controlled images - Two objects (left/right/up/down)'
    },
    'coco_two_obj': {
        'json': BASE_PATH / 'coco_qa_two_obj.json',
        'images': BASE_PATH / 'val2017',
        'description': 'COCO - Two objects (up/down/left/right)'
    },
    'coco_one_obj': {
        'json': BASE_PATH / 'coco_qa_one_obj.json',
        'images': BASE_PATH / 'val2017',
        'description': 'COCO - One object (up/down/left/right)'
    },
    'controlled_clevr': {
        'json': BASE_PATH / 'controlled_clevr_dataset.json',
        'images': BASE_PATH / 'controlled_clevr',
        'description': 'Controlled CLEVR (front/behind/left/right)'
    }
}

## Load Dataset

In [None]:
def load_dataset(dataset_key):
    """Load a specific dataset"""
    config = DATASETS[dataset_key]
    with open(config['json'], 'r') as f:
        data = json.load(f)
    return data, config

def get_image_path(config, image_filename):
    """Get full path to image"""
    return config['images'] / image_filename

## Dataset Statistics

In [None]:
def analyze_dataset(data, dataset_key):
    """Analyze and print dataset statistics"""
    print(f"\n{'='*60}")
    print(f"Dataset: {DATASETS[dataset_key]['description']}")
    print(f"{'='*60}")
    
    # Total samples
    print(f"\nTotal samples: {len(data)}")
    
    # Sample structure
    if data:
        print(f"\nSample keys: {list(data[0].keys())}")
        print(f"\nFirst example:")
        for key, value in data[0].items():
            if isinstance(value, str) and len(value) > 100:
                print(f"  {key}: {value[:100]}...")
            else:
                print(f"  {key}: {value}")
    
    # Analyze questions and answers if they exist
    if 'question' in data[0]:
        questions = [item.get('question', '') for item in data]
        unique_questions = set(questions)
        print(f"\nUnique questions: {len(unique_questions)}")
        print("Question examples:")
        for q in list(unique_questions)[:5]:
            print(f"  - {q}")
    
    if 'answer' in data[0]:
        answers = [item.get('answer', '') for item in data]
        answer_dist = Counter(answers)
        print(f"\nAnswer distribution:")
        for ans, count in answer_dist.most_common():
            print(f"  {ans}: {count} ({count/len(data)*100:.1f}%)")
    
    return data

In [None]:
# Analyze all datasets
for dataset_key in DATASETS.keys():
    try:
        data, config = load_dataset(dataset_key)
        analyze_dataset(data, dataset_key)
    except Exception as e:
        print(f"\nError loading {dataset_key}: {e}")

## Browse Examples

In [None]:
def display_example(data, config, index):
    """Display a single example with image and metadata"""
    example = data[index]
    
    # Create figure
    fig, ax = plt.subplots(1, 1, figsize=(10, 8))
    
    # Load and display image
    if 'image' in example:
        img_path = get_image_path(config, example['image'])
    elif 'image_id' in example:
        # Handle different image naming conventions
        img_path = get_image_path(config, f"{example['image_id']}.jpg")
    else:
        print("No image field found")
        return
    
    if img_path.exists():
        img = Image.open(img_path)
        ax.imshow(img)
        ax.axis('off')
    else:
        print(f"Image not found: {img_path}")
    
    # Display metadata
    title_parts = []
    if 'question' in example:
        title_parts.append(f"Q: {example['question']}")
    if 'answer' in example:
        title_parts.append(f"A: {example['answer']}")
    
    plt.title('\n'.join(title_parts), fontsize=12, pad=20)
    
    # Print all metadata
    print(f"\nExample {index + 1}/{len(data)}")
    print("-" * 60)
    for key, value in example.items():
        if key not in ['image', 'image_id']:  # Already shown
            print(f"{key}: {value}")
    
    plt.tight_layout()
    plt.show()

In [None]:
# Interactive browser
def create_browser(dataset_key):
    """Create an interactive browser for a dataset"""
    data, config = load_dataset(dataset_key)
    
    # Widgets
    index_slider = widgets.IntSlider(
        value=0,
        min=0,
        max=len(data)-1,
        step=1,
        description='Index:',
        continuous_update=False
    )
    
    random_button = widgets.Button(description="Random Example")
    
    output = widgets.Output()
    
    def on_index_change(change):
        with output:
            clear_output(wait=True)
            display_example(data, config, change['new'])
    
    def on_random_click(b):
        index_slider.value = random.randint(0, len(data)-1)
    
    index_slider.observe(on_index_change, names='value')
    random_button.on_click(on_random_click)
    
    # Initial display
    with output:
        display_example(data, config, 0)
    
    display(widgets.VBox([index_slider, random_button, output]))

# Example: Browse a specific dataset
print("Available datasets:")
for key, config in DATASETS.items():
    print(f"  - {key}: {config['description']}")

print("\nUse: create_browser('dataset_key') to browse")

In [None]:
# Browse Visual Genome two objects
create_browser('vg_two_obj')

In [None]:
# Browse COCO two objects
create_browser('coco_two_obj')

In [None]:
# Browse Controlled images
create_browser('controlled_images')

In [None]:
# Browse Controlled CLEVR
create_browser('controlled_clevr')

## Compare Across Datasets

In [None]:
def compare_datasets():
    """Compare statistics across all datasets"""
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes = axes.flatten()
    
    for idx, (dataset_key, config) in enumerate(DATASETS.items()):
        try:
            data, _ = load_dataset(dataset_key)
            
            # Count answers
            if 'answer' in data[0]:
                answers = [item.get('answer', '') for item in data]
                answer_dist = Counter(answers)
                
                labels = list(answer_dist.keys())
                values = list(answer_dist.values())
                
                axes[idx].bar(labels, values)
                axes[idx].set_title(f"{dataset_key}\n({len(data)} samples)", fontsize=10)
                axes[idx].set_ylabel('Count')
                axes[idx].tick_params(axis='x', rotation=45)
        except Exception as e:
            axes[idx].text(0.5, 0.5, f"Error: {e}", 
                          ha='center', va='center', transform=axes[idx].transAxes)
    
    plt.tight_layout()
    plt.show()

compare_datasets()

## Filter and Search

In [None]:
def search_examples(dataset_key, answer=None, question_contains=None):
    """Search for specific examples"""
    data, config = load_dataset(dataset_key)
    
    filtered = data
    
    if answer:
        filtered = [item for item in filtered if item.get('answer') == answer]
    
    if question_contains:
        filtered = [item for item in filtered 
                   if question_contains.lower() in item.get('question', '').lower()]
    
    print(f"Found {len(filtered)} examples")
    
    return filtered, config

# Example: Find all "left" answers in VG dataset
# filtered_data, config = search_examples('vg_two_obj', answer='left')
# if filtered_data:
#     display_example(filtered_data, config, 0)