In [None]:
import os
# os.environ['CUDA_VISIBLE_DEVICES']= '2'
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']= 'false'
# os.environ['MUJOCO_GL']= 'egl'
# os.environ['EGL_DEVICE_ID'] = '2' 
# print("CUDA_VISIBLE_DEVICES =", os.environ.get('CUDA_VISIBLE_DEVICES'))
# import jax
# print("JAX devices:", jax.devices())


In [None]:

import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from learning.module.target_examples.funnel import Funnel
from learning.module.target_examples.gmm40 import GMM40
import functools
import chex


Setup GMMVI

In [None]:
.network import create_gmm_network_and_state
.network import GMMTrainingState
from learning.module.target_examples.student_t_mixture import StudentTMixtureModel
dim=20

# target = StudentTMixtureModel(dim=dim, sample_bounds=[-30., 30.], num_components=40)
target = GMM40(dim=dim)
# target = Funnel(dim=dim, sample_bounds=[-30, 30])
# low = jnp.ones(dim)*-10
# low = jnp.array([-30,-30])
low = jnp.ones(dim) * (-target._plot_bound)
# high = jnp.ones(dim)* 5
high = jnp.ones(dim) * (target._plot_bound)


key= jax.random.PRNGKey(0)
num_envs=1024
batch_size=1024
bound_info = low, high
initial_train_state, gmm_network = create_gmm_network_and_state(dim, num_envs, batch_size, key, prior_scale=.1, bound_info=(low,high))
# Sampling:
contexts, logp = gmm_network.model.sample(initial_train_state.model_state.gmm_state, key, 1000)
plt.figure(figsize=(5, 5))
plt.hist2d(contexts[:, 0], contexts[:, 1], bins=100, range=[[low[0], high[0]], [low[1], high[1]]])
plt.title("Initial GMM samples")
plt.show()


In [None]:
@functools.partial(jax.jit, static_argnames=['target_log_prob_fn'])
def gather_samples(train_state: GMMTrainingState, key: chex.Array, target_log_prob_fn):
    def get_target_grads(samples):
        return jax.vmap(jax.value_and_grad(target_log_prob_fn))(samples)
    key, subkey = jax.random.split(key)
    new_samples, mapping = gmm_network.sample_selector.select_samples(train_state.model_state, subkey)
    # new_samples, mapping = gmm_network.model.sample(train_state.model_state.gmm_state, subkey, num_envs)
    new_target_lnpdfs, new_target_grads = get_target_grads(new_samples)
    # new_target_lnpdfs = target_log_prob_fn(new_samples)
    # new_target_grads = jnp.zeros_like(new_samples)
    new_sample_db_state = gmm_network.sample_selector.save_samples(train_state.model_state, train_state.sample_db_state, new_samples, new_target_lnpdfs, new_target_grads, mapping)
    return GMMTrainingState(temperature=train_state.temperature,
                        model_state=train_state.model_state,
                        component_adaptation_state=train_state.component_adaptation_state,
                        num_updates=train_state.num_updates,
                        sample_db_state=new_sample_db_state,
                        weight_stepsize=train_state.weight_stepsize)
