In [6]:
from typing import NamedTuple, Callable
from functools import partial

import torch
import torch.nn as nn
from sklearn.datasets import make_moons
import matplotlib.pyplot as plt

In [7]:
class MLP1D(nn.Module):

    def __init__(self, in_shapes=[2,10,10,10], out_shapes = [10,10,10,1]):
        super().__init__()
        self.linears = [nn.Linear(in_shapes[i], out_shapes[i]) for i in range(len(in_shapes))]
        self.relu = nn.ReLU()
        
    def model(self):
        return nn.Sequential(self.linears[0], self.relu, 
                             self.linears[1], self.relu, 
                             self.linears[2], self.relu, self.linears[3])

    def forward(self, params, X):
        self.weights_and_biases(params)
        return self.model()(X)
    
    def weights_and_biases(self, params):
        for i,param in enumerate(params):
            self.linears[i].weight = nn.Parameter(param['weight'])
            self.linears[i].bias = nn.Parameter(param['bias'])
    
    def fetch_params(self):
        params = []
        for linear in self.linears:
            params.append({
                'weight': linear.weight.detach(),
                'bias': linear.bias.detach()
            })
        return params

def ravel_params(params):
    def ravel_per_item():
        all_params = []
        def ravel_per_item_per_key(param):
            for value in param.values(): 
                all_params.append(value.ravel())

        if type(params) == list:
            for param in params:
                ravel_per_item_per_key(param)
        else:
            ravel_per_item_per_key(params)
        return all_params
    return torch.cat(ravel_per_item())

def unravel_params(params_raveled: torch.Tensor, structured_tuple):
    
    def unravel_all_shapes(structure, is_list):
        shapes = []
        def extract_shapes(structure):
            if is_list:
                for per_structure in structure:
                    for val in per_structure.values():
                        shapes.append(val.numel())
            else:
                for val in structure.values():
                    shapes.append(val.numel())
            values_per_structure = params_raveled.split(shapes)
            return values_per_structure
        
        values_per_structure = extract_shapes(structure)
        result = []
        index = [0]

        def result_per_dict(values_per_structure, per_structure):
            per_result = {}
            for key, val in per_structure.items():
                # print(key, val.shape, values_per_structure[index[0]].shape)
                per_result[key] = values_per_structure[index[0]].reshape(val.shape)
                index[0] += 1
            return per_result
        
        if is_list:
            for per_structure in structure:
                per_result = result_per_dict(values_per_structure, per_structure)
                result.append(per_result)
            return result
        else:
            return result_per_dict(values_per_structure, structure)
        
    return unravel_all_shapes(structured_tuple, type(structured_tuple) == list)
    
def bnn_log_joint(params, X, y, model:MLP1D):
    logits = model.forward(params, X).ravel()
    flatten_params = ravel_params(params)
    log_prior = torch.distributions.Normal(0.0, 1.0).log_prob(flatten_params).sum()
    log_likelihood = torch.distributions.Bernoulli(logits=logits).log_prob(y).sum()
    log_joint = log_prior + log_likelihood
    return log_joint

noise = 0.2
num_samples = 50
num_warmup = 1000
num_steps = 500
in_shapes = [2,10,10,10]
out_shapes = [10,10,10,1]

X, y = make_moons(n_samples=num_samples, noise=noise, random_state=314)
model = MLP1D(in_shapes=in_shapes, out_shapes=out_shapes)
params = model.fetch_params()
potential = partial(bnn_log_joint, X=torch.Tensor(X), y=torch.Tensor(y), model=model)


# Test

# potential(params)
# print(ravel_params(params[0]).shape)
# print(unravel_params(ravel_params(params[1]),params[1]))
# print(unravel_params(ravel_params(params),params))

