In [126]:
from typing import NamedTuple, Callable
from functools import partial

import torch
import torch.nn as nn
from torch.nn.utils import weight_norm
from sklearn.datasets import make_moons
import matplotlib.pyplot as plt

In [127]:
class MLP1D(nn.Module):

    def __init__(self, in_shapes=[2,10,10,10], out_shapes = [10,10,10,1]):
        super().__init__()
        self.linears = [nn.Linear(in_shapes[i], out_shapes[i]) for i in range(len(in_shapes))]
        self.relu = nn.ReLU()
        
    def model(self):
        return nn.Sequential(self.linears[0], self.relu, 
                             self.linears[1], self.relu, 
                             self.linears[2], self.relu, self.linears[3])

    def forward(self, params, X):
        self.weights_and_biases(params)
        return self.model()(X)
    
    def weights_and_biases(self, params):
        for i,param in enumerate(params):
            self.linears[i].weight = nn.Parameter(param['weight'])
            self.linears[i].bias = nn.Parameter(param['bias'])
    
    def fetch_params(self):
        params = []
        for linear in self.linears:
            params.append({
                'weight': linear.weight.detach(),
                'bias': linear.bias.detach()
            })
        return params

def ravel_params(params):
    def ravel_per_item():
        all_params = []
        def ravel_per_item_per_key(param):
            for value in param.values(): 
                all_params.append(value.ravel())

        if type(params) == list:
            for param in params:
                ravel_per_item_per_key(param)
        else:
            ravel_per_item_per_key(params)
        return all_params
    return torch.cat(ravel_per_item())

def unravel_params(params_raveled: torch.Tensor, structured_tuple):
    
    def unravel_all_shapes(structure, is_list):
        shapes = []
        def extract_shapes(structure):
            if is_list:
                for per_structure in structure:
                    for val in per_structure.values():
                        shapes.append(val.numel())
            else:
                for val in structure.values():
                    shapes.append(val.numel())
            values_per_structure = params_raveled.split(shapes)
            return values_per_structure
        
        values_per_structure = extract_shapes(structure)
        result = []
        index = [0]

        def result_per_dict(values_per_structure, per_structure):
            per_result = {}
            for key, val in per_structure.items():
                # print(key, val.shape, values_per_structure[index[0]].shape)
                per_result[key] = values_per_structure[index[0]].reshape(val.shape)
                index[0] += 1
            return per_result
        
        if is_list:
            for per_structure in structure:
                per_result = result_per_dict(values_per_structure, per_structure)
                result.append(per_result)
            return result
        else:
            return result_per_dict(values_per_structure, structure)
        
    return unravel_all_shapes(structured_tuple, type(structured_tuple) == list)
    
def bnn_log_joint(params, X, y, model:MLP1D):
    logits = model.forward(params, X).ravel()
    flatten_params = ravel_params(params)
    log_prior = torch.distributions.Normal(0.0, 1.0).log_prob(flatten_params).sum()
    log_likelihood = torch.distributions.Bernoulli(logits=logits).log_prob(y).sum()
    log_joint = log_prior + log_likelihood
    return log_joint

noise = 0.2
num_samples = 50
num_warmup = 1000
num_steps = 500
in_shapes = [2,10,10,10]
out_shapes = [10,10,10,1]

X, y = make_moons(n_samples=num_samples, noise=noise, random_state=314)
model = MLP1D(in_shapes=in_shapes, out_shapes=out_shapes)
params = model.fetch_params()
potential = partial(bnn_log_joint, X=torch.Tensor(X), y=torch.Tensor(y), model=model)


# Test

# potential(params)
# params1 = model.fetch_params()
# potential(params1)
# print(ravel_params(params[0]).shape)
# print(unravel_params(ravel_params(params[1]),params[1]))
# print(unravel_params(ravel_params(params),params))

