# Spatial Relations Dataset Analysis

This notebook analyzes the spatial relation datasets stored in HuggingFace format.

Datasets:
- Visual Genome (one/two objects)
- COCO (one/two objects)
- Controlled images
- Controlled CLEVR

Each dataset has:
- `image_path`: Path to the image
- `captions`: List of captions
- `label`: Object label
- `spatial_relation`: Spatial relation (top/right/left/bottom/behind)

In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import random
from pathlib import Path
from datasets import load_from_disk
from collections import Counter
import numpy as np

## Configuration

In [None]:
# Base path
BASE_PATH = Path("/leonardo_work/EUHPC_D27_102/compmech/whatsup_vlms_data")
HF_PATH = BASE_PATH / "hf"

# Dataset configurations
DATASETS = {
    'vg_one_obj': {
        'path': HF_PATH / 'vg_one_obj.hf',
        'description': 'Visual Genome - One object',
        'image_base': BASE_PATH / 'vg_images'
    },
    'vg_two_obj': {
        'path': HF_PATH / 'vg_two_obj.hf',
        'description': 'Visual Genome - Two objects',
        'image_base': BASE_PATH / 'vg_images'
    },
    'coco_one_obj': {
        'path': HF_PATH / 'coco_one_obj.hf',
        'description': 'COCO - One object',
        'image_base': BASE_PATH / 'val2017'
    },
    'coco_two_obj': {
        'path': HF_PATH / 'coco_two_obj.hf',
        'description': 'COCO - Two objects',
        'image_base': BASE_PATH / 'val2017'
    },
    'controlled_images': {
        'path': HF_PATH / 'controlled_images.hf',
        'description': 'Controlled images',
        'image_base': BASE_PATH / 'controlled_images'
    },
    'controlled_clevr': {
        'path': HF_PATH / 'controlled_clevr.hf',
        'description': 'Controlled CLEVR',
        'image_base': BASE_PATH / 'controlled_clevr'
    }
}

In [None]:
from datasets import load_from_disk

dataset = load_from_disk('/leonardo_work/EUHPC_D27_102/compmech/whatsup_vlms_data/hf/controlled_clevr.hf')
from collections import Counter
Counter([x[0].split(' ')[1] for x in dataset['captions']])

## Load Datasets

In [None]:
# Load all datasets
datasets = {}
for name, config in DATASETS.items():
    try:
        datasets[name] = load_from_disk(str(config['path']))
        print(f"Loaded {name}: {len(datasets[name])} samples")
        print(f"  Features: {datasets[name].features}")
        print()
    except Exception as e:
        print(f"Error loading {name}: {e}")
        print()

## Spatial Relations Distribution

In [None]:
# Analyze spatial relations for each dataset
spatial_stats = {}

for name, dataset in datasets.items():
    relations = dataset['spatial_relation']
    spatial_stats[name] = Counter(relations)
    
    print(f"\n{DATASETS[name]['description']} ({name})")
    print(f"Total samples: {len(dataset)}")
    print(f"Spatial relations:")
    for relation, count in sorted(spatial_stats[name].items()):
        print(f"  {relation:10s}: {count:5d} ({count/len(dataset)*100:5.1f}%)")

## Visualization: Distribution Comparison

In [None]:
# Plot spatial relation distribution for all datasets
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
axes = axes.flatten()

# Get all unique relations across datasets
all_relations = set()
for stats in spatial_stats.values():
    all_relations.update(stats.keys())
all_relations = sorted(list(all_relations))

# Color map for relations
colors = {
    'top': '#FF6B6B',
    'bottom': '#4ECDC4',
    'left': '#45B7D1',
    'right': '#FFA07A',
    'behind': '#95E1D3'
}

