<a href="https://colab.research.google.com/github/matthiasdellago/Loss-Cartography/blob/main/ensemble_approach.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torchvision import transforms

class SimpleMLP(nn.Module):
    def __init__(self):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.flatten(1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        return x

center_model = SimpleMLP()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Use a standard MNIST normalization
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

dataset = MNIST(root='./data', download=True, transform=transform)
BATCH_SIZE = int(len(dataset)/10) # Big but not too big, because I think that the batch is duplicated for each model in the vmap-enselmble
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, num_workers = 2)

# criterion functional
criterion = F.cross_entropy

# parameter subspaces to investigate.
# for each string, the union of all modules that contain the string in their name will be considered.
# we will sample a random direction in that subspace + the gradient ascent and descent directions projected onto that subspace.
SUBSPACES = ['fc1', 'weight', 'bias']

#check if we are on GPU, otherwise just proof of concept
if torch.cuda.is_available():
    NUM_DIRS = 3
    MAX_OOM = -1
    MIN_OOM = -11 # go down until all roughness disappears due to numerical precision.
else:
    NUM_DIRS = 0 # number of random directions to sample, in addition to the gradient ascent + descent, and radially in + out.
    MAX_OOM = 0 # maximum order of magnitude to sample
    MIN_OOM = -1 # minimum order of magnitude to sample

print(f'Running on {device}')

Running on cpu


Define utility functions for parameter space arithmetic. iadd is inplace, all others create a new model.

In [2]:
from copy import deepcopy

@torch.no_grad
def iadd(a:nn.Module, b:nn.Module):
    """add the parameters of b to a, inplace"""
    for a_param, b_param in zip(a.parameters(), b.parameters()):
        a_param.data.add_(b_param.data)
    return a

@torch.no_grad
def add(a_old:nn.Module, b:nn.Module):
    """add the parameters of b to a"""
    a = deepcopy(a_old)
    return iadd(a, b)

@torch.no_grad
def scale(a_old:nn.Module, b:float):
    """scale the parameters of a by b"""
    a = deepcopy(a_old)
    for a_param in a.parameters():
        a_param.data.mul_(b)
    return a

@torch.no_grad
def sub(a:nn.Module, b:nn.Module):
    """subtract the parameters of b from a"""
    neg_b = scale(b, -1)
    return iadd(neg_b, a)

@torch.no_grad
def abs(a:nn.Module):
    """return the norm of the parameters of a"""
    return torch.norm(torch.cat([param.data.flatten() for param in a.parameters()]))

@torch.no_grad
def norm(a:nn.Module):
    """normalize the parameters of a"""
    return scale(a, 1/abs(a))

@torch.no_grad
def rand_like(a: nn.Module):
    """normalised random direction in the parameter space of the model."""
    rand = deepcopy(a)
    for param in rand.parameters():
        param.data = torch.randn_like(param.data)
    return norm(rand)

@torch.no_grad
def project_to_module(a: nn.Module, target_subspace: str):
    """
    normalised projection onto the subspace of the model specified by `target_subspace`.
    subspace: Any parameter whose name contains the `target_subspace` string.
    """
    projection = deepcopy(a)
    for name, param in projection.named_parameters():
        if target_subspace not in name:
            param.data.zero_()
    return norm(projection)

Define profiling context manager: measure execution time and GPU RAM usage before and after.

In [3]:
from time import perf_counter
from contextlib import contextmanager
import torch

@contextmanager
def profiler(description, length=80, pad_char=':'):
    print('\n'+description.center(length, pad_char))
    if torch.cuda.is_available():
        print(f'GPU RAM before execution: Allocated: {torch.cuda.memory_allocated()/1e9:.2f} GB | Reserved: {torch.cuda.memory_reserved()/1e9:.2f} GB')
    start = perf_counter()
    yield
    if torch.cuda.is_available():
        print(f'GPU RAM after execution:  Allocated: {torch.cuda.memory_allocated()/1e9:.2f} GB | Reserved: {torch.cuda.memory_reserved()/1e9:.2f} GB')
    print(f'Finished {description} in {perf_counter() - start:.2f} s'.center(length, pad_char))

with profiler('Profiler example'):
    print('Computation goes here...')


::::::::::::::::::::::::::::::::Profiler example::::::::::::::::::::::::::::::::
Computation goes here...
::::::::::::::::::::::Finished Profiler example in 0.00 s:::::::::::::::::::::::


In [4]:
# create a model to find the gradient
grad_model = deepcopy(center_model)
grad_model = grad_model.to(device)

