In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from tqdm import tqdm
import numpy as np
import random
import json
import argparse
from omegaconf import OmegaConf

In [None]:
class MLP(nn.Module):
    """
    Multi-Layer Perceptron (MLP) model.

    Args:
        input_size (int): Size of the input layer.
        hidden_size (int): Size of the hidden layers.
        output_size (int): Size of the output layer.
        num_layers (int): Number of hidden layers.
    """
    def __init__(self, input_size=784, hidden_size=16, output_size=10, num_layers=3):
        super(MLP, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.layers = nn.ModuleList([])
        self.layers.append(nn.Sequential(nn.Linear(input_size, hidden_size, bias=False), nn.ReLU()))
        for _ in range(num_layers):
            self.layers.append(nn.Sequential(nn.Linear(hidden_size, hidden_size, bias=False), nn.ReLU()))
        self.layers.append(nn.Linear(hidden_size, output_size, bias=False))

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x
    
    @property
    def device(self):
        return next(iter(self.parameters())).device

class DataLoader:
    """
    DataLoader class to handle dataset loading and transformations.

    Args:
        dataset_name (str): Name of the dataset to load.
        train_size (int): Size of the training dataset.
        train_batch_size (int): Batch size for training.
    """
    def __init__(self, dataset_name='MNIST', train_size=10000, train_batch_size=64):
        self.dataset_name = dataset_name
        self.train_size = train_size
        self.train_batch_size = train_batch_size
        self.transform = transforms.Compose([transforms.ToTensor(),
                                            transforms.Normalize((0.5,), (0.5,))])
        self.name2dataset = {
            'MNIST': datasets.MNIST,
            'FashionMNIST': datasets.FashionMNIST,
            'CIFAR10': datasets.CIFAR10,
            'CIFAR100': datasets.CIFAR100,
        }
        self.train_dataset = self.name2dataset[self.dataset_name](f'~/.pytorch/{self.dataset_name}_data/',
                                                                  download=True, train=True, transform=self.transform)
        self.train_dataset = torch.utils.data.Subset(self.train_dataset, np.arange(self.train_size))
        self.train_dataloader = torch.utils.data.DataLoader(self.train_dataset, batch_size=self.train_batch_size, shuffle=True)

class ModelTrainer:
    """
    ModelTrainer class to handle model training and evaluation.

    Args:
        model (nn.Module): The neural network model to train.
        dataloader (DataLoader): The dataloader containing data.
        criterion (nn.Module): The loss function.
        device (torch.device): The device to run the model on.
    """
    def __init__(self, model, dataloader, criterion):
        self.model = model
        self.dataloader = dataloader
        self.criterion = criterion
        self.device = model.device

    def criterion_params(self, params, x, y):
        """
        Calculate the loss using the given parameters.

        Args:
            params (list): List of parameters.
            x (torch.Tensor): Input data.
            y (torch.Tensor): Target data.

        Returns:
            float: The calculated loss.
        """
        names = list(n for n, _ in self.model.named_parameters())
        output = torch.func.functional_call(self.model, {n: p for n, p in zip(names, params)}, x)
        loss = self.criterion(output, y)
        return loss

    def get_loss_abs_differences(self, loss_values, shuffle=True):
        """
        Calculate the absolute differences of the loss values.

        Args:
            loss_values (list): List of loss values.
            shuffle (bool): Whether to shuffle the loss values.

        Returns:
            np.ndarray: The absolute differences of the loss values.
        """
        if shuffle:
            loss_values = random.sample(loss_values, len(loss_values))
        loss_cumsum_values = np.cumsum(loss_values)
        loss_mean_values = loss_cumsum_values / np.arange(1, len(loss_cumsum_values) + 1)
        loss_abs_differences = abs(np.diff(loss_mean_values))
        return loss_abs_differences

    def calculate_ema(self, data, window=10):
        """
        Calculate the Exponential Moving Average (EMA) of the data.

        Args:
            data (np.ndarray): The data to calculate the EMA for.
            window (int): The window size for the EMA.

        Returns:
            np.ndarray: The calculated EMA.
        """
        weights = np.exp(np.linspace(-1., 0., window))
        weights /= weights.sum()
        ema = np.convolve(data, weights, mode='full')[:len(data)]
        ema[:window] = ema[window]
        return ema

    def create_random_parameters(self):
        """
        Create random parameters for the model.

        Returns:
            list: List of random parameters.
        """
        random_params = []
        for param in self.model.parameters():
            random_params.append(nn.Parameter(torch.randn_like(param)))
        return random_params

    def calculate_delta(self, B=10, k=100):
        """
        Calculate the \Delta_k using Monte-Carlo approximation.

        Args:
            B (int): Number of sample points for Monte-Carlo approximation.
            k (int): Number of samples.

        Returns:
            float: The mean of the absolute differences of the loss values.
        """
        abs_values = []

        for _ in tqdm(range(B)):
            random_params = self.create_random_parameters()
            loss_values = []

            with torch.no_grad():
                for idx, (x, y) in enumerate(self.dataloader.train_dataloader.dataset):
                    x, y = x.to(self.device), torch.tensor([y]).to(self.device)
                    x = x.view(-1, self.model.input_size)
                    loss = self.criterion_params(random_params, x, y).item()
                    loss_values.append(loss)
                    if idx + 1 == k + 1:
                        break

            loss_abs_differences = self.get_loss_abs_differences(loss_values, shuffle=False)
            abs_values.append(loss_abs_differences[k-1])

        delta = np.mean(abs_values)
        return delta

In [None]:
# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Configuration
config = OmegaConf.load('../configs/delta_experiments.yaml')

In [None]:
# Initialize Model
model = MLP(**config.model).to(device)

# Initialize DataLoader
dataloader = DataLoader(**config.data)

# Initialize Criterion
criterion = nn.CrossEntropyLoss()

# Initialize ModelTrainer
model_trainer = ModelTrainer(model, dataloader, criterion)

In [None]:
# Calculate the delta
delta = model_trainer.calculate_delta(B=10, k=100)
delta