In [1]:
from typing import NamedTuple, Callable
from functools import partial

import torch
import torch.nn as nn
from sklearn.datasets import make_moons
import matplotlib.pyplot as plt

In [3]:
class MLP1D(nn.Module):

    def __init__(self, in_shapes=[2,10,10,10], out_shapes = [10,10,10,1]):
        super().__init__()
        self.linears = [nn.Linear(in_shapes[i], out_shapes[i]) for i in range(len(in_shapes))]
        self.relu = nn.ReLU()
        
    def model(self):
        return nn.Sequential(self.linears[0], self.relu, 
                             self.linears[1], self.relu, 
                             self.linears[2], self.relu, self.linears[3])

    def forward(self, params, X):
        self.weights_and_biases(params)
        return self.model()(X)
    
    def weights_and_biases(self, params):
        for i,param in enumerate(params):
            self.linears[i].weight = nn.Parameter(param['weight'])
            self.linears[i].bias = nn.Parameter(param['bias'])
    
    def fetch_params(self):
        params = []
        for linear in self.linears:
            params.append({
                'weight': linear.weight.detach(),
                'bias': linear.bias.detach()
            })
        return params

def ravel_params(params):
    def ravel_per_item():
        all_params = []
        def ravel_per_item_per_key(param):
            for value in param.values(): 
                all_params.append(value.ravel())

        if type(params) == list:
            for param in params:
                ravel_per_item_per_key(param)
        else:
            ravel_per_item_per_key(params)
        return all_params
    return torch.cat(ravel_per_item())

def unravel_params(params_raveled: torch.Tensor, structured_tuple):
    
    def unravel_all_shapes(structure, is_list):
        shapes = []
        def extract_shapes(structure):
            if is_list:
                for per_structure in structure:
                    for val in per_structure.values():
                        shapes.append(val.numel())
            else:
                for val in structure.values():
                    shapes.append(val.numel())
            values_per_structure = params_raveled.split(shapes)
            return values_per_structure
        
        values_per_structure = extract_shapes(structure)
        result = []
        index = [0]

        def result_per_dict(values_per_structure, per_structure):
            per_result = {}
            for key, val in per_structure.items():
                # print(key, val.shape, values_per_structure[index[0]].shape)
                per_result[key] = values_per_structure[index[0]].reshape(val.shape)
                index[0] += 1
            return per_result
        
        if is_list:
            for per_structure in structure:
                per_result = result_per_dict(values_per_structure, per_structure)
                result.append(per_result)
            return result
        else:
            return result_per_dict(values_per_structure, structure)
        
    return unravel_all_shapes(structured_tuple, type(structured_tuple) == list)
    
def bnn_log_joint(params, X, y, model:MLP1D):
    logits = model.forward(params, X).ravel()
    flatten_params = ravel_params(params)
    log_prior = torch.distributions.Normal(0.0, 1.0).log_prob(flatten_params).sum()
    log_likelihood = torch.distributions.Bernoulli(logits=logits).log_prob(y).sum()
    log_joint = log_prior + log_likelihood
    return log_joint

noise = 0.2
num_samples = 50
num_warmup = 1000
num_steps = 500
in_shapes = [2,10,10,10]
out_shapes = [10,10,10,1]

X, y = make_moons(n_samples=num_samples, noise=noise, random_state=314)
model = MLP1D(in_shapes=in_shapes, out_shapes=out_shapes)
params = model.fetch_params()
potential = partial(bnn_log_joint, X=torch.Tensor(X), y=torch.Tensor(y), model=model)


# Test

potential(params)
# print(ravel_params(params[0]).shape)
# print(unravel_params(ravel_params(params[1]),params[1]))
# print(unravel_params(ravel_params(params),params))

tensor(-281.3270, grad_fn=<AddBackward0>)