for idx, (name, stats) in enumerate(spatial_stats.items()):
    ax = axes[idx]
    
    # Prepare data
    relations_list = []
    counts = []
    bar_colors = []
    
    for relation in all_relations:
        if relation in stats:
            relations_list.append(relation)
            counts.append(stats[relation])
            bar_colors.append(colors.get(relation, '#999999'))
    
    # Plot
    bars = ax.bar(relations_list, counts, color=bar_colors, alpha=0.8, edgecolor='black', linewidth=1.5)
    
    # Add count labels on bars
    for bar in bars:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{int(height)}',
                ha='center', va='bottom', fontsize=10, fontweight='bold')
    
    ax.set_title(f"{DATASETS[name]['description']}\n({len(datasets[name])} samples)", 
                fontsize=11, fontweight='bold', pad=10)
    ax.set_ylabel('Count', fontsize=10)
    ax.set_xlabel('Spatial Relation', fontsize=10)
    ax.grid(axis='y', alpha=0.3, linestyle='--')
    ax.set_axisbelow(True)
    
    # Rotate x labels if needed
    ax.tick_params(axis='x', rotation=45, labelsize=9)

plt.suptitle('Spatial Relation Distribution Across Datasets', 
             fontsize=14, fontweight='bold', y=0.995)
plt.tight_layout()
plt.show()

## Summary Statistics Table

In [None]:
# Create summary table
import pandas as pd

summary_data = []
for name, stats in spatial_stats.items():
    row = {
        'Dataset': DATASETS[name]['description'],
        'Total': len(datasets[name]),
    }
    for relation in all_relations:
        row[relation] = stats.get(relation, 0)
    summary_data.append(row)

summary_df = pd.DataFrame(summary_data)
print("\n" + "="*80)
print("SUMMARY: Sample counts by spatial relation")
print("="*80)
print(summary_df.to_string(index=False))
print()

## Sample Grid Visualization

In [None]:
def get_image_path(image_path_str, image_base):
    """Get full path to image"""
    # Handle different path formats
    img_path = Path(image_path_str)
    
    # If it's a relative path starting with 'data/', remove that prefix
    if str(img_path).startswith('data/'):
        img_path = Path(str(img_path).replace('data/', ''))
    
    # Try full path first
    if img_path.is_absolute() and img_path.exists():
        return img_path
    
    # Try with image_base
    full_path = image_base / img_path.name
    if full_path.exists():
        return full_path
    
    # Try original path relative to image_base
    for part in img_path.parts:
        if part in ['vg_images', 'val2017', 'controlled_images', 'controlled_clevr']:
            idx = img_path.parts.index(part)
            relative_path = Path(*img_path.parts[idx+1:])
            full_path = image_base / relative_path
            if full_path.exists():
                return full_path
    
    return image_base / img_path.name


def display_sample_grid(dataset_name, n_samples=9, figsize=(15, 15), filter_relation=None):
    """
    Display a grid of random samples from a dataset
    
    Args:
        dataset_name: Which dataset to display
        n_samples: Number of samples to show (default: 9 for 3x3 grid)
        figsize: Figure size tuple
        filter_relation: If specified, only show samples with this spatial relation
    """
    dataset = datasets[dataset_name]
    config = DATASETS[dataset_name]
    
    # Filter by relation if specified
    if filter_relation:
        indices = [i for i, rel in enumerate(dataset['spatial_relation']) if rel == filter_relation]
        if not indices:
            print(f"No samples found with relation '{filter_relation}'")
            return
        title_suffix = f" - {filter_relation} only"
    else:
        indices = list(range(len(dataset)))
        title_suffix = ""
    
    # Calculate grid dimensions
    grid_size = int(n_samples ** 0.5)
    if grid_size * grid_size < n_samples:
        grid_size += 1
    
    # Sample random indices
    sample_indices = random.sample(indices, min(n_samples, len(indices)))
    
    # Create figure
    fig, axes = plt.subplots(grid_size, grid_size, figsize=figsize)
    axes = axes.flatten() if n_samples > 1 else [axes]
    
    for idx, sample_idx in enumerate(sample_indices):
        ax = axes[idx]
        sample = dataset[sample_idx]
        
        # Get image path and load
        img_path = get_image_path(sample['image_path'], config['image_base'])
        
        if img_path.exists():
            img = Image.open(img_path)
            ax.imshow(img)
        else:
            ax.text(0.5, 0.5, f"Image\nnot found", 
                    ha='center', va='center', fontsize=8)
        
        ax.axis('off')
        
        # Create caption with spatial relation and label
        relation = sample['spatial_relation']
        label = sample['label']
        caption = f"{relation} - {label}"
        
        # Add caption from captions list if available
        if sample['captions'] and len(sample['captions']) > 0:
            caption_text = sample['captions'][0]
            if len(caption_text) > 50:
                caption_text = caption_text[:47] + "..."
            caption = f"{relation}\n{caption_text}"
        
        ax.set_title(caption, fontsize=7, pad=5, wrap=True, fontweight='bold')
    
    # Hide unused subplots
    for idx in range(len(sample_indices), len(axes)):
        axes[idx].axis('off')
    
    plt.suptitle(f"{config['description']}{title_suffix}\n({len(indices)} samples)", 
                 fontsize=12, fontweight='bold', y=0.98)
    plt.tight_layout()
    plt.show()