optimizer = torch.optim.SGD(grad_model.parameters(), lr=1)
optimizer.zero_grad()
torch.set_grad_enabled(True)

# take a single step of gradient descent
# gradient accumulate over the whole dataset
with profiler('1 step of gradient descent over the whole dataset'):
    for data, target in dataloader:
        data, target = data.to(device), target.to(device)
        output = grad_model(data)
        loss = criterion(output, target)
        loss.backward()

optimizer.step()
# from this point on we wont need any gradients
torch.set_grad_enabled(False)


# calculate the direction of gradient decent from the center model
grad_model = grad_model.to('cpu')
dir_ascent = sub(center_model, grad_model)
dir_ascent = norm(dir_ascent)
dir_descent = scale(dir_ascent, -1)

del grad_model



:::::::::::::::1 step of gradient descent over the whole dataset::::::::::::::::


:::::Finished 1 step of gradient descent over the whole dataset in 17.55 s::::::


In [5]:
radial = norm(center_model)

dirs = {
    'Ascent': dir_ascent,
    'Descent': dir_descent,
    'Radially Out': radial,
    'Radially In': scale(radial,-1),
    }

# add random directions
dirs.update({f'Random {i}': rand_like(center_model) for i in range(NUM_DIRS)})

# add projections of each dir in dirs onto each SUBSPACE
# Add projections of each direction in dirs onto each subspace
for dir_name, direction in list(dirs.items()):
    for subspace in SUBSPACES:
        projected_dir = project_to_module(direction, subspace)
        new_key = f'{dir_name} ⋅ P({subspace})'
        dirs[new_key] = projected_dir

# check that they are all normalised
assert all(torch.isclose(abs(d), torch.tensor(1.0)) for d in dirs.values())

In [6]:
import pandas as pd
import numpy as np


# generate distances each twice as large as the previous, from 10^MIN_OOM to 10^MAX_OOM
scales = 2 ** np.arange(MIN_OOM * np.log2(10), MAX_OOM * np.log2(10) + 1)

# offset each direction by a different amount, so that we get a nice spread of distances, to avoid artefacts/robustness
offset_per_direction = 2. ** np.linspace(0, 1, len(dirs), endpoint=False)

# Initialize an empty list to hold tuples for MultiIndex
index_list = []

# Calculate distances using broadcasting
distances = np.outer(scales, offset_per_direction).T.flatten()

# Repeat directions accordingly
directions = np.repeat(list(dirs.keys()), len(scales))

# Create DataFrame directly from numpy arrays
df = pd.DataFrame(index=pd.MultiIndex.from_arrays([directions, distances], names=['Direction', 'Distance']))

df


Direction,Distance
Ascent,0.062500
Ascent,0.125000
Ascent,0.250000
Ascent,0.500000
Ascent,1.000000
...,...
Radially In ⋅ P(bias),0.119700
Radially In ⋅ P(bias),0.239401
Radially In ⋅ P(bias),0.478802
Radially In ⋅ P(bias),0.957603


In [7]:
# fill the dataframe with the models
for dir_name, dist in df.index:
    shift = scale(dirs[dir_name], dist)     # shift = direction * distance
    location = iadd(shift, center_model)    # location = center + shift
    df.loc[(dir_name, dist), 'Model'] = location

df

Unnamed: 0_level_0,Unnamed: 1_level_0,Model
Direction,Distance,Unnamed: 2_level_1
Ascent,0.062500,"SimpleMLP(\n (fc1): Linear(in_features=784, o..."
Ascent,0.125000,"SimpleMLP(\n (fc1): Linear(in_features=784, o..."
Ascent,0.250000,"SimpleMLP(\n (fc1): Linear(in_features=784, o..."
Ascent,0.500000,"SimpleMLP(\n (fc1): Linear(in_features=784, o..."
Ascent,1.000000,"SimpleMLP(\n (fc1): Linear(in_features=784, o..."
...,...,...
Radially In ⋅ P(bias),0.119700,"SimpleMLP(\n (fc1): Linear(in_features=784, o..."
Radially In ⋅ P(bias),0.239401,"SimpleMLP(\n (fc1): Linear(in_features=784, o..."
Radially In ⋅ P(bias),0.478802,"SimpleMLP(\n (fc1): Linear(in_features=784, o..."
Radially In ⋅ P(bias),0.957603,"SimpleMLP(\n (fc1): Linear(in_features=784, o..."


Parallel evaluation with vmap()


In [8]:
from torch.func import stack_module_state, functional_call
from torch import vmap

