In [1]:
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler
#from utee import misc, quant, selector

import torch.nn.functional as F  # useful stateless functions

import torchvision.datasets as dset
import torchvision.transforms as T

import numpy as np


#Load CIFAR-10
NUM_TRAIN = 49000

# The torchvision.transforms package provides tools for preprocessing data
# and for performing data augmentation; here we set up a transform to
# preprocess the data by subtracting the mean RGB value and dividing by the
# standard deviation of each RGB value; we've hardcoded the mean and std.
transform = T.Compose([
                T.ToTensor(),
                T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
            ])

# We set up a Dataset object for each split (train / val / test); Datasets load
# training examples one at a time, so we wrap each Dataset in a DataLoader which
# iterates through the Dataset and forms minibatches. We divide the CIFAR-10
# training set into train and val sets by passing a Sampler object to the
# DataLoader telling how it should sample from the underlying Dataset.
cifar10_train = dset.CIFAR10('./cs231n/datasets', train=True, download=True,
                             transform=transform)
loader_train = DataLoader(cifar10_train, batch_size=64, 
                          sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN)))

cifar10_val = dset.CIFAR10('./cs231n/datasets', train=True, download=True,
                           transform=transform)
loader_val = DataLoader(cifar10_val, batch_size=64, 
                        sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN, 50000)))

cifar10_test = dset.CIFAR10('./cs231n/datasets', train=False, download=True, 
                            transform=transform)
loader_test = DataLoader(cifar10_test, batch_size=64)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [2]:
USE_GPU = True

dtype = torch.float32 # we will be using float throughout this tutorial

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

# Constant to control how frequently we print train loss
print_every = 100

print('using device:', device)

using device: cuda


In [3]:
def check_accuracy_part34(loader, model):
    if loader.dataset.train:
        print('Checking accuracy on validation set')
    else:
        print('Checking accuracy on test set')   
    num_correct = 0
    num_samples = 0
    model.eval()  # set model to evaluation mode
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device, dtype=dtype)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)
            scores = model(x)
            _, preds = scores.max(1)
            num_correct += (preds == y).sum()
            num_samples += preds.size(0)
        acc = float(num_correct) / num_samples
        print('Got %d / %d correct (%.2f)' % (num_correct, num_samples, 100 * acc))

In [4]:
def train_part34(model, optimizer, epochs=1):
    """
    Train a model on CIFAR-10 using the PyTorch Module API.
    
    Inputs:
    - model: A PyTorch Module giving the model to train.
    - optimizer: An Optimizer object we will use to train the model
    - epochs: (Optional) A Python integer giving the number of epochs to train for
    
    Returns: Nothing, but prints model accuracies during training.
    """
    model = model.to(device=device)  # move the model parameters to CPU/GPU
    t_begin = time.time()
    for e in range(epochs):
        for t, (x, y) in enumerate(loader_train):
            model.train()  # put model to training mode
            x = x.to(device=device, dtype=dtype)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)

            scores = model(x)
            loss = F.cross_entropy(scores, y)

            # Zero out all of the gradients for the variables which the optimizer
            # will update.
            optimizer.zero_grad()

            # This is the backwards pass: compute the gradient of the loss with
            # respect to each  parameter of the model.
            loss.backward()

            # Actually update the parameters of the model using the gradients
            # computed by the backwards pass.
            optimizer.step()

            if t % print_every == 0:
                t_elapse = time.time() - t_begin
                print('Elapsed %.4f s, Epoch %d,  Iteration %d, loss = %.4f' % (t_elapse, e, t, loss.item()))
                check_accuracy_part34(loader_val, model)
                print()

In [5]:
from Model.vgg_modules import *

VGG19 = vgg19()
VGG19.state_dict().keys()

odict_keys(['features.0.weight', 'features.0.bias', 'features.2.weight', 'features.2.bias', 'features.5.weight', 'features.5.bias', 'features.7.weight', 'features.7.bias', 'features.10.weight', 'features.10.bias', 'features.12.weight', 'features.12.bias', 'features.14.weight', 'features.14.bias', 'features.16.weight', 'features.16.bias', 'features.19.weight', 'features.19.bias', 'features.21.weight', 'features.21.bias', 'features.23.weight', 'features.23.bias', 'features.25.weight', 'features.25.bias', 'features.28.weight', 'features.28.bias', 'features.30.weight', 'features.30.bias', 'features.32.weight', 'features.32.bias', 'features.34.weight', 'features.34.bias', 'classifier.1.weight', 'classifier.1.bias', 'classifier.4.weight', 'classifier.4.bias', 'classifier.6.weight', 'classifier.6.bias'])

In [12]:
PATH = '../pretrain_model/model_best.pth.tar'
VGG19.features = torch.nn.DataParallel(VGG19.features)
VGG19.cuda()
checkpoint = torch.load(PATH)
start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
#print(VGG19.state_dict().keys())
#print(checkpoint['state_dict'].keys())
VGG19.load_state_dict(checkpoint['state_dict'])

