In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchvision import datasets, transforms, models
from torch.autograd import Variable

# load data

In [2]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])


train_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('./data', train=True, download=False,
                   transform=transforms.Compose([
                       transforms.RandomResizedCrop(max((299, 299))),
                       transforms.RandomHorizontalFlip(),
                       transforms.ToTensor(),
                       normalize
                   ])),
    batch_size=65, num_workers=2, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('./data', train=False,
                   transform=transforms.Compose([
                       transforms.Resize(int(max((299, 299))/224*256)),
                       transforms.CenterCrop(max((299, 299))),
                       transforms.ToTensor(),
                       normalize
                   ])),
    batch_size=65, num_workers=2, shuffle=True)

# train

In [3]:
class BNConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, **kargs):
        super(BNConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=1e-3)
    
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = F.relu(x, inplace=True)
        return x

In [4]:
class BASE(nn.Module):
    def __init__(self):
        super(BASE, self).__init__()
        self.conv2d_bn_1 = BNConv2d(3, 32, kernel_size=3, stride=2)
        self.conv2d_bn_2 = BNConv2d(32, 32, kernel_size=3)
        self.conv2d_bn_3 = BNConv2d(32, 64, kernel_size=3, padding=1)
        self.conv2d_bn_4 = BNConv2d(64, 80, kernel_size=1)
        self.conv2d_bn_5 = BNConv2d(80, 192, kernel_size=3)

        
    def forward(self, x):
        x = self.conv2d_bn_1(x)
        x = self.conv2d_bn_2(x)
        x = self.conv2d_bn_3(x)
        x = F.max_pool2d(x, kernel_size=3, stride=2)
        
        x = self.conv2d_bn_4(x)
        x = self.conv2d_bn_5(x)
        x = F.max_pool2d(x, kernel_size=3, stride=2)
        return x

In [5]:
class NIN_A(nn.Module):
    def __init__(self, in_channels, out_channels_of_pool):
        super(NIN_A, self).__init__()
        self.branch1x1 = BNConv2d(in_channels, 64, kernel_size=1)
        
        self.branch5x5 = nn.Sequential(BNConv2d(in_channels, 48, kernel_size=1),
                                       BNConv2d(48, 64, kernel_size=5, padding=2))
        
        self.branch3x3dbl = nn.Sequential(BNConv2d(in_channels, 64, kernel_size=1),
                                          BNConv2d(64, 96, kernel_size=3, padding=1),
                                          BNConv2d(96, 96, kernel_size=3, padding=1))
        
        self.branch_pool = nn.Sequential(nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
                                         BNConv2d(in_channels, out_channels_of_pool, kernel_size=1))
        
    def forward(self, x):
        out = [self.branch1x1(x),
               self.branch5x5(x), 
               self.branch3x3dbl(x),
               self.branch_pool(x)]
        return torch.cat(out, 1)

In [6]:
class NIN_B(nn.Module):
    def __init__(self, in_channels):
        super(NIN_B, self).__init__()
        self.branch3x3 = BNConv2d(in_channels, 384, kernel_size=3, stride=2)
        
        self.branch3x3dbl = nn.Sequential(BNConv2d(in_channels, 64, kernel_size=1),
                                          BNConv2d(64, 96, kernel_size=3, padding=1),
                                          BNConv2d(96, 96, kernel_size=3, stride=2))
        
    def forward(self, x):
        out = [self.branch3x3(x),
               self.branch3x3dbl(x),
               F.max_pool2d(x, kernel_size=3, stride=2)]
        return torch.cat(out, 1)

