In [3]:
import torch
import torch.nn.functional as F

import torchsde
import math
import matplotlib.pyplot as plt

import numpy as np

from tqdm.notebook import tqdm
# from torch import datasets

from torch import _vmap_internals
from torchvision import datasets, transforms
from functorch import vmap
# import torch.nn.functional as F

import pandas as pd

In [4]:
from cfollmer.objectives import log_g, relative_entropy_control_cost, stl_relative_entropy_control_cost_xu
from cfollmer.sampler_utils import FollmerSDE
from cfollmer.drifts import *
from cfollmer.trainers import basic_batched_trainer

functorch succesfully imported


# The Model

\begin{align}
\theta &\sim \mathcal{N}(\theta | 0, \sigma_w^2 \mathbb{I}) \\
y_i | x_i, \theta &\sim  \mathrm{Cat}\left[\mathrm{NN}_{\theta}\left(x_i \right)\right]
\end{align}

We want samples from $p(\theta | \{(y_i, x_i)\})$. Note $f(x; \theta)$ is a neural net with params $\theta$

## Loading the iris dataset

In [5]:
images_train = datasets.MNIST("../data/mnist/", download=True, train=True)
images_test = datasets.MNIST("../data/mnist/", download=True, train=False)

transform = torch.nn.Sequential(transforms.Normalize((0.1307,), (0.3081)))

In [6]:
X_train, y_train = images_train.data, images_train.targets
X_test, y_test = images_test.data, images_test.targets

X_train = torch.flatten(transform(X_train.float()), 1)
X_test = torch.flatten(transform(X_test.float()), 1)

y_train = F.one_hot(y_train)
y_test = F.one_hot(y_test)

# X_train = np.concatenate((X_train, np.ones((X_train.shape[0],X_train.shape[1]))), axis=1)
# X_test = np.concatenate((X_test, np.ones((X_test.shape[0],X_train.shape[1]))), axis=1)

In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"

X_train, X_test, y_train, y_test = \
    torch.tensor(X_train, dtype=torch.float32, device=device), \
    torch.tensor(X_test, dtype=torch.float32, device=device), \
    torch.tensor(y_train, dtype=torch.float32, device=device), \
    torch.tensor(y_test, dtype=torch.float32, device=device) 

  torch.tensor(X_train, dtype=torch.float32, device=device), \
  torch.tensor(X_test, dtype=torch.float32, device=device), \
  torch.tensor(y_train, dtype=torch.float32, device=device), \
  torch.tensor(y_test, dtype=torch.float32, device=device)


In [8]:
X_train.shape

torch.Size([60000, 784])

$$\DeclareMathOperator*{\argmin}{arg\,min}$$
$$\def\E{{\mathbb{E}}}$$
$$\def\rvu{{\mathbf{u}}}$$
$$\def\rvTheta{{\bm{\Theta}}}$$
$$\def\gU{{\mathcal{U}}}$$
$$\def\mX{{\mathbf{X}}}$$

## Controlled Schrodinger Follmer Sampler

The objevtive we are trying to implement is:

\begin{align}
  \mathbf{u}_t^{*}=  \argmin_{\rvu_t \in \mathcal{U}}\mathbb{E}\left[\frac{1}{2\gamma}\int_0^1||\rvu(t, \Theta_t)||^2 dt - \ln\left(\frac{ p(\mX | \Theta_1)p(\Theta_1)}{\mathcal{N}(\Theta_1|\mathbf{0}, \gamma \mathbb{I} )}\right)\right] \
\end{align}

Where:
\begin{align}
d\Theta_t = \rvu(t, \Theta_t)dt + \sqrt{\gamma} dB_t
\end{align}

To do so we use the EM discretisation.

In [9]:
import torch.nn.functional as F


