In [1]:
import os
import time
import importlib
import json
from collections import OrderedDict
import logging
import argparse
import numpy as np
import random
import time
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.optim
import torch.utils.data
import torch.backends.cudnn
import torchvision.utils
import torch.nn.functional as F
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.metrics import confusion_matrix

import torchvision.transforms as transforms
import torchvision


In [2]:
class block(nn.Module):
  def __init__(self, in_channels, out_channels, stride = 1, convert = None):
    super(block, self).__init__()
    self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=stride, padding=0, bias=False)
    self.bn1 = nn.BatchNorm2d(in_channels)
    self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=False)
    self.bn2 = nn.BatchNorm2d(in_channels)
    self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
    self.bn3 = nn.BatchNorm2d(out_channels)

    self.relu = nn.ReLU(inplace=True)
    self.convert = convert

  def forward(self, x):
    y = x.clone()
    out = self.conv1(x)
    out = self.bn1(out)
    out = self.relu(out)
    out = self.conv2(out)
    out = self.bn2(out)
    out = self.relu(out)
    out = self.conv3(out)
    out = self.bn3(out)

    if self.convert:
      y = self.convert(y)

    out = out + y
    return out

In [3]:
class ResNet(nn.Module):
  def __init__(self, block, layers, classes = 10):
    super(ResNet, self).__init__()
    self.conv = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
    self.bn = nn.BatchNorm2d(16)
    self.relu = nn.ReLU(inplace=True)
    self.in_channel = 16

    self.layer1 = self.make_layer(block, 16, layers[0])
    self.layer2 = self.make_layer(block, 32, layers[1], 2)
    self.layer3 = self.make_layer(block, 64, layers[2], 2)
    
    self.avg_pool = nn.AvgPool2d(8)
    self.fc = nn.Linear(64, classes)

  def make_layer(self, block, out_channel, num_layers, stride = 1):
    layer = []
    conv = None
    if stride != 1 or self.in_channel != out_channel:
      conv = nn.Sequential(
          nn.Conv2d(self.in_channel, 
                    out_channel, 
                    kernel_size=3,
                    stride=stride, 
                    padding=1, 
                    bias=False), 
          nn.BatchNorm2d(out_channel))
    
    layer.append(block(self.in_channel, out_channel, stride, conv))
    self.in_channel = out_channel

    for _ in range(num_layers - 1):
      layer.append(block(self.in_channel, out_channel))

    return nn.Sequential(*layer)
    
  def forward(self, x):
    out = self.conv(x)
    out = self.bn(out)
    out = self.relu(out)
    out = self.layer1(out)
    out = self.layer2(out)
    out = self.layer3(out)
    out = self.avg_pool(out)
    out = out.view(out.size(0), -1)
    out = self.fc(out)
    return out

In [4]:
transform = transforms.Compose([transforms.Pad(4),transforms.RandomHorizontalFlip(),transforms.RandomCrop(32),transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
test_transform  = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))])

train_dataset = torchvision.datasets.CIFAR10(root='../../data/', train=True,transform=transform, download=True)
test_dataset = torchvision.datasets.CIFAR10(root='../../data/', train=False,transform=test_transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=100, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=100, shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ../../data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=0.0, max=170498071.0), HTML(value='')))


Extracting ../../data/cifar-10-python.tar.gz to ../../data/


In [5]:
depth = 3
epochs = 40
batch_size = 128
base_lr = 0.01
lr_decay = 0.1
milestones = '[80, 120]'
device = "cuda"
num_workers = 3

model = ResNet(block,[10,10,10]).to(device)

In [6]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=base_lr)
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=lr_decay)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=lr_decay)

In [7]:
for epoch in range(epochs):
  for i, (images, labels) in enumerate(train_loader):
    images = images.to(device)
    labels = labels.to(device)
    outputs = model(images)
    loss = criterion(outputs, labels)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if (i+1) % 100 == 0:
      print ("Epoch {}, Step {} Loss: {:.4f}".format(epoch+1, i+1, loss.item()))   
  scheduler.step()

Epoch 1, Step 100 Loss: 1.9018
Epoch 1, Step 200 Loss: 1.7809
Epoch 1, Step 300 Loss: 1.4379
Epoch 1, Step 400 Loss: 1.5117
Epoch 1, Step 500 Loss: 1.3800
Epoch 2, Step 100 Loss: 1.2779
Epoch 2, Step 200 Loss: 1.0572
Epoch 2, Step 300 Loss: 1.2797
Epoch 2, Step 400 Loss: 1.1540
Epoch 2, Step 500 Loss: 1.2174
Epoch 3, Step 100 Loss: 0.9010
Epoch 3, Step 200 Loss: 1.0845
Epoch 3, Step 300 Loss: 1.0221
Epoch 3, Step 400 Loss: 0.9947
Epoch 3, Step 500 Loss: 0.9971
Epoch 4, Step 100 Loss: 0.8517
Epoch 4, Step 200 Loss: 0.9917
Epoch 4, Step 300 Loss: 0.8789
Epoch 4, Step 400 Loss: 1.0005
Epoch 4, Step 500 Loss: 0.7615
Epoch 5, Step 100 Loss: 0.8970
Epoch 5, Step 200 Loss: 0.7268
Epoch 5, Step 300 Loss: 0.7909
Epoch 5, Step 400 Loss: 0.6040
Epoch 5, Step 500 Loss: 0.9037
Epoch 6, Step 100 Loss: 0.8072
Epoch 6, Step 200 Loss: 0.7213
Epoch 6, Step 300 Loss: 0.7595
Epoch 6, Step 400 Loss: 0.6179
Epoch 6, Step 500 Loss: 0.5534
Epoch 7, Step 100 Loss: 0.7832
Epoch 7, Step 200 Loss: 0.7528
Epoch 7,

In [9]:
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Accuracy ( test images ) : {} %'.format(100 * correct / total))

Accuracy ( test images ) : 81.2 %
