# Push Bayesian Deep Learning Tutorial
## Introduction

In this notebook we will introduce the concept of Bayesian Deep Learning and demonstrate its usage in Push by running a deep ensemble.

## The Posterior Predictive Distribution
The goal of Bayesian Deep Learning (BDL) methods is to estimate the posterior predictive distribution

$$p(y|x, D) = \int p(y|x, w) p(w|D) \, dw
$$

where y is an output, x is an input, w are parameters, and D is is the data. This integral is intractable and must be approximated. The typical method of approximation is Monte Carlo:

In [1]:
from typing import *
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import push.bayes.ensemble

# =============================================================================
# Simple Dataset + Neural Network
# =============================================================================

class RandDataset(Dataset):
    def __init__(self, D):
        self.xs = torch.randn(128*10, D)
        self.ys = torch.randn(128*10, 1)

    def __len__(self):
        return len(self.xs)

    def __getitem__(self, idx):
        return self.xs[idx], self.ys[idx]


class MiniNN(nn.Module):
    def __init__(self, D):
        super(MiniNN, self).__init__()
        self.fc1 = nn.Linear(D, D)
        self.fc2 = nn.Linear(D, D)

    def forward(self, x):
        x = self.fc1(x)
        x = torch.nn.ReLU()(x)
        x = self.fc2(x)
        return x
    

class BiggerNN(nn.Module):
    def __init__(self, n, D):
        super(BiggerNN, self).__init__()
        self.minis = []
        self.n = n
        for i in range(0, n):
            self.minis += [MiniNN(D)]
            self.add_module("mini_layer"+str(i), self.minis[-1])
        self.fc = nn.Linear(D, 1)
            
    def forward(self, x):
        for i in range(0, self.n):
            x = self.minis[i](x)
        return self.fc(x)






In [2]:


# L = 10
# D = 20
# dataset = RandDataset(D)
# dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

# epochs = 10
# num_ensembles = 3
# push.bayes.ensemble.train_deep_ensemble(
#     dataloader,
#     torch.nn.MSELoss(),
#     epochs,
#     BiggerNN, L, D,
#     num_ensembles=num_ensembles
# )

References:
https://cims.nyu.edu/~andrewgw/deepensembles/