In [6]:
import torch
import pyro
import torchvision
import preprocessing
from torch.utils.data import random_split

from pyro.infer.autoguide import AutoDiagonalNormal, AutoMultivariateNormal
from pyro.infer import SVI, Trace_ELBO, Predictive
import bcnn

In [2]:
train_path = r'data/sign_mnist_train.csv'
test_path = r'data/sign_mnist_test.csv'

images_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(), 
    torchvision.transforms.RandomHorizontalFlip(), 
    torchvision.transforms.RandomRotation(10), 
])

labels_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(), 
])

train_data_full = preprocessing.ASLDataset(
    csv_file=train_path, 
    transform=images_transforms, 
    target_transform=None
    )

test_data = preprocessing.ASLDataset(
    csv_file=test_path, 
    transform=images_transforms, 
    target_transform=None
    )


val_size = 7455
train_size = len(train_data_full) - val_size

train_data, val_data = random_split(train_data_full, [train_size, val_size])
print('train_data:', len(train_data))
print('val_data:', len(val_data))
print('test_data:', len(test_data))

batch_size = 64

train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)
valid_loader = torch.utils.data.DataLoader(dataset=val_data, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=batch_size, shuffle=False)

train_data: 20000
val_data: 7455
test_data: 7172


In [3]:
model = bcnn.BCNN()
guide = AutoDiagonalNormal(model)
optim = pyro.optim.Adam({"lr": 0.03})
svi = SVI(model, guide, optim, loss=Trace_ELBO())

In [4]:
pyro.clear_param_store()

num_iterations = 2
loss = 0

for j in range(num_iterations):
    loss = 0
    for batch_idx, (images, labels) in enumerate(train_loader):
        # calculate the loss and take a gradient step
        loss += svi.step(images, labels)
    normalizer_train = len(train_loader.dataset)
    total_epoch_loss_train = loss / normalizer_train
    
    print("Epoch ", j, " Loss ", total_epoch_loss_train)

Epoch  0  Loss  17182.23447291355
Epoch  1  Loss  6735.6542171504025


In [5]:
predictive = Predictive(model, guide=guide, num_samples=500)
for batch_idx, (images, labels) in enumerate(train_loader):
    preds = predictive(images)
    print(preds)

NameError: name 'Predictive' is not defined

In [None]:
'''

import numpy as np

sampled_models = [guide(None, None) for _ in range(10)]
print(type(sampled_models))
print(len(sampled_models))
print(list(sampled_models[3].keys()))
#print(sampled_models[2])

for j, (images, labels) in enumerate(test_loader):
    for model in sampled_models:
        print(type(model))
        out = model(images)
        print(out)


num_samples = 10
def predict(x):
    sampled_models = [guide(None, None) for _ in range(num_samples)]
    yhats = [model(x) for model in sampled_models]
    mean = torch.mean(torch.stack(yhats), 0)
    return np.argmax(mean.numpy(), axis=1)

print('Prediction when network is forced to predict')
correct = 0
total = 0
for j, data in enumerate(test_loader):
    images, labels = data
    predicted = predict(images)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()
print("accuracy: %d %%" % (100 * correct / total))


from pyro.infer import Predictive

guide.requires_grad_(False)
predictive = Predictive(model, guide=guide, num_samples=1000, return_sites=("linear.weight", "obs", "_RETURN"))

distribution = dist.Normal(torch.zeros(1, device='cpu'), torch.ones(1, device='cpu'))

for module_name, module in net.named_modules():
    for param_name, param in list(module.named_parameters(recurse=False)):
        full_name = module_name + "." + param_name
        #print('module:', module)
        #print('param_name:', param_name)
        #print('Full_name:', full_name)
        #print('param.shape:', param.shape)
        #print('param.dim():', param.dim())
        prior_dist = distribution.expand(param.shape).to_event(param.dim())
        #print(prior_dist)
        setattr(module, param_name, PyroSample(prior_dist))

#print(torchsummary.summary(net, (1, 28, 28)))
'''
print('hello')