In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
import task_complexity
import matplotlib.pyplot as plt
import numpy as np

In [2]:
class AddGaussianNoise(object):
    def __init__(self, mean = 0, std = 1.):
        self.std = std
        self.mean = mean
    
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size())*self.std + self.mean
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

In [3]:
class Net_CIFAR10(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(32 * 32 * 3, 1024)
        self.fc2 = nn.Linear(1024, 32 * 32 * 3)

    def forward(self, x):
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        x = torch.reshape(x, (-1, 3, 32, 32))
        return x

In [4]:
criterion = nn.MSELoss()

In [5]:
noise_levels = [10**p for p in np.linspace(-3, 3, 7)]
mi_values = []
for noise_level in noise_levels:
    print(f"Noise Level: {noise_level}")
    
    transform = transforms.Compose([
        transforms.ToTensor(),
        AddGaussianNoise(0., noise_level),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    target_transform_cifar10 = transforms.Compose([
        lambda x: torch.tensor(x),
        lambda x: F.one_hot(x, num_classes = 10)
        ])
    
    batch_size = 16

    trainset_cifar10 = CIFAR10(root = '~/datasets', train = True, transform = transform, target_transform = target_transform_cifar10)
    trainloader_cifar10 = torch.utils.data.DataLoader(trainset_cifar10, batch_size = batch_size, shuffle = True, num_workers = 2)
    
    example_set = next(iter(trainloader_cifar10))[0]

    mi_score = task_complexity.compute_complexity(Net_CIFAR10, 10, trainloader_cifar10, criterion, n_bins = 100, autoencoder = True)
    print(f"MI Score: {mi_score}")
    mi_values.append(mi_score)

Noise Level: 0.001
MI Score: 0.01387316169873909
Noise Level: 0.01
MI Score: 0.013668853413963067
Noise Level: 0.1
MI Score: 0.013748495477376999
Noise Level: 1.0
MI Score: 0.014693348315390153
Noise Level: 10.0
MI Score: 0.015271912859752046
Noise Level: 100.0
MI Score: 0.016104534759693312
Noise Level: 1000.0
MI Score: 0.014569042590887982
