In [14]:
from __future__ import print_function, division

import torch
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import datasets, transforms, utils
import time
import os
import copy
import torch.nn as nn
import torch.nn.functional as F

# TODO: Implement a convolutional neural network (https://pytorch.org/tutorials/recipes/recipes/defining_a_neural_network.html)
class Net(nn.Module):
    """
    Input - 1x32x32
    Output - 10
    """
    def __init__(self):
        super(Net, self).__init__()
        
        self.params = {'conv':[(), 
                               (3, 6, 3, 1, 0), 
                               (6, 8, 3, 1, 1),
                               (8, 12, 4, 1, 0),
                               (12, 16, 3, 1, 1)], # in_channels, out_channels, kernel_size, stride, padding
                       'pool':[(), 
                               (2, 2, 0),
                               (3, 2, 0)], # kernel_size, stride, padding
                       'fc':[(), 
                             (16*5*5, 120),
                             (120, 84), 
                             (84, 10)] # in_channels, out_channels
                      }
        
        self.conv1 = nn.Conv2d(*self.params['conv'][1])
        self.conv2 = nn.Conv2d(*self.params['conv'][2])
        self.conv3 = nn.Conv2d(*self.params['conv'][3])
        self.conv4 = nn.Conv2d(*self.params['conv'][4])
        # self.conv5 = nn.Conv2d(*self.params['conv'][5])
        
        self.pool1 = nn.MaxPool2d(*self.params['pool'][1])
        self.pool2 = nn.MaxPool2d(*self.params['pool'][2])
        
        self.fc1 = nn.Linear(*self.params['fc'][1])
        self.fc2 = nn.Linear(*self.params['fc'][2])
        self.fc3 = nn.Linear(*self.params['fc'][3])
        
        self.printed = False

        # TODO: Initialize layers

    def forward(self, img):

        # TODO: Implement forward pass
        x = img
        x = F.relu(self.conv1(x))
        if not self.printed: 
            print("CONV1", x.size())
        x = F.relu(self.conv2(x))
        if not self.printed: 
            print("CONV2", x.size())
        x = self.pool1(x)
        if not self.printed: 
            print("POOL1", x.size())
        x = F.relu(self.conv3(x))
        if not self.printed: 
            print("CONV3", x.size())
        x = F.relu(self.conv4(x))
        if not self.printed: 
            print("CONV4", x.size())
        x = self.pool2(x)
        if not self.printed: 
            print("POOL2", x.size())
            self.printed = True
        
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x

# TODO: You can change these data augmentation and normalization strategies for
#  better training and testing (https://pytorch.org/vision/stable/transforms.html)
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((32,32)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize((32,32)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# Dataset initialization
data_dir = 'data' # Suppose the dataset is stored under this folder
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'test']} # Read train and test sets, respectively.

dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4, shuffle=True, num_workers=0) for x in ['train', 'test']}
# trainloader = torch.utils.data.DataLoader(image_datasets['train'], batch_size=4, shuffle=True, num_workers=2)
# teatloader = torch.utils.data.DataLoader(image_datasets['test'], batch_size=4, shuffle=True, num_workers=2)

dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}

class_names = image_datasets['train'].classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Set device to "cpu" if you have no gpu

# TODO: Implement training and testing procedures (https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html)
def train_test(model, criterion, optimizer, scheduler, num_epochs=25):
    for epoch in range(num_epochs):  

        running_loss = 0.0
        for i, data in enumerate(dataloaders['train'], 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data

            # zero the parameter gradients
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 2000 == 1999:    # print every 2000 mini-batches
                print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))
                running_loss = 0.0

    print('Finished Training')
    
    '''
    # save training results
    PATH = './cifar_net.pth'
    torch.save(model.state_dict(), PATH)
    '''
    
    # testing overall correct rate
    correct = 0
    total = 0
    
    with torch.no_grad():
        for i, data in enumerate(dataloaders['train'], 0):
            images, labels = data
            # calculate outputs by running images through the network
            outputs = model(images)
            # the class with the highest energy is what we choose as prediction
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))
    
    # prepare to count predictions for each class
    correct_pred = {classname: 0 for classname in class_names}
    total_pred = {classname: 0 for classname in class_names}

    # again no gradients needed
    with torch.no_grad():
        for i, data in enumerate(dataloaders['test'], 0):
            images, labels = data
            outputs = model(images)
            _, predictions = torch.max(outputs, 1)
            # collect the correct predictions for each class
            for label, prediction in zip(labels, predictions):
                if label == prediction:
                    correct_pred[class_names[label]] += 1
                total_pred[class_names[label]] += 1


    # print accuracy for each class
    print("Accuracy: ")
    for classname, correct_count in correct_pred.items():
        accuracy = 100 * float(correct_count) / total_pred[classname]
        print("{:1s}: {:.1f}%;  ".format(classname, accuracy), end=' ')
    print()    
    
    return None

