In [1]:
import torch.nn as nn
import torch
import torch.nn.functional as F

In [2]:
class FilterGeneratingNetwork(nn.Module):
    def __init__(self, s, in_chan, out_chan):
        super().__init__()
        
        self.net = nn.Sequential(
            nn.Conv2d(in_chan, 32, 3),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 16, 3),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Flatten(),
            nn.LazyLinear(256),
            nn.ReLU(),
            nn.Linear(256, s*s*in_chan*out_chan)
        )

    def forward(self, x):
        return self.net(x)
    
class DynamicFilterNetworkLayer(nn.Module):
    def __init__(self, s, in_chan, out_chan):
        super().__init__()
        self.in_chan = in_chan
        self.out_chan = out_chan
        self.s = s

        self.fitler_generator = FilterGeneratingNetwork(s, in_chan, out_chan)
    
    def forward(self, x):
        batch_size, _, h, w = x.shape
        filters = self.fitler_generator(x)
        filters = filters.view(-1, self.in_chan, self.s, self.s)
        x = x.view(-1, h, w)
        out = F.conv2d(x, filters, groups=batch_size, padding=1)
        out = out.view(batch_size, -1, h, w)
        return out

In [3]:
images = torch.rand(100, 20, 32, 32)
net = DynamicFilterNetworkLayer(3, 20, 64)

In [4]:
net(images).shape

torch.Size([100, 64, 32, 32])