In [430]:
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(0)

# Here's a simple MLP
class SimpleMLP(nn.Module):
    def __init__(self):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.flatten(1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        return x

In [431]:
device = 'cuda'
num_models = 10

data = torch.randn(100, 64, 1, 28, 28, device=device)
targets = torch.randint(10, (6400,), device=device)

models = [SimpleMLP().to(device) for _ in range(num_models)]

# 选项 1：每个模型使用不同的小批量
minibatches = data[:num_models]
predictions_diff_minibatch_loop = [model(minibatch) for model, minibatch in zip(models, minibatches)]
# 选项 2：相同的小批量
minibatch = data[0]
predictions2 = [model(minibatch) for model in models]

In [433]:
# 用于vmap矢量化集成

# 首先，让我们通过堆叠每个参数将模型的状态组合在一起。例如，model[i].fc1.weight具有形状；我们将堆叠10 个模型中的每一个以产生形状为的大权重。[784, 128].fc1.weight[10, 784, 128]
from torch.func import stack_module_state

params, buffers = stack_module_state(models)

from torch.func import functional_call
import copy

# Construct a "stateless" version of one of the models. It is "stateless" in
# the sense that the parameters are meta Tensors and do not have storage.
base_model = copy.deepcopy(models[0])
base_model = base_model.to('meta')

def fmodel(params, buffers, x):
    return functional_call(base_model, (params, buffers), (x,))

# 选项 1：使用每个模型的不同小批量来获取预测。
print([p.size(0) for p in params.values()]) # show the leading 'num_models' dimension

assert minibatches.shape == (num_models, 64, 1, 28, 28) # verify minibatch has leading dimension of size 'num_models'

from torch import vmap

predictions1_vmap = vmap(fmodel)(params, buffers, minibatches) # 默认in_dim=0, out_dim=0

# verify the ``vmap`` predictions match the
assert torch.allclose(predictions1_vmap, torch.stack(predictions_diff_minibatch_loop), atol=1e-3, rtol=1e-5)

# 选项 2：使用相同的小批量数据获取预测。
predictions2_vmap = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch)

assert torch.allclose(predictions2_vmap, torch.stack(predictions2), atol=1e-3, rtol=1e-5)

[10, 10, 10, 10, 10, 10]


In [434]:
from torch.utils.benchmark import Timer
without_vmap = Timer(
    stmt="[model(minibatch) for model, minibatch in zip(models, minibatches)]",
    globals=globals())
with_vmap = Timer(
    stmt="vmap(fmodel)(params, buffers, minibatches)",
    globals=globals())
print(f'Predictions without vmap {without_vmap.timeit(100)}')
print(f'Predictions with vmap {with_vmap.timeit(100)}')

Predictions without vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7f4aa24c6ad0>
[model(minibatch) for model, minibatch in zip(models, minibatches)]
  1.67 ms
  1 measurement, 100 runs , 1 thread
Predictions with vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7f4a865d80d0>
vmap(fmodel)(params, buffers, minibatches)
  534.89 us
  1 measurement, 100 runs , 1 thread