model_ft = Net() # Model initialization

model_ft = model_ft.to(device) # Move model to cpu

criterion = nn.CrossEntropyLoss() # Loss function initialization

# TODO: Adjust the following hyper-parameters: learning rate, decay strategy, number of training epochs.
optimizer_ft = optim.Adam(model_ft.parameters(), lr=1e-4) # Optimizer initialization

exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=20, gamma=0.1) # Learning rate decay strategy

train_test(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=30)

CONV1 torch.Size([4, 6, 30, 30])
CONV2 torch.Size([4, 8, 30, 30])
POOL1 torch.Size([4, 8, 15, 15])
CONV3 torch.Size([4, 12, 12, 12])
CONV4 torch.Size([4, 16, 12, 12])
POOL2 torch.Size([4, 16, 5, 5])
[1,  2000] loss: 2.188
[1,  4000] loss: 1.593
[1,  6000] loss: 1.461
[2,  2000] loss: 1.226
[2,  4000] loss: 1.093
[2,  6000] loss: 1.070
[3,  2000] loss: 0.932
[3,  4000] loss: 0.908
[3,  6000] loss: 0.872
[4,  2000] loss: 0.788
[4,  4000] loss: 0.776
[4,  6000] loss: 0.757
[5,  2000] loss: 0.730
[5,  4000] loss: 0.696
[5,  6000] loss: 0.702
[6,  2000] loss: 0.632
[6,  4000] loss: 0.656
[6,  6000] loss: 0.638
[7,  2000] loss: 0.597
[7,  4000] loss: 0.594
[7,  6000] loss: 0.589
[8,  2000] loss: 0.567
[8,  4000] loss: 0.559
[8,  6000] loss: 0.537
[9,  2000] loss: 0.532
[9,  4000] loss: 0.531
[9,  6000] loss: 0.520
[10,  2000] loss: 0.489
[10,  4000] loss: 0.495
[10,  6000] loss: 0.504
[11,  2000] loss: 0.472
[11,  4000] loss: 0.458
[11,  6000] loss: 0.485
[12,  2000] loss: 0.467
[12,  4000] 

In [None]:
'''
RECORD 1. 
self.params = {'conv':[(), (3, 6, 3, 1, 0), (6, 8, 3, 1, 1), (8, 12, 4, 1, 0), (12, 16, 3, 1, 1)], # in_channels, out_channels, kernel_size, stride, padding
               'pool':[(), (2, 2, 0), (3, 2, 0)], # kernel_size, stride, padding
               'fc':[(), (16*5*5, 120), (120, 84), (84, 10)] # in_channels, out_channels
               }

CONV1 torch.Size([4, 6, 30, 30])
CONV2 torch.Size([4, 8, 30, 30])
POOL1 torch.Size([4, 8, 15, 15])
CONV3 torch.Size([4, 12, 12, 12])
CONV4 torch.Size([4, 16, 12, 12])
POOL2 torch.Size([4, 16, 5, 5])

[1,  2000] loss: 2.301
[1,  4000] loss: 1.912
[1,  6000] loss: 1.412
[2,  2000] loss: 1.164
[2,  4000] loss: 1.124
[2,  6000] loss: 1.064
[3,  2000] loss: 0.999
[3,  4000] loss: 0.926
[3,  6000] loss: 0.889
[4,  2000] loss: 0.817
[4,  4000] loss: 0.803
[4,  6000] loss: 0.762
[5,  2000] loss: 0.735
[5,  4000] loss: 0.698
[5,  6000] loss: 0.693
Finished Training
Accuracy of the network on the 10000 test images: 79 %
Accuracy for class 0  is: 84.6 %
Accuracy for class 1  is: 84.4 %
Accuracy for class 2  is: 77.6 %
Accuracy for class 3  is: 63.4 %
Accuracy for class 4  is: 81.6 %
Accuracy for class 5  is: 76.6 %
Accuracy for class 6  is: 79.0 %
Accuracy for class 7  is: 85.6 %
Accuracy for class 8  is: 63.8 %
Accuracy for class 9  is: 82.4 %
'''

