# AlexNet iFood2019 - Data Exploration

This notebook explores the iFood 2019 dataset.

## Contents:
1. Setup & Load Data
2. Dataset Statistics
3. Class Distribution Analysis
4. Sample Visualization
5. Test Data Loader

In [None]:
# ============================================================
# Cell 1: Quick Setup
# ============================================================

from google.colab import drive
drive.mount('/content/drive')

import os
import sys

PROJECT_PATH = '/content/drive/MyDrive/AlexNet_iFood2019'
REPO_PATH = '/content/alexnet-ifood2019'

# Clone if needed
if not os.path.exists(REPO_PATH):
    !git clone https://github.com/deftorch/alexnet-ifood2019.git {REPO_PATH}

os.chdir(REPO_PATH)
sys.path.insert(0, REPO_PATH)
sys.path.insert(0, os.path.join(REPO_PATH, 'src'))

print(f"✓ Setup complete")

In [None]:
# ============================================================
# Cell 2: Check/Create Dataset
# ============================================================

import os

DATA_DIR = os.path.join(PROJECT_PATH, 'dataset')

# Check if dataset exists
required = ['train_images', 'val_images', 'train_info.csv', 'val_info.csv']
missing = [f for f in required if not os.path.exists(os.path.join(DATA_DIR, f))]

if missing:
    print("⚠️  Dataset not found. Creating mock data for testing...")
    !python src/create_mock_data.py --output_dir {DATA_DIR} --train_per_class 5 --val_per_class 2
else:
    print("✓ Dataset found!")
    
# Create symlink
!rm -rf data
!ln -s {DATA_DIR} data

In [None]:
# ============================================================
# Cell 3: Load and Explore Dataset Statistics
# ============================================================

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# Load annotations
train_df = pd.read_csv(os.path.join(DATA_DIR, 'train_info.csv'), header=None, names=['image', 'label'])
val_df = pd.read_csv(os.path.join(DATA_DIR, 'val_info.csv'), header=None, names=['image', 'label'])

# Load class names
class_file = os.path.join(DATA_DIR, 'class_list.txt')
if os.path.exists(class_file):
    with open(class_file) as f:
        classes = [line.strip().split(' ', 1)[1] for line in f.readlines()]
else:
    classes = [f'class_{i}' for i in range(251)]

print("Dataset Statistics")
print("=" * 50)
print(f"Total classes: {len(classes)}")
print(f"Training samples: {len(train_df):,}")
print(f"Validation samples: {len(val_df):,}")
print(f"\nSample class names: {classes[:5]}")

In [None]:
# ============================================================
# Cell 4: Class Distribution Analysis
# ============================================================

# Analyze class distribution
class_counts = train_df['label'].value_counts().sort_index()

print("Class Distribution Statistics:")
print(f"  Mean samples per class: {class_counts.mean():.0f}")
print(f"  Median: {class_counts.median():.0f}")
print(f"  Min: {class_counts.min()}")
print(f"  Max: {class_counts.max()}")
print(f"  Std: {class_counts.std():.2f}")
if class_counts.min() > 0:
    print(f"  Imbalance ratio: {class_counts.max() / class_counts.min():.2f}x")

# Plot distribution
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Histogram
axes[0].hist(class_counts.values, bins=30, edgecolor='black', alpha=0.7)
axes[0].set_xlabel('Number of Samples')
axes[0].set_ylabel('Number of Classes')
axes[0].set_title('Distribution of Samples per Class')
axes[0].grid(alpha=0.3)

# Sorted distribution
sorted_counts = class_counts.sort_values(ascending=False)
axes[1].plot(range(len(sorted_counts)), sorted_counts.values, linewidth=2)
axes[1].fill_between(range(len(sorted_counts)), sorted_counts.values, alpha=0.3)
axes[1].set_xlabel('Class Rank')
axes[1].set_ylabel('Number of Samples')
axes[1].set_title('Sorted Class Distribution')
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(PROJECT_PATH, 'analysis_results', 'class_distribution.png'), dpi=150)
plt.show()

In [None]:
# ============================================================
# Cell 5: Visualize Sample Images
# ============================================================

from PIL import Image
import random

# Sample images
fig, axes = plt.subplots(2, 5, figsize=(15, 6))

for i in range(2):
    for j in range(5):
        idx = random.randint(0, len(train_df) - 1)
        img_name = train_df.iloc[idx]['image']
        label = train_df.iloc[idx]['label']
        
        img_path = os.path.join(DATA_DIR, 'train_images', img_name)
        
        try:
            img = Image.open(img_path)
            axes[i, j].imshow(img)
            class_name = classes[label] if label < len(classes) else f'class_{label}'
            axes[i, j].set_title(f"{class_name[:15]}...", fontsize=9)
        except Exception as e:
            axes[i, j].text(0.5, 0.5, 'Error', ha='center', va='center')
        
        axes[i, j].axis('off')

plt.suptitle('Random Training Samples', fontsize=14)
plt.tight_layout()
plt.savefig(os.path.join(PROJECT_PATH, 'analysis_results', 'sample_images.png'), dpi=150)
plt.show()

In [None]:
# ============================================================
# Cell 6: Test Data Loader
# ============================================================

from src.data_loader import get_dataloaders, get_transforms

print("Testing data loaders...")

dataloaders = get_dataloaders(
    data_dir=DATA_DIR,
    batch_size=32,
    num_workers=2
)

for split, loader in dataloaders.items():
    print(f"\n{split.upper()}:")
    print(f"  Batches: {len(loader)}")
    print(f"  Samples: {len(loader.dataset)}")

# Test one batch
images, labels = next(iter(dataloaders['train']))
print(f"\nBatch shapes:")
print(f"  Images: {images.shape}")
print(f"  Labels: {labels.shape}")
print(f"  Image range: [{images.min():.2f}, {images.max():.2f}]")

In [None]:
# ============================================================
# Cell 7: Visualize Augmented Images
# ============================================================

from src.data_loader import get_transforms
import torch

# Get a sample image
sample_idx = 0
img_name = train_df.iloc[sample_idx]['image']
img_path = os.path.join(DATA_DIR, 'train_images', img_name)
original_img = Image.open(img_path)

# Get train transforms
train_transform = get_transforms('train')

# Show original and 5 augmented versions
fig, axes = plt.subplots(2, 3, figsize=(12, 8))

# Original
axes[0, 0].imshow(original_img)
axes[0, 0].set_title('Original')
axes[0, 0].axis('off')

# Augmented versions
for i in range(5):
    row = (i + 1) // 3
    col = (i + 1) % 3
    
    augmented = train_transform(original_img)
    
    # Denormalize for visualization
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    augmented_denorm = augmented * std + mean
    augmented_denorm = torch.clamp(augmented_denorm, 0, 1)
    
    axes[row, col].imshow(augmented_denorm.permute(1, 2, 0).numpy())
    axes[row, col].set_title(f'Augmented {i+1}')
    axes[row, col].axis('off')

plt.suptitle('Data Augmentation Examples', fontsize=14)
plt.tight_layout()
plt.savefig(os.path.join(PROJECT_PATH, 'analysis_results', 'augmentation_examples.png'), dpi=150)
plt.show()

print("\n✓ Data exploration complete!")