# add the center model
ensemble_list = [center_model] + list(df['Model'])

print(f'Ensemble size: {len(ensemble_list)}')

# stack to prepare for vmap
params, buffers = stack_module_state(ensemble_list)

with profiler(f'Moving stacked ensemble to {device}'):
    params = {name: tensor.to(device) for name, tensor in params.items()}
    buffers = {name: tensor.to(device) for name, tensor in buffers.items()}

stacked_ensemble = (params, buffers)

# Construct a "stateless" version of one of the models. It is "stateless" in
# the sense that the parameters are meta Tensors and do not have storage.
meta_model = deepcopy(center_model).to('meta')

def meta_model_loss(params_and_buffers, data, target):
    """Compute the loss of a set of params on a batch of data, via the meta model"""
    predictions = functional_call(meta_model, params_and_buffers, (data,))
    predictions = predictions.double() # double precision for better resoltion of small scales
    loss = criterion(predictions, target)
    return loss

# define a loss function that takes the stacked ensemble params, data, and target
# in_dims=(0, None, None) adds an ensemble dimension to the first argument 'params_and_buffers' but not to the other two arguments
vmap_loss = vmap(meta_model_loss, in_dims=(0, None, None))

batch_losses = []
with profiler(f'Evaluating stacked ensemble on {device}'):
    for data, target in dataloader:
        data, target = data.to(device), target.to(device)
        batch_loss = vmap_loss(stacked_ensemble, data, target)
        batch_losses.append(batch_loss)

with profiler('Freeing data and target'):
    del data, target
    torch.cuda.empty_cache()

with profiler('Freeing stacked ensemble'):
    del stacked_ensemble
    torch.cuda.empty_cache()

loss_tensor = torch.mean(torch.stack(batch_losses), dim=0)

Ensemble size: 81

:::::::::::::::::::::::::Moving stacked ensemble to cpu:::::::::::::::::::::::::
:::::::::::::::Finished Moving stacked ensemble to cpu in 0.00 s::::::::::::::::

:::::::::::::::::::::::Evaluating stacked ensemble on cpu:::::::::::::::::::::::
:::::::::::::Finished Evaluating stacked ensemble on cpu in 32.86 s:::::::::::::

::::::::::::::::::::::::::::Freeing data and target:::::::::::::::::::::::::::::
:::::::::::::::::::Finished Freeing data and target in 0.00 s:::::::::::::::::::

::::::::::::::::::::::::::::Freeing stacked ensemble::::::::::::::::::::::::::::
::::::::::::::::::Finished Freeing stacked ensemble in 0.00 s:::::::::::::::::::


In [9]:
# Convert the loss tensor to a list and unpack
center_loss, *dir_losses = loss_tensor.tolist()

# Ensure the length of dir_losses matches the DataFrame length
assert len(dir_losses) == len(df)

# Add the directional losses to the DataFrame
df['Loss'] = dir_losses

# Add the center_loss to each Direction, for Distance=0
for direction in df.index.get_level_values('Direction').unique():
    df.loc[(direction, 0.0), 'Loss'] = center_loss

df.sort_index(inplace=True)

df

Unnamed: 0_level_0,Unnamed: 1_level_0,Model,Loss
Direction,Distance,Unnamed: 2_level_1,Unnamed: 3_level_1
Ascent,0.000000,,2.310385
Ascent,0.062500,"SimpleMLP(\n (fc1): Linear(in_features=784, o...",2.345272
Ascent,0.125000,"SimpleMLP(\n (fc1): Linear(in_features=784, o...",2.380990
Ascent,0.250000,"SimpleMLP(\n (fc1): Linear(in_features=784, o...",2.452914
Ascent,0.500000,"SimpleMLP(\n (fc1): Linear(in_features=784, o...",2.598581
...,...,...,...
Radially Out ⋅ P(weight),0.100656,"SimpleMLP(\n (fc1): Linear(in_features=784, o...",2.310738
Radially Out ⋅ P(weight),0.201311,"SimpleMLP(\n (fc1): Linear(in_features=784, o...",2.311108
Radially Out ⋅ P(weight),0.402623,"SimpleMLP(\n (fc1): Linear(in_features=784, o...",2.311901
Radially Out ⋅ P(weight),0.805245,"SimpleMLP(\n (fc1): Linear(in_features=784, o...",2.313717


In [10]:


import plotly.graph_objects as go

# Create a Plotly figure
fig = go.Figure()

