In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Here's a simple MLP
class SimpleMLP(nn.Module):
    def __init__(self):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.flatten(1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        return x
        
from torchvision.datasets import MNIST
from torchvision import transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Use a standard MNIST normalization
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

dataset = MNIST(root='./data', train=False, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

num_models = 3

models = [SimpleMLP().to(device) for _ in range(num_models)]
data, target = next(iter(dataloader))
data = data.to(device)
target = target.to(device)

# predictions with for loop
predictions_loop = torch.stack([model(data) for model in models])
loss_loop = torch.stack([F.cross_entropy(predictions.double(), target) for predictions in predictions_loop])
print(loss_loop.shape)

torch.Size([3])


Parallel evaluation with vmap()


In [12]:
from torch.func import stack_module_state
from icecream import ic

params, buffers = stack_module_state(models)

from torch.func import functional_call
from torch.nn.functional import cross_entropy
import copy

# Construct a "stateless" version of one of the models. It is "stateless" in
# the sense that the parameters are meta Tensors and do not have storage.
meta_model = copy.deepcopy(models[0])
meta_model = meta_model.to('meta')

def meta_model_loss(params, buffers, data, target):
    predictions = functional_call(meta_model, (params, buffers), (data,))
    predictions = predictions.double()
    loss = cross_entropy(predictions, target)  # Compute loss per sample
    return loss

from torch import vmap
ensembled_loss = vmap(meta_model_loss, in_dims=(0, 0, None, None)) # adds an ensemble dimension to the first two arguments (params, buffers)
# data and target are not ensembled over, so we don't add an ensemble dimension for them
loss_vmap = ensembled_loss(params, buffers, data, target)


assert torch.allclose(loss_loop, loss_vmap, atol=1e-3, rtol=1e-5)
