In [None]:
import math 

import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import pylab as plt
from sklearn import datasets

In [None]:
# X, y = datasets.make_moons(200, ) = 
X = torch.Tensor(40, 1).uniform_(-3, 3).float()
y = torch.sin(X).float()
y = y + torch.randn_like(y)*.1

X = (X - X.std(0)) - X.mean(0)
train_ds = torch.utils.data.TensorDataset(X, y)

plt.scatter(X.squeeze(), y)
plt.show()

In [None]:
from dal_toolbox.models.utils.variational_inference import BayesianLinear, BayesianConv2d

class Net(nn.Module):
    def __init__(self, prior_sigma=1) -> None:
        super().__init__()
        self.l1 = BayesianLinear(1, 50, prior_sigma=prior_sigma)
        self.l3 = BayesianLinear(50, 1, prior_sigma=prior_sigma)
        self.act = nn.Tanh()
    
    def forward(self, x):
        out = self.l1(x)
        out = self.act(out)
        # out = self.l2(out)
        # kout = self.act(out)
        out = self.l3(out)
        return out

In [None]:
from dal_toolbox.models.variational_inference.trainer import VITrainer

train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32, shuffle=True)

model = Net(prior_sigma=1)
optimizer = torch.optim.SGD(model.parameters(), lr=.1, momentum=0.9)

trainer = VITrainer(model, optimizer, criterion=nn.MSELoss(), kl_temperature=1)
train_stats = trainer.train(1000, train_loader)
train_history = train_stats['train_history']

fig = plt.figure(figsize=(20, 5))
axis = torch.linspace(X.min()-2, X.max()+2, 101)
with torch.no_grad():
    axis_pred = torch.stack([model(axis.view(-1, 1)).squeeze() for _ in range(100)], dim=-1)

plt.subplot(141)
plt.scatter(X, y)
plt.plot(axis, axis_pred, color='red', alpha=.2)
plt.subplot(142)
plt.plot([d['train_loss'] for d in train_history], label='Total', color='green')
plt.subplot(143)
plt.plot([d['train_nll'] for d in train_history], label='NLL', color='red')
plt.subplot(144)
plt.plot([d['train_kl_loss'] for d in train_history], label='KL', color='blue')
fig.legend(loc='upper center')
plt.show()


In [None]:
class BayesianCNN(nn.Module):
    def __init__(self, prior_sigma=1) -> None:
        super().__init__()
        self.conv1 = BayesianConv2d(1, 16, kernel_size=5, stride=2, prior_sigma=prior_sigma)
        self.conv2 = BayesianConv2d(16, 32, kernel_size=5, stride=2, prior_sigma=prior_sigma)
        self.l1 = BayesianLinear(512, 10, prior_sigma=prior_sigma)

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = torch.flatten(out, start_dim=1)
        out = self.l1(out)
        return out
        
class CNN(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.conv1 = nn.Conv2d(1, 16, kernel_size=5, stride=2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
        self.l1 = nn.Linear(512, 10)

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = torch.flatten(out, start_dim=1)
        out = self.l1(out)
        return out
        

In [None]:
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

train_ds = MNIST('/tmp', True, transform=ToTensor(), download=True)
test_ds = MNIST('/tmp', False, transform=ToTensor(), download=True)

train_loader = torch.utils.data.DataLoader(train_ds, batch_size=256, shuffle=True, drop_last=True)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=256)

device='cuda'

model = BayesianCNN(prior_sigma=1)
model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=.1, momentum=0.9)
trainer = VITrainer(model, optimizer, criterion=nn.CrossEntropyLoss(), kl_temperature=1, device='cuda')
train_stats = trainer.train(10, train_loader)
train_history = train_stats['train_history']

plt.figure(figsize=(15, 5))
plt.subplot(131)
plt.plot([d['train_loss'] for d in train_history], label='Total', color='green')
plt.subplot(132)
plt.plot([d['train_nll'] for d in train_history], label='NLL', color='red')
plt.subplot(133)
plt.plot([d['train_kl_loss'] for d in train_history], label='KL', color='blue')
fig.legend(loc='upper center')
plt.show()

In [None]:
trainer.evaluate(test_loader)