In [1]:
import torch

from mnist_pytorch_lightning import MNISTDataModule, LitMNIST
import os
from torchmetrics import Accuracy
from pytorch_lightning import seed_everything

seed_everything(0, workers=True) # seeds for torch, numpy, random

params = {}
params['data_path'] = "."
params['AVAIL_GPUS'] = min(0, torch.cuda.device_count())
params['batch_size'] = 64
params['max_epoch'] = 50
params['lr'] = 2e-4

mnist_data_module = MNISTDataModule(params)
mnist_data_module.setup(stage = "fit")
mnist_data_module.setup(stage = "test")

results_dir = 'lightning_logs/version_16'
ckpt_path = os.listdir(os.path.join(results_dir, 'checkpoints'))
ckpt = torch.load(os.path.join(results_dir, 'checkpoints', ckpt_path[0]))

network = LitMNIST.load_from_checkpoint(params=params, checkpoint_path = os.path.join(results_dir, 'checkpoints', ckpt_path[0]))

def calc_accuracy(dataloader, network):
    network.eval()
    y_epoch, out_epoch, pred_epoch = [], [], []
    acc = Accuracy()
    with torch.no_grad():
        for (x,y) in dataloader:
            out = network(x)
            y_epoch.append(y)
            out_epoch.append(out)
            pred = out.max(1, keepdim=True)[1] # get the index of the max log-probability
            pred_epoch.append(pred)
            acc.update(pred.view(-1), y.view(-1).int())
    y_epoch = torch.concat(y_epoch)
    out_epoch = torch.concat(out_epoch)
    pred_epoch = torch.concat(pred_epoch)
    
    accuracy = pred_epoch.eq(y_epoch.view_as(pred_epoch)).sum().item() / y_epoch.shape[0]
    assert abs(float(acc.compute()) - accuracy) < 0.001
    return float(acc.compute()),  pred_epoch, y_epoch
val_acc, _, _ = calc_accuracy(mnist_data_module.val_dataloader(), network)
test_acc, pred_epoch, y_epoch = calc_accuracy(mnist_data_module.test_dataloader(), network)

print("Calculated validation accuracy:",  val_acc)
print("Calculated test accuracy:",  test_acc)


  from .autonotebook import tqdm as notebook_tqdm
Global seed set to 0
Global seed set to 0


Calculated validation accuracy: 0.9850000143051147
Calculated test accuracy: 0.9740999937057495
