# Group convolution benchmark

In [1]:
!wget --quiet --no-clobber http://storage.googleapis.com/gresearch/federated-vision-datasets/cifar10.zip
!unzip -qq -o cifar10.zip

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.profiler import profile, record_function, ProfilerActivity

In [3]:
# From: https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html
class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square convolution
        # kernel
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5*5 from image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square, you can specify with a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = torch.flatten(x, 1) # flatten all dimensions except the batch dimension
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [4]:
# From: https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html
class BatchedNet(nn.Module):

    def __init__(self, p):
        super(BatchedNet, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square convolution
        # kernel
        self.p = p
        self.ops = nn.Sequential(
            nn.Conv2d(3 * self.p, 6 * self.p, 5, groups=self.p, device="cuda"),
            nn.ReLU(),
            nn.MaxPool2d((2, 2)),
            nn.Conv2d(6 * self.p, 16* self.p, 5, groups=self.p, device="cuda"),
            nn.ReLU(),
            nn.MaxPool2d(2),
            #nn.Flatten(1, -1)
        )
        self.ops2 = nn.Sequential(
            nn.Conv1d(16 * 5 * 5 * self.p, 120 * self.p, kernel_size=1, groups=self.p, device="cuda"),
            nn.ReLU(),
            nn.Conv1d(120* self.p, 84* self.p, kernel_size=1, groups=self.p, device="cuda"),
            nn.ReLU(),
            nn.Conv1d(84* self.p, 10* self.p, kernel_size=1, groups=self.p, device="cuda")
        )

    def forward(self, x):
        x = self.ops(x)
        x = x.view(10, -1, 1)
        return self.ops2(x)


In [5]:
net = Net().to("cuda")
K = 10 # number of concurrent clients
batchedNet = BatchedNet(K).to("cuda")
images = [torch.randn((10, 3, 32, 32), device="cuda") for _ in range(K)]
batch = torch.stack(images).reshape((10, -1, 32, 32))

def trace_handler(p):
    output = p.key_averages().table(sort_by="self_cuda_time_total", row_limit=10)
    print(output)
    p.export_chrome_trace("trace_" + str(p.step_num) + ".json")

with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, schedule=torch.profiler.schedule(
        wait=1,
        warmup=1,
        active=1),
    on_trace_ready=trace_handler
) as prof:

    for i in range(3):
        with record_function("non-batched"):
            for im in images:
                net(im)
        with record_function("batched"):
            batchedNet(batch)
        prof.step()


-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                aten::cudnn_convolution        17.88%       4.179ms        26.88%       6.285ms     251.400us       5.440ms        69.34%       5.440ms     217.600us            25  
void cudnn::cnn::conv2d_grouped_direct_kernel<float,...         0.00%       0.000us         0.00%       0.000us       0.000us       2.920ms        37.22%       2.920ms       1.460ms             2  
void prec

### Profiler trace

![Profiler trace](./assets/profiler_trace_grouped_convolutions.png)