# Group convolution playground

In [1]:
!wget --quiet --no-clobber http://storage.googleapis.com/gresearch/federated-vision-datasets/cifar10.zip
!unzip -qq -o cifar10.zip

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.profiler import profile, record_function, ProfilerActivity

# torch.backends.cudnn.deterministic = True

## Standard model

In [2]:
# From: https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html
class Net(nn.Module):
    def __init__(self, *, const_init=None):
        super(Net, self).__init__()
        # convolutional layers
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)

        # fully connected layers
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

        if const_init is not None:
            for layer in [self.conv1, self.conv2, self.fc1, self.fc2, self.fc3]:
                layer.weight.data.fill_(const_init)
                layer.bias.data.fill_(const_init)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x), inplace=True), (2, 2))
        # If the size is a square, you can specify with a single number
        x = F.max_pool2d(F.relu(self.conv2(x), inplace=True), 2)

        x = torch.flatten(x, 1)  # flatten all dimensions except the batch dimension
        x = F.relu(self.fc1(x), inplace=True)
        x = F.relu(self.fc2(x), inplace=True)
        x = self.fc3(x)
        return x

## Batched model

In [3]:
class BatchedNet(nn.Module):
    def __init__(self, P, *, const_init=None):
        super(BatchedNet, self).__init__()

        self.P = P

        # convolutional layers
        self.conv1 = nn.Conv2d(3 * self.P, 6 * self.P, 5, groups=self.P)
        self.conv2 = nn.Conv2d(6 * self.P, 16 * self.P, 5, groups=self.P)

        # fully connected layers
        self.fc1 = nn.Conv1d(400 * self.P, 120 * self.P, kernel_size=1, groups=self.P)
        self.fc2 = nn.Conv1d(120 * self.P, 84 * self.P, kernel_size=1, groups=self.P)
        self.fc3 = nn.Conv1d(84 * self.P, 10 * self.P, kernel_size=1, groups=self.P)

        if const_init is not None:
            for layer in [self.conv1, self.conv2, self.fc1, self.fc2, self.fc3]:
                layer.weight.data.fill_(const_init)
                layer.bias.data.fill_(const_init)

    def forward(self, x):
        batch_size = x.shape[0]

        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x), inplace=True), (2, 2))
        # If the size is a square, you can specify with a single number
        x = F.max_pool2d(F.relu(self.conv2(x), inplace=True), 2)

        x = x.view(batch_size, -1, 1)
        x = F.relu(self.fc1(x), inplace=True)
        x = F.relu(self.fc2(x), inplace=True)
        x = self.fc3(x)
        return x

In [4]:
K = 10  # number of concurrent clients
batch_size = 5

net = Net(const_init=0.01)
batchedNet = BatchedNet(K, const_init=0.01)

# each entry of the images array represents the batch input for one client
# images contains tensors of shape (batch_size, n_channels, 32, 32)
images = [torch.randn((batch_size, 3, 32, 32)) for _ in range(K)]

# stack the images for all the clients
# batch is a tensor of shape (batch_size, n_channels * n_clients, 32, 32)
# batch[:, 0:3] is the batch input of client 0,
# batch[:, 3:6] is the batch input of client 1 and so on...
# batch = torch.stack(images).reshape((batch_size, -1, 32, 32))
batch = torch.stack(images, dim=1).flatten(1, 2)

# verify the images are stacked in the correct way
assert all(
    torch.allclose(image_from_batch, image)
    for image_from_batch, image in zip(torch.chunk(batch, K, dim=1), images)
)

### Test convolution

Verify the output of a standard convolution applied over all the clients batches is equivalent to the output produced by a grouped convolution.

In [5]:
ln = nn.Conv2d(3, 6, 5)
lbn = nn.Conv2d(3 * K, 6 * K, 5, groups=K)

ln.weight.data.fill_(0.01)
ln.bias.data.fill_(0.01)