In [7]:
class NIN_C(nn.Module):
    def __init__(self, in_channels, in_channels_7x7):
        super(NIN_C, self).__init__()
        self.branch1x1 = BNConv2d(in_channels, 192, kernel_size=1)
        
        self.branch7x7 = nn.Sequential(BNConv2d(in_channels, in_channels_7x7, kernel_size=1),
                                       BNConv2d(in_channels_7x7, in_channels_7x7,
                                                kernel_size=(1,7), padding=(0,3)),
                                       BNConv2d(in_channels_7x7, 192,
                                                kernel_size=(7,1), padding=(3,0)))
        
        self.branch7x7dbl = nn.Sequential(BNConv2d(in_channels, in_channels_7x7, kernel_size=1),
                                          BNConv2d(in_channels_7x7, in_channels_7x7,
                                                   kernel_size=(7,1), padding=(3,0)),
                                          BNConv2d(in_channels_7x7, in_channels_7x7,
                                                   kernel_size=(1,7), padding=(0,3)),
                                          BNConv2d(in_channels_7x7, in_channels_7x7,
                                                   kernel_size=(7,1), padding=(3,0)),
                                          BNConv2d(in_channels_7x7, 192,
                                                   kernel_size=(1,7), padding=(0,3)))
        
        self.branch_pool = nn.Sequential(nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
                                         BNConv2d(in_channels, 192, kernel_size=1))
        
    def forward(self, x):
        out = [self.branch1x1(x),
               self.branch7x7(x), 
               self.branch7x7dbl(x),
               self.branch_pool(x)]
        return torch.cat(out, 1)

In [8]:
class NIN_D(nn.Module):
    def __init__(self, in_channels):
        super(NIN_D, self).__init__()
        self.branch3x3 = nn.Sequential(BNConv2d(in_channels, 192, kernel_size=1),
                                       BNConv2d(192, 320, kernel_size=3, stride=2))
        
        self.branch7x7x3 = nn.Sequential(BNConv2d(in_channels, 192, kernel_size=1),
                                         BNConv2d(192, 192, kernel_size=(1,7), padding=(0,3)),
                                         BNConv2d(192, 192, kernel_size=(7,1), padding=(3,0)),
                                         BNConv2d(192, 192, kernel_size=3, stride=2))
        
    def forward(self, x):
        out = [self.branch3x3(x),
               self.branch7x7x3(x),
               F.max_pool2d(x, kernel_size=3, stride=2)]
        return torch.cat(out, 1)

In [9]:
class NIN_E(nn.Module):
    def __init__(self, in_channels):
        super(NIN_E, self).__init__()
        self.branch1x1 = BNConv2d(in_channels, 320, kernel_size=1)
        
        self.branch3x3_1 = BNConv2d(in_channels, 384, kernel_size=1)
        self.branch3x3_2a = BNConv2d(384, 384, kernel_size=(1,3), padding=(0,1))
        self.branch3x3_2b = BNConv2d(384, 384, kernel_size=(3,1), padding=(1,0))
        
        self.branch3x3dbl_1 = BNConv2d(in_channels, 448, kernel_size=1)
        self.branch3x3dbl_2 = BNConv2d(448, 384, kernel_size=3, padding=1)
        self.branch3x3dbl_3a = BNConv2d(384, 384, kernel_size=(1,3), padding=(0,1))
        self.branch3x3dbl_3b = BNConv2d(384, 384, kernel_size=(3,1), padding=(1,0))
        
        self.branch_pool = nn.Sequential(nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
                                         BNConv2d(in_channels, 192, kernel_size=1))
        
    def forward(self, x):   
        branch3x3 = self.branch3x3_1(x)
        branch3x3 = [self.branch3x3_2a(branch3x3), self.branch3x3_2b(branch3x3)]
        branch3x3 = torch.cat(branch3x3, 1)
        
        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = [self.branch3x3dbl_3a(branch3x3dbl), self.branch3x3dbl_3b(branch3x3dbl)]
        branch3x3dbl = torch.cat(branch3x3dbl, 1)
        
        out = [self.branch1x1(x),
               branch3x3,
               branch3x3dbl,
               self.branch_pool(x)]
        return torch.cat(out, 1)