In [128]:
class EnergyParameters:

    def __init__(self, potential_energy_fn:Callable, precision):
        self.potential_energy_fn = potential_energy_fn 
        self.precision = precision
    
    def set_position(self, position):
        self.position = position
        self.potential_energy = self.potential_energy_fn(position)
        self.potential_energy_grad = torch.func.grad(self.potential_energy_fn)(position)
    
    def set_velocity(self, velocity):
        self.velocity = velocity
        self.kinetic_energy = self.kinetic_energy_fn(self.velocity)
        self.kinetic_energy_grad = torch.func.grad(self.kinetic_energy_fn)(self.velocity)

    def init_per_step(self, position, velocity):
        self.set_position(position)
        self.set_velocity(velocity)
        self.total_init_energy = self.total_current_energy()
    
    def kinetic_energy_fn(self, velocity:torch.Tensor):
        return 0.5*torch.matmul(velocity, torch.matmul(self.precision, velocity.T))
    
    def update_position(self, step_size_with_direction):
        position_raveled = ravel_params(self.position)
        result = position_raveled + step_size_with_direction*self.kinetic_energy_grad
        n_position = unravel_params(result, self.position)
        self.set_position(n_position)
        return n_position

    def updated_velocity(self, step_size_with_direction, is_half_step_momentum = False):
        n_velocity = torch.subtract(self.velocity, step_size_with_direction*(0.5 if is_half_step_momentum else 1)*ravel_params(self.potential_energy_grad))
        self.set_velocity(n_velocity)
        return n_velocity
    
    def delta_energy(self):
        return -self.total_current_energy() + self.total_init_energy 
    
    def total_current_energy(self):
        return -self.potential_energy - self.kinetic_energy

# Test

# e_params = EnergyParameters(potential_energy_fn=potential)
# e_params.set_energy_parameters(params)


# Vanilla HMC with Custom Step

In [129]:
class HMCAlgorithm:
    
    def __init__(self, log_density_fn: Callable, precision, l=1):
        self.log_density_fn = log_density_fn
        self.l = l
        self.ep = EnergyParameters(potential_energy_fn=log_density_fn, precision=precision)
    
    def generate_default_velocity(self, position_shape):

        def generate_random_gaussian(covariance: torch.Tensor):
            return torch.distributions.MultivariateNormal(loc=torch.zeros(covariance.shape[-1]),covariance_matrix=covariance).sample()

        def generate_inverse_mass_matrix_and_sigma(covariance_shape: int):
            covariance = torch.eye(covariance_shape)
            inverse_mass_matrix = covariance.inverse()
            return inverse_mass_matrix, covariance
        
        _, covariance = generate_inverse_mass_matrix_and_sigma(position_shape)
        return generate_random_gaussian(covariance)

    def find_reasonable_epsilon(self, position):
        step_size, velocity = 0.01, self.generate_default_velocity(ravel_params(position).shape[-1])
        delta_energy, proposal_position, proposal_velocity, _ = self.step_integrator(position, velocity, step_size)
        print(delta_energy)
        alpha = 2*(1 if torch.exp(delta_energy) > 0.5 else 0) - 1
        while torch.pow(torch.exp(delta_energy), alpha) > torch.pow(torch.tensor(0.5), alpha):
            step_size = torch.pow(torch.tensor(2), alpha) * step_size
            print(f'Step size : {step_size, alpha, delta_energy}')
            delta_energy, proposal_position, proposal_velocity, _ = self.step_integrator(proposal_position, proposal_velocity, step_size)
        return step_size

    def update(self, position, sample_no):
        if sample_no == 0:
            self.step_size = self.find_reasonable_epsilon(position)
            
        velocity = self.generate_default_velocity(ravel_params(position).shape[-1])
        delta_energy, proposal_position, _, _ = self.step_integrator(position, velocity, self.step_size)
        # nan fix
        delta_energy = torch.where(torch.isnan(delta_energy), -torch.inf, delta_energy)
        # MH Algo
        alpha = torch.min(torch.exp(delta_energy), torch.ones(1))
        print(f'Delta energy : {torch.exp(delta_energy)}, {alpha}')
        accept_condition = torch.distributions.Bernoulli(alpha).sample()
        # print(f'Delta energy : {delta_energy}, {alpha}')
        return proposal_position if accept_condition.bool() else position
    
    def step_integrator(self, position, velocity, step_size_with_direction):
        self.ep.init_per_step(position, velocity)
        # init step
        self.ep.updated_velocity(step_size_with_direction, is_half_step_momentum=True)
        # l_frog_steps
        for i in range(self.l):
            self.ep.update_position(step_size_with_direction)
            self.ep.updated_velocity(step_size_with_direction, is_half_step_momentum=False)
        # final step
        f_step_position = self.ep.update_position(step_size_with_direction)
        half_step_velocity = self.ep.updated_velocity(step_size_with_direction, is_half_step_momentum=True)
        
        return self.ep.delta_energy(), f_step_position, half_step_velocity, self.ep.total_current_energy()

