In [6]:
import torch
from torch import nn
import segmentation_models_pytorch as smp

def replace_to_group_norm(module):
    '''Recursively replace BatchNorm2d layers of a module by GroupNorm.'''

    for name, child in module.named_children():
        if len(list(child.children())) > 0:
            replace_to_group_norm(child)
            
        if isinstance(child, nn.BatchNorm2d):
            # logic used in the paper to set the number of groups
            num_features = child.num_features
            num_groups = 32 if (num_features % 32 == 0) else num_features
            #---
            layer = nn.GroupNorm(num_channels=num_features, num_groups=num_groups)
            setattr(module, name, layer)

model = smp.FPN(
    encoder_name='resnet50',        
    encoder_weights='imagenet'
)
replace_to_group_norm(model)
x = torch.rand(1, 3, 224, 224)
model(x)


tensor([[[[-4.0490, -3.7096, -3.3701,  ..., -2.7198, -2.8360, -2.9522],
          [-3.8295, -3.5299, -3.2302,  ..., -2.6179, -2.8165, -3.0151],
          [-3.6100, -3.3501, -3.0902,  ..., -2.5161, -2.7971, -3.0781],
          ...,
          [-1.7003, -1.7106, -1.7210,  ..., -3.4018, -3.6852, -3.9686],
          [-1.7309, -1.7544, -1.7778,  ..., -3.1539, -3.4459, -3.7378],
          [-1.7615, -1.7981, -1.8347,  ..., -2.9060, -3.2065, -3.5070]]]],
       grad_fn=<UpsampleBilinear2DBackward0>)

In [5]:
model

FPN(
  (encoder): ResNetEncoder(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): GroupNorm(32, 64, eps=1e-05, affine=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): GroupNorm(32, 64, eps=1e-05, affine=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): GroupNorm(32, 64, eps=1e-05, affine=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): GroupNorm(32, 256, eps=1e-05, affine=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): GroupNorm(32, 256, eps=1e-05, affine=True)
        )
      )
     

In [4]:
q = nn.BatchNorm2d(56)
q.num_features

56