### IMPORTS

In [None]:
import Inference.Variational as Variational
import Inference.NN as NN
from Tools.NNtools import *

In [None]:
import torch
from torch import nn
from torch import functional as F
import matplotlib.pyplot as plt
import numpy as np
from livelossplot import PlotLosses

In [None]:
def _log_norm(x, mu, std):
        return -0.5 * torch.log(2*np.pi*std**2) -(0.5 * (1/(std**2))* (x-mu)**2)

In [None]:
# Find what device to work with
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

In [None]:
# load data and make it in proper shape
data = torch.load('data/foong_data.pt')
x_data = data[0].to(device)
y_data = data[1].to(device)
y_data = y_data.unsqueeze(-1)

In [None]:
Net = NN.ParallelNN(input_size=1, output_size=1, layer_width=50, nb_layers=4, device=device)

In [None]:
param_count = Net.count_parameters()

In [None]:
prior = Variational.MeanFieldVariationalDistribution(param_count, sigma=1.0, device=device)
prior.mu.requires_grad = False
prior.rho.requires_grad = False

In [None]:
def logprior(theta):
    return prior.log_prob(theta)

In [None]:
def loglikelihood(theta):
    Net.set_parameters(theta)
    y_pred = model(x_data)
    L = _log_norm(y_pred, y_data, torch.tensor([0.1],device=device))
    return torch.sum(L, dim=[1,2]).unsqueeze(-1)

In [None]:
def logposterior(theta):
    return logprior(theta) + loglikelihood(theta)

In [None]:
q = Variational.MeanFieldVariationalDistribution(param_count, sigma=0.0001, device=device)
q.mu = nn.Parameter(prior.sample()[0,:].detach().clone().to(device), requires_grad=True) 

In [None]:
vo = Variational.VariationalOptimizer(learning_rate=0.01, patience=20, factor=0.7, device=device)

In [None]:
vo.run(q, logposterior, n_epoch=1000, n_ELBO_samples=100, plot=True)

In [None]:
x_test = torch.linspace(-2.0, 2.0).unsqueeze(1).to(device)

In [None]:
# Sampling the distribution over Neural Networks 1000 times, and plotting with transparency to make it appear as a smooth distribution
fig, ax = plt.subplots()
fig.set_size_inches(11.7, 8.27)
plt.scatter(x_data.cpu(), y_data.cpu())
for _ in range(1000):
    z = q.sample()
    Net.set_parameters(z)

    y_test = Net.forward(x_test)
    plt.plot(x_test.detach().cpu().numpy(), y_test.squeeze(0).detach().cpu().numpy(), alpha=0.05, linewidth=1, color='lightblue')