In [1]:
from typing import NamedTuple, Callable
from functools import partial

import torch
import torch.nn as nn
from sklearn.datasets import make_moons
import matplotlib.pyplot as plt

In [2]:
class MLP1D(nn.Module):

    def __init__(self, in_shapes=[2,10,10,10], out_shapes = [10,10,10,1]):
        super().__init__()
        self.linears = [nn.Linear(in_shapes[i], out_shapes[i]) for i in range(len(in_shapes))]
        self.relu = nn.ReLU()
        
    def model(self):
        return nn.Sequential(self.linears[0], self.relu, 
                             self.linears[1], self.relu, 
                             self.linears[2], self.relu, self.linears[3])

    def forward(self, params, X):
        self.weights_and_biases(params)
        return self.model()(X)
    
    def weights_and_biases(self, params):
        for i,param in enumerate(params):
            self.linears[i].weight = nn.Parameter(param['weight'])
            self.linears[i].bias = nn.Parameter(param['bias'])
    
    def fetch_params(self):
        params = []
        for linear in self.linears:
            params.append({
                'weight': linear.weight.detach(),
                'bias': linear.bias.detach()
            })
        return params

def ravel_params(params):
    def ravel_per_item():
        all_params = []
        def ravel_per_item_per_key(param):
            for value in param.values(): 
                all_params.append(value.ravel())

        if type(params) == list:
            for param in params:
                ravel_per_item_per_key(param)
        else:
            ravel_per_item_per_key(params)
        return all_params
    return torch.cat(ravel_per_item())

def unravel_params(params_raveled: torch.Tensor, structured_tuple):
    
    def unravel_all_shapes(structure, is_list):
        shapes = []
        def extract_shapes(structure):
            if is_list:
                for per_structure in structure:
                    for val in per_structure.values():
                        shapes.append(val.numel())
            else:
                for val in structure.values():
                    shapes.append(val.numel())
            values_per_structure = params_raveled.split(shapes)
            return values_per_structure
        
        values_per_structure = extract_shapes(structure)
        result = []
        index = [0]

        def result_per_dict(values_per_structure, per_structure):
            per_result = {}
            for key, val in per_structure.items():
                # print(key, val.shape, values_per_structure[index[0]].shape)
                per_result[key] = values_per_structure[index[0]].reshape(val.shape)
                index[0] += 1
            return per_result
        
        if is_list:
            for per_structure in structure:
                per_result = result_per_dict(values_per_structure, per_structure)
                result.append(per_result)
            return result
        else:
            return result_per_dict(values_per_structure, structure)
        
    return unravel_all_shapes(structured_tuple, type(structured_tuple) == list)
    
def bnn_log_joint(params, X, y, model:MLP1D):
    logits = model.forward(params, X).ravel()
    flatten_params = ravel_params(params)
    log_prior = torch.distributions.Normal(0.0, 1.0).log_prob(flatten_params).sum()
    log_likelihood = torch.distributions.Bernoulli(logits=logits).log_prob(y).sum()
    log_joint = log_prior + log_likelihood
    return log_joint

noise = 0.2
num_samples = 50
num_warmup = 1000
num_steps = 500
in_shapes = [2,10,10,10]
out_shapes = [10,10,10,1]

X, y = make_moons(n_samples=num_samples, noise=noise, random_state=314)
model = MLP1D(in_shapes=in_shapes, out_shapes=out_shapes)
params = model.fetch_params()
potential = partial(bnn_log_joint, X=torch.Tensor(X), y=torch.Tensor(y), model=model)


# Test
print(torch.Tensor(X).shape, torch.Tensor(y).shape)
# potential(params)
# print(ravel_params(params[0]).shape)
# print(unravel_params(ravel_params(params[1]),params[1]))
# print(unravel_params(ravel_params(params),params))

torch.Size([50, 2]) torch.Size([50])


In [3]:
class WelfordComputation:

    def __init__(self, params_shape):
        self.parameters = (torch.zeros(params_shape), torch.zeros(params_shape), 0)

    def update_step(self, value):
        mean, m2, no_of_samples = self.parameters
        no_of_samples += 1
        delta_change = value - mean
        mean = mean + delta_change/no_of_samples
        new_delta_change = value - mean
        new_m2 = m2 + delta_change*new_delta_change
        self.parameters = (mean, new_m2, no_of_samples)
    
    def final_step(self):
        mean, m2, no_of_samples = self.parameters
        baised_variance = m2/no_of_samples
        unbaised_variance = m2/(no_of_samples-1)
        return (mean, baised_variance, unbaised_variance)