class ClassificationNetwork(object):
    
    def __init__(
        self, input_dim=1, output_dim=1, depth=None,
        width=20, width_seq=None, device="cpu", activation=F.relu
    ):
        
        self.device = device
        self.output_dim = output_dim
        self.input_dim = input_dim 
        self.activation = activation
        
        self.depth = depth
        if not self.depth:
            self.depth = 1
        if not width_seq:
            self.width = width
            self.width_seq = [self.width] * (self.depth + 1)
            self.shapes = [(self.width_seq[i-1], self.width_seq[i])  for i in range(1,self.depth)]
            self.shapes += [(self.width_seq[-1], self.output_dim)]
            self.shapes = [(self.input_dim, self.width_seq[0])] + self.shapes
        
        self.dim = sum([wx * wy + wy for wx, wy in self.shapes])
        
    def forward(self, x, Θ):
        index = 0
        n, d = x.shape

        for wx, wy in self.shapes[:-1]:
            x = F.linear(
                x,
                Θ[index: index + wx * wy].reshape(wy, wx),
                Θ[index + wx * wy: index + wx * wy + wy].reshape(1,wy)
            )
            x = self.activation(x)
            index += wx * wy  + wy
        wx, wy = self.shapes[-1]
        x = F.linear(
            x,
            Θ[index: index + wx * wy].reshape(wy, wx), #* σ_Θ + μ_Θ,
            Θ[index + wx * wy: index + wx * wy + wy].reshape(1,wy) # * σ_Θ + μ_Θ
        )
        return x.to(self.device)
    
    def map_forward(self, x, Θ):
        preds_func = lambda θ: self.forward(x, θ)
        batched_preds = vmap(preds_func)
        vmapped = batched_preds(preds_func, Θ)
        preds = torch.hstack(vmapped)
#         preds = torch.hstack(list(map(preds_func, Θ)))
        return preds

In [10]:
dim = X_train.shape[1]
out_dim = y_train.shape[1]

net = ClassificationNetwork(
    dim, out_dim, device=device, depth=1, width=50, activation=F.tanh
)


def gaussian_prior(Θ, σ_w=3.8):
    """
    Logistic regresion bayesian prior
    """
    return -0.5 * (Θ**2).sum(axis=1) / σ_w


def log_likelihood_vmap_nn(Θ, X, y, net=net):
    """
    Hoping this implementation is less buggy / faster
    
    still feels a bit slow.
    """
    
    def loss(θ):
        preds = net.forward(X, θ)
        cel = torch.nn.CrossEntropyLoss(reduction="sum")
#         import pdb; pdb.set_trace()
        ll_cel = -1.0 * cel(preds, y.argmax(dim=1))
        return ll_cel
    
    batched_loss =  vmap(loss)

    return batched_loss(Θ)

In [11]:
net.dim

39760

In [None]:
class SimpleForwardNetBN_larger(AbstractDrift):

    def __init__(self, input_dim=1, width=300, activation=torch.nn.Softplus):
        super(SimpleForwardNetBN_larger, self).__init__()
        
        self.nn = torch.nn.Sequential(
            torch.nn.Linear(input_dim + 1, width), torch.nn.BatchNorm1d(width, affine=False), activation(),
            torch.nn.Linear(width, width), torch.nn.BatchNorm1d(width, affine=False), activation(),
            torch.nn.Linear(width, width), torch.nn.BatchNorm1d(width, affine=False), activation(),
            torch.nn.Linear(width, width), torch.nn.BatchNorm1d(width, affine=False), activation(),
            torch.nn.Linear(width, input_dim )
        )
        
        self.nn[-1].weight.data.fill_(0.0)


γ =  0.1**2
Δt=0.01

dim= net.dim

prior = gaussian_prior

