In [3]:
import lightly

net = lightly.models.SimSiam(512, num_ftrs=512)#, num_mlp_layers=3)

net



SimSiam(
  (projection_mlp): SimSiamProjectionHead(
    (layers): Sequential(
      (0): Linear(in_features=512, out_features=2048, bias=True)
      (1): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Linear(in_features=2048, out_features=2048, bias=True)
      (4): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
      (6): Linear(in_features=2048, out_features=2048, bias=True)
      (7): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (prediction_mlp): SimSiamPredictionHead(
    (layers): Sequential(
      (0): Linear(in_features=2048, out_features=512, bias=True)
      (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Linear(in_features=512, out_features=2048, bias=True)
    )
  )
)

In [13]:
import lightly
import torch

def make_resnet_backbone(
    backbone_in_ch: int,
    backbone_type: str):

    if backbone_type == 'resnet18':
        net = lightly.models.resnet.ResNetGenerator(name='resnet-18', num_classes=10)
    elif backbone_type == 'resnet34':
        net = lightly.models.resnet.ResNetGenerator(name='resnet-34', num_classes=10)
    elif backbone_type == 'resnet50':
        net = lightly.models.resnet.ResNetGenerator(name='resnet-50', num_classes=10)
    else:
        raise Exception("backbone_type not recognized. Received ->", backbone_type)

    first_conv_out_channels = list(net.children())[0].out_channels

    new_net = torch.nn.Sequential(
        torch.nn.Conv2d(
            backbone_in_ch,
            first_conv_out_channels,
            kernel_size=(3, 3),
            stride=(1, 1),
            padding=(1, 1),
            bias=False,
        ),
        *list(net.children())[1:-1],  # note that I am excluding the last_fc_layer.
        #torch.nn.AdaptiveAvgPool2d(1),  # adding adaptive pooling
        # torch.nn.Flatten(start_dim=1)  # adding flattening
    )
    return new_net

net = make_resnet_backbone(backbone_in_ch=9, backbone_type='resnet18')

In [14]:
net

Sequential(
  (0): Conv2d(9, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, t

In [18]:
x_in = torch.zeros(1,9,16,16)
x_out = net(x_in)
print(x_out.shape)

torch.Size([1, 512, 2, 2])