precision_shape = ravel_params(params).shape[-1]
# I^-1 = I so inverse doesn't matter but writing it for visibility
default_precision = torch.eye(precision_shape).inverse()
hmc_kernel = HMCAlgorithm(log_density_fn=potential, precision=default_precision)

# Test
hmc_kernel.update(params, 0)

tensor(0.0841, grad_fn=<AddBackward0>)
Step size : (tensor(0.0200), 1, tensor(0.0841, grad_fn=<AddBackward0>))
Step size : (tensor(0.0400), 1, tensor(0.2379, grad_fn=<AddBackward0>))
Step size : (tensor(0.0800), 1, tensor(0.5374, grad_fn=<AddBackward0>))
Delta energy : 1.4600844383239746, tensor([1.], grad_fn=<MinimumBackward0>)


[{'weight': tensor([[ 0.1768,  0.3229],
          [ 0.2188, -0.0412],
          [-0.8197,  0.6504],
          [-0.6700, -0.0302],
          [ 0.1581, -0.4905],
          [ 0.2462, -0.3139],
          [-0.0928,  0.1792],
          [-0.6698, -0.6330],
          [-0.1577, -0.0029],
          [-0.1000,  0.3291]]),
  'bias': tensor([-0.6610, -0.0309, -0.2686,  0.0346, -0.0242,  0.0723, -0.0021, -0.1374,
          -0.7409, -0.5061])},
 {'weight': tensor([[ 0.4226,  0.3729, -0.1392,  0.0475, -0.3746,  0.0776,  0.0948,  0.4588,
           -0.0121, -0.3430],
          [-0.5789, -0.0504,  0.3856, -0.2566, -0.0842, -0.0336,  0.0200,  0.0273,
           -0.4046,  0.1061],
          [ 0.1664, -0.1215,  0.5548, -0.4408, -0.0179, -0.1699,  0.2189,  0.2268,
           -0.1085,  0.0833],
          [-0.1088, -0.1356, -0.0064, -0.0774,  0.5039,  0.0245, -0.0948, -0.0461,
            0.2187, -0.1022],
          [ 0.1483,  0.1493, -0.1479,  0.4428, -0.3385, -0.0875,  0.1107, -0.3266,
            0.3848,  0

In [38]:
def scan(func, init_values, length):
    carry = init_values
    result = []
    for i in range(length):
        carry = func(carry, i)
        result.append(carry)
    return carry, result 

final_state, result = scan(hmc_kernel.update, init_values=params, length=num_steps)


Step size : (tensor(0.0200), 1, tensor(-0.0089, grad_fn=<SubBackward0>))
Step size : (tensor(0.0400), 1, tensor(0.0003, grad_fn=<SubBackward0>))
Step size : (tensor(0.0800), 1, tensor(-0.0329, grad_fn=<SubBackward0>))
Delta energy : 0.8822007179260254, tensor([0.8822], grad_fn=<MinimumBackward0>)
Delta energy : 7.980336666107178, tensor([1.], grad_fn=<MinimumBackward0>)
Delta energy : 24.392852783203125, tensor([1.], grad_fn=<MinimumBackward0>)
Delta energy : 9.355762481689453, tensor([1.], grad_fn=<MinimumBackward0>)
Delta energy : 7879471923200.0, tensor([1.], grad_fn=<MinimumBackward0>)
Delta energy : 105835.484375, tensor([1.], grad_fn=<MinimumBackward0>)
Delta energy : 10486066.0, tensor([1.], grad_fn=<MinimumBackward0>)
Delta energy : 391434764288.0, tensor([1.], grad_fn=<MinimumBackward0>)
Delta energy : 1.2827541468141135e-05, tensor([1.2828e-05], grad_fn=<MinimumBackward0>)
Delta energy : 23238808.0, tensor([1.], grad_fn=<MinimumBackward0>)
Delta energy : 6608.90380859375, ten

In [None]:
result

In [None]:
step = 0.2
vmin, vmax = X.min() - step, X.max() + step
X_grid = torch.mgrid[vmin:vmax:100j, vmin:vmax:100j]