In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torchvision import transforms

# Here's a simple MLP
class SimpleMLP(nn.Module):
    def __init__(self):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.flatten(1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        return x

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Use a standard MNIST normalization
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

dataset = MNIST(root='./data', download=True, transform=transform)
BATCH_SIZE = int(len(dataset)/10) # Big but not too big, because I think that the batch is duplicated for each model in the vmap-enselmble
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, num_workers = 2)

# criterion functional
criterion = F.cross_entropy

NUM_DIRS = 2 # number of random directions to sample, in addition to the gradient ascent + descent, and radially in + out.
MAX_OOM = 0 # maximum order of magnitude to sample
MIN_OOM = -1 # minimum order of magnitude to sample

In [2]:
# utils for model arithmetics
from copy import deepcopy

@torch.no_grad
def iadd(a:nn.Module, b:nn.Module):
    """add the parameters of b to a, inplace"""
    for a_param, b_param in zip(a.parameters(), b.parameters()):
        a_param.data.add_(b_param.data)
    return a

@torch.no_grad
def add(a_old:nn.Module, b:nn.Module):
    """add the parameters of b to a"""
    a = deepcopy(a_old)
    return iadd(a, b)

@torch.no_grad
def scale(a_old:nn.Module, b:float):
    """scale the parameters of a by b"""
    a = deepcopy(a_old)
    for a_param in a.parameters():
        a_param.data.mul_(b)
    return a

@torch.no_grad
def sub(a:nn.Module, b:nn.Module):
    """subtract the parameters of b from a"""
    neg_b = scale(b, -1)
    return iadd(neg_b, a)

@torch.no_grad
def abs(a:nn.Module):
    """return the norm of the parameters of a"""
    return torch.norm(torch.cat([param.data.flatten() for param in a.parameters()]))

@torch.no_grad
def norm(a:nn.Module):
    """normalize the parameters of a"""
    return scale(a, 1/abs(a))

@torch.no_grad
def rand_like(a:nn.Module):
    """random, normalised direction in the parameter space of a"""
    new = deepcopy(a)
    for param in new.parameters():
        param.data = torch.randn_like(param.data)
    
    return norm(new)

Define profiling context manager, measure execution time and GPU RAM usage before and after.

In [20]:
from time import perf_counter
from contextlib import contextmanager

@contextmanager
def timer(description):
    print(description)
    if torch.cuda.is_available():
        print(f'GPU RAM before exectution: Allocated: {torch.cuda.memory_allocated()/1e9:.2f}GB, Reserved: {torch.cuda.memory_reserved()/1e9:.2f}GB')
    start = perf_counter()
    yield
    end = perf_counter()
    if torch.cuda.is_available():
        print(f'GPU RAM after exectution: Allocated: {torch.cuda.memory_allocated()/1e9:.2f}GB, Reserved: {torch.cuda.memory_reserved()/1e9:.2f}GB')
    print(f'Finished {description} in {end-start:.2}s')

In [61]:
from time import perf_counter
from contextlib import contextmanager
import torch

@contextmanager
def profiler(description, length=80, pad_char=':'):
    print(description.center(length, pad_char))
    if torch.cuda.is_available():
        print(f'GPU RAM before execution: Allocated: {torch.cuda.memory_allocated()/1e9:.2f} GB | Reserved: {torch.cuda.memory_reserved()/1e9:.2f} GB')
    start = perf_counter()
    yield
    if torch.cuda.is_available():
        print(f'GPU RAM after execution:  Allocated: {torch.cuda.memory_allocated()/1e9:.2f} GB | Reserved: {torch.cuda.memory_reserved()/1e9:.2f} GB')
    print(f'Finished {description} in {perf_counter() - start:.2f} s'.center(length, pad_char))

::::::::::::::::::::::::::::::::::::::test::::::::::::::::::::::::::::::::::::::
1
::::::::::::::::::::::::::::Finished test in 0.00 s:::::::::::::::::::::::::::::


In [3]:
from copy import deepcopy
center_model = SimpleMLP()

# create a model to find the gradient
grad_model = deepcopy(center_model)
grad_model = grad_model.to(device)

optimizer = torch.optim.SGD(grad_model.parameters(), lr=1)
optimizer.zero_grad()
torch.set_grad_enabled(True)

# take a single step of gradient descent
# gradient accumulate over the whole dataset
with profiler('1 step of gradient descent over the whole dataset'):
    for data, target in dataloader:
        data, target = data.to(device), target.to(device)
        output = grad_model(data)
        loss = criterion(output, target)
        loss.backward()

