In [2]:
import torch
from torch import nn

from lenet_5 import LeNet5_5
from torchvision.datasets.mnist import MNIST
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
import numpy as np

In [4]:
# Load data from MNIST
BATCH_SIZE = 256
BATCH_TEST_SIZE = 1024
data_train = MNIST('./data/mnist',
                   download=True,
                   transform=transforms.Compose([
                       transforms.Resize((32, 32)),
                       transforms.ToTensor()]))
data_test = MNIST('./data/mnist',
                  train=False,
                  download=True,
                  transform=transforms.Compose([
                      transforms.Resize((32, 32)),
                      transforms.ToTensor()]))
data_train_loader = DataLoader(data_train, batch_size = BATCH_SIZE , shuffle=True, num_workers=8)
data_test_loader = DataLoader(data_test,  batch_size = BATCH_TEST_SIZE, num_workers=8)
data_test_loader2 = DataLoader(data_test,  batch_size = 1, num_workers=0)

TRAIN_SIZE = len(data_train_loader.dataset)
TEST_SIZE = len(data_test_loader.dataset)
NUM_BATCHES = len(data_train_loader)
NUM_TEST_BATCHES = len(data_test_loader)

In [8]:
# Load pre-trained model
model_loaded = LeNet5_5()
model_loaded.load_state_dict(torch.load("./LeNet-saved-5"))
criterion = nn.NLLLoss()

In [10]:
# Validate
def validate (net, criterion):
    net.eval()
    total_correct = 0
    avg_loss = 0.0
    for i, (images, labels) in enumerate(data_test_loader):
        labels = (labels > 5).long()
        output = net(images)
        avg_loss += criterion(output, labels).sum()
        pred = output.detach().max(1)[1]
        total_correct += pred.eq(labels.view_as(pred)).sum()

    avg_loss /= len(data_test)
    print('Test Avg. Loss: %f, Accuracy: %f' % (avg_loss.detach().cpu().item(), float(total_correct) / len(data_test)))
    return 

In [12]:
# Validate checks accuracy of the model
validate (model_loaded, criterion)

Test Avg. Loss: 0.000023, Accuracy: 0.992900


In [None]:
imgs = []
intermediate_acts = []
total_correct = 0

model_loaded.eval()
for i, (images, labele) in enumerate(data_test_loader2):
    imgs.append(((np.reshape(np.squeeze(images.detach().numpy()), (1,-1)) )))
    x = images
    x = model_loaded.convnet(x)
    
    intermediate_acts.append(((np.reshape(np.squeeze(x.detach().numpy()), (1,-1)) )))
    
    np.save("images", np.array(imgs).squeeze(1))
    np.save("intermediate_act", np.array(intermediate_acts).squeeze(1))

In [None]:
import sys
sys.path.insert(1, './ite-repo')
import ite