# Dataset Exploration Notebook

This notebook allows you to visualize images and prompts from the spatial reasoning datasets.

In [None]:
import matplotlib.pyplot as plt
import json
from dataset_zoo import get_dataset
import numpy as np
from PIL import Image

%matplotlib inline

## Configuration

In [None]:
# Choose dataset
DATASET_NAME = "Controlled_Images_A"  # Options: Controlled_Images_A, Controlled_Images_B, COCO_QA_one_obj, COCO_QA_two_obj, VG_QA_one_obj, VG_QA_two_obj
DATA_DIR = "/leonardo_work/EUHPC_D27_102/compmech/whatsup_vlms_data"  # Path to data

# Load prompts
dataset_options = {
    "COCO_QA_one_obj": "four",
    "COCO_QA_two_obj": "four",
    "Controlled_Images_A": "four",
    "Controlled_Images_B": "four",
    "VG_QA_one_obj": "six",
    "VG_QA_two_obj": "six",
}
option = dataset_options.get(DATASET_NAME, "four")

## Load Dataset

In [None]:
# Load dataset
dataset = get_dataset(DATASET_NAME, image_preprocess=None, download=False, root_dir=DATA_DIR)

# Load prompts
prompt_file = f"./prompts/{DATASET_NAME}_with_answer_{option}_options.jsonl"
prompts = []
answers = []

with open(prompt_file, 'r') as f:
    for line in f:
        data = json.loads(line)
        prompts.append(data['question'])
        answers.append(data['answer'])

print(f"Dataset: {DATASET_NAME}")
print(f"Total samples: {len(dataset)}")
print(f"Total prompts: {len(prompts)}")

## Visualize Random Samples

In [None]:
def plot_samples(dataset, prompts, answers, indices=None, n=9):
    """
    Plot n random samples from the dataset with their prompts.
    
    Args:
        dataset: The dataset object
        prompts: List of prompts
        answers: List of answers
        indices: Specific indices to plot (optional)
        n: Number of samples to plot (default: 9)
    """
    if indices is None:
        indices = np.random.choice(len(dataset), size=min(n, len(dataset)), replace=False)
    
    n = len(indices)
    rows = int(np.ceil(np.sqrt(n)))
    cols = int(np.ceil(n / rows))
    
    fig, axes = plt.subplots(rows, cols, figsize=(20, 20))
    axes = axes.flatten() if n > 1 else [axes]
    
    for idx, ax in enumerate(axes):
        if idx < len(indices):
            sample_idx = indices[idx]
            sample = dataset[sample_idx]
            
            # Get image (handle different dataset formats)
            if hasattr(sample, 'image_options'):
                img = sample.image_options[0]
            else:
                img = sample['image_options'][0]
            
            # Convert to numpy if needed
            if not isinstance(img, Image.Image):
                img = Image.fromarray(img)
            
            ax.imshow(img)
            ax.axis('off')
            
            # Add prompt as title
            prompt_text = prompts[sample_idx][:100]  # Truncate if too long
            answer_text = answers[sample_idx][0] if isinstance(answers[sample_idx], list) else answers[sample_idx]
            ax.set_title(f"Sample {sample_idx}\nPrompt: {prompt_text}...\nAnswer: {answer_text}", 
                        fontsize=10, wrap=True)
        else:
            ax.axis('off')
    
    plt.tight_layout()
    plt.show()

# Plot 9 random samples
plot_samples(dataset, prompts, answers, n=9)

## Explore Specific Samples

In [None]:
# Plot specific indices (e.g., the first 9)
specific_indices = list(range(9))
plot_samples(dataset, prompts, answers, indices=specific_indices)

## Analyze Prompt Distribution

In [None]:
from collections import Counter

# Count answer distribution
answer_list = [ans[0] if isinstance(ans, list) else ans for ans in answers]
answer_counts = Counter(answer_list)

print("Answer Distribution:")
for answer, count in answer_counts.most_common():
    print(f"  {answer}: {count} ({100*count/len(answers):.1f}%)")

# Plot distribution
plt.figure(figsize=(10, 6))
plt.bar(answer_counts.keys(), answer_counts.values())
plt.xlabel('Answer')
plt.ylabel('Count')
plt.title(f'Answer Distribution in {DATASET_NAME}')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

## Interactive Sample Viewer

In [None]:
def view_sample(idx):
    """View a specific sample in detail."""
    sample = dataset[idx]
    
    # Get image
    if hasattr(sample, 'image_options'):
        img = sample.image_options[0]
    else:
        img = sample['image_options'][0]
    
    if not isinstance(img, Image.Image):
        img = Image.fromarray(img)
    
    # Display
    plt.figure(figsize=(10, 8))
    plt.imshow(img)
    plt.axis('off')
    
    answer_text = answers[idx][0] if isinstance(answers[idx], list) else answers[idx]
    plt.title(f"Sample {idx}\n\nPrompt: {prompts[idx]}\n\nGolden Answer: {answer_text}", 
             fontsize=12, pad=20)
    plt.tight_layout()
    plt.show()

# View a specific sample (change the index to explore different samples)
view_sample(0)