In [1]:
%run utils

In [None]:
import torch
from utils import ConvRelu
from torch import nn

class Inception(nn.Module):
    def __init__(self, in_channels, ch1x1, ch3x3, ch5x5, pool_proj):
        super().__init__()
        self.branch1 = ConvRelu(in_channels, ch1x1, conv_size=1)

        self.branch2 = nn.Sequential(
            ConvRelu(in_channels, ch3x3[0], conv_size=1),
            ConvRelu(ch3x3[0], ch3x3[1], conv_size=3, padding=1)
        )

        self.branch3 = nn.Sequential(
            ConvRelu(in_channels, ch5x5[0], conv_size=1),
            ConvRelu(ch5x5[0], ch5x5[1], conv_size=5, padding=2)
        )

        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            ConvRelu(in_channels, pool_proj, conv_size=1)
        )

    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)
        return torch.cat([branch1, branch2, branch3, branch4], 1)