In [1]:
import os
import time
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from Models import ResNet20, ResNet50

## Test network

In [2]:
input_test = torch.randn([5,3,32,32])
resnet20 = ResNet20()
out = resnet20(input_test)
print(out.size())

resnet50 = ResNet50()
out = resnet50(input_test)
print(out.size())

torch.Size([5, 10])
torch.Size([5, 10])


## Data Preprocess

In [3]:
DATA_ROOT = "./data"
CIFAR10_shape = (3, 32, 32)
pad_size = 2
BATCH_SIZE = 128


# Preprocessing
transform_train = transforms.Compose([torchvision.transforms.RandomHorizontalFlip(p=0.5),
                                      torchvision.transforms.RandomCrop((32,32), padding=4),
                                      transforms.ToTensor(), 
                                      transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

transform_val = transforms.Compose([transforms.ToTensor(), 
                                    transforms.Normalize((0.4914, 0.4822, 0.4465),
                                    (0.2023, 0.1994, 0.2010))])

transform_test = transforms.Compose([transforms.ToTensor(), 
                                    transforms.Normalize((0.4914, 0.4822, 0.4465),
                                    (0.2023, 0.1994, 0.2010))])

train_CIFAR10 = torchvision.datasets.CIFAR10(root='./', train=True, download=True, transform = transform_train)

test_CIFAR10 = torchvision.datasets.CIFAR10(root='./', train=False, download=True, transform=transform_test)

num_train = int(1.0 * len(train_CIFAR10) * 95 / 100)
num_val = len(train_CIFAR10) - num_train
train_CIFAR10, val_CIFAR10 = torch.utils.data.random_split(train_CIFAR10, [num_train, num_val])

train_loader = DataLoader(
    train_CIFAR10, 
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4
)

val_loader = DataLoader(
    val_CIFAR10, 
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4
)

test_loader = DataLoader(
    test_CIFAR10, 
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4
)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("The model is deployed to", device)

resnet20 = ResNet20().to(device)
resnet50 = ResNet50().to(device)

The model is deployed to cuda


## Define Training and Testing Function

In [5]:
def train_model(model, mode):
    # some hyperparameters
    # total number of training epochs
    EPOCHS = 150

    # learning rate decay policy
    DECAY_EPOCHS = [75,110]
    DECAY = 0.1

    # the folder where the trained model is saved
    CHECKPOINT_PATH = "./saved_model"

    # start the training/validation process
    # the process should take about 5 minutes on a GTX 1070-Ti
    # if the code is written efficiently.
    best_val_acc = 0
    current_learning_rate = INITIAL_LR

    print("==> Training starts!")
    print("="*50)
    for i in range(0, EPOCHS):
        # handle the learning rate scheduler.
        if i in DECAY_EPOCHS and i != 0:
            current_learning_rate = current_learning_rate * DECAY
            for param_group in optimizer.param_groups:
                param_group['lr'] = current_learning_rate
            print("Current learning rate has decayed to %f" %current_learning_rate)

        model.train()

        print("Epoch %d:" %i)
        # this help you compute the training accuracy
        total_examples = 0
        correct_examples = 0

        train_loss = 0 # track training loss if you want

        # Train the model for 1 epoch.
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)

            # compute the output and loss
            y_preds = model(inputs)        
            loss = criterion(y_preds,targets)
            train_loss += loss

            # zero the gradient
            optimizer.zero_grad()

            # backpropagation
            loss.backward()

            # apply gradient and update the weights
            optimizer.step()

            # count the number of correctly predicted samples in the current batch
            y_preds_class = torch.argmax(y_preds, dim=1)
            correct_examples += (targets == y_preds_class).sum().item()
            total_examples += targets.size(0)

        avg_loss = train_loss / len(train_loader)
        avg_acc = correct_examples / total_examples
        print("Training loss: %.4f, Training accuracy: %.4f" %(avg_loss, avg_acc))

        # switch to eval mode
        model.eval()

        total_examples = 0
        correct_examples = 0

        val_loss = 0 # again, track the validation loss if you want

        # disable gradient during validation, which can save GPU memory
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(val_loader):
                inputs, targets = inputs.to(device), targets.to(device)

                # compute the output and loss
                y_preds = model(inputs)        
                loss = criterion(y_preds,targets)
                val_loss += loss

                # count the number of correctly predicted samples in the current batch
                y_preds_class = torch.argmax(y_preds, dim=1)
                correct_examples += (targets == y_preds_class).sum().item()
                total_examples += targets.size(0)

        avg_loss = val_loss / len(val_loader)
        avg_acc = correct_examples / total_examples
        print("Validation loss: %.4f, Validation accuracy: %.4f" % (avg_loss, avg_acc))

        # save the model checkpoint
        if avg_acc > best_val_acc:
            best_val_acc = avg_acc
            if not os.path.exists(CHECKPOINT_PATH):
                os.makedirs(CHECKPOINT_PATH)
            print("Saving ...")
            state = {'state_dict': model.state_dict(),
                     'epoch': i,
                     'lr': current_learning_rate}
            if mode == 'resnet20':
                torch.save(state, os.path.join(CHECKPOINT_PATH, 'resnet20.pth'))
            elif mode == 'resnet50':
                torch.save(state, os.path.join(CHECKPOINT_PATH, 'resnet50.pth'))
            else:
                raise Exception('Mode not exist')

        print('')

    print("="*50)
    print(f"==> Optimization finished! Best validation accuracy: {best_val_acc:.4f}")

In [6]:
def test_model(model):
    model.to(device)
    model.eval()

    total_examples = 0
    correct_examples = 0
    softmax = torch.nn.Softmax(dim=1)

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            pred = model(inputs)
            total_examples += inputs.shape[0]

            out = softmax(pred)
            out = torch.max(out, 1)

            correct_examples += torch.sum(targets==out[1]).cpu().data.numpy().tolist()

    avg_acc = correct_examples / total_examples
    print("Total examples is {}, correct examples is {}; Test accuracy: {}".format(total_examples, correct_examples, avg_acc))

## ResNet20 Model

In [7]:
INITIAL_LR = 0.01

# momentum for optimizer
MOMENTUM = 0.9

# L2 regularization strength
REG = 1e-4

# Define loss
criterion = nn.CrossEntropyLoss()

# Add optimizer
optimizer = torch.optim.SGD(resnet20.parameters(), lr=INITIAL_LR, momentum=MOMENTUM, weight_decay=REG)

In [8]:
train_model(resnet20, "resnet20")

==> Training starts!
Epoch 0:
Training loss: 1.5555, Training accuracy: 0.4207
Validation loss: 1.6231, Validation accuracy: 0.4148
Saving ...

Epoch 1:
Training loss: 1.1050, Training accuracy: 0.6001
Validation loss: 1.2231, Validation accuracy: 0.5660
Saving ...

Epoch 2:
Training loss: 0.9277, Training accuracy: 0.6701
Validation loss: 0.9561, Validation accuracy: 0.6648
Saving ...

Epoch 3:
Training loss: 0.8099, Training accuracy: 0.7148
Validation loss: 0.8946, Validation accuracy: 0.6856
Saving ...

Epoch 4:
Training loss: 0.7235, Training accuracy: 0.7475
Validation loss: 0.7908, Validation accuracy: 0.7284
Saving ...

Epoch 5:
Training loss: 0.6652, Training accuracy: 0.7691
Validation loss: 0.7477, Validation accuracy: 0.7472
Saving ...

Epoch 6:
Training loss: 0.6127, Training accuracy: 0.7889
Validation loss: 0.7014, Validation accuracy: 0.7548
Saving ...

Epoch 7:
Training loss: 0.5726, Training accuracy: 0.8007
Validation loss: 0.6193, Validation accuracy: 0.7784
Saving 

In [12]:
resnet20 = ResNet20()
checkpoint = torch.load("./saved_model/resnet20.pth")
resnet20.load_state_dict(checkpoint['state_dict'])

test_model(resnet20)

Total examples is 10000, correct examples is 8967; Test accuracy: 0.8967


## ResNet50

In [13]:
INITIAL_LR = 0.01

# momentum for optimizer
MOMENTUM = 0.9

# L2 regularization strength
REG = 1e-4

# Define loss
criterion = nn.CrossEntropyLoss()

# Add optimizer
optimizer = torch.optim.SGD(resnet50.parameters(), lr=INITIAL_LR, momentum=MOMENTUM, weight_decay=REG)

In [14]:
train_model(resnet50, "resnet50")

==> Training starts!
Epoch 0:
Training loss: 1.4989, Training accuracy: 0.4452
Validation loss: 1.5594, Validation accuracy: 0.4268
Saving ...

Epoch 1:
Training loss: 1.1586, Training accuracy: 0.5820
Validation loss: 1.1808, Validation accuracy: 0.5792
Saving ...

Epoch 2:
Training loss: 0.9546, Training accuracy: 0.6619
Validation loss: 1.0306, Validation accuracy: 0.6236
Saving ...

Epoch 3:
Training loss: 0.8147, Training accuracy: 0.7143
Validation loss: 0.7943, Validation accuracy: 0.7284
Saving ...

Epoch 4:
Training loss: 0.7184, Training accuracy: 0.7483
Validation loss: 0.7780, Validation accuracy: 0.7312
Saving ...

Epoch 5:
Training loss: 0.6531, Training accuracy: 0.7722
Validation loss: 0.7320, Validation accuracy: 0.7444
Saving ...

Epoch 6:
Training loss: 0.6025, Training accuracy: 0.7924
Validation loss: 0.6469, Validation accuracy: 0.7808
Saving ...

Epoch 7:
Training loss: 0.5567, Training accuracy: 0.8071
Validation loss: 0.6832, Validation accuracy: 0.7672

Epoch 

In [17]:
resnet50 = ResNet50()
checkpoint = torch.load("./saved_model/resnet50_final.pth")
resnet50.load_state_dict(checkpoint['state_dict'])

test_model(resnet50)

Total examples is 10000, correct examples is 9079; Test accuracy: 0.9079
