In [2]:
import sys
sys.path.append('..')

from src.config import load_config
from src.load_data import load_data, get_class_dist
from src.splits import save_splits
from src.transforms import get_transforms, get_base_transforms, compute_dataset_stats, denormalize
import matplotlib.pyplot as plt

In [None]:
config = load_config()

train_dataset, val_dataset, test_dataset = load_data(split_data=True)
total = len(train_dataset) + len(val_dataset) + len(test_dataset)

print(f"\n{'='*50}")
print(f"Total Images per Split:")
print(f"{'='*50}")
print(f"Train: {len(train_dataset):4d} images ({len(train_dataset) / total * 100:.2f}%)")
print(f"Val:   {len(val_dataset):4d} images ({len(val_dataset) / total * 100:.2f}%)")
print(f"Test:  {len(test_dataset):4d} images ({len(test_dataset) / total * 100:.2f}%)")
print(f"Total: {total:4d} images")
print(f"{'='*50}")

# Dataset Distributions

In [None]:
train_dist = get_class_dist(train_dataset)
val_dist = get_class_dist(val_dataset)
test_dist = get_class_dist(test_dataset)

class_names = list(train_dist.keys())

print(f"\n{'='*80}")
print(f"Class Balance Verification")
print(f"{'='*80}")
print(f"{'Class':<12} | {'Train':>15} | {'Val':>15} | {'Test':>15}")
print(f"{'-'*12}-+-{'-'*15}-+-{'-'*15}-+-{'-'*15}")

for class_name in class_names:
    train_count = train_dist[class_name]
    val_count = val_dist[class_name]
    test_count = test_dist[class_name]
    
    train_pct = (train_count / len(train_dataset)) * 100
    val_pct = (val_count / len(val_dataset)) * 100
    test_pct = (test_count / len(test_dataset)) * 100
    
    print(f"{class_name:<12} | {train_count:4d} ({train_pct:5.1f}%) | {val_count:4d} ({val_pct:5.1f}%) | {test_count:4d} ({test_pct:5.1f}%)")

print(f"{'-'*12}-+-{'-'*15}-+-{'-'*15}-+-{'-'*15}")
print(f"{'Total':<12} | {len(train_dataset):4d} (100.0%) | {len(val_dataset):4d} (100.0%) | {len(test_dataset):4d} (100.0%)")
print(f"{'='*80}")

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

splits = [('Train', train_dist, len(train_dataset)), 
          ('Validation', val_dist, len(val_dataset)), 
          ('Test', test_dist, len(test_dataset))]

for ax, (split_name, dist, total) in zip(axes, splits):
    classes = list(dist.keys())
    counts = list(dist.values())
    
    bars = ax.bar(classes, counts, color='steelblue', edgecolor='black', alpha=0.7)
    ax.set_xlabel('Waste Category', fontsize=11)
    ax.set_ylabel('Number of Images', fontsize=11)
    ax.set_title(f'{split_name} Split (n={total})', fontsize=13, fontweight='bold')
    ax.tick_params(axis='x', rotation=45)
    ax.grid(axis='y', alpha=0.3)
    
    for i, (cls, count) in enumerate(zip(classes, counts)):
        ax.text(i, count + 10, str(count), ha='center', va='bottom', fontsize=9)

plt.tight_layout()
plt.show()

In [None]:
splits_info = save_splits(train_dataset, val_dataset, test_dataset, config)

print(f"\nSplits saved to {config['paths']['splits_dir']}")
print(f"\nFiles created:")
print(f"  - train_indices.txt ({splits_info['train_size']} samples)")
print(f"  - val_indices.txt ({splits_info['val_size']} samples)")
print(f"  - test_indices.txt ({splits_info['test_size']} samples)")
print(f"  - splits.json (complete metadata)")

## Final Summary

In [None]:
total_images = splits_info['total']

print(f"\n{'='*60}")
print(f"FINAL SPLIT SUMMARY")
print(f"{'='*60}")
print(f"\nDataset: {splits_info['dataset_name']}")
print(f"Random Seed: {splits_info['seed']}")
print(f"\nSplit Distribution:")
print(f"  Training:   {splits_info['train_size']:4d} images ({splits_info['train_size']/total_images*100:.1f}%)")
print(f"  Validation: {splits_info['val_size']:4d} images ({splits_info['val_size']/total_images*100:.1f}%)")
print(f"  Test:       {splits_info['test_size']:4d} images ({splits_info['test_size']/total_images*100:.1f}%)")
print(f"  Total:      {total_images:4d} images")
print(f"\nNumber of Classes: {config['data']['num_classes']}")
print(f"Classes: {', '.join(config['data']['classes'])}")
print(f"\nSplit files saved to: {config['paths']['splits_dir']}")
print(f"{'='*60}")