@functools.partial(jax.jit, static_argnames=['target_log_prob_fn'])
def train_iter(train_state: GMMTrainingState, key: chex.Array, target_log_prob_fn):
    def get_target_grads(samples):
        return jax.vmap(jax.value_and_grad(target_log_prob_fn))(samples)
    key, subkey = jax.random.split(key)
    new_samples, mapping = gmm_network.sample_selector.select_samples(train_state.model_state,
                                        subkey)
    # new_samples, mapping = gmm_network.model.sample(train_state.model_state.gmm_state, subkey, num_envs)
    new_target_lnpdfs, new_target_grads = get_target_grads(new_samples)
    # new_target_lnpdfs = target_log_prob_fn(new_samples)
    # new_target_grads = jnp.zeros_like(new_samples)
    new_sample_db_state = gmm_network.sample_selector.save_samples(train_state.model_state, train_state.sample_db_state, new_samples, new_target_lnpdfs, new_target_grads, mapping)
    samples, mapping, sample_dist_densities, target_lnpdfs, target_lnpdf_grads = \
        gmm_network.sample_selector.select_train_datas(new_sample_db_state)

    new_component_stepsizes = gmm_network.component_stepsize_fn(train_state.model_state)
    new_model_state = gmm_network.model.update_stepsizes(train_state.model_state, new_component_stepsizes)
    # expected_hessian_neg, expected_grad_neg = gmm_network.ng_estimator(new_model_state,
    #                                                         samples,
    #                                                         sample_dist_densities,
    #                                                         target_lnpdfs,
    #                                                         target_lnpdf_grads)
    expected_hessian_neg, expected_grad_neg = gmm_network.more_ng_estimator(new_model_state,
                                                            samples,
                                                            sample_dist_densities,
                                                            target_lnpdfs,
                                                            target_lnpdf_grads)
    new_model_state = gmm_network.component_updater(new_model_state,
                                    expected_hessian_neg,
                                    expected_grad_neg,
                                    new_model_state.stepsizes)

    new_model_state = gmm_network.weight_updater(new_model_state, samples, sample_dist_densities, target_lnpdfs,
                                                    train_state.weight_stepsize)
    new_num_updates = train_state.num_updates + 1
    key, subkey = jax.random.split(key)
    new_model_state, new_component_adapter_state, new_sample_db_state = \
        gmm_network.component_adapter(train_state.component_adaptation_state,
                                                    new_sample_db_state,
                                                    new_model_state,
                                                    new_num_updates,
                                                    subkey)
    return GMMTrainingState(temperature=train_state.temperature,
                        model_state=new_model_state,
                        component_adaptation_state=new_component_adapter_state,
                        num_updates=new_num_updates,
                        sample_db_state=new_sample_db_state,
                        weight_stepsize=train_state.weight_stepsize)
def eval(seed: chex.Array, train_state: GMMTrainingState, target_log_prob_fn, n_eval_samples, target_samples=None):
    samples = gmm_network.model.sample(train_state.model_state.gmm_state, seed, n_eval_samples)[0]
    def log_prob_model_fn(sample):
        log_prob = jax.vmap(functools.partial(gmm_network.model.log_density, gmm_state=train_state.model_state.gmm_state))(sample=sample)
        bijector_log_prob = lambda x : jnp.log(2 * jnp.ones_like(low)).sum(-1) -jnp.log(high-low).sum(-1)-jnp.log(1- ((2*x-(low+high))/(high-low))**2).sum(-1)
        return log_prob #- bijector_log_prob(sample)
    log_prob_model = log_prob_model_fn(sample=samples)
    log_prob_target = jax.vmap(target_log_prob_fn)(samples)
    log_ratio = log_prob_target - log_prob_model
    if target_samples is not None:
        fwd_log_prob_model = jax.vmap(gmm_network.model.log_density, in_axes=(None, 0))(train_state.model_state.gmm_state, target_samples)
        fwd_log_prob_target = jax.vmap(target_log_prob_fn)(target_samples)
        fwd_log_ratio = fwd_log_prob_target - fwd_log_prob_model
    else:
        fwd_log_ratio = None

    return samples, log_ratio, log_prob_target, fwd_log_ratio, n_eval_samples, log_prob_model_fn


In [None]:
from time import time

logger = {
    'KL/elbo': [],
    'KL/eubo': [],
    'logZ/delta_forward': [],
    'logZ/forward': [],
    'logZ/delta_reverse': [],
    'logZ/reverse': [],
    'ESS/forward': [],
    'ESS/reverse': [],
    'num_components' : [],
}
def compute_reverse_ess(log_weights, eval_samples):
    # Subtract the maximum log weight for numerical stability
    max_log_weight = jnp.max(log_weights)
    stable_log_weights = log_weights - max_log_weight

    # Compute the importance weights in a numerically stable way
    is_weights = jnp.exp(stable_log_weights)

    # Compute the sums needed for ESS
    sum_is_weights = jnp.sum(is_weights)
    sum_is_weights_squared = jnp.sum(is_weights ** 2)

    # Calculate the effective sample size (ESS)
    ess = (sum_is_weights ** 2) / (eval_samples * sum_is_weights_squared)

    return ess
