In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch import linalg as LA
import torch.optim as optim
from torchvision import datasets, transforms
from types import SimpleNamespace
import matplotlib.pyplot as plt
import numpy as np
from models import NeuralNet
from loss import SmoothSailing, kappa

In [2]:
config = SimpleNamespace(batch_size=32, test_batch_size=1000, epochs=10,
                       lr=0.0001, momentum=0.5, seed=1, log_interval=100, noise_level=0, beta=0.01)
torch.manual_seed(config.seed)
use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')

In [3]:
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                     transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.1307,), (0.3081,))
                        ])),
    batch_size=config.batch_size, shuffle=True, **kwargs)
    
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])),
    batch_size=config.test_batch_size, shuffle=True, **kwargs)

In [5]:
# load model
model = NeuralNet(size=2048).to(device)
m = torch.load('model.pt', map_location=torch.device('cpu'))
model.load_state_dict(m)

In [None]:
model.eval()

acc = 0

num_repeats = 100

for i in range(num_repeats):
    with torch.no_grad():
        for data, target in test_loader:
            # add noise
            data += config.noise_level * torch.randn_like(data)
            data = Variable(data.view(-1, 28*28))
            data, target = data.to(device), target.to(device)

            output = model(data)

            pred = output.max(1, keepdim=True)[1]
            accur = pred.eq(target.view_as(pred)).sum().item()

        acc += accur/config.test_batch_size*100

print(f"Accuracy: {acc/num_repeats:.2f}%")

# save the results
with open('cond_class.txt', 'w') as f:
    f.write(f"\tAccuracy: {acc/num_repeats:.2f}%\n")
