In [1]:
import torch
import torch.nn.functional as F

import torchsde
import math
import matplotlib.pyplot as plt

import numpy as np

from tqdm.notebook import tqdm
# from torch import datasets

from torch import _vmap_internals
from torchvision import datasets, transforms
from functorch import vmap
from torchvision.transforms import ToTensor

# import torch.nn.functional as F

import pandas as pd

In [2]:
from cfollmer.objectives import log_g, relative_entropy_control_cost, stl_relative_entropy_control_cost_xu
from cfollmer.sampler_utils import FollmerSDE
from cfollmer.drifts import *
from cfollmer.trainers import basic_batched_trainer

functorch succesfully imported


# The Model

\begin{align}
\theta &\sim \mathcal{N}(\theta | 0, \sigma_w^2 \mathbb{I}) \\
y_i | x_i, \theta &\sim  \mathrm{Cat}\left[\mathrm{NN}_{\theta}\left(x_i \right)\right]
\end{align}

We want samples from $p(\theta | \{(y_i, x_i)\})$. Note $f(x; \theta)$ is a neural net with params $\theta$

## Loading the iris dataset

In [3]:
images_train = datasets.MNIST("../data/mnist/", download=True, transform=ToTensor(), train=True)
images_test = datasets.MNIST("../data/mnist/", download=True, transform=ToTensor(), train=False)

transform = torch.nn.Sequential(transforms.Normalize((0.1307,), (0.3081)))

In [4]:
images_train.data.shape

torch.Size([60000, 28, 28])

In [5]:
X_train, y_train = images_train.data, images_train.targets
X_test, y_test = images_test.data, images_test.targets

X_train = transform(X_train.float())
X_test = transform(X_test.float())

y_train = F.one_hot(y_train)
y_test = F.one_hot(y_test)

# X_train = np.concatenate((X_train, np.ones((X_train.shape[0],X_train.shape[1]))), axis=1)
# X_test = np.concatenate((X_test, np.ones((X_test.shape[0],X_train.shape[1]))), axis=1)

In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"

X_train, X_test, y_train, y_test = \
    torch.tensor(X_train, dtype=torch.float32, device=device), \
    torch.tensor(X_test, dtype=torch.float32, device=device), \
    torch.tensor(y_train, dtype=torch.float32, device=device), \
    torch.tensor(y_test, dtype=torch.float32, device=device) 

  torch.tensor(X_train, dtype=torch.float32, device=device), \
  torch.tensor(X_test, dtype=torch.float32, device=device), \
  torch.tensor(y_train, dtype=torch.float32, device=device), \
  torch.tensor(y_test, dtype=torch.float32, device=device)


In [7]:
X_train.shape

torch.Size([60000, 28, 28])

$$\DeclareMathOperator*{\argmin}{arg\,min}$$
$$\def\E{{\mathbb{E}}}$$
$$\def\rvu{{\mathbf{u}}}$$
$$\def\rvTheta{{\bm{\Theta}}}$$
$$\def\gU{{\mathcal{U}}}$$
$$\def\mX{{\mathbf{X}}}$$

## Controlled Schrodinger Follmer Sampler

The objevtive we are trying to implement is:

\begin{align}
  \mathbf{u}_t^{*}=  \argmin_{\rvu_t \in \mathcal{U}}\mathbb{E}\left[\frac{1}{2\gamma}\int_0^1||\rvu(t, \Theta_t)||^2 dt - \ln\left(\frac{ p(\mX | \Theta_1)p(\Theta_1)}{\mathcal{N}(\Theta_1|\mathbf{0}, \gamma \mathbb{I} )}\right)\right] \
\end{align}

Where:
\begin{align}
d\Theta_t = \rvu(t, \Theta_t)dt + \sqrt{\gamma} dB_t
\end{align}

To do so we use the EM discretisation.

In [9]:
class LeNet5(torch.nn.Module):

    def __init__(self, n_classes):
        super(LeNet5, self).__init__()
        
        self.feature_extractor = torch.nn.Sequential(            
            torch.nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1),
            torch.nn.Tanh(),
            torch.nn.AvgPool2d(kernel_size=2),
            torch.nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1),
            torch.nn.Tanh(),
            torch.nn.AvgPool2d(kernel_size=2),
        )

        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(in_features=256, out_features=120),
            torch.nn.Tanh(),
            torch.nn.Linear(in_features=120, out_features=84),
            torch.nn.Tanh(),
            torch.nn.Linear(in_features=84, out_features=n_classes),
        )


    def forward(self, x):
        x = self.feature_extractor(x)
        x = torch.flatten(x, 1)
        logits = self.classifier(x)
        return logits

from functorch import make_functional