In [10]:
class EnergyParameters:

    def __init__(self, potential_energy_fn:Callable, precision):
        self.potential_energy_fn = potential_energy_fn 
        self.precision = precision
    
    def set_position(self, position):
        self.position = position
        self.potential_energy = self.potential_energy_fn(position)
        self.potential_energy_grad = torch.func.grad(self.potential_energy_fn)(position)
    
    def set_velocity(self, velocity):
        self.velocity = velocity
        self.kinetic_energy = self.kinetic_energy_fn(self.velocity)
        self.kinetic_energy_grad = torch.func.grad(self.kinetic_energy_fn)(self.velocity)

    def init_per_step(self, position, velocity):
        self.set_position(position)
        self.set_velocity(velocity)
        self.total_init_energy = self.total_current_energy()
    
    def kinetic_energy_fn(self, velocity:torch.Tensor):
        return 0.5*torch.matmul(velocity, torch.matmul(self.precision, velocity.T))
    
    def update_position(self, step_size_with_direction):
        position_raveled = ravel_params(self.position)
        result = position_raveled + step_size_with_direction*self.kinetic_energy_grad
        n_position = unravel_params(result, self.position)
        self.set_position(n_position)
        return n_position

    def updated_velocity(self, step_size_with_direction, is_half_step_momentum = False):
        n_velocity = torch.subtract(self.velocity, step_size_with_direction*(0.5 if is_half_step_momentum else 1)*ravel_params(self.potential_energy_grad))
        self.set_velocity(n_velocity)
        return n_velocity
    
    def delta_energy(self):
        return - self.total_current_energy() + self.total_init_energy 
    
    def total_current_energy(self):
        return - self.potential_energy - self.kinetic_energy

# Test

# e_params = EnergyParameters(potential_energy_fn=potential)
# e_params.set_energy_parameters(params)


# HMC with Dual Averaging

In [12]:
class HMCAlgorithm:
    
    def __init__(self, log_density_fn: Callable, precision, expected_prob_density=0.65, Lambda = 0.12, m_warmup=1000, num_steps=500):
        self.log_density_fn = log_density_fn
        # Lambda
        self.Lambda = Lambda
        # delta
        self.expected_prob_density = expected_prob_density
        # M_adapt
        self.m_warmup = m_warmup
        # M
        self.M = m_warmup + num_steps
        self.ep = EnergyParameters(potential_energy_fn=log_density_fn, precision=precision)

    
    def generate_default_velocity(self, position_shape):

        def generate_random_gaussian(covariance: torch.Tensor):
            return torch.distributions.MultivariateNormal(loc=torch.zeros(covariance.shape[-1]),covariance_matrix=covariance).sample()

        def generate_inverse_mass_matrix_and_sigma(covariance_shape: int):
            covariance = torch.eye(covariance_shape)
            inverse_mass_matrix = covariance.inverse()
            return inverse_mass_matrix, covariance
        
        _, covariance = generate_inverse_mass_matrix_and_sigma(position_shape)
        return generate_random_gaussian(covariance)

    def find_reasonable_epsilon(self, position):
        step_size, velocity = 0.01, self.generate_default_velocity(ravel_params(position).shape[-1])
        delta_energy, proposal_position, proposal_velocity, _ = self.step_integrator(position, velocity, step_size)
        alpha = 2*(1 if torch.exp(delta_energy) > 0.5 else 0) - 1
        while torch.pow(torch.exp(delta_energy), alpha) > torch.pow(torch.tensor(0.5), alpha):
            step_size = torch.pow(torch.tensor(2), alpha) * step_size
            delta_energy, proposal_position, proposal_velocity, _ = self.step_integrator(proposal_position, proposal_velocity, step_size)
        print(f'Step size (find_reasonable_epsilon) : {step_size}')
        return step_size

    def update(self, position):
        self.init_step(position)
        log_step_size_average = torch.log(self.step_size_average)
        # starting from 1
        for m in range(1,self.M+1):
            velocity = self.generate_default_velocity(ravel_params(position).shape[-1])
            proposal_position, n_position, n_velocity = position, position, velocity
            current_l_frog_steps = torch.max(torch.tensor(1),torch.round(self.Lambda/self.step_size)).int()
            delta_energy, proposal_position, proposal_velocity, _ = self.step_integrator(position, velocity, self.step_size, current_l_frog_steps)
            # nan fix
            delta_energy = torch.where(torch.isnan(delta_energy), -torch.inf, delta_energy)
            # MH Algo
            alpha = torch.min(torch.exp(delta_energy), torch.ones(1))
            accept_condition = torch.distributions.Bernoulli(alpha).sample()
            proposal_position, proposal_velocity = (proposal_position, proposal_velocity) if accept_condition.bool() else (position, velocity)
            # tuning step size
            if m <= self.m_warmup:
                self.H_m = (1-1/(m+self.t_0))*self.H_m + (1/(m+self.t_0))*(self.expected_prob_density - alpha)
                log_step_size = self.mu - (torch.sqrt(torch.tensor(m))/self.gamma)*self.H_m
                step = self.step_fn(m)
                log_step_size_average = step * log_step_size + (1-step)*log_step_size_average
            self.step_size = torch.exp(log_step_size_average)
            self.step_size_average = torch.exp(log_step_size_average)
            print('Step size : ', self.step_size, delta_energy, current_l_frog_steps)

    def step_fn(self, current_step):
        return torch.pow(torch.tensor(1/current_step), self.k)

    def init_step(self, position):
        self.step_size = self.find_reasonable_epsilon(position)
        self.mu = torch.log(10*self.step_size)  
        self.step_size_average = torch.tensor(1)
        self.H_m = 0
        self.gamma = 0.05
        self.t_0 = 10
        self.k = 0.75
    
    def step_integrator(self, position, velocity, step_size_with_direction, l_frog_step=1):
        self.ep.init_per_step(position, velocity)
        # init step
        self.ep.updated_velocity(step_size_with_direction, is_half_step_momentum=True)
        # l_frog_steps
        for i in range(l_frog_step):
            self.ep.update_position(step_size_with_direction)
            self.ep.updated_velocity(step_size_with_direction, is_half_step_momentum=False)
        # final step
        f_step_position = self.ep.update_position(step_size_with_direction)
        half_step_velocity = self.ep.updated_velocity(step_size_with_direction, is_half_step_momentum=True)
        return self.ep.delta_energy(), f_step_position, half_step_velocity, self.ep.total_current_energy()