sde, losses = basic_batched_trainer(
    γ, Δt, prior, log_likelihood_vmap_nn, dim, X_train, y_train,
    method="euler", stl="stl_xu", adjoint=False, optimizer=None,
    num_steps=79, batch_size_data=int(X_train.shape[0] // 5), batch_size_Θ=30,
    batchnorm=True, device=device, lr=0.0001, drift=SimpleForwardNetBN_larger, schedule="uniform",
    γ_min= 0.1**2, γ_max= 0.4**2
)

  0%|          | 0/79 [00:00<?, ?it/s]

  return torch.batch_norm(
  return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)


2.411092519760132
2.225189447402954
1.966545820236206
1.6961148977279663
1.3891446590423584
1.2065753936767578
1.0401029586791992
0.8705145716667175
0.7646529674530029
0.6692946553230286
0.615804135799408
0.5522686839103699
0.4932783246040344
0.4630196988582611
0.4415580928325653
0.4189043641090393
0.3944348096847534
0.3844749629497528
0.38448360562324524
0.3780539929866791
0.3753856122493744
0.3657078444957733
0.3354945182800293
0.3230648934841156
0.33025792241096497
0.3336712718009949
0.3149867355823517
0.30683112144470215
0.3031422197818756
0.3259376585483551
0.30745261907577515
0.3041743338108063
0.30191516876220703
0.30404382944107056
0.28975051641464233
0.2922874689102173
0.29828792810440063
0.3024911880493164
0.3197266459465027
0.2959238886833191
0.28482943773269653
0.28801682591438293
0.3185204565525055
0.2715175151824951
0.2924477458000183
0.2989695370197296
0.281375527381897
0.27676498889923096
0.29948675632476807
0.27484557032585144
0.265476793050766
0.2891625165939331
0.296

In [None]:
losses

In [None]:
plt.plot(losses[:])

In [None]:
X_train.shape

In [None]:
t_size = int(math.ceil(1.0/Δt))
ts = torch.linspace(0, 1, t_size).to(device)
no_posterior_samples = 100
Θ_0 = torch.zeros((no_posterior_samples, net.dim)).to(device)

Θ_1 = torchsde.sdeint(sde, Θ_0, ts, dt=Δt)[-1,...]

In [None]:
fig, (ax1,ax2,ax3) = plt.subplots(1,3)

ax1.hist(Θ_1[:,0].cpu().detach().numpy())
ax2.hist(Θ_1[:,1].cpu().detach().numpy())
ax3.hist(Θ_1[:,2].cpu().detach().numpy())

In [None]:
def predc(X, Θ):
    return torch.vstack([(net.forward(X, θ)[None,...]).softmax(dim=-1) for θ in Θ]).mean(dim=0)

In [None]:
pred = predc(X_train, Θ_1)

In [None]:
pred.shape

In [None]:

((pred.argmax(dim=-1)).float().flatten()== y_train.argmax(dim=-1)).float().mean()

In [None]:
pred_test = predc(X_test.float(), Θ_1)

In [None]:
((pred_test.argmax(dim=-1)).float().flatten()== y_test.argmax(dim=-1)).float().mean()

## MAP Baseline

We run the point estimate approximation (Maximum a posteriori) to double check what the learned weights look like.  We get the  exact same training accuracy as with the controlled model and similarly large weights for the non bias weights. 

In [None]:
Θ_map = torch.zeros((1, dim), requires_grad=True, device=device)
optimizer_map = torch.optim.Adam([Θ_map], lr=0.05)
#     optimizer = torch.optim.LBFGS(gpr.parameters(), lr=0.01)

losses_map = []
num_steps = 1000
for i in tqdm(range(num_steps)):
    optimizer_map.zero_grad()

    if isinstance(optimizer_map, torch.optim.LBFGS):
        def closure_map():
            loss_map = log_likelihood_vmap()
            optimizer_map.zero_grad()
            loss_map.backward()
            return loss

        optimizer_map.step(closure_map)
        losses_map.append(closure_map().item())
    else:
        loss_map = -(log_likelihood_vmap(Θ_map, X_train, y_train) + gaussian_prior(Θ_map))
        optimizer_map.zero_grad()
        loss_map.backward()
        print(loss_map.item())
        optimizer_map.step()
        losses_map.append(loss_map.item())

Θ_map
pred_map = torch.sigmoid(X_train.mm(Θ_map.T)).mean(axis=1)
((pred_map < 0.5).float() == y_train).float().mean(), Θ_map