optimizer.step()
# from this point on we wont need any gradients
torch.set_grad_enabled(False)


# calculate the direction of gradient decent from the center model
grad_model = grad_model.to('cpu')
dir_ascent = sub(center_model, grad_model)
dir_ascent = norm(dir_ascent)
dir_descent = scale(dir_ascent, -1)

del grad_model


In [4]:
radial = norm(center_model)

dirs = {
    'Ascent': dir_ascent,
    'Descent': dir_descent,
    'Radially Out': radial,
    'Radially In': scale(radial,-1),
    }

# add random directions
dirs.update({f'Random {i}': rand_like(center_model) for i in range(NUM_DIRS)})

# check that they are all normalised
assert all(torch.isclose(abs(d), torch.tensor(1.0)) for d in dirs.values())

In [5]:
import pandas as pd
import numpy as np

# convert to base 2
POW_MIN_DIST = int(np.floor(MIN_OOM*np.log2(10)))
POW_MAX_DIST = int(np.ceil(MAX_OOM*np.log2(10)))

# will sample these distances
scales = 2. ** np.arange(POW_MIN_DIST, POW_MAX_DIST + 1)

# offset each direction by a different amount, so that we get a nice spread of distances, to avoid artefacts/robustness
offset_per_direction = 2. ** np.linspace(0, 1, len(dirs), endpoint=False)

# Initialize an empty list to hold tuples for MultiIndex
index_list = []

# Calculate distances using broadcasting
distances = np.outer(scales, offset_per_direction).T.flatten()

# Repeat directions accordingly
directions = np.repeat(list(dirs.keys()), len(scales))

# Create DataFrame directly from numpy arrays
df = pd.DataFrame(index=pd.MultiIndex.from_arrays([directions, distances], names=['Direction', 'Distance']))

df


Direction,Distance
Ascent,0.0625
Ascent,0.125
Ascent,0.25
Ascent,0.5
Ascent,1.0
Descent,0.070154
Descent,0.140308
Descent,0.280616
Descent,0.561231
Descent,1.122462


In [6]:
# fill the dataframe with the models
for dir_name, dist in df.index:
    shift = scale(dirs[dir_name], dist)     # shift = direction * distance
    location = iadd(shift, center_model)    # location = center + shift
    df.loc[(dir_name, dist), 'Model'] = location

df

Unnamed: 0_level_0,Unnamed: 1_level_0,Model
Direction,Distance,Unnamed: 2_level_1
Ascent,0.0625,"SimpleMLP(\n (fc1): Linear(in_features=784, o..."
Ascent,0.125,"SimpleMLP(\n (fc1): Linear(in_features=784, o..."
Ascent,0.25,"SimpleMLP(\n (fc1): Linear(in_features=784, o..."
Ascent,0.5,"SimpleMLP(\n (fc1): Linear(in_features=784, o..."
Ascent,1.0,"SimpleMLP(\n (fc1): Linear(in_features=784, o..."
Descent,0.070154,"SimpleMLP(\n (fc1): Linear(in_features=784, o..."
Descent,0.140308,"SimpleMLP(\n (fc1): Linear(in_features=784, o..."
Descent,0.280616,"SimpleMLP(\n (fc1): Linear(in_features=784, o..."
Descent,0.561231,"SimpleMLP(\n (fc1): Linear(in_features=784, o..."
Descent,1.122462,"SimpleMLP(\n (fc1): Linear(in_features=784, o..."


Parallel evaluation with vmap()


In [14]:
import torch
from torch.func import stack_module_state, functional_call
from torch import vmap
from time import perf_counter
from contextlib import contextmanager
from copy import deepcopy


ensemble_list = [center_model] + list(df['Model'])
params, buffers = stack_module_state(ensemble_list)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

with profiler(f'Moving stacked ensemble to {device}'):
    params = {name: tensor.to(device) for name, tensor in params.items()}
    buffers = {name: tensor.to(device) for name, tensor in buffers.items()}

stacked_ensemble = (params, buffers)

meta_model = deepcopy(df['Model'][0])
meta_model = meta_model.to('meta')

# Define the function to vmap over
def model_function(params_and_buffers, x):
    return functional_call(meta_model, params_and_buffers, (x,))

batch_losses = []

with profiler(f'Evaluating stacked ensemble on {device}'):
    for data, target in dataloader:
        data, target = data.to(device), target.to(device)
        
        # Compute predictions for all models in the ensemble using vmap
        predictions = vmap(model_function, in_dims=(0, None))(stacked_ensemble, data)
        
        # Increase precision of predictions before computing the loss
        predictions = predictions.double()
        
        # Compute the loss
        # Use vmap again to compute the loss for each model separately
        losses = vmap(criterion, in_dims=(0, None))(predictions, target)
        
        batch_losses.append(losses)