In [8]:
class EnergyParameters:

    def __init__(self, potential_energy_fn:Callable, precision):
        self.potential_energy_fn = potential_energy_fn 
        self.precision = precision
    
    def set_position(self, position):
        self.position = position
        self.potential_energy = self.potential_energy_fn(position)
        self.potential_energy_grad = torch.func.grad(self.potential_energy_fn)(position)
    
    def set_velocity(self, velocity):
        self.velocity = velocity
        self.kinetic_energy = self.kinetic_energy_fn(self.velocity)
        self.kinetic_energy_grad = torch.func.grad(self.kinetic_energy_fn)(self.velocity)

    def init_per_step(self, position, velocity):
        self.set_position(position)
        self.set_velocity(velocity)
        self.total_init_energy = self.total_current_energy()
    
    def kinetic_energy_fn(self, velocity:torch.Tensor):
        return 0.5*torch.matmul(velocity, torch.matmul(self.precision, velocity.T))
    
    def update_position(self, step_size_with_direction):
        position_raveled = ravel_params(self.position)
        result = position_raveled + step_size_with_direction*self.kinetic_energy_grad
        n_position = unravel_params(result, self.position)
        self.set_position(n_position)
        return n_position

    def updated_velocity(self, step_size_with_direction, is_half_step_momentum = False):
        n_velocity = torch.subtract(self.velocity, step_size_with_direction*(0.5 if is_half_step_momentum else 1)*ravel_params(self.potential_energy_grad))
        self.set_velocity(n_velocity)
        return n_velocity
    
    def delta_energy(self):
        return self.total_current_energy() - self.total_init_energy 
    
    def total_current_energy(self):
        return - self.potential_energy - self.kinetic_energy

# Test

# e_params = EnergyParameters(potential_energy_fn=potential)
# e_params.set_energy_parameters(params)


# Naive No-U-Turn Sampler

In [9]:
class HMCAlgorithm:
    
    def __init__(self, log_density_fn: Callable, precision, l=1, step_size=0.1):
        self.log_density_fn = log_density_fn
        self.l = l
        self.step_size = step_size
        self.ep = EnergyParameters(potential_energy_fn=log_density_fn, precision=precision)
    

    def generate_default_velocity(self, position_shape):

        def generate_random_gaussian(covariance: torch.Tensor):
            return torch.distributions.MultivariateNormal(loc=torch.zeros(covariance.shape[-1]),covariance_matrix=covariance).sample()

        def generate_inverse_mass_matrix_and_sigma(covariance_shape: int):
            covariance = torch.eye(covariance_shape)
            inverse_mass_matrix = covariance.inverse()
            return inverse_mass_matrix, covariance
        
        _, covariance = generate_inverse_mass_matrix_and_sigma(position_shape)
        return generate_random_gaussian(covariance)

    def current_direction(self):
        direction = torch.distributions.Bernoulli(logits=torch.tensor([0.5])).sample()
        return -1 if direction.bool() else 1

    def stop_condition(self, p_minus, v_minus, p_plus, v_plus):
        diff = ravel_params(p_plus)-ravel_params(p_minus)
        condition = torch.all(diff*v_minus >= torch.zeros(diff.shape)) and torch.all(diff*v_plus >= torch.zeros(diff.shape))
        return condition.int()

    def build_tree(self, position, velocity, energy, direction, iter_no, step_size, m):
        delta_max = torch.tensor(1000)
        parameters = []
        if iter_no == 0:
            _, new_position, new_velocity, new_energy = self.step_integrator(position, velocity, direction*step_size)
            print(new_energy, energy, m)
            if energy <= torch.exp(new_energy):
                parameters.append((new_position, new_velocity))
            s = 1 if new_energy > torch.log(energy) - delta_max else 0
            return new_position, new_velocity, new_position, new_velocity, parameters, s
        else:
            p_minus, v_minus, p_plus, v_plus, n_params, n_s = self.build_tree(position, velocity, energy, direction, iter_no-1, step_size, m)
            if direction == -1:
                p_minus, v_minus, _, _, n_params, n_s = self.build_tree(p_minus, v_minus, energy, direction, iter_no-1, step_size, m)
            else:
                _, _, p_plus, v_plus, n_params, n_s = self.build_tree(p_plus, v_plus, energy, direction, iter_no-1, step_size, m)
            s = n_s * self.stop_condition(p_minus, v_minus, p_plus, v_plus)
            parameters += n_params
            return p_minus, v_minus, p_plus, v_plus, parameters, s

    def update(self, position, m):
        # self.step_size = self.find_reasonable_epsilon(position)
        velocity = self.generate_default_velocity(ravel_params(position).shape[-1])
        self.ep.init_per_step(position, velocity)
        energy = torch.distributions.Uniform(torch.tensor([0.0]),self.ep.total_current_energy()).sample()
        p_minus, p_plus, v_minus, v_plus = position, position, velocity, velocity
        counter, s = 0, 1
        parameters = []

        parameters.append((position, velocity))
        while s == 1:
            direction = self.current_direction()
            if direction == -1:
                p_minus, v_minus, _,_, n_parameters, new_s = self.build_tree(p_minus, v_minus, energy, direction, counter, self.step_size, m) 
            else:
                _, _, p_plus, v_plus, n_parameters, new_s = self.build_tree(p_plus, v_plus, energy, direction, counter, self.step_size, m) 
            if new_s == 1:
                parameters += n_parameters
            s = new_s * self.stop_condition(p_minus, v_minus, p_plus, v_plus)
            counter += 1
        selected_index = torch.distributions.Categorical(logits=torch.tensor([1.0])/len(parameters)).sample()
        return parameters[selected_index][0]

    def step_integrator(self, position, velocity, step_size_with_direction):
        self.ep.init_per_step(position, velocity)
        # init step
        self.ep.updated_velocity(step_size_with_direction, is_half_step_momentum=True)
        # l_frog_steps
        for i in range(self.l):
            self.ep.update_position(step_size_with_direction)
            self.ep.updated_velocity(step_size_with_direction, is_half_step_momentum=False)
        # final step
        f_step_position = self.ep.update_position(step_size_with_direction)
        half_step_velocity = self.ep.updated_velocity(step_size_with_direction, is_half_step_momentum=True)
        return self.ep.delta_energy(), f_step_position, half_step_velocity, self.ep.total_current_energy()