RuntimeError: Error(s) in loading state_dict for VGG:
	Missing key(s) in state_dict: "features.module.module.0.weight", "features.module.module.0.bias", "features.module.module.2.weight", "features.module.module.2.bias", "features.module.module.5.weight", "features.module.module.5.bias", "features.module.module.7.weight", "features.module.module.7.bias", "features.module.module.10.weight", "features.module.module.10.bias", "features.module.module.12.weight", "features.module.module.12.bias", "features.module.module.14.weight", "features.module.module.14.bias", "features.module.module.16.weight", "features.module.module.16.bias", "features.module.module.19.weight", "features.module.module.19.bias", "features.module.module.21.weight", "features.module.module.21.bias", "features.module.module.23.weight", "features.module.module.23.bias", "features.module.module.25.weight", "features.module.module.25.bias", "features.module.module.28.weight", "features.module.module.28.bias", "features.module.module.30.weight", "features.module.module.30.bias", "features.module.module.32.weight", "features.module.module.32.bias", "features.module.module.34.weight", "features.module.module.34.bias". 
	Unexpected key(s) in state_dict: "features.module.0.weight", "features.module.0.bias", "features.module.2.weight", "features.module.2.bias", "features.module.5.weight", "features.module.5.bias", "features.module.7.weight", "features.module.7.bias", "features.module.10.weight", "features.module.10.bias", "features.module.12.weight", "features.module.12.bias", "features.module.14.weight", "features.module.14.bias", "features.module.16.weight", "features.module.16.bias", "features.module.19.weight", "features.module.19.bias", "features.module.21.weight", "features.module.21.bias", "features.module.23.weight", "features.module.23.bias", "features.module.25.weight", "features.module.25.bias", "features.module.28.weight", "features.module.28.bias", "features.module.30.weight", "features.module.30.bias", "features.module.32.weight", "features.module.32.bias", "features.module.34.weight", "features.module.34.bias". 

In [21]:
from Model.Fixedvgg import *

FixedVGG19 = fixed_vgg19()
FixedVGG19

FixedVGG(
  (features): Sequential(
    (Q0): activation_quantization()
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace)
    (Q2): activation_quantization()
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (Q5): activation_quantization()
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace)
    (Q7): activation_quantization()
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (Q10): activation_quantization()
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (Q12): activation_quantization()
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    

In [22]:
PATH = '../pretrain_model/model_best.pth.tar'
FixedVGG19.features = torch.nn.DataParallel(FixedVGG19.features)
FixedVGG19.cuda()
checkpoint = torch.load(PATH)
start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
#print(VGG19.state_dict().keys())
#print(checkpoint['state_dict'].keys())
FixedVGG19.load_state_dict(checkpoint['state_dict'])


In [23]:
print("\nFixed VGG11 Accuracy:")
check_accuracy_part34(loader_test, FixedVGG19)
print("\nFloadt VGG11 Accuracy:")
check_accuracy_part34(loader_test, VGG19)


Fixed VGG11 Accuracy:
Checking accuracy on test set
Got 1000 / 10000 correct (10.00)

Floadt VGG11 Accuracy:
Checking accuracy on test set
Got 9206 / 10000 correct (92.06)


In [24]:
#Train this model
learning_rate = 2e-5

optimizer = optim.Adam(params=FixedVGG19.parameters(), lr=learning_rate)
#optimizer = optim.SGD(model.parameters(), lr=learning_rate,momentum=0.9, nesterov=True)
train_part34(FixedVGG19, optimizer, epochs=5)

Elapsed 0.0839 s, Epoch 0,  Iteration 0, loss = 2.3067
Checking accuracy on validation set
Got 112 / 1000 correct (11.20)

Elapsed 16.6850 s, Epoch 0,  Iteration 100, loss = 0.0776
Checking accuracy on validation set
Got 931 / 1000 correct (93.10)

Elapsed 33.3569 s, Epoch 0,  Iteration 200, loss = 0.1747
Checking accuracy on validation set
Got 934 / 1000 correct (93.40)

Elapsed 50.0442 s, Epoch 0,  Iteration 300, loss = 0.1509
Checking accuracy on validation set
Got 961 / 1000 correct (96.10)

Elapsed 66.7280 s, Epoch 0,  Iteration 400, loss = 0.1938
Checking accuracy on validation set
Got 959 / 1000 correct (95.90)

Elapsed 83.4200 s, Epoch 0,  Iteration 500, loss = 0.1710
Checking accuracy on validation set
Got 941 / 1000 correct (94.10)

Elapsed 100.1263 s, Epoch 0,  Iteration 600, loss = 0.0779
Checking accuracy on validation set
Got 953 / 1000 correct (95.30)

Elapsed 116.8172 s, Epoch 0,  Iteration 700, loss = 0.2561
Checking accuracy on validation set
Got 973 / 1000 correct (9

KeyboardInterrupt: 

In [25]:
print("\nFinetune Fixed VGG11 Accuracy:")
check_accuracy_part34(loader_test, FixedVGG19)
print("\nFloadt VGG11 Accuracy:")
check_accuracy_part34(loader_test, VGG19)


Finetune Fixed VGG11 Accuracy:
Checking accuracy on test set
Got 9073 / 10000 correct (90.73)

Floadt VGG11 Accuracy:
Checking accuracy on test set
Got 9206 / 10000 correct (92.06)
