In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

In [2]:
# Device configurations
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")

In [3]:
# Hyper-parameters
num_epochs = 2
num_classes= 10
learning_rate = 0.001
batch_size=100

In [4]:
# Image Processing Modules
transform = transforms.Compose([
    transforms.Pad(4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32),
    transforms.ToTensor()])

In [5]:
# CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR10(root='../../data/',
                                             train=True, 
                                             transform=transform,  # using transform when training while no transform when testing
                                             download=False)

test_dataset = torchvision.datasets.CIFAR10(root='../../data/',
                                            train=False, 
                                            transform=transforms.ToTensor())

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size, 
                                          shuffle=False)

In [6]:
# A frequently used conv
def conv3x3(in_channels, out_channels, stride=1):
    return nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1, bias=False)

In [7]:
# ResNet's famous ResidualBlock
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResidualBlock, self).__init__()
        self.conv1 = conv3x3(in_channels=in_channels, out_channels=out_channels, stride=stride)
        self.bn1= nn.BatchNorm2d(num_features=out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(in_channels=out_channels, out_channels=out_channels)
        self.bn2 = nn.BatchNorm2d(num_features=out_channels)
        self.downsample = downsample
    
    def forward(self, x): # if this Residual Block has downsampled the inputs
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out

In [8]:
# ResNet
class ResNet(nn.Module):
    
    # ResNet is actually composed of ResidualBlocks, whose class prototype is passed in as "block"
    # a layer is consist of ResidualBlocks (layers[k] are num of ResidualBlocks of layer k)
    def __init__(self, block, layers, num_classes=10):
        super(ResNet, self).__init__()
        self.in_channels = 16
        self.conv = conv3x3(in_channels=3, out_channels=16)
        self.bn = nn.BatchNorm2d(num_features=16)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self.make_layer(block=block, out_channels=16, blocks=layers[0], stride=1)
        self.layer2 = self.make_layer(block=block, out_channels=32, blocks=layers[0], stride=2)
        self.layer3 = self.make_layer(block=block, out_channels=64, blocks=layers[1], stride=2)
        self.avg_pool = nn.AvgPool2d(kernel_size=8)
        self.fc = nn.Linear(in_features=64, out_features=num_classes)
        
    def make_layer(self, block, out_channels, blocks, stride=1):
        downsample = None
        if (stride!=1) or (self.in_channels != out_channels):
            downsample = nn.Sequential(
                conv3x3(in_channels=self.in_channels, out_channels=out_channels, stride=stride),
                nn.BatchNorm2d(out_channels))
        
        layers = []
        layers.append(block(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels
        for i in range(1, blocks):
            layers.append(block(out_channels, out_channels))
        return nn.Sequential(*layers)
    
    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.relu(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out
                        

In [9]:
# Init Net & Loss Func & Optimizer
model = ResNet(ResidualBlock, [2,2,2,2]).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for param in criterion.parameters():
    print(param.shape)

In [10]:
# for updating learning rate
def update_lr(optimizer, lr):
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [11]:
# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i,(images, labels) in enumerate(train_loader):
        # get OUTPUT & GT
        images = images.to(device)
        labels = labels.to(device)
        
        # Forward through net & loss(output, labels)
        outputs = model(images)
        #print(labels)
        #print(outputs)
        loss = criterion(outputs, labels)
        
        # zero & backward & step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
        # logging
        if (i+1) % 1 == 0:
            print ("Epoch [{}/{}], Step [{}/{}] Loss: {:.4f}".format(epoch+1, num_epochs, i+1, total_step, loss.item()))
    # Decay learning rate
    if (epoch+1) % 20 == 0:
        curr_lr /= 3
        update_lr(optimizer, curr_lr)


Epoch [1/2], Step [1/500] Loss: 2.3178
Epoch [1/2], Step [2/500] Loss: 2.2689
Epoch [1/2], Step [3/500] Loss: 2.2506
Epoch [1/2], Step [4/500] Loss: 2.1893
Epoch [1/2], Step [5/500] Loss: 2.1937
Epoch [1/2], Step [6/500] Loss: 2.1620
Epoch [1/2], Step [7/500] Loss: 2.1337
Epoch [1/2], Step [8/500] Loss: 2.0788
Epoch [1/2], Step [9/500] Loss: 2.0407
Epoch [1/2], Step [10/500] Loss: 2.0308
Epoch [1/2], Step [11/500] Loss: 2.0681
Epoch [1/2], Step [12/500] Loss: 1.9786
Epoch [1/2], Step [13/500] Loss: 1.9680
Epoch [1/2], Step [14/500] Loss: 2.0013
Epoch [1/2], Step [15/500] Loss: 1.9849
Epoch [1/2], Step [16/500] Loss: 1.8921
Epoch [1/2], Step [17/500] Loss: 1.9099
Epoch [1/2], Step [18/500] Loss: 1.9642
Epoch [1/2], Step [19/500] Loss: 1.9957
Epoch [1/2], Step [20/500] Loss: 1.9018
Epoch [1/2], Step [21/500] Loss: 1.9149
Epoch [1/2], Step [22/500] Loss: 1.9803
Epoch [1/2], Step [23/500] Loss: 2.0222
Epoch [1/2], Step [24/500] Loss: 1.9821
Epoch [1/2], Step [25/500] Loss: 1.9576
Epoch [1/

Epoch [1/2], Step [204/500] Loss: 1.5780
Epoch [1/2], Step [205/500] Loss: 1.4807
Epoch [1/2], Step [206/500] Loss: 1.4424
Epoch [1/2], Step [207/500] Loss: 1.4919
Epoch [1/2], Step [208/500] Loss: 1.4198
Epoch [1/2], Step [209/500] Loss: 1.5606
Epoch [1/2], Step [210/500] Loss: 1.3574
Epoch [1/2], Step [211/500] Loss: 1.4911
Epoch [1/2], Step [212/500] Loss: 1.3797
Epoch [1/2], Step [213/500] Loss: 1.4822
Epoch [1/2], Step [214/500] Loss: 1.4021
Epoch [1/2], Step [215/500] Loss: 1.5840
Epoch [1/2], Step [216/500] Loss: 1.3966
Epoch [1/2], Step [217/500] Loss: 1.4545
Epoch [1/2], Step [218/500] Loss: 1.4777
Epoch [1/2], Step [219/500] Loss: 1.4694
Epoch [1/2], Step [220/500] Loss: 1.3869
Epoch [1/2], Step [221/500] Loss: 1.4096
Epoch [1/2], Step [222/500] Loss: 1.5131
Epoch [1/2], Step [223/500] Loss: 1.4397
Epoch [1/2], Step [224/500] Loss: 1.3971
Epoch [1/2], Step [225/500] Loss: 1.2019
Epoch [1/2], Step [226/500] Loss: 1.4513
Epoch [1/2], Step [227/500] Loss: 1.5105
Epoch [1/2], Ste

Epoch [1/2], Step [404/500] Loss: 1.3463
Epoch [1/2], Step [405/500] Loss: 1.2586
Epoch [1/2], Step [406/500] Loss: 1.1480
Epoch [1/2], Step [407/500] Loss: 1.2144
Epoch [1/2], Step [408/500] Loss: 1.2044
Epoch [1/2], Step [409/500] Loss: 1.2112
Epoch [1/2], Step [410/500] Loss: 1.2650
Epoch [1/2], Step [411/500] Loss: 1.2779
Epoch [1/2], Step [412/500] Loss: 1.2561
Epoch [1/2], Step [413/500] Loss: 1.1775
Epoch [1/2], Step [414/500] Loss: 1.4571
Epoch [1/2], Step [415/500] Loss: 1.4193
Epoch [1/2], Step [416/500] Loss: 1.1280
Epoch [1/2], Step [417/500] Loss: 1.2572
Epoch [1/2], Step [418/500] Loss: 1.1814
Epoch [1/2], Step [419/500] Loss: 1.1616
Epoch [1/2], Step [420/500] Loss: 1.3687
Epoch [1/2], Step [421/500] Loss: 1.1103
Epoch [1/2], Step [422/500] Loss: 1.2243
Epoch [1/2], Step [423/500] Loss: 1.2273
Epoch [1/2], Step [424/500] Loss: 1.2811
Epoch [1/2], Step [425/500] Loss: 1.2804
Epoch [1/2], Step [426/500] Loss: 1.2095
Epoch [1/2], Step [427/500] Loss: 1.1524
Epoch [1/2], Ste

Epoch [2/2], Step [107/500] Loss: 1.1536
Epoch [2/2], Step [108/500] Loss: 1.1177
Epoch [2/2], Step [109/500] Loss: 1.0539
Epoch [2/2], Step [110/500] Loss: 1.0847
Epoch [2/2], Step [111/500] Loss: 1.1412
Epoch [2/2], Step [112/500] Loss: 0.9852
Epoch [2/2], Step [113/500] Loss: 1.1596
Epoch [2/2], Step [114/500] Loss: 1.0379
Epoch [2/2], Step [115/500] Loss: 1.1484
Epoch [2/2], Step [116/500] Loss: 1.0701
Epoch [2/2], Step [117/500] Loss: 1.0401
Epoch [2/2], Step [118/500] Loss: 1.2463
Epoch [2/2], Step [119/500] Loss: 0.9698
Epoch [2/2], Step [120/500] Loss: 1.0563
Epoch [2/2], Step [121/500] Loss: 0.9203
Epoch [2/2], Step [122/500] Loss: 1.1829
Epoch [2/2], Step [123/500] Loss: 1.0928
Epoch [2/2], Step [124/500] Loss: 1.0534
Epoch [2/2], Step [125/500] Loss: 1.0053
Epoch [2/2], Step [126/500] Loss: 1.1918
Epoch [2/2], Step [127/500] Loss: 1.2074
Epoch [2/2], Step [128/500] Loss: 1.1324
Epoch [2/2], Step [129/500] Loss: 1.0540
Epoch [2/2], Step [130/500] Loss: 0.8369
Epoch [2/2], Ste

Epoch [2/2], Step [307/500] Loss: 1.1078
Epoch [2/2], Step [308/500] Loss: 1.1141
Epoch [2/2], Step [309/500] Loss: 1.0154
Epoch [2/2], Step [310/500] Loss: 0.9645
Epoch [2/2], Step [311/500] Loss: 1.0366
Epoch [2/2], Step [312/500] Loss: 0.9836
Epoch [2/2], Step [313/500] Loss: 1.0035
Epoch [2/2], Step [314/500] Loss: 0.9024
Epoch [2/2], Step [315/500] Loss: 1.1951
Epoch [2/2], Step [316/500] Loss: 0.9812
Epoch [2/2], Step [317/500] Loss: 1.0483
Epoch [2/2], Step [318/500] Loss: 0.8998
Epoch [2/2], Step [319/500] Loss: 0.9192
Epoch [2/2], Step [320/500] Loss: 1.0050
Epoch [2/2], Step [321/500] Loss: 1.1316
Epoch [2/2], Step [322/500] Loss: 0.8479
Epoch [2/2], Step [323/500] Loss: 1.0059
Epoch [2/2], Step [324/500] Loss: 1.0367
Epoch [2/2], Step [325/500] Loss: 0.8875
Epoch [2/2], Step [326/500] Loss: 1.2369
Epoch [2/2], Step [327/500] Loss: 1.0009
Epoch [2/2], Step [328/500] Loss: 0.8089
Epoch [2/2], Step [329/500] Loss: 0.9664
Epoch [2/2], Step [330/500] Loss: 0.8948
Epoch [2/2], Ste

In [14]:
# Test the model
model.eval()
with torch.no_grad():
    correct =0
    total = 0
    total_step = len(test_loader)
    for i,(images, lebels) in enumerate(test_loader):
        images = images.to(device)
        labels = labels.to(device)
        
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        if (i+1) % 10 == 0:
            print('[{}/{}]Accuracy of the model on the test images: {:.4f} %'.format(i+1, total_step, 100 * correct / total))

[10/100]Accuracy of the model on the test images: 11.3000 %
[20/100]Accuracy of the model on the test images: 10.6500 %
[30/100]Accuracy of the model on the test images: 11.1333 %
[40/100]Accuracy of the model on the test images: 10.5000 %
[50/100]Accuracy of the model on the test images: 10.5000 %
[60/100]Accuracy of the model on the test images: 10.2500 %
[70/100]Accuracy of the model on the test images: 10.0286 %
[80/100]Accuracy of the model on the test images: 10.1125 %
[90/100]Accuracy of the model on the test images: 9.9444 %
[100/100]Accuracy of the model on the test images: 9.9000 %


In [13]:
torch.save(model.state_dict(), 'model.ckpt')