precision_shape = ravel_params(params).shape[-1]
# I^-1 = I so inverse doesn't matter but writing it for visibility
default_precision = torch.eye(precision_shape).inverse()
hmc_kernel = HMCAlgorithm(log_density_fn=potential, precision=default_precision)

# Test
# hmc_kernel.update(params)

In [10]:
def scan(func, init_values, length):
    carry = init_values
    result = []
    for i in range(length):
        carry = func(carry,i)
        result.append(carry)
    return carry, result 

final_state, result = scan(hmc_kernel.update, init_values=params, length=num_steps)


tensor(165.8301, grad_fn=<SubBackward0>) tensor([135.4637]) 0
tensor(152.3228, grad_fn=<SubBackward0>) tensor([146.5278]) 1
tensor(145.9471, grad_fn=<SubBackward0>) tensor([33.2595]) 2
tensor(143.4247, grad_fn=<SubBackward0>) tensor([86.9108]) 3
tensor(162.0123, grad_fn=<SubBackward0>) tensor([69.2887]) 4
tensor(167.0401, grad_fn=<SubBackward0>) tensor([60.5012]) 5
tensor(167.4197, grad_fn=<SubBackward0>) tensor([134.6359]) 6
tensor(139.0300, grad_fn=<SubBackward0>) tensor([60.0303]) 7
tensor(140.0326, grad_fn=<SubBackward0>) tensor([97.1153]) 8
tensor(140.8889, grad_fn=<SubBackward0>) tensor([130.6704]) 9
tensor(146.6672, grad_fn=<SubBackward0>) tensor([106.7928]) 10
tensor(167.2018, grad_fn=<SubBackward0>) tensor([97.0777]) 11
tensor(151.4746, grad_fn=<SubBackward0>) tensor([16.1752]) 12
tensor(131.6249, grad_fn=<SubBackward0>) tensor([45.6553]) 13
tensor(130.2507, grad_fn=<SubBackward0>) tensor([98.6568]) 14
tensor(155.9018, grad_fn=<SubBackward0>) tensor([46.0738]) 15
tensor(164.96

In [None]:
result

In [None]:
step = 0.2
vmin, vmax = X.min() - step, X.max() + step
X_grid = torch.mgrid[vmin:vmax:100j, vmin:vmax:100j]