<center><h1>NiN</h1></center>

<center><p><a href="http://arxiv.org/abs/1312.4400">Network In Network</a></p></center>

<img src="https://gojay.top/gallery/thumbnails/NIN.png" width="1000"/>

In [1]:
import torch
from torch import nn

# NiN Block

In [2]:
def nin_block(in_channels, out_channels, **kwargs):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, **kwargs),
        nn.ReLU(),
        nn.Conv2d(out_channels, out_channels, kernel_size=1),
        nn.ReLU(),
        nn.Conv2d(out_channels, out_channels, kernel_size=1),
        nn.ReLU(),
    )

# NiN Network

In [3]:
class NiN(nn.Module):
    def __init__(self, num_classes=1000, init_weights=True, dropout=0.5):
        super().__init__()
        self.features = nn.Sequential(
            nin_block(3, 96, kernel_size=11, stride=4, padding=0),  # 224 -> 54
            nn.MaxPool2d(kernel_size=3, stride=2),  # 54 -> 26
            nin_block(96, 256, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(kernel_size=3, stride=2),  # 26 -> 12
            nin_block(256, 384, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(kernel_size=3, stride=2),  # 12 -> 5
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=dropout),
            nin_block(384, num_classes, kernel_size=3, stride=1, padding=1),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.constant_(m.bias, 0)

# Summary

## Data

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

data = torch.randn((32, 3, 224, 224)).to(device)

## NiN

In [5]:
from torchkeras import summary

net = NiN(num_classes=1000).to(device)

summary(net, input_data=data)
del net

--------------------------------------------------------------------------
Layer (type)                            Output Shape              Param #
Conv2d-1                            [-1, 96, 54, 54]               34,944
ReLU-2                              [-1, 96, 54, 54]                    0
Conv2d-3                            [-1, 96, 54, 54]                9,312
ReLU-4                              [-1, 96, 54, 54]                    0
Conv2d-5                            [-1, 96, 54, 54]                9,312
ReLU-6                              [-1, 96, 54, 54]                    0
MaxPool2d-7                         [-1, 96, 26, 26]                    0
Conv2d-8                           [-1, 256, 26, 26]              614,656
ReLU-9                             [-1, 256, 26, 26]                    0
Conv2d-10                          [-1, 256, 26, 26]               65,792
ReLU-11                            [-1, 256, 26, 26]                    0
Conv2d-12                          [-