lbn.weight.data.fill_(0.01)
lbn.bias.data.fill_(0.01)

# output of the "standard" convolution over each batch
output_ln = [ln(ims) for ims in images]

output_lbn = lbn(batch)
output_lbn = torch.chunk(output_lbn, K, dim=1)

assert all(
    torch.allclose(out_ln, out_lbn, atol=1e-5)
    for out_ln, out_lbn in zip(output_ln, output_lbn)
)

### Test first layer

In [6]:
def net_first_layer(x):
    conv1 = nn.Conv2d(3, 6, 5)
    conv2 = nn.Conv2d(6, 16, 5)

    conv1.weight.data.fill_(0.01)
    conv1.bias.data.fill_(0.01)

    conv2.weight.data.fill_(0.01)
    conv2.bias.data.fill_(0.01)

    x = F.max_pool2d(F.relu(conv1(x)), (2, 2))
    x = F.max_pool2d(F.relu(conv2(x)), 2)
    return x


# compute the output of the first layer of the standard model for each client batch
output_first_layer_net = [net_first_layer(client_batch) for client_batch in images]

In [7]:
c = nn.Conv2d(3 * K, 6 * K, 5, groups=K)
c.weight.data.fill_(0)
c.bias.data.fill_(1)
c(batch).shape

torch.Size([5, 60, 28, 28])

In [8]:
def batched_net_first_layer(x):
    conv1 = nn.Conv2d(3 * K, 6 * K, 5, groups=K)
    conv2 = nn.Conv2d(6 * K, 16 * K, 5, groups=K)

    conv1.weight.data.fill_(0.01)
    conv1.bias.data.fill_(0.01)

    conv2.weight.data.fill_(0.01)
    conv2.bias.data.fill_(0.01)

    x = F.max_pool2d(F.relu(conv1(x)), (2, 2))
    x = F.max_pool2d(F.relu(conv2(x)), 2)
    return x


# compute the output of the batched model for the single batch
output_first_layer_batched_net = batched_net_first_layer(batch)
# print(output_first_layer_batched_net.shape)

In [9]:
assert all(
    torch.allclose(batched_net_output, net_output, atol=1e-5)
    for batched_net_output, net_output in zip(
        torch.chunk(output_first_layer_batched_net, K, dim=1), output_first_layer_net
    )
)

### Test second layer

In [10]:
def net_second_layer(x):
    fc1 = nn.Linear(16 * 5 * 5, 120)
    fc2 = nn.Linear(120, 84)
    fc3 = nn.Linear(84, 10)

    fc1.weight.data.fill_(0.01)
    fc2.weight.data.fill_(0.01)
    fc3.weight.data.fill_(0.01)
    fc1.bias.data.fill_(0.01)
    fc2.bias.data.fill_(0.01)
    fc3.bias.data.fill_(0.01)

    x = torch.flatten(x, 1)  # flatten all dimensions except the batch dimension
    x = F.relu(fc1(x))
    x = F.relu(fc2(x))
    x = fc3(x)
    return x


output_second_layer_net = [net_second_layer(out1) for out1 in output_first_layer_net]

In [11]:
def batched_net_second_layer(x):
    fc1 = nn.Conv1d(16 * 5 * 5 * K, 120 * K, kernel_size=1, groups=K)
    fc2 = nn.Conv1d(120 * K, 84 * K, kernel_size=1, groups=K)
    fc3 = nn.Conv1d(84 * K, 10 * K, kernel_size=1, groups=K)

    fc1.weight.data.fill_(0.01)
    fc2.weight.data.fill_(0.01)
    fc3.weight.data.fill_(0.01)
    fc1.bias.data.fill_(0.01)
    fc2.bias.data.fill_(0.01)
    fc3.bias.data.fill_(0.01)

    x = x.view(batch_size, -1, 1)
    x = F.relu(fc1(x))
    x = F.relu(fc2(x))
    x = fc3(x)
    return x


