In [None]:
import torch
import torchvision
import torchvision.transforms as transforms


def estimate_trace_and_largest_eigenvalue_of_covariance(dataset, num_samples=10000):
    # Load a subset of the dataset
    loader = torch.utils.data.DataLoader(dataset, batch_size=num_samples, shuffle=True)
    data_iter = iter(loader)
    images, _ = next(data_iter)

    # Flatten the images
    flattened_images = images.view(images.size(0), -1)

    # Center the data
    flattened_images -= flattened_images.mean(dim=0)

    # Estimate covariance matrix
    covariance_matrix = flattened_images.T @ flattened_images / (num_samples - 1)

    # Estimate trace of covariance matrix
    trace_estimate = torch.trace(covariance_matrix).item()

    # Estimate largest eigenvalue of covariance matrix
    largest_eigenvalue = torch.linalg.eigvalsh(covariance_matrix).max().item()

    return trace_estimate, largest_eigenvalue


if __name__ == "__main__":
    # Define the transform to normalize the data
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    # Load datasets
    cifar10_train = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
    cifar100_train = torchvision.datasets.CIFAR100(root="./data", train=True, download=True, transform=transform)
    mnist_train = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)

    # Estimate trace and largest eigenvalue for each dataset
    cifar10_trace, cifar10_largest_eigenvalue = estimate_trace_and_largest_eigenvalue_of_covariance(cifar10_train)
    print(f"CIFAR-10 -> Trace: {cifar10_trace}, Largest Eigenvalue: {cifar10_largest_eigenvalue}")
    print(f"CIFAR-10 -> intrisic dimension: {cifar10_trace / cifar10_largest_eigenvalue} ")

    cifar100_trace, cifar100_largest_eigenvalue = estimate_trace_and_largest_eigenvalue_of_covariance(cifar100_train)
    print(f"CIFAR-100 -> Trace: {cifar100_trace}, Largest Eigenvalue: {cifar100_largest_eigenvalue}")

    mnist_trace, mnist_largest_eigenvalue = estimate_trace_and_largest_eigenvalue_of_covariance(mnist_train)
    print(f"MNIST -> Trace: {mnist_trace}, Largest Eigenvalue: {mnist_largest_eigenvalue}")


100%|██████████| 170M/170M [00:05<00:00, 29.2MB/s] 
100%|██████████| 169M/169M [00:02<00:00, 79.2MB/s] 
100%|██████████| 9.91M/9.91M [00:00<00:00, 44.7MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 2.46MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 15.8MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 2.48MB/s]


CIFAR-10 -> Trace: 190.6145782470703, Largest Eigenvalue: 55.66353225708008
CIFAR-100 -> Trace: 217.10032653808594, Largest Eigenvalue: 73.90557861328125
MNIST -> Trace: 52.9100456237793, Largest Eigenvalue: 5.135982513427734