In [4]:
class EnergyParameters:

    def __init__(self, potential_energy_fn:Callable, precision):
        self.potential_energy_fn = potential_energy_fn 
        self.precision = precision
    
    def set_position(self, position):
        self.position = position
        self.potential_energy = self.potential_energy_fn(position)
        self.potential_energy_grad = torch.func.grad(self.potential_energy_fn)(position)
    
    def set_velocity(self, velocity):
        self.velocity = velocity
        self.kinetic_energy = self.kinetic_energy_fn(self.velocity)
        self.kinetic_energy_grad = torch.func.grad(self.kinetic_energy_fn)(self.velocity)

    def init_per_step(self, position, velocity):
        self.set_position(position)
        self.set_velocity(velocity)
        self.total_init_energy = self.total_current_energy()
    
    def kinetic_energy_fn(self, velocity:torch.Tensor):
        return 0.5*torch.matmul(velocity, torch.matmul(self.precision, velocity.T))
    
    def update_position(self, step_size_with_direction):
        position_raveled = ravel_params(self.position)
        result = position_raveled + step_size_with_direction*self.kinetic_energy_grad
        n_position = unravel_params(result, self.position)
        self.set_position(n_position)
        return n_position

    def updated_velocity(self, step_size_with_direction, is_half_step_momentum = False):
        n_velocity = torch.subtract(self.velocity, step_size_with_direction*(0.5 if is_half_step_momentum else 1)*ravel_params(self.potential_energy_grad))
        self.set_velocity(n_velocity)
        return n_velocity
    
    def delta_energy(self):
        return - self.total_current_energy() + self.total_init_energy 
    
    def total_current_energy(self):
        return - self.potential_energy - self.kinetic_energy

# Test

# e_params = EnergyParameters(potential_energy_fn=potential)
# e_params.set_energy_parameters(params)

# Efficient No-U-Turn Sampler with Dual Averaging