class LeNet5Fun(object):
    
    def __init__(
        self, input_dim=1, output_dim=1, depth=None,
        width=20, width_seq=None, device="cpu", activation=F.relu
    ):
        
        self.device = device
        self.output_dim = output_dim
        self.input_dim = input_dim 
        self.activation = activation
        
        self.model = LeNet5(n_classes=10)
        self.func_model, self.params = make_functional(self.model)
        
        
        self.dim = sum([math.prod(x.shape) for x in self.params])
        
        self.size_tuples = [p.shape for p in self.params]

    def get_params_from_array(self, array):
        cur_index = 0
        param_list = []
        for s in self.size_tuples:
            step_number = math.prod(s)
            param_list.append(array[cur_index:cur_index+step_number].reshape(s))
            cur_index += step_number
        return param_list
    
    def forward(self, x, Θ):
        Θ = self.get_params_from_array(Θ)
        return self.func_model(Θ, x)



In [10]:
dim = X_train.shape[1]
out_dim = y_train.shape[1]

# net = ClassificationNetwork(
#     dim, out_dim, device=device, depth=1, width=50, activation=F.tanh
# )
net = LeNet5Fun()


def gaussian_prior(Θ, σ_w=3.8):
    """
    Logistic regresion bayesian prior
    """
    return -0.5 * (Θ**2).sum(axis=1) / σ_w


def log_likelihood_vmap_nn(Θ, X, y, net=net):
    """
    Hoping this implementation is less buggy / faster
    
    still feels a bit slow.
    """
    
    def loss(θ):
        preds = net.forward(X, θ)
        cel = torch.nn.CrossEntropyLoss(reduction="sum")
#         import pdb; pdb.set_trace()
        ll_cel = -1.0 * cel(preds, y.argmax(dim=1))
        return ll_cel
    
    batched_loss =  vmap(loss)

    return batched_loss(Θ)

In [11]:
net.dim

44426

In [None]:
class SimpleForwardNetBN_larger(AbstractDrift):

    def __init__(self, input_dim=1, width=300, activation=torch.nn.Softplus):
        super(SimpleForwardNetBN_larger, self).__init__()
        
        self.nn = torch.nn.Sequential(
            torch.nn.Linear(input_dim + 1, width), torch.nn.BatchNorm1d(width, affine=False), activation(),
            torch.nn.Linear(width, width), torch.nn.BatchNorm1d(width, affine=False), activation(),
            torch.nn.Linear(width, width), torch.nn.BatchNorm1d(width, affine=False), activation(),
            torch.nn.Linear(width, width), torch.nn.BatchNorm1d(width, affine=False), activation(),
            torch.nn.Linear(width, input_dim )
        )
        
        self.nn[-1].weight.data.fill_(0.0)


γ =  0.1**2
Δt=0.01

dim= net.dim

prior = gaussian_prior