# Stack all the losses and compute the mean across batches but not models
loss_tensor = torch.mean(torch.stack(batch_losses), dim=0)

# Free VRAM
del stacked_ensemble
torch.cuda.empty_cache()

Moving stacked ensemble to cpu, finished in 0.00s
Evaluating stacked ensemble on cpu

, finished in 25.24s
Loss: tensor([2.3049, 2.3354, 2.3658, 2.4234, 2.5256, 2.7241, 2.2715, 2.2400, 2.1833,
        2.0901, 1.9345, 2.3050, 2.3051, 2.3052, 2.3056, 2.3065, 2.3049, 2.3048,
        2.3047, 2.3045, 2.3042, 2.3048, 2.3048, 2.3046, 2.3044, 2.3042, 2.3051,
        2.3052, 2.3055, 2.3062, 2.3078], dtype=torch.float64)


In [15]:
from torch.func import stack_module_state
from torch.func import functional_call
from torch.nn.functional import cross_entropy
from torch import vmap

# add the center model
ensemble = [center_model] + list(df['Model'])

# stack to prepare for vmap
params, buffers = stack_module_state(ensemble)

with profiler(f'Moving stacked ensemble to {device}'):
    # Move parameters to the specified device
    params = {name: tensor.to(device) for name, tensor in params.items()}
    # Move buffers to the specified device
    buffers = {name: tensor.to(device) for name, tensor in buffers.items()}


stacked_ensemble = (params, buffers)

# Construct a "stateless" version of one of the models. It is "stateless" in
# the sense that the parameters are meta Tensors and do not have storage.
meta_model = deepcopy(df['Model'][0])
meta_model = meta_model.to('meta')

def meta_model_loss(params_and_buffers, data, target):
    predictions = functional_call(meta_model, params_and_buffers, (data,))
    predictions = predictions.double()
    loss = criterion(predictions, target)  # Compute loss per sample
    return loss

ensembled_loss = vmap(meta_model_loss, in_dims=(0, None, None)) # adds an ensemble dimension to the first two arguments (params, buffers)
# data and target are not ensembled over, so we don't add an ensemble dimension for them
batch_losses = []

with profiler(f'Evaluating stacked ensemble on {device}'):
    for data, target in dataloader:
        data, target = data.to(device), target.to(device)
        batch_loss = ensembled_loss(stacked_ensemble, data, target)
        batch_losses.append(batch_loss)

loss_tensor = torch.mean(torch.stack(batch_losses), dim=0)

# free the VRAM. stacked_ensemble takes up a LOT of space (tens of GB), but I'm not sure why.
# The models should be small even all together, but I think that the vmap-ensemble is duplicating the data batch for each model.
# I looks like memory is around len(ensemble) * BATCH_SIZE * 4 bytes. Why is it scaling with the number of models?!
# Shouldn't it len(ensemble) * sizeof(model) + BATCH_SIZE * 4 bytes? Why does vmap duplicate the data for each model?
# print VRAM usage
del stacked_ensemble
torch.cuda.empty_cache()

Moving stacked ensemble to cpu, finished in 0.00011s
Evaluating stacked ensemble on cpu

, finished in 2.4e+01s


In [None]:
# Convert the loss tensor to a list and unpack
center_loss, *dir_losses = loss_tensor.tolist()

# Ensure the length of dir_losses matches the DataFrame length
assert len(dir_losses) == len(df)

# Add the directional losses to the DataFrame
df['Loss'] = dir_losses

# Add the center_loss to each Direction, for Distance=0
for direction in df.index.get_level_values('Direction').unique():
    df.loc[(direction, 0.0), 'Loss'] = center_loss

df.sort_index(inplace=True)

df

Unnamed: 0_level_0,Unnamed: 1_level_0,Model,Loss
Direction,Distance,Unnamed: 2_level_1,Unnamed: 3_level_1
Ascent,0.0,,2.310358
Ascent,0.0625,"SimpleMLP(\n (fc1): Linear(in_features=784, o...",2.340195
Ascent,0.125,"SimpleMLP(\n (fc1): Linear(in_features=784, o...",2.370295
Ascent,0.25,"SimpleMLP(\n (fc1): Linear(in_features=784, o...",2.430307
Ascent,0.5,"SimpleMLP(\n (fc1): Linear(in_features=784, o...",2.553088
Ascent,1.0,"SimpleMLP(\n (fc1): Linear(in_features=784, o...",2.8652
Descent,0.0,,2.310358
Descent,0.070154,"SimpleMLP(\n (fc1): Linear(in_features=784, o...",2.278104
Descent,0.140308,"SimpleMLP(\n (fc1): Linear(in_features=784, o...",2.248296
Descent,0.280616,"SimpleMLP(\n (fc1): Linear(in_features=784, o...",2.198165


