In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, ConcatDataset
import numpy as np

# Function to perform mixup on a batch of images and labels.
def mixup_data(x, y, alpha=1.0):
    """Return mixed inputs, paired labels, and mixing coefficient."""
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1.0
    batch_size = x.size()[0]
    # Generate a random permutation of indices
    index = torch.randperm(batch_size)
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

# Define several augmentation transforms.
transform_a = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

transform_b = transforms.Compose([
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
])

transform_c = transforms.Compose([
    transforms.RandomGrayscale(p=0.2),
    transforms.RandomAffine(degrees=15, translate=(0.1, 0.1)),
    transforms.ToTensor(),
])

# Load CIFAR-10 with different transforms (each instance is an augmented dataset).
dataset_a = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_a)
dataset_b = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_b)
dataset_c = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_c)

# Combine the three augmented datasets into one enriched dataset.
enriched_dataset = ConcatDataset([dataset_a, dataset_b, dataset_c])

# Define a custom collate function to apply mixup on a batch.
def mixup_collate_fn(batch, alpha=0.4):
    """
    Each element in 'batch' is a tuple (image, label).
    This function stacks the images and labels, then applies mixup.
    """
    images, labels = zip(*batch)
    images = torch.stack(images)
    labels = torch.tensor(labels)
    mixed_images, labels_a, labels_b, lam = mixup_data(images, labels, alpha)
    # For training, you can compute a combined loss:
    # loss = lam * criterion(student_output, labels_a) + (1 - lam) * criterion(student_output, labels_b)
    return mixed_images, labels_a, labels_b, lam

# Create a DataLoader that uses the custom collate function.
dataloader = DataLoader(enriched_dataset, batch_size=64, shuffle=True, collate_fn=lambda b: mixup_collate_fn(b, alpha=0.4))

# Example: Iterating over one batch and inspecting the shapes.
for mixed_images, labels_a, labels_b, lam in dataloader:
    print("Mixed images shape:", mixed_images.shape)  # Expected shape: [batch_size, 3, 32, 32]
    print("Lambda:", lam)
    break


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:03<00:00, 51393826.94it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Files already downloaded and verified
Mixed images shape: torch.Size([64, 3, 32, 32])
Lambda: 0.593706996454326