def display_all_dataset_grids(n_samples=9, figsize=(15, 15)):
    """Display sample grids for all datasets"""
    for dataset_name in datasets.keys():
        print(f"\n{'='*80}")
        print(f"Dataset: {dataset_name}")
        print(f"{'='*80}")
        display_sample_grid(dataset_name, n_samples=n_samples, figsize=figsize)

## Display Sample Grids for Individual Datasets

In [None]:
# Display grid for Visual Genome - one object
display_sample_grid('vg_one_obj', n_samples=9)

In [None]:
# Display grid for COCO - one object
display_sample_grid('coco_one_obj', n_samples=9)

In [None]:
# Display grid for Controlled images
display_sample_grid('controlled_images', n_samples=9)

In [None]:
# Display grid for Controlled CLEVR
display_sample_grid('controlled_clevr', n_samples=9)

## Display Grids Filtered by Spatial Relation

In [None]:
# Example: Show only 'left' relations from VG dataset
display_sample_grid('vg_one_obj', n_samples=9, filter_relation='left')

In [None]:
# Example: Show only 'top' relations from COCO dataset
display_sample_grid('coco_one_obj', n_samples=9, filter_relation='top')

## Display All Dataset Grids

In [None]:
# Display grids for all datasets at once
display_all_dataset_grids(n_samples=9)

## Combined Analysis: All Relations Together

In [None]:
# Stacked bar chart showing all datasets together
fig, ax = plt.subplots(figsize=(12, 6))

dataset_names = [DATASETS[name]['description'] for name in spatial_stats.keys()]
x = np.arange(len(dataset_names))
width = 0.15

# Plot bars for each relation
for i, relation in enumerate(all_relations):
    counts = [spatial_stats[name].get(relation, 0) for name in spatial_stats.keys()]
    offset = (i - len(all_relations)/2 + 0.5) * width
    ax.bar(x + offset, counts, width, label=relation, 
           color=colors.get(relation, '#999999'), alpha=0.8, edgecolor='black')

ax.set_xlabel('Dataset', fontsize=12, fontweight='bold')
ax.set_ylabel('Count', fontsize=12, fontweight='bold')
ax.set_title('Spatial Relations Comparison Across All Datasets', fontsize=14, fontweight='bold', pad=20)
ax.set_xticks(x)
ax.set_xticklabels(dataset_names, rotation=45, ha='right')
ax.legend(title='Spatial Relation', fontsize=10, title_fontsize=11)
ax.grid(axis='y', alpha=0.3, linestyle='--')
ax.set_axisbelow(True)

plt.tight_layout()
plt.show()

## Analysis: Filtering Out 'Behind'

In [None]:
# Show statistics after filtering out 'behind'
print("\n" + "="*80)
print("IMPACT OF FILTERING OUT 'BEHIND' RELATION")
print("="*80)

for name, stats in spatial_stats.items():
    total = sum(stats.values())
    behind_count = stats.get('behind', 0)
    remaining = total - behind_count
    
    print(f"\n{DATASETS[name]['description']}:")
    print(f"  Original samples: {total}")
    print(f"  'Behind' samples: {behind_count} ({behind_count/total*100:.1f}%)")
    print(f"  Remaining samples: {remaining} ({remaining/total*100:.1f}%)")
    
    # Show new distribution
    print(f"  New distribution:")
    for relation, count in sorted(stats.items()):
        if relation != 'behind':
            print(f"    {relation:10s}: {count:5d} ({count/remaining*100:5.1f}%)")