# ResNet on CIFAR10

In [1]:
'''
Loading necessary libraries
'''
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

In [2]:
'''
Setup parameters
'''
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

n_epochs = 80
n_classes = 10
batch_size = 100
lr = 1e-3

In [7]:
'''
Loading CIDAR10 dataset
'''
transform = transforms.Compose([
    transforms.Pad(4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32),
    transforms.ToTensor()
])

trainset = torchvision.datasets.CIFAR10(root='./data_cifar10', 
                                        train=True,
                                        download=True, 
                                        transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, 
                                          batch_size=batch_size,
                                          shuffle=True, 
                                          num_workers=12)

testset = torchvision.datasets.CIFAR10(root='./data_cifar10', 
                                       train=False,
                                       download=True, 
                                       transform=transform)

testloader = torch.utils.data.DataLoader(testset, 
                                         batch_size=batch_size,
                                         shuffle=False, 
                                         num_workers=12)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [11]:
'''
Define model class
'''
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super().__init__()
        self.conv1 = self._conv3x3(in_channels, out_channels, stride)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = self._conv3x3(out_channels, out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample
        
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        
        return out
        
    def _conv3x3(self, in_channels, out_channels, stride=1):
        return nn.Conv2d(in_channels, out_channels, kernel_size=3,
                        stride=stride, padding=1, bias=False)
    
class ResNet(nn.Module):
    def __init__(self, block, layers, n_classes):
        super().__init__()
        self.in_channels = 16
        self.conv = self._conv3x3(3,16)
        self.bn = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self.make_layer(block, 16, layers[0])
        self.layer2 = self.make_layer(block, 32, layers[1], 2)
        self.layer3 = self.make_layer(block, 64, layers[2], 2)
        self.avg_pool = nn.AvgPool2d(8)
        self.fc = nn.Linear(64, n_classes)
        
    def make_layer(self, block, out_channels, blocks, stride=1):
        downsample = None
        if (stride != 1) or (self.in_channels != out_channels):
            downsample = nn.Sequential(
                self._conv3x3(self.in_channels, out_channels, stride=stride),
                nn.BatchNorm2d(out_channels)
            )
        
        layers = []
        layers += [block(self.in_channels, out_channels, stride, downsample)]
        
        self.in_channels = out_channels
        
        for i in range(1, blocks):
            layers += [block(out_channels, out_channels)]
        return nn.Sequential(*layers)
    
    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.relu(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out        
    
    def _conv3x3(self, in_channels, out_channels, stride=1):
        return nn.Conv2d(in_channels, out_channels, kernel_size=3,
                        stride=stride, padding=1, bias=False)
    
model = ResNet(ResidualBlock, [2,2,2], n_classes).to(device)

In [14]:
'''
Optimizer and Loss function
'''
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss()

def update_lr(optimizer, lr):
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [17]:
'''
Train the model
'''
total_steps = len(trainloader)
current_lr = lr

for epoch in range(n_epochs):
    for i, (images, labels) in enumerate(trainloader):
        images = images.to(device)
        labels = labels.to(device)
        
        # forward
        outputs = model(images)
        loss = loss_fn(outputs, labels)
        
        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # print
        if (i+1)%100==0:
            print('Epoch [{}/{}], step [{}/{}], loss {:.4f}'
                 .format(epoch+1, n_epochs, i+1, total_steps, loss.item()))
    
    if (epoch+1)%20==0:
        current_lr /= 3
        update_lr(optimizer, current_lr)

Epoch [1/80], step [100/500], loss 0.4704
Epoch [1/80], step [200/500], loss 0.3051
Epoch [1/80], step [300/500], loss 0.3720
Epoch [1/80], step [400/500], loss 0.2905
Epoch [1/80], step [500/500], loss 0.3336
Epoch [2/80], step [100/500], loss 0.3214
Epoch [2/80], step [200/500], loss 0.3799
Epoch [2/80], step [300/500], loss 0.2835
Epoch [2/80], step [400/500], loss 0.2795
Epoch [2/80], step [500/500], loss 0.3328
Epoch [3/80], step [100/500], loss 0.3199
Epoch [3/80], step [200/500], loss 0.3183
Epoch [3/80], step [300/500], loss 0.4969
Epoch [3/80], step [400/500], loss 0.4103
Epoch [3/80], step [500/500], loss 0.4963
Epoch [4/80], step [100/500], loss 0.4162
Epoch [4/80], step [200/500], loss 0.3033
Epoch [4/80], step [300/500], loss 0.2805
Epoch [4/80], step [400/500], loss 0.1906
Epoch [4/80], step [500/500], loss 0.5452
Epoch [5/80], step [100/500], loss 0.2176
Epoch [5/80], step [200/500], loss 0.3401
Epoch [5/80], step [300/500], loss 0.2871
Epoch [5/80], step [400/500], loss

In [19]:
'''
Test the model
'''
model.eval()

with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in testloader:
        images = images.to(device)
        labels = labels.to(device)
        
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
    print('Accuracy of the model on the test images: {} %'.format(100*correct/total))

Accuracy of the model on the test images: 88.55 %


In [20]:
'''
Save the model
'''
torch.save(model.state_dict(), 'resnet_cifar10.ckpt')