# Plotting directly from grouped data
for direction, data in df.groupby(level='Direction'):
    fig.add_trace(go.Scatter(
        x=data.index.get_level_values('Distance'),
        y=data['Loss'],
        mode='lines+markers',
        name=direction
    ))

# Set axis labels and title with adjusted figure dimensions
fig.update_layout(
    title=f'Loss Landscape of {center_model.__class__.__name__} on {dataset.__class__.__name__}',
    xaxis_title='Distance from Center in Parameter Space',
    yaxis_title='Loss',
    legend_title='Direction',
    template='seaborn',
    height=600,
)


# Show the figure
fig.show()


Now let's measure the roughness.

In [11]:
# TODO: Add nice explanation of the roughness calculation
df['Roughness'] = np.nan

# Calculate the roughness, one direction at a time
for direction, group in df.groupby(level='Direction'):
    distA = group.index.get_level_values('Distance').to_numpy()
    loss = group['Loss'].to_numpy()

    # Named slices, so that
    # distA[x] really corresponds to the distance between A and X ∈ {A, B, C},
    # loss[x] really corresponds to the loss at X ∈ {A, B, C}
    a = 0
    b = slice(1, -1)
    c = slice(2, None)

    # Ensure distances and slices are correct
    assert np.all(distA[c] == 2 * distA[b]), "Distance check failed"

    # Calculate paths between A, B, and C
    AB = np.sqrt((loss[b] - loss[a]) ** 2 + distA[b] ** 2)
    BC = np.sqrt((loss[c] - loss[b]) ** 2 + distA[b] ** 2)
    AC = np.sqrt((loss[c] - loss[a]) ** 2 + distA[c] ** 2)

    # Roghness is how much the path via B (A->B->C) is longer than the direct path A->C
    # subtracting 1, then we can use log scale better!
    roughness = (AB + BC) / AC - 1

    # Update DataFrame with roughness values
    df.loc[(direction,), 'Roughness'] = np.concatenate([[np.nan], roughness, [np.nan]])

df

Unnamed: 0_level_0,Unnamed: 1_level_0,Model,Loss,Roughness
Direction,Distance,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
Ascent,0.000000,,2.310385,
Ascent,0.062500,"SimpleMLP(\n (fc1): Linear(in_features=784, o...",2.345272,1.269395e-05
Ascent,0.125000,"SimpleMLP(\n (fc1): Linear(in_features=784, o...",2.380990,7.937201e-06
Ascent,0.250000,"SimpleMLP(\n (fc1): Linear(in_features=784, o...",2.452914,1.109713e-05
Ascent,0.500000,"SimpleMLP(\n (fc1): Linear(in_features=784, o...",2.598581,1.135639e-03
...,...,...,...,...
Radially Out ⋅ P(weight),0.100656,"SimpleMLP(\n (fc1): Linear(in_features=784, o...",2.310738,3.526662e-09
Radially Out ⋅ P(weight),0.201311,"SimpleMLP(\n (fc1): Linear(in_features=784, o...",2.311108,1.502231e-08
Radially Out ⋅ P(weight),0.402623,"SimpleMLP(\n (fc1): Linear(in_features=784, o...",2.311901,6.985093e-08
Radially Out ⋅ P(weight),0.805245,"SimpleMLP(\n (fc1): Linear(in_features=784, o...",2.313717,3.773381e-07


In [12]:
# new nomenclature: grit := scale dependent roughness

# Create a Plotly figure
fig = go.Figure()

# Plotting roughness for each direction using grouped data
for direction, data in df.groupby(level='Direction'):
    fig.add_trace(go.Scatter(
        x=data.index.get_level_values('Distance'),
        y=data['Roughness'],
        mode='markers',
        name=direction
    ))

# Set axis labels, title, and configure log scale on x-axis
fig.update_layout(
    title=f'Scale Dependent Roughness of {center_model.__class__.__name__} on {dataset.__class__.__name__}',
    xaxis={
        'title': 'Coarse Graining Scale',
        'type': 'log',
        'dtick': 1,  # Tick every power of ten
    },
    yaxis ={
        'title': 'Roughness',
        'type': 'log',
    },
    legend_title='Direction',
    template='seaborn',
    height=600
)

# Display the figure
fig.show()


Consider the formula for the second derivative of the loss function around point x:
$$
l''(x) = \lim_{{h \to 0}} \frac{l(x) - 2l(x+h) + l(x+2h)}{h^2}
$$

Now let's consider this not as a limit but at a finite h.
$$
curvature(x,h) = \frac{l(x) - 2l(x+h) + l(x+2h)}{h^2}
$$