def eval_fn(samples, log_ratio, target_log_prob, fwd_log_ratio, n_eval_samples, model_log_prob_fn):
    ln_z = jax.nn.logsumexp(log_ratio) - jnp.log(n_eval_samples)
    elbo = jnp.mean(log_ratio)

    if target.log_Z is not None:
        logger['logZ/delta_reverse'].append(jnp.abs(ln_z - target.log_Z))

    logger['logZ/reverse'].append(ln_z)
    logger['KL/elbo'].append(elbo)
    logger['ESS/reverse'].append(compute_reverse_ess(log_ratio, n_eval_samples))
    # logger['other/target_log_prob'].append(jnp.mean(target_log_prob))

    if fwd_log_ratio is not None:
        eubo = jnp.mean(fwd_log_ratio)
        fwd_ln_z = - (jax.scipy.special.logsumexp(-fwd_log_ratio) - jnp.log(n_eval_samples))
        fwd_ess = jnp.exp(fwd_ln_z - (jax.scipy.special.logsumexp(fwd_log_ratio) - jnp.log(n_eval_samples)))

        if target.log_Z is not None:
            logger['logZ/delta_forward'].append(jnp.abs(fwd_ln_z - target.log_Z))
        logger['logZ/forward'].append(fwd_ln_z)
        logger['KL/eubo'].append(eubo)
        logger['ESS/forward'].append(fwd_ess)
    # if dim==2:
    #     logger.update(target.visualise(samples=samples, model_log_prob_fn=model_log_prob_fn ,show=True))
    return logger


In [None]:
iterations = 5000
seed = 23
eval_freq = 10
n_eval_samples= 10000
target_samples = target.sample(jax.random.PRNGKey(0), (n_eval_samples,))
# target_log_prob = jax.jit(lambda x : target.log_prob(x))
rng =jax.random.PRNGKey(seed)
key, rng = jax.random.split(rng)
timer = 0
state = initial_train_state

def _train(carry, _):
    state, key = carry
    key, subkey = jax.random.split(key)
    state = train_iter(state, subkey, target.log_prob)
    return (state, key), _
def jitted_train(state, key):
    (state, _), _ = jax.lax.scan(_train, (state, key), (), length=eval_freq)
    return state
# jax.config.update("jax_disable_jit", True) 
assert iterations % eval_freq == 0
num_eval_calls = iterations // eval_freq
for _ in range(batch_size//num_envs):
    key, subkey = jax.random.split(key)
    state = gather_samples(state, subkey, target.log_prob)
global_step = 0
for _ in range(num_eval_calls):
    iter_time = time()
    key, subkey = jax.random.split(key)
    
    state = train_iter(state, subkey, target.log_prob)
    # (state, _), _ = jax.lax.scan(_train, (state, subkey), (), length=eval_freq)
    # step+=1
    timer += time() - iter_time
    key, subkey = jax.random.split(key)
    logger = eval_fn(*eval(subkey, state, target.log_prob, n_eval_samples, target_samples))
    # logger['stats/num_samples'] = [state.sample_db_state.num_samples_written]
    # logger['stats/num_components'] = [state.model_state.gmm_state.num_components]

    # print(f"{_*eval_freq}/{iterations}: "
    #         f"The model now has {state.model_state.gmm_state.num_components} components ")
    logger['num_components'].append(int(state.model_state.gmm_state.num_components))
plt.plot(logger['KL/eubo'], label='KL/eubo')
plt.legend()
plt.show()
plt.close()
plt.plot(logger['KL/elbo'], label='KL/elbo')
plt.legend()
plt.show()
plt.close()
plt.plot(logger['num_components'], label='num_components')
plt.legend()
plt.show()
plt.close()
del logger['KL/eubo']
del logger['KL/elbo']
del logger['num_components']
for k, v in logger.items():
    plt.plot(v, label=k)
plt.legend()
plt.show()
