# Network in Network

In [2]:
import torch
from torch import nn

To enhance model discriminability for local patches within the receptive field (i.e., across output channels), researchers proposed to apply an MLP on the channel for each pixel individually  

A clever implementation trick is to apply a convolution layer of kernel size 1 to the output

<center>
    <img src='images/nin.svg' width=50% style="margin-left:auto; margin-right:auto"/>
    <p style="font-size:14px;">Source: <a href='http://d2l.ai/'>D2L</a></p>
</center>

In [4]:
def nin_block(in_channels, out_channels, kernel_size, strides, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, strides, padding),
        nn.ReLU(), nn.Conv2d(out_channels, out_channels, kernel_size=1),
        nn.ReLU(), nn.Conv2d(out_channels, out_channels, kernel_size=1),
        nn.ReLU())

The 1x1 convolution helps to share information across the output channel (i.e., the extracted feature) while having a very low cost (few extra parameters)

An extra specificity of *NiN* is the absence of a fully connected layer at the output
The number of output channels of the last block is equal to the number of classes. Then a global average pooling layer is applied to each output channel, producing logits