In [10]:
class Inception(nn.Module):
    def __init__(self):
        super(Inception, self).__init__()
        self.base = BASE()
        self.mixed0 = NIN_A(in_channels=192, out_channels_of_pool=32)
        self.mixed1 = NIN_A(in_channels=256, out_channels_of_pool=64)
        self.mixed2 = NIN_A(in_channels=288, out_channels_of_pool=64)

        self.mixed3 = NIN_B(in_channels=288)

        self.mixed4 = NIN_C(768, in_channels_7x7=128)
        self.mixed5 = NIN_C(768, in_channels_7x7=160)
        self.mixed6 = NIN_C(768, in_channels_7x7=160)
        self.mixed7 = NIN_C(768, in_channels_7x7=192)

        self.mixed8 = NIN_D(768)

        self.mixed9 = NIN_E(1280)
        self.mixed10 = NIN_E(2048)
        self.fc = nn.Linear(2048, 1000)

    def forward(self, x):
        x = self.base(x)  # 35 x 35 x 192

        x = self.mixed0(x)  # 35 x 35 x 256
        x = self.mixed1(x)  # 35 x 35 x 288
        x = self.mixed2(x)  # 35 x 35 x 288

        x = self.mixed3(x)  # 17 x 17 x 768

        x = self.mixed4(x)  # 17 x 17 x 768
        x = self.mixed5(x)  # 17 x 17 x 768
        x = self.mixed6(x)  # 17 x 17 x 768
        x = self.mixed7(x)  # 17 x 17 x 768

        x = self.mixed8(x)  # 8 x 8 x 1280

        x = self.mixed9(x)  # 8 x 8 x 2048
        x = self.mixed10(x)  # 8 x 8 x 2048

        x = F.avg_pool2d(x, kernel_size=8)
        x = F.dropout(x, training=self.training)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

In [11]:
model = Inception().cuda()
model

Inception(
  (base): BASE(
    (conv2d_bn_1): BNConv2d(
      (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
      (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True)
    )
    (conv2d_bn_2): BNConv2d(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True)
    )
    (conv2d_bn_3): BNConv2d(
      (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True)
    )
    (conv2d_bn_4): BNConv2d(
      (conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True)
    )
    (conv2d_bn_5): BNConv2d(
      (conv): Conv2d(80, 192, kernel_size=(3, 3), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True)
    )
  )
  (mixed0): NIN_A(
    (branch1x1): BNConv2d(
      (conv): Conv2d(

In [23]:
optimizer = optim.SGD(model.parameters(), lr=0.001)
loss_func = nn.CrossEntropyLoss()
loss_func.cuda()

def train(train_loader):
    model.train()
    train_loss = 0
    for batch_idx, (x, y) in enumerate(train_loader):
        x, y = x.cuda(), y.cuda()
        b_x = Variable(x)
        b_y = Variable(y)
        
        optimizer.zero_grad()
        outputs = model(b_x)
        
        loss = loss_func(outputs, b_y)
        train_loss += loss.data[0]

        loss.backward()
        optimizer.step()

        if batch_idx == 100:
            break

    train_loss = train_loss / len(train_loader)
    
    return train_loss

def valid(test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    for i, (x, y) in enumerate(test_loader):
        x, y = x.cuda(), y.cuda()
        b_x = Variable(x, volatile=True)
        b_y = Variable(y, volatile=True)
        
        outputs = model(b_x)
        loss = loss_func(outputs, b_y)
        test_loss += loss.data[0]
        
        pred = outputs.data.max(1, keepdim=True)[1]
        correct += pred.eq(b_y.data.view_as(pred)).cpu().sum()
        total += b_y.size(0)

        if i == 10:
            break

    val_loss = test_loss / len(test_loader)
    val_acc = correct / total
    
    return val_loss, val_acc

In [24]:
epochs = 1
loss_list = []
val_loss_list = []
val_acc_list = []
for epoch in range(epochs):
    loss = train(train_loader)
    val_loss, val_acc = valid(test_loader)

    print('epoch %d, loss: %.4f val_loss: %.4f val_acc: %.4f'
          % (epoch, loss, val_loss, val_acc))
    
    # logging
    loss_list.append(loss)
    val_loss_list.append(val_loss)
    val_acc_list.append(val_acc)

epoch 0, loss: 0.5273 val_loss: 0.2450 val_acc: 0.1986