output_second_layer_batched_net = batched_net_second_layer(
    output_first_layer_batched_net
).squeeze()
output_second_layer_batched_net = torch.chunk(output_second_layer_batched_net, K, dim=1)

In [12]:
assert all(
    torch.allclose(batched_net_output, net_output, atol=1e-5)
    for batched_net_output, net_output in zip(
        output_second_layer_batched_net, output_second_layer_net
    )
)

### Test full model

In [13]:
output_net = [net(ims) for ims in images]

output_batched_net = batchedNet(batch).squeeze()
output_batched_net = torch.chunk(output_batched_net, K, dim=1)


assert all(
    torch.allclose(o1, o2, atol=1e-5) for o1, o2 in zip(output_net, output_batched_net)
)

## Test parameters sharing

In [14]:
net = Net(const_init=None)  # DO NOT initialize weights
batchedNet = BatchedNet(K, const_init=None)  # DO NOT initialize weights

output_net = [net(ims) for ims in images]

output_batched_net = batchedNet(batch).squeeze()
output_batched_net = torch.chunk(output_batched_net, K, dim=1)

# now, the outputs should not be the same due to the different weights
assert all(
    torch.allclose(o1, o2, atol=1e-5) == False
    for o1, o2 in zip(output_net, output_batched_net)
)

In [15]:
reference_parameters = net.state_dict()

parameters = {
    key: torch.stack([params] * K).flatten(0, 1)
    for key, params in reference_parameters.items()
}

# adjust the size of the fc weights
parameters["fc1.weight"] = torch.unsqueeze(parameters["fc1.weight"], -1)
parameters["fc2.weight"] = torch.unsqueeze(parameters["fc2.weight"], -1)
parameters["fc3.weight"] = torch.unsqueeze(parameters["fc3.weight"], -1)

assert all(
    torch.allclose(reference_parameters["conv1.weight"], client_parameters)
    for client_parameters in torch.chunk(parameters["conv1.weight"], K, dim=0)
)
assert all(
    torch.allclose(reference_parameters["fc1.weight"], client_parameters.squeeze())
    for client_parameters in torch.chunk(parameters["fc1.weight"], K, dim=0)
)

In [16]:
output_net = [net(ims) for ims in images]

batchedNet.load_state_dict(parameters)

output_batched_net = batchedNet(batch).squeeze()
output_batched_net = torch.chunk(output_batched_net, K, dim=1)


# now, the outputs should not be the same due to the different weights
assert all(
    torch.allclose(o1, o2, atol=1e-5) for o1, o2 in zip(output_net, output_batched_net)
)

## Benchmark

In [17]:
net = net.to("cuda")
batchedNet = batchedNet.to("cuda")

images = [torch.randn((batch_size, 3, 32, 32), device="cuda") for _ in range(K)]
batch = torch.stack(images, dim=1).flatten(1, 2)

# wait everything is ready before starting the benchmark
torch.cuda.synchronize()

In [18]:
def trace_handler(p):
    output = p.key_averages().table(sort_by="cpu_time_total", row_limit=20)
    print(output)
    p.export_chrome_trace("trace.json")


with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    record_shapes=True,
    schedule=torch.profiler.schedule(wait=1, warmup=1, active=2),
    on_trace_ready=trace_handler,
) as prof:

    for i in range(4):
        with record_function("non-batched"):
            for im in images:
                out = net(im)
            torch.cuda.synchronize()

        with record_function("batched"):
            out = batchedNet(batch)
            torch.cuda.synchronize()

        prof.step()

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          ProfilerStep*         0.42%     139.000us        99.34%      32.674ms      16.337ms       0.000us         0.00%       9.101ms       4.551ms             2  
                                            non-batched        12.00%       3.946ms        82.31%      27.073ms      13.537ms       0.000us         0.00%       4.061ms       2.030ms             2  
         

In [19]:
del net
del batchedNet

### Profiler trace

![Profiler trace](./assets/profiler_trace_grouped_convolutions.png)