In [1]:
import os
os.chdir('..')

In [2]:
from torchvision.datasets import MNIST
import torch
import torch.nn as nn
import torchvision.transforms as TF
from tqdm.auto import tqdm
from NetBayesianization import wrap, api

In [3]:
train_ds = MNIST(root='./.cache', train=True, download=True, 
                 transform=TF.ToTensor()) 
test_ds = MNIST(root='./.cache', train=False, download=False, 
                transform=TF.ToTensor())

In [15]:
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=32, shuffle=True)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=32, shuffle=False)

In [16]:
num_classes = 10

In [17]:
model = nn.Sequential(nn.Flatten() ,nn.Linear(28*28, 128), 
                       nn.ReLU(), nn.Linear(128, 64), 
                       nn.ReLU(), nn.Linear(64, num_classes), nn.Softmax(dim=1))
optimizer = torch.optim.SGD(model.parameters(), lr=0.003, momentum=0.9)
criterion = nn.CrossEntropyLoss()
images, labels = next(iter(train_dl))
images = images.view(images.shape[0], -1)
logps = model(images)
loss = criterion(logps, labels)

In [18]:
def testAccuracy():
    
    model.eval()
    accuracy = 0.0
    total = 0.0
    
    with torch.no_grad():
        for data in test_dl:
            images, labels = data
            # run the model on the test set to predict labels
            outputs = model(images)
            # the label with the highest energy will be our prediction
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            accuracy += (predicted == labels).sum().item()
    
    # compute the accuracy over all test images
    accuracy = (100 * accuracy / total)
    return(accuracy)

In [19]:
# train
n_epochs = 5
best_accuracy = 0.0
for e in range(n_epochs):
    running_loss = 0
    for i, (images, labels) in enumerate(train_dl, 0):
        optimizer.zero_grad()
        output = model(images)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        # Let's print statistics for every 1,000 images
        running_loss += loss.item()     # extract the loss value
        if i % 1000 == 999:    
            # print every 1000 (twice per epoch) 
            print('[%d, %5d] loss: %.3f' %
                  (e + 1, i + 1, running_loss / 1000))
            # zero the loss
            running_loss = 0.0
    
    accuracy = testAccuracy()
    print('For epoch', e + 1,'the test accuracy over the whole test set is %d %%' % (accuracy))

    if accuracy > best_accuracy:
        best_accuracy = accuracy

[1,  1000] loss: 2.296
For epoch 1 the test accuracy over the whole test set is 56 %
[2,  1000] loss: 1.858
For epoch 2 the test accuracy over the whole test set is 74 %
[3,  1000] loss: 1.714
For epoch 3 the test accuracy over the whole test set is 80 %
[4,  1000] loss: 1.658
For epoch 4 the test accuracy over the whole test set is 89 %
[5,  1000] loss: 1.576
For epoch 5 the test accuracy over the whole test set is 91 %


In [20]:
from torchmetrics.classification import MulticlassAccuracy
metric = MulticlassAccuracy(num_classes=10)

In [43]:
with torch.no_grad():
    n = 0 
    test_acc = 0 
    for batch_idx, (inputs, targets) in enumerate(test_dl):
        outputs = model(inputs)
        test_acc += metric(outputs, targets)
        n += 1
    
    test_acc = test_acc / n
    print(f'the test accuracy over the whole test set is {round((100 * test_acc.numpy()), 2)}%')

the test accuracy over the whole test set is 88.87%


In [9]:
# correct_count, all_count = 0, 0
# for images,labels in test_dl:
#   for i in range(len(labels)):
#     img = images[i].view(1, 784)
#     with torch.no_grad():
#         logps = model(img)

#     ps = torch.exp(logps)
#     probab = list(ps.numpy()[0])
#     pred_label = probab.index(max(probab))
#     true_label = labels.numpy()[i]
#     if(true_label == pred_label):
#       correct_count += 1
#     all_count += 1

# print("Number Of Images Tested =", all_count)
# print("\nModel Accuracy =", (correct_count/all_count))

Number Of Images Tested = 10000

Model Accuracy = 0.9067


In [40]:
# build classical bayesian model
from torchmetrics.classification import MulticlassAccuracy
metric = MulticlassAccuracy(num_classes=10)
bayes_model = api.BasicBayesianWrapper(model, 'basic', 0.05, None, None)

In [44]:
n_iter = 5
with torch.no_grad():
    n = 0 
    test_acc = 0 
    for batch_idx, (inputs, targets) in enumerate(test_dl):
        test = bayes_model.predict(inputs, n_iter)
        test_acc += metric(test['mean'], targets)
        n += 1
    
    test_acc = test_acc / n
    print(f'the test accuracy over the whole test set is {round((100 * test_acc.numpy()), 2)}%')

the test accuracy over the whole test set is 88.71%


In [12]:
# build bayesian model with beta distibution
bayes_model = api.BasicBayesianWrapper(model, 'beta', None, 0.2, 4.0)

In [45]:
n_iter = 5
with torch.no_grad():
    n = 0 
    test_acc = 0 
    for batch_idx, (inputs, targets) in enumerate(test_dl):
        test = bayes_model.predict(inputs, n_iter)
        test_acc += metric(test['mean'], targets)
        n += 1
    
    test_acc = test_acc / n
    print(f'the test accuracy over the whole test set is {round((100 * test_acc.numpy()), 2)}%')

the test accuracy over the whole test set is 88.33%
