In [1]:
# Fix root path
import sys
sys.path.append('..')

In [2]:
import os
import sys

import torch
import torch.nn as nn
import torchvision as tv
import torchvision.datasets as datasets
from torch.utils.data import DataLoader

from models.mnist import BaseModel
from models.cifar10 import Resnet, Vgg
from models.torch_util import validate

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
PATH_FILE = os.path.join('..', 'data')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
loss = nn.CrossEntropyLoss()

## Testing MNIST

In [5]:
transforms = tv.transforms.Compose([tv.transforms.ToTensor()])
dataset_train = datasets.MNIST(PATH_FILE, train=True, download=True, transform=transforms)
dataset_test = datasets.MNIST(PATH_FILE, train=False, download=True, transform=transforms)
dataloader_train = DataLoader(dataset_train, batch_size=128, shuffle=False)
dataloader_test = DataLoader(dataset_test, batch_size=128, shuffle=False)

In [6]:
PATH_MODEL = os.path.join('..', 'results', 'mnist_200.pt')
model = BaseModel().to(device)
model.load_state_dict(torch.load(PATH_MODEL))
    
_, acc = validate(model, dataloader_train, loss, device)
print('Accuracy on training set: {}'.format(acc))

_, acc = validate(model, dataloader_test, loss, device)
print('Accuracy on test set: {}'.format(acc))

Accuracy on training set: 0.9971166666666667
Accuracy on test set: 0.9852


## Testing CIFAR10

In [7]:
# NOTE: Training set uses random horizontal flop and random crop during training
transforms = tv.transforms.Compose([tv.transforms.ToTensor()])
dataset_train = datasets.CIFAR10(PATH_FILE, train=True, download=True, transform=transforms)
dataset_test = datasets.CIFAR10(PATH_FILE, train=False, download=True, transform=transforms)
dataloader_train = DataLoader(dataset_train, batch_size=128, shuffle=False)
dataloader_test = DataLoader(dataset_test, batch_size=128, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [8]:
PATH_MODEL = os.path.join('..', 'results', 'cifar10_resnet_200.pt')
model = Resnet().to(device)
model.load_state_dict(torch.load(PATH_MODEL))
    
_, acc = validate(model, dataloader_train, loss, device)
print('Accuracy on training set: {}'.format(acc))

_, acc = validate(model, dataloader_test, loss, device)
print('Accuracy on test set: {}'.format(acc))

Accuracy on training set: 0.97072
Accuracy on test set: 0.8812


In [9]:
PATH_MODEL = os.path.join('..', 'results', 'cifar10_vgg_200.pt')
model = Vgg().to(device)
model.load_state_dict(torch.load(PATH_MODEL))
    
_, acc = validate(model, dataloader_train, loss, device)
print('Accuracy on training set: {}'.format(acc))

_, acc = validate(model, dataloader_test, loss, device)
print('Accuracy on test set: {}'.format(acc))

Accuracy on training set: 0.99078
Accuracy on test set: 0.9044