But this is not what we want: It has dimensions of $loss/distance^2$, and is obviously scale dependet - a curved scaled up by a factor $2$ has a curvature half as big.
To fix this we will multiply by $h$.

$$
h \cdot curvature(x,h) = \frac{l(x) - 2l(x+h) + l(x+2h)}{h}
$$

Now let's set x to x to our central model and calculate the curvature for different h's.
Each h is a different scale, and so we get scale dependent curvature. Let's call it "grit".

$$
Grit_{x}(h) = \frac{l(x) - 2l(x+h) + l(x+2h)}{h}
$$

In [17]:
df['Grit'] = np.nan

# Calculate the grit, one direction at a time
for direction, group in df.groupby(level='Direction'):
    dist_h = group.index.get_level_values('Distance').to_numpy()
    l = group['Loss'].to_numpy()

    # Named slices, so that x
    x = 0
    x_plus_h = slice(1, -1)
    x_plus_2h = slice(2, None)

    # check that the distances are correct
    assert np.all(dist_h[x_plus_2h] == 2 * dist_h[x_plus_h]), "Distances are not geometrically spaced"

    # calculate the grit

    grit = (l[x] - 2*l[x_plus_h] + l[x_plus_2h]) / dist_h[x_plus_h]

    # Update DataFrame with roughness values
    df.loc[(direction,), 'Grit'] = np.concatenate([[np.nan], grit, [np.nan]])

df

Unnamed: 0_level_0,Unnamed: 1_level_0,Model,Loss,Roughness,Grit
Direction,Distance,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
Ascent,0.000000,,2.310385,,
Ascent,0.062500,"SimpleMLP(\n (fc1): Linear(in_features=784, o...",2.345272,1.269395e-05,0.013292
Ascent,0.125000,"SimpleMLP(\n (fc1): Linear(in_features=784, o...",2.380990,7.937201e-06,0.010559
Ascent,0.250000,"SimpleMLP(\n (fc1): Linear(in_features=784, o...",2.452914,1.109713e-05,0.012552
Ascent,0.500000,"SimpleMLP(\n (fc1): Linear(in_features=784, o...",2.598581,1.135639e-03,0.134796
...,...,...,...,...,...
Radially Out ⋅ P(weight),0.100656,"SimpleMLP(\n (fc1): Linear(in_features=784, o...",2.310738,3.526662e-09,0.000168
Radially Out ⋅ P(weight),0.201311,"SimpleMLP(\n (fc1): Linear(in_features=784, o...",2.311108,1.502231e-08,0.000347
Radially Out ⋅ P(weight),0.402623,"SimpleMLP(\n (fc1): Linear(in_features=784, o...",2.311901,6.985093e-08,0.000748
Radially Out ⋅ P(weight),0.805245,"SimpleMLP(\n (fc1): Linear(in_features=784, o...",2.313717,3.773381e-07,0.001737


In [19]:
# Create a Plotly figure
fig = go.Figure()

# Plotting roughness for each direction using grouped data
for direction, data in df.groupby(level='Direction'):
    fig.add_trace(go.Scatter(
        x=data.index.get_level_values('Distance'),
        y=data['Grit'],
        mode='markers',
        name=direction
    ))

# Set axis labels, title, and configure log scale on x-axis
fig.update_layout(
    title=f'Grit of {center_model.__class__.__name__} on {dataset.__class__.__name__}',
    xaxis={
        'title': 'Coarse Graining Scale',
        'type': 'log',
        'dtick': 1,  # Tick every power of ten
    },
    yaxis ={
        'title': 'Grit',
    },
    legend_title='Direction',
    template='seaborn',
    height=600
)

# Display the figure
fig.show()

In [20]:
# Create a Plotly figure
fig = go.Figure()

# Plotting roughness for each direction using grouped data
for direction, data in df.groupby(level='Direction'):
    fig.add_trace(go.Scatter(
        x=data.index.get_level_values('Distance'),
        y=np.abs(data['Grit']),
        mode='lines+markers',
        name=direction
    ))

# Set axis labels, title, and configure log scale on x-axis
fig.update_layout(
    title=f'Grit Magnitude of {center_model.__class__.__name__} on {dataset.__class__.__name__}',
    xaxis={
        'title': 'Coarse Graining Scale',
        'type': 'log',
        'dtick': 1,  # Tick every power of ten
    },
    yaxis ={
        'title': '|Grit|',
        'type': 'log',
    },
    legend_title='Direction',
    template='seaborn',
    height=600
)

# Display the figure
fig.show()