In [1]:
import torch

In [2]:
class ConvNet(torch.nn.Module):
    class Block(torch.nn.Module):
        def __init__(self, n_input, n_output, stride=1):
            super().__init__()
            self.net = torch.nn.Sequential(
              torch.nn.Conv2d(n_input, n_output, kernel_size=3, padding=1, stride=stride),
              torch.nn.ReLU(),
              torch.nn.Conv2d(n_output, n_output, kernel_size=3, padding=1),
              torch.nn.ReLU()
            )
            torch.nn.init.xavier_normal_(self.net[0].weight)
            torch.nn.init.constant_(self.net[0].bias, 0.1)

        def forward(self, x):
            return self.net(x)

    def __init__(self, layers=[32,64,128], n_input_channels=3):
        super().__init__()
        L = [torch.nn.Conv2d(n_input_channels, 32, kernel_size=7, padding=3, stride=2),
             torch.nn.ReLU(),
             torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)]
        c = 32
        for l in layers:
            L.append(self.Block(c, l, stride=2))
            c = l
        self.network = torch.nn.Sequential(*L)
        self.classifier = torch.nn.Linear(c, 1)

        # Initialize the weights - DIFFERENT FROM BEFORE ########
        torch.nn.init.zeros_(self.classifier.weight) # Initialize last classification layer weights with zeros
        torch.nn.init.xavier_normal_(self.network[0].weight) # Initialize layer inside a sequential network (first conv --> initialize first layer in network)
        torch.nn.init.constant_(self.network[0].bias, 0.1) # Initialize biases of initial conv. (if you want second conv, call network[3])

    def forward(self, x):
        # Compute the features
        z = self.network(x)
        # Global average pooling
        z = z.mean(dim=[2,3])
        # Classify
        return self.classifier(z)[:,0]

In [3]:
net = ConvNet()
print(net.network[0].weight, net.network[0].bias)
print()
print(net.classifier.weight, net.classifier.bias)

Parameter containing:
tensor([[[[ 4.2394e-02,  2.2789e-02, -7.5122e-04,  ...,  3.8984e-02,
           -5.8399e-02, -1.7611e-02],
          [-1.5585e-02, -8.1914e-03,  3.0027e-02,  ..., -1.8647e-02,
           -6.8906e-02,  4.4575e-02],
          [-3.2298e-02,  3.9990e-03,  1.9327e-02,  ...,  3.4929e-02,
            3.1065e-02,  3.4080e-02],
          ...,
          [-1.5307e-02, -4.4499e-02,  6.3312e-03,  ...,  3.2467e-02,
           -1.9778e-02,  2.2951e-02],
          [-4.4020e-02,  3.2547e-02,  2.2710e-02,  ...,  1.3135e-02,
            4.1214e-02,  3.5381e-02],
          [ 6.4735e-03,  2.4502e-02,  1.0369e-01,  ...,  6.3564e-02,
            8.0561e-03,  9.2223e-03]],

         [[-4.8233e-02, -1.0506e-02,  1.2357e-03,  ..., -4.7673e-02,
            7.2925e-02, -1.0875e-02],
          [ 8.7330e-03,  1.0796e-02, -2.8819e-02,  ...,  3.9095e-02,
            3.4284e-02,  3.6296e-02],
          [ 8.7510e-04,  4.5300e-02, -6.9471e-03,  ..., -1.8138e-02,
           -2.5587e-02, -3.9576e-02]