In [9]:
class HMCAlgorithm:
    
    def __init__(self, log_density_fn: Callable, precision, position_shape, expected_prob_density=0.65, Lambda = 1, m_warmup=1000, m_after_warmup=50):
        self.log_density_fn = log_density_fn
        # Lambda
        self.Lambda = Lambda
        # delta
        self.expected_prob_density = expected_prob_density
        # M_adapt
        self.m_warmup = m_warmup
        # M
        self.M = m_warmup +  m_after_warmup
        self.ep = EnergyParameters(potential_energy_fn=log_density_fn, precision=precision)
        self.welford_position = WelfordComputation(position_shape)
    

    def generate_default_velocity(self, position_shape):

        def generate_random_gaussian(covariance: torch.Tensor):
            return torch.distributions.MultivariateNormal(loc=torch.zeros(covariance.shape[-1]),covariance_matrix=covariance).sample()

        def generate_inverse_mass_matrix_and_sigma(covariance_shape: int):
            covariance = torch.eye(covariance_shape)
            inverse_mass_matrix = covariance.inverse()
            return inverse_mass_matrix, covariance
        
        _, covariance = generate_inverse_mass_matrix_and_sigma(position_shape)
        return generate_random_gaussian(covariance)

    # welford algorithm
    def square_distance_from_mean(self, position):
        mean, m2, no_of_samples = self.parameters
        no_of_samples += 1
        delta_change = position - mean
        mean = mean + delta_change/no_of_samples
        new_delta_change = position - mean
        new_m2 = m2 + delta_change*new_delta_change
        self.parameters = (mean, new_m2, no_of_samples)

    def find_reasonable_epsilon(self, position):
        step_size, velocity = 0.01, self.generate_default_velocity(ravel_params(position).shape[-1])
        delta_energy, proposal_position, proposal_velocity, _ = self.step_integrator(position, velocity, step_size)
        alpha = 2*(1 if torch.exp(delta_energy) > 0.5 else 0) - 1
        while torch.pow(torch.exp(delta_energy), alpha) > torch.pow(torch.tensor(0.5), alpha):
            step_size = torch.pow(torch.tensor(2), alpha) * step_size
            delta_energy, proposal_position, proposal_velocity, _ = self.step_integrator(proposal_position, proposal_velocity, step_size)
        print(f'Step size (find_reasonable_epsilon) : {step_size}')
        return step_size

    def current_direction(self):
        direction = torch.distributions.Bernoulli(logits=torch.tensor([0.5])).sample()
        return -1 if direction.bool() else 1

    def stop_condition(self, p_minus, v_minus, p_plus, v_plus):
        diff = ravel_params(p_plus)-ravel_params(p_minus)
        condition = torch.all(diff*v_minus >= torch.zeros(diff.shape)) and torch.all(diff*v_plus >= torch.zeros(diff.shape))
        return 1 if condition else 0

    def build_tree(self, position, velocity, energy, direction, iter_no, step_size, init_energy, m):
        delta_max = torch.tensor(1000)
        if iter_no == 0:
            delta_energy, new_position, new_velocity, new_energy = self.step_integrator(position, velocity, direction*step_size)
            n_leaves = 1 if energy <= new_energy else 0
            s = 1 if energy < torch.log(delta_max) + new_energy  else 0
            # print(f'build_tree {m} : delta, steps :', delta_energy, direction*step_size)
            return new_position, new_velocity, new_position, new_velocity, new_position, n_leaves, s, torch.min(torch.tensor(1), torch.exp(-new_energy+init_energy)), 1
        else:
            p_minus, v_minus, p_plus, v_plus, p, n_leaves, s, alpha, n_alpha = self.build_tree(position, velocity, energy, direction, iter_no-1, step_size, init_energy, m)
            if s == 1:
                if direction == -1:
                    p_minus, v_minus, _, _, new_p, n_n_leaves, new_s, new_alpha, n_new_alpha = self.build_tree(p_minus, v_minus, energy, direction, iter_no-1, step_size, init_energy, m)
                else:
                    _, _, p_plus, v_plus, new_p, n_n_leaves, new_s, new_alpha, n_new_alpha = self.build_tree(p_plus, v_plus, energy, direction, iter_no-1, step_size, init_energy, m)
                p = new_p if torch.distributions.Bernoulli(logits=new_s/(s+new_s)).sample().bool() else p
                alpha += new_alpha
                n_alpha += n_new_alpha
                s = new_s * self.stop_condition(p_minus, v_minus, p_plus, v_plus)
                n_leaves += n_n_leaves
            return p_minus, v_minus, p_plus, v_plus, p, n_leaves, s, alpha, n_alpha

    def update(self, init_position):
        self.init_step(init_position)
        log_step_size_average = torch.log(self.step_size_average)
        # starting from 1
        # annotations : n --> number, new -> new value 
        proposal_position = init_position
        for m in range(1,self.M+1):
            velocity = self.generate_default_velocity(ravel_params(proposal_position).shape[-1])
            self.ep.init_per_step(proposal_position, velocity)
            init_energy = self.ep.total_current_energy()
            # print(init_energy)
            energy = torch.distributions.Uniform(torch.tensor([0.0]),init_energy).sample()
            p_minus, p_plus, v_minus, v_plus = proposal_position, proposal_position, velocity, velocity
            counter, s, n_leaves = 0, 1, 1
            while s == 1:
                direction = self.current_direction()
                if direction == -1:
                    p_minus, v_minus, _,_, new_p, n_n_leaves, new_s, alpha, n_alpha = self.build_tree(p_minus, v_minus, energy, direction, counter, self.step_size, init_energy, m) 
                else:
                    _, _, p_plus, v_plus, new_p, n_n_leaves, new_s, alpha, n_alpha = self.build_tree(p_plus, v_plus, energy, direction, counter, self.step_size, init_energy, m) 
                if new_s == 1:
                    probs = torch.tensor(min(1, n_n_leaves/n_leaves))
                    proposal_position = new_p if torch.distributions.Bernoulli(logits=probs).sample().bool() else proposal_position
                n_leaves += n_n_leaves
                s = new_s * self.stop_condition(p_minus, v_minus, p_plus, v_plus)
                counter += 1
            if m <= self.m_warmup:
                part_1 = (1-1/(m+self.t_0))*self.H_m
                part_2 = (1/(m+self.t_0))*(self.expected_prob_density - alpha/n_alpha)
                self.H_m =  part_1 + part_2 
                log_step_size = self.mu - (torch.sqrt(torch.tensor(m))/self.gamma)*self.H_m
                step = self.step_fn(m)
                log_step_size_average = step * log_step_size + (1-step)*log_step_size_average
                self.step_size_average = torch.exp(log_step_size_average)
                self.step_size = torch.exp(log_step_size_average)
                # print("s, s', m, step, alpha, n_alpha, H_m: ", s, new_s, m, self.step_size, alpha, n_alpha, self.H_m)
            else:
                # print("After warmup : s, s', m, step_size, alpha, n_alpha : ", s, new_s, m, self.step_size, alpha, n_alpha)
                self.step_size = torch.exp(log_step_size_average)
            print(m, self.step_size, alpha/n_alpha)
            self.welford_position.update_step(ravel_params(proposal_position))
        return proposal_position, self.welford_position.final_step()
    
    def step_fn(self, current_step):
        return torch.pow(torch.tensor(1/current_step), self.k)

    def init_step(self, position):
        self.step_size = self.find_reasonable_epsilon(position)
        self.mu = torch.log(10*self.step_size)  
        self.step_size_average = torch.tensor(1)
        self.H_m = torch.zeros(1)
        self.gamma = 0.05
        self.t_0 = 10
        self.k = 0.75

    def step_integrator(self, position, velocity, step_size_with_direction, l_frog_step=1):
        self.ep.init_per_step(position, velocity)
        # init step
        self.ep.updated_velocity(step_size_with_direction, is_half_step_momentum=True)
        # l_frog_steps
        for i in range(l_frog_step):
            self.ep.update_position(step_size_with_direction)
            self.ep.updated_velocity(step_size_with_direction, is_half_step_momentum=False)
        # final step
        f_step_position = self.ep.update_position(step_size_with_direction)
        half_step_velocity = self.ep.updated_velocity(step_size_with_direction, is_half_step_momentum=True)
        return self.ep.delta_energy(), f_step_position, half_step_velocity, self.ep.total_current_energy()

