In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.optim as optim

In [None]:
# Vamos a implementar la arquitectura "Inception" empleada en GoogleLeNet
# Al tratarse de una arquitectura más amplia vamos a modularizar en clases
class CNN_Bloque(nn.Module):

    def __init__(self, num_entradas, num_salidas, kernel_size, stride, padding):
        
        super(CNN_Bloque, self).__init__()
        
        self.activacion = nn.ReLU()

        self.conv = nn.Conv2d(
            in_channels = num_entradas, out_channels = num_salidas, 
            kernel_size = kernel_size, stride = stride, padding = padding
            )
        
        self.norm = nn.BatchNorm2d(num_features = num_salidas)

    def forward(self, x):

        x = self.norm(self.conv(x))
        x = self.activacion(x)

        return x

class Inception_Bloque(nn.Module):

    # En el paper, en la tabla 1 se encuentra la estructura de la arquitectura
    # En ellas se enceuntran las salidas, vamos a utilizar esas etiquetas para indicar las
    # las salidas
    def __init__(self, num_entradas, sal_1x1, sal_3x3_red, sal_3x3, sal_5x5_red, sal_5x5, sal_1x1_pool):
        
        super(Inception_Bloque, self).__init__()

        # Modulo Inception con reducción en la dimension
        # Son 4 ramas, interpretación de izquierda a derecha (rama 1, ..., rama 4)

        self.rama1 = CNN_Bloque(
            num_entradas = num_entradas, num_salidas= sal_1x1, 
            kernel_size = (1, 1)
        )

        self.rama2 = nn.Sequential(
            CNN_Bloque(num_entradas = num_entradas, num_salidas = sal_3x3_red, kernel_size = (1, 1)),
            CNN_Bloque(num_entradas = sal_3x3_red, num_salidas = sal_3x3, kernel_size = (3, 3), padding = 1)
        )

        self.rama3 = nn.Sequential(
            CNN_Bloque(num_entradas = num_entradas, num_salidas = sal_5x5_red, kernel_size = (1, 1)),
            CNN_Bloque(num_entradas = sal_5x5_red, num_salidas = sal_5x5, kernel_size = (5, 5), padding = 2)
        )

        self.rama4 = nn.Sequential(
            nn.MaxPool2d(kernel_size = (3, 3), stride = 1, padding = 1),
            CNN_Bloque(num_entradas = num_entradas, num_salidas = sal_1x1_pool, kernel_size = (1, 1))
        )

    def forward(self, x):

        # Tenemos que concatenar el resultado de cada rama
        rama1 = self.rama1(x)
        rama2 = self.rama2(x)
        rama3 = self.rama3(x)
        rama4 = self.rama4(x)

        filtro_concatenacion = torch.cat([rama1, rama2, rama3, rama4], dim = 1)

        return filtro_concatenacion

class GoogleLeNet(nn.Module):

    def __init__(self, num_entradas):

        super(GoogleLeNet, self).__init__()

        self.conv1 = CNN_Bloque(num_entradas = num_entradas, num_salidas = 64, kernel_size = (7, 7), stride = (2, 2), padding = (3, 3))
        self.pool1 = nn.MaxPool2d(kernel_size = (3, 3), stride = (2, 2), padding = (1, 1))

        self.conv2 = CNN_Bloque(num_entradas = 64, num_salidas = 192, kernel_size = (3, 3), stride = (1, 1), padding = (1, 1))
        self.pool2 = self.pool1

        self.incep3a = Inception_Bloque(num_entradas = 192, sal_1x1 = 64, sal_3x3_red = 96, sal_3x3 = 128, sal_5x5_red = 16, sal_5x5 = 32, sal_1x1_pool = 32)
        self.incep3b = Inception_Bloque(num_entradas = 256, sal_1x1 = 128, sal_3x3_red = 128, sal_3x3 = 192, sal_5x5_red = 32, sal_5x5 = 96, sal_1x1_pool = 64)
        self.pool3 = self.pool1

        self.incep4a = Inception_Bloque(num_entradas = 480, sal_1x1 = 192, sal_3x3_red = 96, sal_3x3 = 208, sal_5x5_red = 16, sal_5x5 = 48, sal_1x1_pool = 64)
        self.incep4b = Inception_Bloque(num_entradas = 512, sal_1x1 = 160, sal_3x3_red = 112, sal_3x3 = 224, sal_5x5_red = 24, sal_5x5 = 64, sal_1x1_pool = 64)
        self.incep4c = Inception_Bloque(num_entradas = 512, sal_1x1 = 128, sal_3x3_red = 128, sal_3x3 = 256, sal_5x5_red = 24, sal_5x5 = 64, sal_1x1_pool = 64)
        self.incep4d = Inception_Bloque(num_entradas = 512, sal_1x1 = 112, sal_3x3_red = 144, sal_3x3 = 288, sal_5x5_red = 32, sal_5x5 = 64, sal_1x1_pool = 64)
        self.incep4e = Inception_Bloque(num_entradas = 528, sal_1x1 = 256, sal_3x3_red = 160, sal_3x3 = 320, sal_5x5_red = 32, sal_5x5 = 128, sal_1x1_pool = 128)
        self.pool4 = self.pool1

        self.incep5a = Inception_Bloque(num_entradas = 832, sal_1x1 = 256, sal_3x3_red = 160, sal_3x3 = 320, sal_5x5_red = 32, sal_5x5 = 128, sal_1x1_pool = 128)
        self.incep5b = Inception_Bloque(num_entradas = 832, sal_1x1 = 384, sal_3x3_red = 192, sal_3x3 = 384, sal_5x5_red = 48, sal_5x5 = 128, sal_1x1_pool = 128)
        
        self.avg_pool = nn.AvgPool2d(kernel_size = (7, 7), stride = (1, 1))
    
        self.drop_out = nn.Dropout(p = 0.4)

        self.fc = nn.Linear(in_features = 1024, out_features = 1000)

    def forward(self, x):

        x = self.pool1(self.conv1(x))
        
        x = self.pool2(self.conv2(x))

        x = self.incep3a(x)
        x = self.pool3(self.incep3b(x))

        x = self.incep4a(x)
        x = self.incep4b(x)
        x = self.incep4c(x)
        x = self.incep4d(x)
        x = self.pool4(self.incep3e(x))

        x = self.incep5a(x)
        x = self.avg_pool(self.incep5b(x))

        x = x.reshape(x.shape[0], -1)

        x = self.drop_out(x)

        x = self.fc(x)

        return x 