In [1]:
import torch
import torch.nn as nn

class LinearNet(nn.Module):
    def __init__(self):
        super(LinearNet, self).__init__()

        self.feature_map = [
            nn.Conv2d(3, 64, 3, 1, 1),
            nn.ReLU(),
            nn.AvgPool2d(2, 2),
            nn.Conv2d(64, 128, 3, 1, 1),
            nn.ReLU(),
            nn.AvgPool2d(2, 2),
            nn.Flatten()
        ]

        self.classifier = [
            nn.Linear(128 * 8 * 8 + 10, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        ]

        self.feature_map = nn.Sequential(*self.feature_map)
        self.classifier = nn.Sequential(*self.classifier)

    def forward(self, x):
        x = self.feature_map(x)
        x = torch.cat([x, torch.randn(x.size(0), 10)], dim=1)
        x = self.classifier(x)
        return x

model = LinearNet()

print(model)

random_dataset = torch.randn(32, 3, 32, 32)

output = model(random_dataset)


LinearNet(
  (feature_map): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (6): Flatten(start_dim=1, end_dim=-1)
  )
  (classifier): Sequential(
    (0): Linear(in_features=8202, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=10, bias=True)
  )
)