params_shape = ravel_params(params).shape[-1]
# I^-1 = I so inverse doesn't matter but writing it for visibility
default_precision = torch.eye(params_shape).inverse()
hmc_kernel = HMCAlgorithm(log_density_fn=potential, precision=default_precision, position_shape=params_shape)
# Test
final_state, (_,_, unbiased_variance_positon) = hmc_kernel.update(params)
final_params = unravel_params(unbiased_variance_positon, final_state)

Step size (find_reasonable_epsilon) : 0.1599999964237213
1 tensor([0.4908], grad_fn=<ExpBackward0>) tensor(1.4091e-08, grad_fn=<DivBackward0>)
2 tensor([0.1602], grad_fn=<ExpBackward0>) tensor(0., grad_fn=<DivBackward0>)
3 tensor([0.1448], grad_fn=<ExpBackward0>) tensor(1., grad_fn=<DivBackward0>)
4 tensor([0.0673], grad_fn=<ExpBackward0>) tensor(0., grad_fn=<DivBackward0>)
5 tensor([0.0569], grad_fn=<ExpBackward0>) tensor(1., grad_fn=<DivBackward0>)
6 tensor([0.0298], grad_fn=<ExpBackward0>) tensor(0., grad_fn=<DivBackward0>)
7 tensor([0.0245], grad_fn=<ExpBackward0>) tensor(1., grad_fn=<DivBackward0>)
8 tensor([0.0138], grad_fn=<ExpBackward0>) tensor(0., grad_fn=<DivBackward0>)
9 tensor([0.0061], grad_fn=<ExpBackward0>) tensor(7.5760e-06, grad_fn=<DivBackward0>)
10 tensor([0.0023], grad_fn=<ExpBackward0>) tensor(5.0953e-08, grad_fn=<DivBackward0>)
11 tensor([0.0008], grad_fn=<ExpBackward0>) tensor(0.0922, grad_fn=<DivBackward0>)
12 tensor([0.0004], grad_fn=<ExpBackward0>) tensor(1., 

In [25]:
predict_X, predict_y = make_moons(n_samples=100, random_state=400)
y_predicted = model.forward(final_params, torch.Tensor(predict_X)).ravel()
print(y_predicted[:10], y[:10])

tensor([0.0129, 0.0127, 0.0129, 0.0127, 0.0127, 0.0127, 0.0128, 0.0127, 0.0129,
        0.0129], grad_fn=<SliceBackward0>) [1 0 0 0 1 0 1 0 0 0]


In [None]:
torch.arange(vmin, vmax, 100)