In [0]:
import numpy as np
import torch
from torch import nn
from torch.autograd import Variable
from torchvision.datasets import CIFAR10

In [0]:
def conv_relu(in_channel, out_channel, kernel, stride=1, padding=0):
    layer = nn.Sequential(
        nn.Conv2d(in_channel, out_channel, kernel, stride, padding),
        nn.BatchNorm2d(out_channel, eps=1e-3),
        nn.ReLU(True)
    )
    return layer

In [0]:
class inception(nn.Module):
    def __init__(self, in_channel, out1_1, out2_1, out2_3, out3_1, out3_5, out4_1):
        super(inception,self).__init__()
        #3 * 96 * 96 ----> 96 * 96 * 64
        self.branchx1 = conv_relu(in_channel, out1_1,1)
        
        self.branchx2 = nn.Sequential(
            #3 * 96 * 96 -----> 96 * 96 * 48
            conv_relu(in_channel, out2_1, 1),
            #96 * 96 * 48 ------> 96 * 96 * 64
            conv_relu(out2_1, out2_3,3,padding=1)
        )
        
        self.branchx3 = nn.Sequential(
            #3 * 96 * 96 ------> 96 * 96 * 64
            conv_relu(in_channel,out3_1,1),
            #96 * 96 * 64 ------> 96 * 96 * 96
            conv_relu(out3_1, out3_5, 5, padding=2)
        )
        
        self.branch_pool = nn.Sequential(
            #3 * 96 * 96 ------> 3 * 96 * 96
            nn.MaxPool2d(3, stride=1, padding=1),
            #3 * 96 * 96 ------> 96 * 96 * 32
            conv_relu(in_channel, out4_1, 1)
        )
        
    def forward(self, x):
        f1 = self.branchx1(x)
        f2 = self.branchx2(x)
        f3 = self.branchx3(x)
        f4 = self.branch_pool(x)
        # 64 + 64 + 96+ 32 = 256
        output = torch.cat((f1,f2,f3,f4),dim=1)
        
        return output

In [4]:
test_net = inception(3,64,48,64,64,96,32)
test_x = Variable(torch.zeros(1,3,96,96))
print('input shape:{} x {} x {}'.format(test_x.shape[1], test_x.shape[2], test_x.shape[3]))
test_y = test_net(test_x)
print('output shape:{} x {} x {}'.format(test_y.shape[1], test_y.shape[2], test_y.shape[3]))

input shape:3 x 96 x 96
output shape:256 x 96 x 96


In [0]:
class googlenet(nn.Module):
    def __init__(self, in_channel, num_classes, verbose=False):
        super(googlenet, self).__init__()
        self.verbose = verbose
        
        self.block1 = nn.Sequential(
            conv_relu(in_channel, out_channel=64, kernel=7, stride = 2, padding=3),
            nn.MaxPool2d(3,2)
        )
        
        self.block2 = nn.Sequential(
            conv_relu(64, 64, kernel=1),
            conv_relu(64, 192, kernel=3, padding = 1),
            nn.MaxPool2d(3,2)
        )
        
        self.block3 = nn.Sequential(
            inception(192,64,96,128,16,32,32),
            inception(256,128,128,192,32,96,64),
            nn.MaxPool2d(3,2)
        )
        
        self.block4 = nn.Sequential(
            inception(480, 192, 96,208,16,48,64),
            inception(512,160,112,224,24,64,64),
            inception(512,128,128,256,24,64,64),
            inception(512,112,144,288,32,64,64),
            inception(528,256,160,320,32,128,128),
            nn.MaxPool2d(3,2)
        )
        
        self.block5 = nn.Sequential(
            inception(832,256,160,320,32,128,128),
            inception(832,384,182,384,48,128,128),
            nn.AvgPool2d(2)
        )
        
        self.classifier = nn.Linear(1024, num_classes)
        
    def forward(self, x):
        x = self.block1(x)
        if self.verbose:
            print('block 1 output:{}'.format(x.shape))
        x = self.block2(x)
        if self.verbose:
            print('block 2 output:{}'.format(x.shape))
        x = self.block3(x)
        if self.verbose:
            print('block 3 output:{}'.format(x.shape))
        x = self.block4(x)
        if self.verbose:
            print('block 4 output:{}'.format(x.shape))
        x = self.block5(x)
        if self.verbose:
            print('block 5 output:{}'.format(x.shape))
        x = x.view(x.shape[0], -1)
        x = self.classifier(x)
        
        return x
        

In [11]:
test_net = googlenet(3,10,True)
test_x = Variable(torch.zeros(1,3,96,96))
test_y = test_net(test_x)
print('output:{}'.format(test_y.shape))

block 1 output:torch.Size([1, 64, 23, 23])
block 2 output:torch.Size([1, 192, 11, 11])
block 3 output:torch.Size([1, 480, 5, 5])
block 4 output:torch.Size([1, 832, 2, 2])
block 5 output:torch.Size([1, 1024, 1, 1])
output:torch.Size([1, 10])


In [12]:
from Myutils import train

def data_tf(x):
    x = x.resize((96,96),2)
    x = np.array(x, dtype='float32') / 255
    x = (x - 0.5) / 0.5
    x = x.transpose((2,0,1))
    x = torch.from_numpy(x)
    return x

train_set = CIFAR10('./data', train=True, transform=data_tf ,download=True)
train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)

test_set = CIFAR10('./data', train=False, transform=data_tf)
test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)

net = googlenet(3,10)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

Files already downloaded and verified


In [13]:
train(net, train_data, test_data, 20, optimizer, criterion)

  im = Variable(im.cuda(), volatile=True)
  label = Variable(label.cuda(), volatile=True)


Epoch 0. Train Loss: 1.540103, Train Acc: 0.436681, Valid Loss: 1.662348, Valid Acc: 0.420886, Time 00:02:30
Epoch 1. Train Loss: 1.079803, Train Acc: 0.616568, Valid Loss: 1.330100, Valid Acc: 0.539359, Time 00:02:42
Epoch 2. Train Loss: 0.863344, Train Acc: 0.697530, Valid Loss: 1.193356, Valid Acc: 0.598794, Time 00:02:40
Epoch 3. Train Loss: 0.700412, Train Acc: 0.757573, Valid Loss: 0.910633, Valid Acc: 0.686017, Time 00:02:42
Epoch 4. Train Loss: 0.584154, Train Acc: 0.800352, Valid Loss: 0.868417, Valid Acc: 0.704905, Time 00:02:40
Epoch 5. Train Loss: 0.484561, Train Acc: 0.832920, Valid Loss: 1.139949, Valid Acc: 0.635581, Time 00:02:41
Epoch 6. Train Loss: 0.411403, Train Acc: 0.858556, Valid Loss: 1.082032, Valid Acc: 0.655854, Time 00:02:40
Epoch 7. Train Loss: 0.347685, Train Acc: 0.880954, Valid Loss: 0.825741, Valid Acc: 0.735364, Time 00:02:41
Epoch 8. Train Loss: 0.288629, Train Acc: 0.900416, Valid Loss: 0.802042, Valid Acc: 0.754153, Time 00:02:40
Epoch 9. Train Loss