In [29]:
import torch.nn as nn
import torch
import numpy as np

class NIN(nn.Module):
    def __init__(self, num_classes):
        super(NIN, self).__init__()
        self.num_classes = num_classes

        self.classifier = nn.Sequential(
                nn.Conv2d(3, 192, kernel_size=5, stride=1, padding=2),
                nn.ReLU(inplace=True),
                nn.Conv2d(192, 160, kernel_size=1, stride=1, padding=0),
                nn.ReLU(inplace=True),
                nn.Conv2d(160,  96, kernel_size=1, stride=1, padding=0),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
                nn.Dropout(0.5),

                nn.Conv2d(96, 192, kernel_size=5, stride=1, padding=2),
                nn.ReLU(inplace=True),
                nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0),
                nn.ReLU(inplace=True),
                nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0),
                nn.ReLU(inplace=True),
                nn.AvgPool2d(kernel_size=3, stride=2, padding=1),
                nn.Dropout(0.5),

                nn.Conv2d(192, 192, kernel_size=3, stride=1, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0),
                nn.ReLU(inplace=True),
                nn.Conv2d(192,  self.num_classes, kernel_size=1, stride=1, padding=0),
                nn.ReLU(inplace=True),
                nn.AvgPool2d(kernel_size=8, stride=1, padding=0),          
        )
        
        self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((1, 1))
        
        x = torch.randn(1,3,224,224)
        self._to_linear = None
        self.convs(x)
        
        self.fc = nn.Linear(self._to_linear, 12)
        
    def convs(self, x):
        x = self.classifier(x)
        if self._to_linear == None:
            self._to_linear = x[0].shape[0]*x[0].shape[1]*x[0].shape[2]
        return x

    def forward(self, x):
        #x = self.convs(x)
        #x = torch.flatten(x,1)
        #x = self.fc(x)
        x = self.classifier(x)
        x = self.AdaptiveAvgPool(x)
        x = torch.flatten(x,1)
        #x = x.view(x.size(0), self.num_classes)
        return x
    
class NIN_v2(nn.Module):
    def __init__(self, num_classes):
        super(NIN_v2, self).__init__()
        self.num_classes = num_classes
        
        self.conv1_1 = nn.Conv2d(3, 192, kernel_size=5, stride=1, padding=2)
        self.conv1_2 = nn.Conv2d(192, 160, kernel_size=1, stride=1, padding=0)
        self.conv1_3 = nn.Conv2d(160,  96, kernel_size=1, stride=1, padding=0)
        
        self.conv2_1 = nn.Conv2d(96, 192, kernel_size=5, stride=1, padding=2)
        self.conv2_2_3 = nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0)
        
        self.conv3_1 = nn.Conv2d(192, 192, kernel_size=3, stride=1, padding=1)
        self.conv3_2 = nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0)
        self.conv3_3 = nn.Conv2d(192,  self.num_classes, kernel_size=1, stride=1, padding=0)
        
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.avgpool_1 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
        self.avgpool_2 = nn.AvgPool2d(kernel_size=8, stride=1, padding=0)
        
        self.activation = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(0.5)
        
    def forward(self, x):
        x = self.conv1_1(x)
        print(x.size())
        x = self.activation(x)
        x = self.conv1_2(x)
        print(x.size())
        x = self.activation(x)
        x = self.conv1_3(x)
        print(x.size())
        x = self.activation(x)
        x = self.maxpool(x)
        print(x.size())
        x = self.dropout(x)
        x = self.conv2_1(x)
        print(x.size())
        x = self.activation(x)
        x = self.conv2_2_3(x)
        print(x.size())
        x = self.activation(x)
        x = self.conv2_2_3(x)
        print(x.size())
        x = self.activation(x)
        x = self.avgpool_1(x)
        print(x.size())
        x = self.dropout(x)
        x = self.conv3_1(x)
        print(x.size())
        x = self.activation(x)
        x = self.conv3_2(x)
        print(x.size())
        x = self.activation(x)
        x = self.conv3_3(x)
        print(x.size())
        x = self.activation(x)
        x = self.avgpool_2(x)
        print(x.size())
        x = torch.flatten(x,1)
        return x

        
    
img = torch.Tensor(np.random.rand(1,3,224,224))
nin = NIN(12)
#print(nin._to_linear)
out = nin(img)
print(out.size())

torch.Size([1, 12])
