In [1]:
# ! pip install torchsummary -q

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from datetime import datetime
from matplotlib import pyplot as plt
from time import time
from torch.utils.data import DataLoader
from torchsummary import summary

torch.__version__

'1.7.1'

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [3]:
# ! pip install wandb -qqq
import wandb
wandb.login()

PROJECT = 'INCEPTION_{}'.format(datetime.now().strftime('%y%m%d'))
PROJECT

Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize


[34m[1mwandb[0m: Paste an API key from your profile and hit enter:  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


'INCEPTION_210106'

In [5]:
batch_size = 64

train_dataset = datasets.MNIST(root='data/', train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(train_dataset, batch_size, shuffle=True)
test_dataset = datasets.MNIST(root='data/', train=False, transform=transforms.ToTensor(), download=True)
test_loader = DataLoader(test_dataset, batch_size, shuffle=True)

5.3%

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100.1%

Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw


113.5%

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw
Processing...
Done!


In [49]:
def accuracy(model, data_loader):
    num_correct = 0
    num_samples = 0
    model.eval()
    with torch.no_grad():
        for batch_idx, (x, y) in enumerate(data_loader):
            x, y = x.to(device), y.to(device)
            y_pred = model(x).argmax(1)
            num_correct += (y == y_pred).sum()
            num_samples += y.size(0)
    acc = num_correct / num_samples
    return acc

def train(run_name, model, data_loader, num_epochs=30, learning_rate=1e-3):
    name = '{}_{}'.format(run_name, datetime.now().strftime('%H%M%S'))
    wandb.init(project=PROJECT, name=name)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=5, verbose=True)
    loss_fun = nn.CrossEntropyLoss()
    
    for epoch in range(num_epochs):
        model.train()
        start_time = time()
        losses = []
        for batch_idx, (x, y) in enumerate(data_loader):
            x, y = x.to(device), y.to(device)
            scores = model(x)
            loss = loss_fun(scores, y)
            losses.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            wandb.log({'loss': loss})
        mean_loss = sum(losses) / len(losses)    
        scheduler.step(mean_loss)
        
        print(f'epoch {epoch}, {time() - start_time:.1f}s: {accuracy(model, data_loader):.1%}')

In [50]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, in_size, kernel_size=3, stride=1, padding=0):
        super(ConvBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.in_size = np.array(in_size)
        self.out_size = (self.in_size + 2 * padding - kernel_size) // stride + 1
        self.num_params = out_channels * (in_channels * kernel_size ** 2 + 3)
        
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.batchnorm = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        return self.relu(self.batchnorm(self.conv(x)))

class InceptionBlock(nn.Module):
    def __init__(self, in_channels, out_channels, in_size):
        super(InceptionBlock, self).__init__()
        self.in_channels = in_channels 
        self.out_channels = out_channels
        self.in_size = np.array(in_size),
        self.out_size = np.array(in_size),
        self.num_params = sum([
            self.out_channels // 4 * (self.in_channels * 1 ** 2 + 3),
            sum([
                self.out_channels // 2 * (self.in_channels * 1 ** 2 + 3),
                self.out_channels // 2 * (self.out_channels // 2 * 3 ** 2 + 3)
            ]),
            sum([
                self.out_channels // 16 * (self.in_channels * 1 ** 2 + 3),
                self.out_channels // 8 * (self.out_channels // 16 * 5 ** 2 + 3)
            ]),
            self.out_channels // 8 * (self.in_channels * 1 ** 2 + 3)
        ])
        
        self.branch1 = ConvBlock(in_channels, out_channels // 4, self.in_size, kernel_size=1) 
        self.branch2 = nn.Sequential(
            ConvBlock(in_channels, out_channels // 2, self.in_size, kernel_size=1), 
            ConvBlock(out_channels // 2, out_channels // 2, self.in_size, kernel_size=3, padding=1) 
        )
        self.branch3 = nn.Sequential(
            ConvBlock(in_channels, out_channels // 16, self.in_size, kernel_size=1), 
            ConvBlock(out_channels // 16, out_channels // 8, self.in_size, kernel_size=5, padding=2), 
        )
        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            ConvBlock(in_channels, out_channels // 8, self.in_size, kernel_size=1)
        )
        
    def forward(self, x):
        return torch.cat([self.branch1(x), self.branch2(x), self.branch3(x), self.branch4(x)], 1)
    
class Inception(nn.Module):
    def __init__(self, in_channels, num_classes, in_size):
        super(Inception, self).__init__()
        self.in_size = np.array(in_size)
        self.cnv_1 = ConvBlock(in_channels, 64, self.in_size, kernel_size=5, stride=1, padding=2)
        self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.inc_1 = InceptionBlock(64, 32, self.in_size // 2)
        self.inc_2 = InceptionBlock(32, 32, self.in_size // 2)
        self.inc_3 = InceptionBlock(32, 64, self.in_size // 2)
        self.inc_4 = InceptionBlock(64, 64, self.in_size // 2) # 
        self.avgp = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
        self.dout = nn.Dropout(p=0.4)
        self.fc = nn.Linear(1024, num_classes)
        self.to(device)
        
    def forward(self, x):
        x = self.cnv_1(x)
        x = self.pool(x)
        x = self.inc_1(x)
        x = self.inc_2(x)
        x = self.pool(x)
        x = self.inc_3(x)
        x = self.inc_4(x)
        x = self.avgp(x)
        x = x.reshape(x.shape[0], -1)
        x = self.dout(x)
        x = self.fc(x)
        return x

model =  Inception(1, 10, (28, 28))
# summary(model, (1, 28, 28))
train('TinyInception', model, train_loader)
print(f'test: {accuracy(model, test_loader):.1%}')

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
loss,0.01602
_step,1875.0
_runtime,82.0
_timestamp,1609944692.0


0,1
loss,█▃▂▂▂▁▁▁▂▂▁▁▁▁▁▂▁▁▁▁▂▁▁▁▁▁▂▁▂▁▁▁▁▁▁▂▁▁▁▁
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇███


epoch 0, 35.6s: 98.6%
epoch 1, 35.5s: 99.0%
epoch 2, 35.6s: 98.9%
epoch 3, 35.4s: 99.5%
epoch 4, 35.2s: 99.6%
epoch 5, 35.2s: 99.4%
epoch 6, 36.4s: 99.5%
epoch 7, 35.5s: 99.5%
epoch 8, 35.1s: 99.4%
epoch 9, 36.4s: 99.7%
epoch 10, 35.9s: 99.8%
epoch 11, 36.4s: 99.7%
epoch 12, 36.5s: 99.4%
epoch 13, 36.0s: 99.7%
epoch 14, 36.3s: 99.7%
epoch 15, 35.1s: 99.7%
epoch 16, 35.5s: 99.8%
epoch 17, 35.9s: 99.8%
epoch 18, 34.8s: 99.8%
epoch 19, 35.6s: 99.8%
epoch 20, 35.8s: 99.8%
epoch 21, 35.2s: 99.9%
epoch 22, 35.4s: 99.9%
epoch 23, 35.5s: 99.9%
epoch 24, 35.7s: 99.9%
epoch 25, 34.8s: 99.9%
epoch 26, 35.5s: 99.9%
epoch 27, 35.6s: 99.9%
epoch 28, 35.0s: 100.0%
epoch 29, 35.8s: 99.9%
test: 99.3%


In [12]:
a.out_size

array([26, 26])

In [13]:
a.num_params

300

In [25]:
torch.tensor((28, 28))

tensor([28, 28])