sde, losses = basic_batched_trainer(
    γ, Δt, prior, log_likelihood_vmap_nn, dim, X_train.reshape(-1,1,28,28), y_train,
    method="euler", stl="stl_xu", adjoint=False, optimizer=None,
    num_steps=79, batch_size_data=int(X_train.shape[0] // 20), batch_size_Θ=30,
    batchnorm=True, device=device, lr=0.0001, drift=SimpleForwardNetBN_larger, schedule="uniform",
    γ_min= 0.1**2, γ_max= 0.4**2
)

  0%|          | 0/79 [00:00<?, ?it/s]

  return torch.batch_norm(
  return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)


1.6934889554977417
1.682041049003601
1.7083704471588135
1.6547383069992065
1.4843652248382568
1.2474050521850586
1.0920374393463135
0.9228026270866394
0.8101496696472168
0.575107216835022
0.4128902554512024
0.37032076716423035
0.21082764863967896
0.11910095810890198
0.09894826263189316
0.05193038284778595
-0.01872221939265728
-0.03882938623428345
-0.06529614329338074
-0.10613273084163666
-0.0931284949183464
-0.2079002857208252
-0.2010226994752884
-0.23627139627933502
-0.25669246912002563
-0.25908640027046204
-0.2945830523967743
-0.27499616146087646
-0.3318319618701935
-0.2823198437690735
-0.35332196950912476
-0.3623194992542267
-0.31903567910194397
-0.3487713932991028
-0.36587467789649963
-0.37138012051582336
-0.40518203377723694
-0.40245339274406433
-0.3899686336517334
-0.3916984498500824
-0.38440078496932983
-0.41223815083503723
-0.3923581838607788
-0.37056660652160645
-0.41244783997535706
-0.4361793100833893
-0.4233507812023163
-0.41886937618255615
-0.4071148633956909
-0.41920357942

-0.5155816674232483
-0.5135477781295776
-0.49482157826423645
-0.5380510091781616
-0.5188863277435303
-0.5013899803161621
-0.5141793489456177
-0.5280399322509766
-0.5257059931755066
-0.5129442811012268
-0.5114151835441589
-0.5313513278961182
-0.5478516221046448
-0.5320116877555847
-0.5235782861709595
-0.5168309211730957
-0.5253767371177673
-0.5193021893501282
-0.5083229541778564
-0.5198453068733215
-0.5490954518318176
-0.5139727592468262
-0.5380529761314392
-0.5293859243392944
-0.49353182315826416
-0.5372241735458374
-0.5226840972900391
-0.515090823173523
-0.5264928936958313
-0.5203284621238708
-0.5428516268730164
-0.5296195149421692
-0.5482606291770935
-0.5098691582679749
-0.5239064693450928
-0.5343331098556519
-0.5079092383384705
-0.533955454826355
-0.5291103720664978
-0.5417516827583313
-0.5236768126487732
-0.5199479460716248
-0.5445462465286255
-0.5259272456169128
-0.5272689461708069
-0.5370115041732788
-0.5350571870803833
-0.532121479511261
-0.524308443069458
-0.5151787996292114
-0

-0.5101883411407471


In [None]:
losses

In [None]:
plt.plot(losses[:])

In [None]:
X_train.shape

In [None]:
t_size = int(math.ceil(1.0/Δt))
ts = torch.linspace(0, 1, t_size).to(device)
no_posterior_samples = 50
Θ_0 = torch.zeros((no_posterior_samples, net.dim)).to(device)

Θ_1 = torchsde.sdeint(sde, Θ_0, ts, dt=Δt)[-1,...]

In [None]:
fig, (ax1,ax2,ax3) = plt.subplots(1,3)

ax1.hist(Θ_1[:,0].cpu().detach().numpy())
ax2.hist(Θ_1[:,1].cpu().detach().numpy())
ax3.hist(Θ_1[:,2].cpu().detach().numpy())

In [None]:
def predc(X, Θ):
    return torch.vstack([(net.forward(X, θ)[None,...]).softmax(dim=-1) for θ in Θ]).mean(dim=0)

In [None]:
import gc

from tqdm.notebook import tqdm

pred = []

gc.collect()
torch.cuda.empty_cache()

subsamp = 30

stride = 10

for i in tqdm(range(0,len(X_train), stride)):
    
    pred.append(predc(X_train[i:i+stride,...].reshape(-1,1,28,28), Θ_1[:subsamp,:]).cpu())
    gc.collect()
    torch.cuda.empty_cache()


In [None]:
pred = torch.vstack(pred)

In [None]:
# pred = torch.vstack(pred2)

In [None]:
subsamp = 30

In [None]:

((pred.argmax(dim=-1)).float().flatten().cpu() == y_train[:len(pred)].argmax(dim=-1).cpu() ).float().mean()

In [None]:
pred.shape

In [None]:
gc.collect()
torch.cuda.empty_cache()


pred_test = []
for i in tqdm(range(0,len(X_test), stride)):
    
    pred_test.append(predc(X_test[i:i+stride,...].float().reshape(-1,1,28,28), Θ_1[:subsamp,:]).cpu())

    gc.collect()
    torch.cuda.empty_cache()

pred_test = pred_test(torch.vstack(pred_test))

In [None]:
((pred_test.argmax(dim=-1)).float().flatten().cpu() == y_test.argmax(dim=-1).cpu()).float().mean()

## MAP Baseline

We run the point estimate approximation (Maximum a posteriori) to double check what the learned weights look like.  We get the  exact same training accuracy as with the controlled model and similarly large weights for the non bias weights. 

In [None]:
Θ_map = torch.zeros((1, dim), requires_grad=True, device=device)
optimizer_map = torch.optim.Adam([Θ_map], lr=0.05)
#     optimizer = torch.optim.LBFGS(gpr.parameters(), lr=0.01)

losses_map = []
num_steps = 1000
for i in tqdm(range(num_steps)):
    optimizer_map.zero_grad()

    if isinstance(optimizer_map, torch.optim.LBFGS):
        def closure_map():
            loss_map = log_likelihood_vmap()
            optimizer_map.zero_grad()
            loss_map.backward()
            return loss

        optimizer_map.step(closure_map)
        losses_map.append(closure_map().item())
    else:
        loss_map = -(log_likelihood_vmap(Θ_map, X_train, y_train) + gaussian_prior(Θ_map))
        optimizer_map.zero_grad()
        loss_map.backward()
        print(loss_map.item())
        optimizer_map.step()
        losses_map.append(loss_map.item())

Θ_map
pred_map = torch.sigmoid(X_train.mm(Θ_map.T)).mean(axis=1)
((pred_map < 0.5).float() == y_train).float().mean(), Θ_map