In [None]:
import plotly.graph_objects as go

# Create a Plotly figure
fig = go.Figure()

# Plotting directly from grouped data
for direction, data in df.groupby(level='Direction'):
    fig.add_trace(go.Scatter(
        x=data.index.get_level_values('Distance'),
        y=data['Loss'],
        mode='lines+markers',
        name=direction
    ))

# Set axis labels and title with adjusted figure dimensions
fig.update_layout(
    xaxis_title='Distance from Center in Parameter Space',
    yaxis_title='Loss',
    title=f'Loss Landscape of {center_model.__class__.__name__} on {dataset.__class__.__name__}',
    legend_title='Direction',
    template='seaborn',
    height=600,
)


# Show the figure
fig.show()


Now let's measure the roughness.

In [None]:
# TODO: Add nice explanation of the roughness calculation
df['Roughness'] = np.nan

# Calculate the roughness, one direction at a time
for direction, group in df.groupby(level='Direction'):
    distA = group.index.get_level_values('Distance').to_numpy()
    losses = group['Loss'].to_numpy()

    # Named slices, so that 
    # distA[x] really corresponds to the distance between A and X ∈ {A, B, C},
    # losses[x] really corresponds to the loss at X ∈ {A, B, C}
    a = 0
    b = slice(1, -1)
    c = slice(2, None)

    # Ensure distances and slices are correct
    assert np.all(distA[c] == 2 * distA[b]), "Distance check failed"

    # Calculate paths between A, B, and C
    AB = np.sqrt((losses[b] - losses[a]) ** 2 + distA[b] ** 2)
    BC = np.sqrt((losses[c] - losses[b]) ** 2 + distA[b] ** 2)
    AC = np.sqrt((losses[c] - losses[a]) ** 2 + distA[c] ** 2)

    # Roghness is how much the path via B (A->B->C) is longer than the direct path A->C
    roughness = (AB + BC) / AC

    # Update DataFrame with roughness values
    df.loc[(direction,), 'Roughness'] = np.concatenate([[np.nan], roughness, [np.nan]])

df

Unnamed: 0_level_0,Unnamed: 1_level_0,Model,Loss,Roughness
Direction,Distance,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
Ascent,0.0,,2.310358,
Ascent,0.0625,"SimpleMLP(\n (fc1): Linear(in_features=784, o...",2.340195,1.000001
Ascent,0.125,"SimpleMLP(\n (fc1): Linear(in_features=784, o...",2.370295,1.0
Ascent,0.25,"SimpleMLP(\n (fc1): Linear(in_features=784, o...",2.430307,1.000011
Ascent,0.5,"SimpleMLP(\n (fc1): Linear(in_features=784, o...",2.553088,1.001407
Ascent,1.0,"SimpleMLP(\n (fc1): Linear(in_features=784, o...",2.8652,
Descent,0.0,,2.310358,
Descent,0.070154,"SimpleMLP(\n (fc1): Linear(in_features=784, o...",2.278104,1.000106
Descent,0.140308,"SimpleMLP(\n (fc1): Linear(in_features=784, o...",2.248296,1.000672
Descent,0.280616,"SimpleMLP(\n (fc1): Linear(in_features=784, o...",2.198165,1.002076


In [None]:
import plotly.graph_objects as go

# Create a Plotly figure
fig = go.Figure()

# Plotting roughness for each direction using grouped data
for direction, data in df.groupby(level='Direction'):
    fig.add_trace(go.Scatter(
        x=data.index.get_level_values('Distance'),
        y=data['Roughness'],
        mode='markers',
        name=direction
    ))

# Set axis labels, title, and configure log scale on x-axis
fig.update_layout(
    xaxis_title='Coarse Graining Scale',
    yaxis_title='Roughness',
    title=f'Scale Dependent Roughness of {center_model.__class__.__name__} on {dataset.__class__.__name__}',
    legend_title='Direction',
    template='seaborn',
    xaxis={
        'type': 'log',
        'dtick': 1,  # Tick every power of ten
    },
    height=600
)

# Display the figure
fig.show()
