# Dataset Exploration

Explore the 5 polyp segmentation datasets, visualize domain differences across centers.

In [None]:
import sys
sys.path.insert(0, '..')

import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

from data.datasets import PolypCenterDataset, build_center_datasets
from data.augmentations import get_test_transforms

In [None]:
# Load all centers
centers = ['Kvasir', 'CVC-ClinicDB', 'CVC-ColonDB', 'ETIS-LaribPolypDB', 'CVC-300']
transform = get_test_transforms(352)
datasets = build_center_datasets('../datasets', centers, transform=transform)

# Print statistics
print('Dataset Statistics')
print('=' * 40)
for name, ds in datasets.items():
    print(f'{name:25s}: {len(ds):4d} images')

In [None]:
# Visualize samples from each center
fig, axes = plt.subplots(5, 4, figsize=(16, 20))

for i, (name, ds) in enumerate(datasets.items()):
    for j in range(2):
        sample = ds[j * (len(ds) // 3)]
        img = sample['image'].permute(1, 2, 0).numpy()
        # Denormalize
        img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        img = np.clip(img, 0, 1)
        mask = sample['mask'][0].numpy()
        
        axes[i, j*2].imshow(img)
        axes[i, j*2].set_title(f'{name} - Image {j+1}')
        axes[i, j*2].axis('off')
        
        axes[i, j*2+1].imshow(mask, cmap='gray')
        axes[i, j*2+1].set_title(f'{name} - Mask {j+1}')
        axes[i, j*2+1].axis('off')

plt.tight_layout()
plt.show()

In [None]:
# Polyp size distribution per center
fig, axes = plt.subplots(1, 5, figsize=(20, 4))

for i, (name, ds) in enumerate(datasets.items()):
    ratios = []
    for idx in range(len(ds)):
        sample = ds[idx]
        mask = sample['mask'][0].numpy()
        ratio = mask.sum() / mask.size  # foreground ratio
        ratios.append(ratio)
    
    axes[i].hist(ratios, bins=20, color=f'C{i}', alpha=0.7)
    axes[i].set_title(name)
    axes[i].set_xlabel('Foreground Ratio')
    axes[i].set_ylabel('Count')

plt.suptitle('Polyp Size Distribution per Center', fontsize=14)
plt.tight_layout()
plt.show()