In [1]:
import torch

In [2]:
class ConvNet(torch.nn.Module):
    class Block(torch.nn.Module):
        def __init__(self, n_input, n_output, stride=1):
            super().__init__()
            self.net = torch.nn.Sequential(
              torch.nn.Conv2d(n_input, n_output, kernel_size=3, padding=1, stride=stride),
              torch.nn.ReLU(),
              torch.nn.Conv2d(n_output, n_output, kernel_size=3, padding=1),
              torch.nn.ReLU()
            )
            torch.nn.init.xavier_normal_(self.net[0].weight)
            torch.nn.init.constant_(self.net[0].bias, 0.1)
        
        def forward(self, x):
            return self.net(x)
        
    def __init__(self, layers=[32,64,128], n_input_channels=3):
        super().__init__()
        L = [torch.nn.Conv2d(n_input_channels, 32, kernel_size=7, padding=3, stride=2),
             torch.nn.ReLU(),
             torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)]
        c = 32
        for l in layers:
            L.append(self.Block(c, l, stride=2))
            c = l
        self.network = torch.nn.Sequential(*L)
        self.classifier = torch.nn.Linear(c, 1)
        
        # Initialize the weights
        torch.nn.init.zeros_(self.classifier.weight)
        torch.nn.init.xavier_normal_(self.network[0].weight)
        torch.nn.init.constant_(self.network[0].bias, 0.1)
    
    def forward(self, x):
        # Compute the features
        z = self.network(x)
        # Global average pooling
        z = z.mean(dim=[2,3])
        # Classify
        return self.classifier(z)[:,0]

In [3]:
net = ConvNet()
print(net.network[0].weight, net.network[0].bias)
print()
print(net.classifier.weight, net.classifier.bias)

Parameter containing:
tensor([[[[ 1.7270e-02, -7.3183e-04, -3.1007e-02,  ...,  7.2382e-02,
           -5.6650e-02,  3.4035e-02],
          [ 3.4279e-02, -1.5811e-02, -4.0903e-02,  ..., -3.3977e-02,
            1.0024e-02,  5.9173e-02],
          [ 3.1137e-02,  1.8381e-02, -4.6479e-02,  ..., -3.4519e-02,
            3.6153e-02,  5.8218e-02],
          ...,
          [ 3.8489e-02,  7.1153e-03, -1.7250e-03,  ...,  2.4888e-02,
           -1.6550e-02, -5.4573e-04],
          [ 3.4834e-02, -3.4391e-02, -5.1466e-02,  ..., -4.7134e-02,
            8.0856e-02,  6.8783e-02],
          [ 1.4183e-02, -5.3272e-02,  2.7214e-03,  ...,  4.6807e-02,
            6.0609e-02, -4.9352e-02]],

         [[-3.0341e-02, -6.1002e-02, -2.6696e-02,  ...,  1.3062e-02,
            1.6921e-02, -1.7774e-02],
          [ 2.1336e-02, -4.2838e-02,  1.3234e-02,  ..., -7.4801e-03,
           -2.2554e-02, -1.7629e-02],
          [-4.9333e-03, -2.8675e-02, -1.5096e-02,  ..., -9.5340e-02,
           -3.4195e-02,  2.3809e-02]