In [None]:
'''
RECORD 2: 
CONV1 torch.Size([4, 6, 30, 30])
CONV2 torch.Size([4, 8, 30, 30])
POOL1 torch.Size([4, 8, 15, 15])
CONV3 torch.Size([4, 12, 12, 12])
CONV4 torch.Size([4, 16, 12, 12])
POOL2 torch.Size([4, 16, 5, 5])
[1,  2000] loss: 2.188
[1,  4000] loss: 1.593
[1,  6000] loss: 1.461
[2,  2000] loss: 1.226
[2,  4000] loss: 1.093
[2,  6000] loss: 1.070
[3,  2000] loss: 0.932
[3,  4000] loss: 0.908
[3,  6000] loss: 0.872
[4,  2000] loss: 0.788
[4,  4000] loss: 0.776
[4,  6000] loss: 0.757
[5,  2000] loss: 0.730
[5,  4000] loss: 0.696
[5,  6000] loss: 0.702
[6,  2000] loss: 0.632
[6,  4000] loss: 0.656
[6,  6000] loss: 0.638
[7,  2000] loss: 0.597
[7,  4000] loss: 0.594
[7,  6000] loss: 0.589
[8,  2000] loss: 0.567
[8,  4000] loss: 0.559
[8,  6000] loss: 0.537
[9,  2000] loss: 0.532
[9,  4000] loss: 0.531
[9,  6000] loss: 0.520
[10,  2000] loss: 0.489
[10,  4000] loss: 0.495
[10,  6000] loss: 0.504
[11,  2000] loss: 0.472
[11,  4000] loss: 0.458
[11,  6000] loss: 0.485
[12,  2000] loss: 0.467
[12,  4000] loss: 0.460
[12,  6000] loss: 0.442
[13,  2000] loss: 0.432
[13,  4000] loss: 0.452
[13,  6000] loss: 0.434
[14,  2000] loss: 0.408
[14,  4000] loss: 0.415
[14,  6000] loss: 0.429
[15,  2000] loss: 0.404
[15,  4000] loss: 0.417
[15,  6000] loss: 0.406
[16,  2000] loss: 0.379
[16,  4000] loss: 0.389
[16,  6000] loss: 0.399
[17,  2000] loss: 0.373
[17,  4000] loss: 0.366
[17,  6000] loss: 0.395
[18,  2000] loss: 0.358
[18,  4000] loss: 0.372
[18,  6000] loss: 0.376
[19,  2000] loss: 0.351
[19,  4000] loss: 0.346
[19,  6000] loss: 0.376
[20,  2000] loss: 0.331
[20,  4000] loss: 0.325
[20,  6000] loss: 0.364
[21,  2000] loss: 0.321
[21,  4000] loss: 0.349
[21,  6000] loss: 0.332
[22,  2000] loss: 0.334
[22,  4000] loss: 0.303
[22,  6000] loss: 0.329
[23,  2000] loss: 0.304
[23,  4000] loss: 0.310
[23,  6000] loss: 0.320
[24,  2000] loss: 0.305
[24,  4000] loss: 0.300
[24,  6000] loss: 0.304
[25,  2000] loss: 0.297
[25,  4000] loss: 0.292
[25,  6000] loss: 0.307
[26,  2000] loss: 0.284
[26,  4000] loss: 0.293
[26,  6000] loss: 0.287
[27,  2000] loss: 0.260
[27,  4000] loss: 0.287
[27,  6000] loss: 0.279
[28,  2000] loss: 0.276
[28,  4000] loss: 0.258
[28,  6000] loss: 0.272
[29,  2000] loss: 0.260
[29,  4000] loss: 0.262
[29,  6000] loss: 0.250
[30,  2000] loss: 0.257
[30,  4000] loss: 0.260
[30,  6000] loss: 0.266
Finished Training
Accuracy of the network on the 10000 test images: 92 %
Accuracy: 
0: 89.0%;   1: 89.0%;   2: 85.2%;   3: 80.2%;   4: 91.4%;   5: 83.2%;   6: 84.6%;   7: 88.0%;   8: 84.0%;   9: 85.4%;
'''