In [12]:
import os
os.environ["KERAS_BACKEND"] = "jax"  # Must be specified before loading keras_core
os.environ["JAX_PLATFORM_NAME"] = "cpu"  # CPU is faster for batchsize=1 inference.

import keras_core as kerasjk
import jax
import jax.numpy as jnp
from jax import random, vmap, jit, grad
#assert jax.default_backend() == 'gpu'
import utils
import numpy as np
import pandas as pd
import time
from pathlib import Path
from datetime import datetime
import matplotlib.pyplot as plt
#import elegy # pip install elegy. # Trying to do this with keras core instead.
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
import tensorflow as tf

In [13]:
mcmc_or_hmc = 'mcmc' # 'mcmc' or 'hmc'
num_results = 110_000 #110_000 #150000 #500000 # 10k takes 11min. About 1/5 of these accepted? now .97
num_steps_between_results = 0 # Thinning
num_burnin_steps = 100_000 # Number of steps before beginning sampling

# Note: below parameters are only for hmc
num_adaptation_steps = np.floor(.8*num_burnin_steps) #Somewhat smaller than number of burnin
step_size = 1e-4 # 1e-3 (experiment?) # 1e-5 has 0.95 acc rate and moves
max_tree_depth = 10 # Default=10. Smaller results in shorter steps. Larger takes memory.
max_energy_diff = 1000 # Default 1000.0. Divergent samples are those that exceed this.
unrolled_leapfrog_steps = 1 # Default 1. The number of leapfrogs to unroll per tree expansion step

In [16]:
# Custom normal function with adjustable scale
def custom_normal_fn(scale):
    def normal_fn(x):
        return tfd.Normal(loc=x, scale=scale)
    return normal_fn

class SimpleScaleAdjustedRandomWalkMetropolis(tfp.mcmc.RandomWalkMetropolis):
    def __init__(self, target_log_prob_fn, initial_step_size, num_adaptation_steps, adaptation_rate=0.01, target_accept_prob=0.25, *args, **kwargs):
        super().__init__(target_log_prob_fn, *args, **kwargs)
        self._step_size = initial_step_size
        self.num_adaptation_steps = num_adaptation_steps
        self.adaptation_rate = adaptation_rate
        self.target_accept_prob = target_accept_prob
        self.accept_prob_history = []
        self.adaptation_completed = False

    @property
    def step_size(self):
        return self._step_size

    @step_size.setter
    def step_size(self, value):
        self._step_size = value

    def one_step(self, current_state, previous_kernel_results, seed=None):
        # Update the scale of the normal function based on the current step size
        self.new_state_fn = custom_normal_fn(self.step_size)
        kernel_results = super().one_step(current_state, previous_kernel_results, seed=seed)
        
        # Record the acceptance probability
        accept_prob = tf.math.exp(kernel_results.log_accept_ratio)
        self.accept_prob_history.append(accept_prob)
        
        # Adapt the step size during the first num_adaptation_steps
        if not self.adaptation_completed:
            if len(self.accept_prob_history) >= self.num_adaptation_steps:
                mean_accept_prob = np.mean(self.accept_prob_history)
                if mean_accept_prob > self.target_accept_prob:
                    self._step_size *= (1.0 + self.adaptation_rate)
                else:
                    self._step_size *= (1.0 - self.adaptation_rate)
                    
                # Reset accept_prob_history for the next adaptation phase
                self.accept_prob_history = []
                self.adaptation_completed = True

        return kernel_results

In [18]:
# Initial step size
initial_step_size = 1.0
target_distribution = tfd.Normal(loc=0., scale=1.)

# Create Custom Random Walk Metropolis kernel
kernel = SimpleScaleAdjustedRandomWalkMetropolis(
    target_log_prob_fn=target_distribution.log_prob,
    initial_step_size=step_size, 
    num_adaptation_steps=num_adaptation_steps, 
    adaptation_rate=0.01, 
    target_accept_prob=0.25
)

def trace_fn(_, pkr):
    return [pkr.log_accept_ratio]

samples, pkr = tfp.mcmc.sample_chain(
        num_results=num_results,
        num_burnin_steps=num_burnin_steps,
        num_steps_between_results=num_steps_between_results,
        kernel=kernel,
        trace_fn=trace_fn,
        current_state=0.0,
        seed=random.PRNGKey(5)
)

AttributeError: can't set attribute