In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

In [2]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
# Hyper-parameters
num_epochs = 80
batch_size = 100
learning_rate = 0.001

# Image preprocessing modules
transform = transforms.Compose([
    transforms.Pad(4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32),
    transforms.ToTensor()])

In [5]:
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=transform)


Files already downloaded and verified
Files already downloaded and verified


In [6]:
trainloader = torch.utils.data.DataLoader(dataset=trainset,batch_size=4,
                                         shuffle=True,num_workers=2)
testloader = torch.utils.data.DataLoader(dataset=testset,batch_size=100,
                                         shuffle=True,num_workers=2)

In [7]:
# This is one residual block
class ResidualBlock(nn.Module):
    def __init__(self,in_channels,out_channels,stride=1,downsample=None):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=stride, \
                              padding=1,bias=False)
        self.bn1 = nn.BatchNorm2d(num_features=out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels,out_channels,kernel_size=3,stride=1, \
                              padding=1,bias=False)
        self.bn2 = nn.BatchNorm2d(num_features=out_channels)
        self.downsample = downsample
        
    def forward(self,x):
        residual = x
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        if self.downsample:
            residual = self.downsample(residual)
        x += residual
        x = self.relu(x)
        return x

In [8]:
## According to resnet architecture - (mellowed down by PyTorch)
class ResNet(nn.Module):
    def __init__(self,block,layers,num_classes=10):
        super(ResNet,self).__init__()
        self.in_channels = 16
        self.conv = nn.Conv2d(3,16,kernel_size=3,stride=1, \
                              padding=1,bias=False)
        self.bn = nn.BatchNorm2d(num_features=16)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self.make_layer(block, 16, layers[0])
        self.layer2 = self.make_layer(block, 32, layers[1], 2)
        self.layer3 = self.make_layer(block, 64, layers[2], 2)
        self.avg_pool = nn.AvgPool2d(8)
        self.fc = nn.Linear(64, num_classes)
        
    def make_layer(self,block,out_channels,blocks,stride=1):
        downsample = None
        if (self.in_channels != out_channels) or (stride != 1):
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channels,out_channels,kernel_size=3,stride=stride, \
                              padding=1,bias=False), \
                nn.BatchNorm2d(out_channels))
        layers = []
        layers.append(block(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels
        for i in range(1, blocks):
            layers.append(block(out_channels, out_channels))
        return nn.Sequential(*layers)
    
    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.relu(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out
    
model = ResNet(ResidualBlock, [2, 2, 2]).to(device)

In [9]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [11]:
total_step = len(trainloader)
curr_lr = learning_rate
for epoch in range(num_epochs):
    for i,(images,labels) in enumerate(trainloader):
        images = images.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs,labels)
        loss.backward()
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print ("Epoch [{}/{}], Step [{}/{}] Loss: {:.4f}"
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))
        

Epoch [1/80], Step [100/12500] Loss: 2.5587
Epoch [1/80], Step [200/12500] Loss: 2.1955
Epoch [1/80], Step [300/12500] Loss: 1.8618
Epoch [1/80], Step [400/12500] Loss: 2.0797
Epoch [1/80], Step [500/12500] Loss: 1.6879
Epoch [1/80], Step [600/12500] Loss: 2.0044
Epoch [1/80], Step [700/12500] Loss: 2.0169
Epoch [1/80], Step [800/12500] Loss: 2.2538
Epoch [1/80], Step [900/12500] Loss: 2.1079
Epoch [1/80], Step [1000/12500] Loss: 2.1439
Epoch [1/80], Step [1100/12500] Loss: 1.5376
Epoch [1/80], Step [1200/12500] Loss: 1.9431
Epoch [1/80], Step [1300/12500] Loss: 1.9862
Epoch [1/80], Step [1400/12500] Loss: 2.7970
Epoch [1/80], Step [1500/12500] Loss: 1.8451
Epoch [1/80], Step [1600/12500] Loss: 1.8750
Epoch [1/80], Step [1700/12500] Loss: 2.0855
Epoch [1/80], Step [1800/12500] Loss: 2.0999
Epoch [1/80], Step [1900/12500] Loss: 1.4392
Epoch [1/80], Step [2000/12500] Loss: 2.3268
Epoch [1/80], Step [2100/12500] Loss: 1.8443
Epoch [1/80], Step [2200/12500] Loss: 1.6540
Epoch [1/80], Step 

Epoch [2/80], Step [5800/12500] Loss: 0.4902
Epoch [2/80], Step [5900/12500] Loss: 1.4300
Epoch [2/80], Step [6000/12500] Loss: 1.3151
Epoch [2/80], Step [6100/12500] Loss: 1.2521
Epoch [2/80], Step [6200/12500] Loss: 0.8263
Epoch [2/80], Step [6300/12500] Loss: 0.6794
Epoch [2/80], Step [6400/12500] Loss: 0.7500
Epoch [2/80], Step [6500/12500] Loss: 0.9534
Epoch [2/80], Step [6600/12500] Loss: 1.1327
Epoch [2/80], Step [6700/12500] Loss: 1.1782
Epoch [2/80], Step [6800/12500] Loss: 0.7276
Epoch [2/80], Step [6900/12500] Loss: 0.2783
Epoch [2/80], Step [7000/12500] Loss: 0.6306
Epoch [2/80], Step [7100/12500] Loss: 1.5533
Epoch [2/80], Step [7200/12500] Loss: 1.5356
Epoch [2/80], Step [7300/12500] Loss: 0.5336
Epoch [2/80], Step [7400/12500] Loss: 1.5859
Epoch [2/80], Step [7500/12500] Loss: 0.4513
Epoch [2/80], Step [7600/12500] Loss: 1.0298
Epoch [2/80], Step [7700/12500] Loss: 0.8559
Epoch [2/80], Step [7800/12500] Loss: 0.3866
Epoch [2/80], Step [7900/12500] Loss: 1.7260
Epoch [2/8

Epoch [3/80], Step [11500/12500] Loss: 0.4964
Epoch [3/80], Step [11600/12500] Loss: 0.3307
Epoch [3/80], Step [11700/12500] Loss: 0.5795
Epoch [3/80], Step [11800/12500] Loss: 0.5361
Epoch [3/80], Step [11900/12500] Loss: 0.4238
Epoch [3/80], Step [12000/12500] Loss: 2.1658
Epoch [3/80], Step [12100/12500] Loss: 0.2953
Epoch [3/80], Step [12200/12500] Loss: 0.5139
Epoch [3/80], Step [12300/12500] Loss: 1.1611
Epoch [3/80], Step [12400/12500] Loss: 0.8456
Epoch [3/80], Step [12500/12500] Loss: 0.4712
Epoch [4/80], Step [100/12500] Loss: 0.3636
Epoch [4/80], Step [200/12500] Loss: 1.2258
Epoch [4/80], Step [300/12500] Loss: 1.0139
Epoch [4/80], Step [400/12500] Loss: 2.1005
Epoch [4/80], Step [500/12500] Loss: 1.1643
Epoch [4/80], Step [600/12500] Loss: 0.2739
Epoch [4/80], Step [700/12500] Loss: 1.2365
Epoch [4/80], Step [800/12500] Loss: 0.4014
Epoch [4/80], Step [900/12500] Loss: 0.9463
Epoch [4/80], Step [1000/12500] Loss: 1.7920
Epoch [4/80], Step [1100/12500] Loss: 1.1772
Epoch [4

Epoch [5/80], Step [4700/12500] Loss: 0.3702
Epoch [5/80], Step [4800/12500] Loss: 0.8521
Epoch [5/80], Step [4900/12500] Loss: 0.6622
Epoch [5/80], Step [5000/12500] Loss: 0.3122
Epoch [5/80], Step [5100/12500] Loss: 0.7799
Epoch [5/80], Step [5200/12500] Loss: 0.9384
Epoch [5/80], Step [5300/12500] Loss: 0.1481
Epoch [5/80], Step [5400/12500] Loss: 1.2869
Epoch [5/80], Step [5500/12500] Loss: 0.3616
Epoch [5/80], Step [5600/12500] Loss: 0.3593
Epoch [5/80], Step [5700/12500] Loss: 0.9036
Epoch [5/80], Step [5800/12500] Loss: 0.5923
Epoch [5/80], Step [5900/12500] Loss: 0.3758
Epoch [5/80], Step [6000/12500] Loss: 0.2418
Epoch [5/80], Step [6100/12500] Loss: 0.2888
Epoch [5/80], Step [6200/12500] Loss: 0.2900
Epoch [5/80], Step [6300/12500] Loss: 1.4569
Epoch [5/80], Step [6400/12500] Loss: 0.7032
Epoch [5/80], Step [6500/12500] Loss: 0.1317
Epoch [5/80], Step [6600/12500] Loss: 1.3302
Epoch [5/80], Step [6700/12500] Loss: 0.1302
Epoch [5/80], Step [6800/12500] Loss: 0.9650
Epoch [5/8

Epoch [6/80], Step [10400/12500] Loss: 0.4092
Epoch [6/80], Step [10500/12500] Loss: 0.1689
Epoch [6/80], Step [10600/12500] Loss: 1.3556
Epoch [6/80], Step [10700/12500] Loss: 0.2569
Epoch [6/80], Step [10800/12500] Loss: 0.0457
Epoch [6/80], Step [10900/12500] Loss: 0.2080
Epoch [6/80], Step [11000/12500] Loss: 1.5791
Epoch [6/80], Step [11100/12500] Loss: 0.9342
Epoch [6/80], Step [11200/12500] Loss: 0.9940
Epoch [6/80], Step [11300/12500] Loss: 0.0919
Epoch [6/80], Step [11400/12500] Loss: 1.2279
Epoch [6/80], Step [11500/12500] Loss: 0.5780
Epoch [6/80], Step [11600/12500] Loss: 0.0578
Epoch [6/80], Step [11700/12500] Loss: 0.0808
Epoch [6/80], Step [11800/12500] Loss: 0.5510
Epoch [6/80], Step [11900/12500] Loss: 0.2208
Epoch [6/80], Step [12000/12500] Loss: 0.1342
Epoch [6/80], Step [12100/12500] Loss: 0.8073
Epoch [6/80], Step [12200/12500] Loss: 0.4103
Epoch [6/80], Step [12300/12500] Loss: 0.3960
Epoch [6/80], Step [12400/12500] Loss: 0.6809
Epoch [6/80], Step [12500/12500] L

Epoch [8/80], Step [3600/12500] Loss: 0.2525
Epoch [8/80], Step [3700/12500] Loss: 1.4601
Epoch [8/80], Step [3800/12500] Loss: 0.2022
Epoch [8/80], Step [3900/12500] Loss: 0.2024
Epoch [8/80], Step [4000/12500] Loss: 0.1160
Epoch [8/80], Step [4100/12500] Loss: 0.4743
Epoch [8/80], Step [4200/12500] Loss: 0.1036
Epoch [8/80], Step [4300/12500] Loss: 0.1825
Epoch [8/80], Step [4400/12500] Loss: 0.0943
Epoch [8/80], Step [4500/12500] Loss: 0.0289
Epoch [8/80], Step [4600/12500] Loss: 0.2133
Epoch [8/80], Step [4700/12500] Loss: 0.4588
Epoch [8/80], Step [4800/12500] Loss: 0.3297
Epoch [8/80], Step [4900/12500] Loss: 1.0723
Epoch [8/80], Step [5000/12500] Loss: 0.2229
Epoch [8/80], Step [5100/12500] Loss: 1.6734
Epoch [8/80], Step [5200/12500] Loss: 0.5132
Epoch [8/80], Step [5300/12500] Loss: 0.0894
Epoch [8/80], Step [5400/12500] Loss: 0.1231
Epoch [8/80], Step [5500/12500] Loss: 0.4787
Epoch [8/80], Step [5600/12500] Loss: 0.8134
Epoch [8/80], Step [5700/12500] Loss: 0.0565
Epoch [8/8

Epoch [9/80], Step [9300/12500] Loss: 1.0339
Epoch [9/80], Step [9400/12500] Loss: 1.4647
Epoch [9/80], Step [9500/12500] Loss: 1.0274
Epoch [9/80], Step [9600/12500] Loss: 0.2850
Epoch [9/80], Step [9700/12500] Loss: 0.1404
Epoch [9/80], Step [9800/12500] Loss: 0.4876
Epoch [9/80], Step [9900/12500] Loss: 0.9819
Epoch [9/80], Step [10000/12500] Loss: 1.1593
Epoch [9/80], Step [10100/12500] Loss: 0.2048
Epoch [9/80], Step [10200/12500] Loss: 1.0433
Epoch [9/80], Step [10300/12500] Loss: 0.2462
Epoch [9/80], Step [10400/12500] Loss: 0.0428
Epoch [9/80], Step [10500/12500] Loss: 1.0207
Epoch [9/80], Step [10600/12500] Loss: 0.4135
Epoch [9/80], Step [10700/12500] Loss: 0.5074
Epoch [9/80], Step [10800/12500] Loss: 0.4754
Epoch [9/80], Step [10900/12500] Loss: 0.2581
Epoch [9/80], Step [11000/12500] Loss: 0.6277
Epoch [9/80], Step [11100/12500] Loss: 0.1720
Epoch [9/80], Step [11200/12500] Loss: 0.0170
Epoch [9/80], Step [11300/12500] Loss: 0.2595
Epoch [9/80], Step [11400/12500] Loss: 0.

Epoch [11/80], Step [2200/12500] Loss: 0.2201
Epoch [11/80], Step [2300/12500] Loss: 0.3158
Epoch [11/80], Step [2400/12500] Loss: 1.0479
Epoch [11/80], Step [2500/12500] Loss: 0.7308
Epoch [11/80], Step [2600/12500] Loss: 0.7825
Epoch [11/80], Step [2700/12500] Loss: 0.6348
Epoch [11/80], Step [2800/12500] Loss: 0.1355
Epoch [11/80], Step [2900/12500] Loss: 0.0989
Epoch [11/80], Step [3000/12500] Loss: 0.0685
Epoch [11/80], Step [3100/12500] Loss: 0.2985
Epoch [11/80], Step [3200/12500] Loss: 0.1977
Epoch [11/80], Step [3300/12500] Loss: 0.7831
Epoch [11/80], Step [3400/12500] Loss: 0.8189
Epoch [11/80], Step [3500/12500] Loss: 0.3862
Epoch [11/80], Step [3600/12500] Loss: 0.3209
Epoch [11/80], Step [3700/12500] Loss: 0.6141
Epoch [11/80], Step [3800/12500] Loss: 0.9329
Epoch [11/80], Step [3900/12500] Loss: 1.2465
Epoch [11/80], Step [4000/12500] Loss: 0.2962
Epoch [11/80], Step [4100/12500] Loss: 0.1557
Epoch [11/80], Step [4200/12500] Loss: 1.0872
Epoch [11/80], Step [4300/12500] L

Epoch [12/80], Step [7500/12500] Loss: 0.6855
Epoch [12/80], Step [7600/12500] Loss: 0.8216
Epoch [12/80], Step [7700/12500] Loss: 0.3803
Epoch [12/80], Step [7800/12500] Loss: 0.0062
Epoch [12/80], Step [7900/12500] Loss: 0.3587
Epoch [12/80], Step [8000/12500] Loss: 1.0600
Epoch [12/80], Step [8100/12500] Loss: 0.2864
Epoch [12/80], Step [8200/12500] Loss: 0.0742
Epoch [12/80], Step [8300/12500] Loss: 0.0900
Epoch [12/80], Step [8400/12500] Loss: 0.7614
Epoch [12/80], Step [8500/12500] Loss: 0.6204
Epoch [12/80], Step [8600/12500] Loss: 0.9026
Epoch [12/80], Step [8700/12500] Loss: 0.1496
Epoch [12/80], Step [8800/12500] Loss: 0.7424
Epoch [12/80], Step [8900/12500] Loss: 0.8605
Epoch [12/80], Step [9000/12500] Loss: 0.1524
Epoch [12/80], Step [9100/12500] Loss: 1.2838
Epoch [12/80], Step [9200/12500] Loss: 0.5440
Epoch [12/80], Step [9300/12500] Loss: 0.9153
Epoch [12/80], Step [9400/12500] Loss: 0.3330
Epoch [12/80], Step [9500/12500] Loss: 0.2600
Epoch [12/80], Step [9600/12500] L

KeyboardInterrupt: 