In [None]:
import torch
import torchvision
from torchvision import datasets, transforms
import numpy as np
from torch.utils.data import Subset, DataLoader
import random

def select_imagenet_subset(data_root, subset_size=1000, strategy='random', 
                          samples_per_class=None, random_seed=42):
    """
    Select a subset D' from ImageNet training data.
    
    Args:
        data_root (str): Path to ImageNet dataset root directory
        subset_size (int): Size of subset to select (used if strategy='random')
        strategy (str): Selection strategy ('random', 'stratified', 'first_n')
        samples_per_class (int): Number of samples per class (for stratified sampling)
        random_seed (int): Random seed for reproducibility
    
    Returns:
        Subset: PyTorch Subset object containing selected data points
        list: Indices of selected samples
    """
    
    # Set random seed for reproducibility
    torch.manual_seed(random_seed)
    np.random.seed(random_seed)
    random.seed(random_seed)
    
    # Define transforms (minimal for analysis)
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    
    # Load full ImageNet training dataset
    full_dataset = datasets.ImageNet(
        root=data_root,
        split='train',
        transform=transform
    )
    
    print(f"Full ImageNet training set size: {len(full_dataset)}")
    print(f"Number of classes: {len(full_dataset.classes)}")
    
    if strategy == 'random':
        # Random subset selection
        indices = random.sample(range(len(full_dataset)), subset_size)
        
    elif strategy == 'stratified':
        # Stratified sampling: equal samples per class
        if samples_per_class is None:
            samples_per_class = subset_size // len(full_dataset.classes)
        
        indices = []
        class_indices = {}
        
        # Group indices by class
        for idx, (_, label) in enumerate(full_dataset):
            if label not in class_indices:
                class_indices[label] = []
            class_indices[label].append(idx)
            
            # Stop early if we have enough samples for analysis
            if len(indices) >= subset_size:
                break
        
        # Sample from each class
        for class_id in range(len(full_dataset.classes)):
            if class_id in class_indices:
                class_samples = random.sample(
                    class_indices[class_id], 
                    min(samples_per_class, len(class_indices[class_id]))
                )
                indices.extend(class_samples)
                
                if len(indices) >= subset_size:
                    indices = indices[:subset_size]
                    break
    
    elif strategy == 'first_n':
        # Simply take first n samples (deterministic)
        indices = list(range(subset_size))
    
    else:
        raise ValueError("Strategy must be 'random', 'stratified', or 'first_n'")
    
    # Create subset
    subset = Subset(full_dataset, indices)
    
    print(f"Selected subset size: {len(subset)}")
    print(f"Selected indices range: [{min(indices)}, {max(indices)}]")
    
    return subset, indices

def create_dataloader(subset, batch_size=32, num_workers=4):
    """Create DataLoader for the subset."""
    return DataLoader(
        subset,
        batch_size=batch_size,
        shuffle=False,  # Keep deterministic order for analysis
        num_workers=num_workers,
        pin_memory=True
    )

# Example usage for UnitMem/LayerMem analysis
def get_subset_for_memorization_analysis(data_root, analysis_type='unitmem'):
    """
    Get appropriately sized subset for memorization analysis.
    Based on the research papers, different subset sizes are used.
    """
    
    if analysis_type.lower() == 'unitmem':
        # For UnitMem analysis, typically use 100-5000 samples
        subset_size = 1000
        strategy = 'stratified'
        samples_per_class = 1  # 1 sample per class for fine-grained analysis
        
    elif analysis_type.lower() == 'layermem':
        # For LayerMem analysis, can use larger subsets
        subset_size = 5000
        strategy = 'random'
        samples_per_class = None
        
    else:
        # Default subset for general analysis
        subset_size = 1000
        strategy = 'random'
        samples_per_class = None
    
    return select_imagenet_subset(
        data_root=data_root,
        subset_size=subset_size,
        strategy=strategy,
        samples_per_class=samples_per_class
    )

# Example usage:
if __name__ == "__main__":
    # Path to your ImageNet dataset
    IMAGENET_ROOT = "../data/imagenette/"
    
    # For UnitMem analysis (as mentioned in the research)
    print("Creating subset for UnitMem analysis...")
    subset_unitmem, indices_unitmem = get_subset_for_memorization_analysis(
        IMAGENET_ROOT, 
        analysis_type='unitmem'
    )
    
    # Create DataLoader
    dataloader = create_dataloader(subset_unitmem, batch_size=64)
    
    # Example: Access subset data
    print(f"\nSubset statistics:")
    print(f"Number of samples: {len(subset_unitmem)}")
    print(f"Sample indices: {indices_unitmem[:10]}...")  # First 10 indices
    
    # Verify data loading
    for batch_idx, (images, labels) in enumerate(dataloader):
        print(f"Batch {batch_idx}: {images.shape}, {labels.shape}")
        if batch_idx >= 2:  # Just show first few batches
            break


Creating subset for UnitMem analysis...


RuntimeError: The archive ILSVRC2012_devkit_t12.tar.gz is not present in the root directory or is corrupted. You need to download it externally and place it in ../imagenette2/.