In [4]:
%cd /content/drive/MyDrive/Paper-Replications/Memorization-to-generalization

/content/drive/MyDrive/Paper-Replications/Memorization-to-generalization


In [9]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import torch
import pickle
from IPython.display import display, HTML

# Add the project root to the path
project_root = '../..'
sys.path.append(project_root)

# Import our dataset utilities
from src.data.datasets import generate_dataset_splits, load_dataset_split, load_base_dataset

In [None]:
datasets = ['mnist', 'fashion_mnist', 'cifar10']
num_splits = 38
seed = 42
save_dir = './data/splits'

for dataset_name in datasets:
    print(f"Generating splits for {dataset_name}...")
    splits_info = generate_dataset_splits(
        dataset_name,
        num_splits=num_splits,
        seed=seed,
        save_dir=save_dir
    )

In [None]:
plt.figure(figsize=(12, 6))
for dataset_name in datasets:
    # Load split info
    with open(os.path.join(save_dir, dataset_name, "splits_info.pkl"), 'rb') as f:
        splits_info = pickle.load(f)

    # Extract split sizes
    sizes = [info['size'] for info in splits_info]

    # Plot on log scale
    plt.plot(range(1, len(sizes) + 1), sizes, 'o-', label=dataset_name)

plt.xlabel('Split Index')
plt.ylabel('Number of Training Samples')
plt.title('Dataset Split Sizes')
plt.grid(True, which='both', linestyle='--', alpha=0.7)
plt.yscale('log')
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
def show_samples(dataset, indices, title, n_cols=10):
    n_samples = len(indices)
    n_rows = (n_samples + n_cols - 1) // n_cols

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols*1.5, n_rows*1.5))
    axes = axes.flatten()

    for i, idx in enumerate(indices):
        img, label = dataset[idx]
        if isinstance(img, torch.Tensor):
            img = img.permute(1, 2, 0).numpy()

        # Handle different channels
        if img.shape[2] == 1:  # Grayscale
            img = img.squeeze()
            axes[i].imshow(img, cmap='gray')
        else:  # RGB
            # Denormalize
            if dataset_name == 'cifar10':
                mean = np.array([0.4914, 0.4822, 0.4465])
                std = np.array([0.2470, 0.2435, 0.2616])
                img = img * std + mean
                img = np.clip(img, 0, 1)

            axes[i].imshow(img)

        axes[i].set_title(f"#{idx}: {label}")
        axes[i].axis('off')

    # Hide unused subplots
    for i in range(n_samples, len(axes)):
        axes[i].axis('off')

    plt.suptitle(title, fontsize=16)
    plt.tight_layout()
    plt.show()

# Select a few key split sizes to visualize
splits_to_show = [1, 5, 10, 20]  # Corresponds to splits with indices 1, 5, 10, 20

dataset_name = 'mnist'  # Change to visualize different datasets
for split_idx in splits_to_show:
    # Load the dataset split
    train_subset, _ = load_dataset_split(dataset_name, split_idx)

    # For the subset, we need to get the actual indices
    indices = train_subset.indices[:20]  # Show at most 20 samples

    # Get the base dataset to access the actual samples
    base_dataset, _ = load_base_dataset(dataset_name)

    # Show samples from this split
    show_samples(
        base_dataset,
        indices,
        f"{dataset_name} - Split {split_idx} (Size: {len(train_subset)} samples)"
    )

In [None]:
dataset_name = 'mnist'  # Change to visualize different datasets

# Load the full dataset
full_dataset, _ = load_base_dataset(dataset_name)

# Check label distribution for each split
label_distributions = []
split_sizes = []

for split_idx in range(1, num_splits + 1):
    train_subset, _ = load_dataset_split(dataset_name, split_idx)
    split_sizes.append(len(train_subset))

    # Count labels
    labels = [full_dataset[idx][1] for idx in train_subset.indices]
    unique_labels, counts = np.unique(labels, return_counts=True)

    # Create a distribution with all classes
    distribution = np.zeros(10)  # Assuming 10 classes
    distribution[unique_labels] = counts

    label_distributions.append(distribution)

# Convert to percentage
label_distributions_percent = [dist / sum(dist) * 100 for dist in label_distributions]

# Plot the label distribution for selected splits
plt.figure(figsize=(14, 8))

splits_to_plot = [0, 9, 19, 29, -1]  # First, a few middle ones, and last split
for i, split_idx in enumerate(splits_to_plot):
    plt.subplot(len(splits_to_plot), 1, i+1)
    plt.bar(range(10), label_distributions_percent[split_idx])
    plt.title(f"Split {split_idx+1} (Size: {split_sizes[split_idx]} samples)")
    plt.xlabel('Class')
    plt.ylabel('Percentage')
    plt.xticks(range(10))
    plt.ylim(0, 50)  # Limit to better see variations

plt.tight_layout()
plt.show()

# Print summary statistics
print("Dataset splits summary:")
for dataset_name in datasets:
    with open(os.path.join(save_dir, dataset_name, "splits_info.pkl"), 'rb') as f:
        splits_info = pickle.load(f)

    sizes = [info['size'] for info in splits_info]

    print(f"\n{dataset_name}:")
    print(f"  Number of splits: {len(sizes)}")
    print(f"  Smallest split (split 1): {sizes[0]} samples")
    print(f"  Largest split (split {len(sizes)}): {sizes[-1]} samples")

    # Print some intermediate split sizes
    indices = [0, 4, 9, 19, 29, -1]
    print(f"  Selected split sizes:")
    for idx in indices:
        if idx == -1:
            split_num = len(sizes)
        else:
            split_num = idx + 1
        print(f"    Split {split_num}: {sizes[idx]} samples")