precision_shape = ravel_params(params).shape[-1]
# I^-1 = I so inverse doesn't matter but writing it for visibility
default_precision = torch.eye(precision_shape).inverse()
hmc_kernel = HMCAlgorithm(log_density_fn=potential, precision=default_precision)

# Test
hmc_kernel.update(params)

Step size (find_reasonable_epsilon) : 0.07999999821186066
Step size :  tensor([0.4697], grad_fn=<ExpBackward0>) tensor(-1.0297, grad_fn=<WhereBackward0>) tensor(2, dtype=torch.int32)
Step size :  tensor([0.1725], grad_fn=<ExpBackward0>) tensor(-6.1974, grad_fn=<WhereBackward0>) tensor([1], dtype=torch.int32)
Step size :  tensor([0.0526], grad_fn=<ExpBackward0>) tensor(-28.1749, grad_fn=<WhereBackward0>) tensor([1], dtype=torch.int32)
Step size :  tensor([0.0199], grad_fn=<ExpBackward0>) tensor(-1.1208, grad_fn=<WhereBackward0>) tensor([2], dtype=torch.int32)
Step size :  tensor([0.0141], grad_fn=<ExpBackward0>) tensor(-0.0612, grad_fn=<WhereBackward0>) tensor([6], dtype=torch.int32)
Step size :  tensor([0.0146], grad_fn=<ExpBackward0>) tensor(0.7122, grad_fn=<WhereBackward0>) tensor([8], dtype=torch.int32)
Step size :  tensor([0.0119], grad_fn=<ExpBackward0>) tensor(-1.0486, grad_fn=<WhereBackward0>) tensor([8], dtype=torch.int32)
Step size :  tensor([0.0075], grad_fn=<ExpBackward0>) t