In [1]:
SEED = 11

In [2]:
import torch
import matplotlib.pyplot as plt
from IPython.display import display, Math

In [3]:
import torch.nn as nn
import torch.optim as optim
import torch.nn.init as init
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader

In [4]:
def conv_1_block(in_chn, out_chn, act_fn, stride = 1): # kernel size = 1
    model = nn.Sequential(
        nn.Conv2d(in_chn, out_chn, kernel_size = 1, stride = stride),
        act_fn,
    )
    return(model)

def conv_3_block(in_chn, out_chn, act_fn):
    model = nn.Sequential(
        nn.Conv2d(in_chn, out_chn, kernel_size = 3, padding = 1),
        act_fn,
    )
    return(model)

class bottle_neck(nn.Module):
    def __init__(self, in_chn, mid_chn, out_chn, act_fn, down_shape = False, num_classes = 10):
        super(bottle_neck, self).__init__()
        self.act_fn = act_fn
        self.down_shape = down_shape
        if self.down_shape:
            self.layer = nn.Sequential(
                conv_1_block(in_chn, mid_chn, act_fn, stride = 2),
                conv_3_block(mid_chn, mid_chn, act_fn),
                conv_1_block(mid_chn, out_chn, act_fn),
            )
            self.downsample = nn.Conv2d(in_chn, out_chn, kernel_size = 1, stride = 2) # reduce shape
        else:
            self.layer = nn.Sequential(
                conv_1_block(in_chn, mid_chn, act_fn, stride = 1),
                conv_3_block(mid_chn, mid_chn, act_fn),
                conv_1_block(mid_chn, out_chn, act_fn),
            )
        self.dim_equalizer = nn.Conv2d(in_chn, out_chn, kernel_size = 1)

    def forward(self, x):
        if self.down_shape: # down sampling (bottle neck)
            out = self.layer(x) # normal connection
            downsample = self.downsample(x) # residual skip connection
            out = out + downsample
        else:
            out = self.layer(x) # normal connection
            if x.size() is not out.size:
                x = self.dim_equalizer(x) # residual skip connection
            out = out + x
        return(out)

In [5]:
class my_resnet(nn.Module):
    def __init__(self, in_chn, base_chn, act_fn, num_classes = 10):
        super(my_resnet, self).__init__()
        self.act_fn = act_fn
        self.layer_1 = nn.Sequential(
            nn.Conv2d(in_chn, base_chn, kernel_size = 7, stride = 2, padding = 3),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1),
        )
        
        self.layer_2 = nn.Sequential(
            bottle_neck(base_chn, base_chn, base_chn*2, self.act_fn),
            bottle_neck(base_chn*2, base_chn, base_chn*2, self.act_fn),
            bottle_neck(base_chn*2, base_chn, base_chn*2, self.act_fn, down_shape = True)
        )
        
        self.layer_3 = nn.Sequential(
            bottle_neck(base_chn*2, base_chn, base_chn*4, self.act_fn),
            bottle_neck(base_chn*4, base_chn*2, base_chn*4, self.act_fn),
            bottle_neck(base_chn*4, base_chn*2, base_chn*4, self.act_fn),
            bottle_neck(base_chn*4, base_chn*2, base_chn*8, self.act_fn, down_shape = True)
        )
    
        self.maxpool = nn.MaxPool2d(kernel_size = 2, stride = 1)
        self.fc_layer = nn.Linear(base_chn*8, num_classes)
        
    def forward(self, x):
        out = self.layer_1(x)
        out = self.layer_2(out)
        out = self.layer_3(out)
        out = self.maxpool(out)
        out = out.view(batch_size, -1)
        out = self.fc_layer(out)
        return(out)


In [6]:
mnist_train = datasets.MNIST("./", train=True,
                            transform = transforms.ToTensor(),
                            target_transform=None,
                            download=True)

In [7]:
mnist_test = datasets.MNIST("./", train=False,
                            transform = transforms.ToTensor(),
                            target_transform=None,
                            download=True)

In [8]:
device = torch.device("cuda:0")
model = my_resnet(in_chn = 1, base_chn = 16, act_fn = nn.ReLU(), num_classes=len(mnist_train.classes)).to(device)


In [9]:
batch_size = 128
learning_rate = 0.0002
num_epoch = 10

In [10]:
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = learning_rate)

In [11]:
train_loader = DataLoader(mnist_train, batch_size = batch_size,
                         shuffle = True, num_workers = 2,
                         drop_last = True)

In [12]:
test_loader = DataLoader(mnist_test, batch_size = batch_size,
                         shuffle = False, num_workers = 2,
                         drop_last = True)

In [13]:
arr_loss = []
for ii in range(num_epoch):
    for jj, [image, label] in enumerate(train_loader):
        x = image.to(device)
        y_ = label.to(device)
        
        optimizer.zero_grad() # initialize 0 for each data
        output = model.forward(x)
        loss = loss_func(output, y_)
        loss.backward() # calculate back prop (gradient)
        optimizer.step() # update weight
        if jj % 1000 == 0:
            print(loss)
            arr_loss.append(loss.cpu().detach().numpy())
        

tensor(2.3007, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(0.3600, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(0.1132, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(0.0903, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(0.1445, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(0.1408, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(0.1022, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(0.0782, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(0.0587, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(0.0971, device='cuda:0', grad_fn=<NllLossBackward>)


In [14]:
btnk = bottle_neck(16, 4,16, nn.ReLU()).to(device)

In [24]:
x.size()

torch.Size([128, 1, 28, 28])

In [19]:
correct = 0
total = 0

with torch.no_grad():
    for image, label in test_loader:
        x = image.to(device)
        y_ = label.to(device)
        output = model.forward(x)
        _, output_index = torch.max(output, 1)
        total += label.size(0)
        correct += (output_index == y_).sum().float()
    print('Test acc: {}'.format(correct/total))

Test acc: 0.9814703464508057
