In [1]:
import sys
import argparse
import math
from math import pi
import time

import matplotlib.colors
import matplotlib.pyplot as plt
import numpy as np
from sklearn.cluster import KMeans

from jax import numpy as jnp, random

import numpyro
from numpyro.distributions import (
    Beta,
    Categorical,
    Dirichlet,
    Gamma,
    Normal,
    SineSkewed,
    Uniform,
    VonMises,
    HalfNormal
)
from numpyro.infer import MCMC, NUTS, Predictive, init_to_value
import arviz as az
import seaborn as sns
import pandas as pd
from numpyro.infer.util import log_likelihood
from numpyro.handlers import seed

In [2]:
def model(num_mix_comp, data = None, n = None, identifiable=False, kappa_max=1000):
    if data is None:
        # We are predicting - no data
        assert(not n is None)
    else:
        # We are training - we have data
        assert(n is None)
        n = data.shape[0]
        assert(data.shape == (n,2))

    # Sampling mixture weights from a Dirichlet distribution (latent variable)
    mix_weights = numpyro.sample("mix_weights", Dirichlet(jnp.ones((num_mix_comp,))))

    if identifiable:
    # We constrain psi_loc so that it increases with mixture component m
        # Note the +1, this is because we will use cumsum
        edges = numpyro.sample("edges", Dirichlet(jnp.ones((num_mix_comp+1,))))
        cumsum = jnp.cumsum(edges)
        # Remove last dim in the cumsum because it is always 1
        cumsum = cumsum[:-1]
        psi_loc = 2*jnp.pi*cumsum - jnp.pi

    # Sample mixture rvs
    # Using plate to indicate that the following variables are conditionally independent
    with numpyro.plate("mixture", num_mix_comp):
        # Sampling mixture component parameters for von Mises distributions
        # Locations
        phi_loc = numpyro.sample("phi_loc", Uniform(-jnp.pi, jnp.pi)) # Mean direction for phi

        if not identifiable:
            # No constraints
            psi_loc = numpyro.sample("psi_loc", Uniform(-jnp.pi, jnp.pi)) # Mean direction for psi

        psi_conc = numpyro.sample("psi_conc", Uniform(0.1, kappa_max))
        phi_conc = numpyro.sample("phi_conc", Uniform(0.1, kappa_max))

    # Combine the locs and kappas for a single VonMises likelihood
    locs = jnp.stack((phi_loc, psi_loc), -1)
    assert(locs.shape==(num_mix_comp,2))
    kappas = jnp.stack((phi_conc, psi_loc), -1)
    assert(kappas.shape==(num_mix_comp,2))

    # Using plate for the observed data
    with numpyro.plate("data_plate", n):
        # Sampling the mixture component assignment (latent variable)
        assign = numpyro.sample("mix_comp", Categorical(mix_weights), infer={"enumerate": "parallel"})

        locs_a = locs[assign]
        kappas_a = kappas[assign]

        # to_event is used because the VM is univariate, but we have 2D data
        von = VonMises(locs_a, kappas_a).to_event(1)

        # Define the likelihood 
        phi_psi = numpyro.sample("phi_psi", von, obs = data)



In [3]:
# Run Hamiltonian Monte Carlo
def run_hmc(rng_key, model, data, num_mix_comp, args):
    from numpyro.infer import init_to_mean
    kernel = NUTS(model, init_strategy=init_to_mean())
    # TH: two chains 
    mcmc = MCMC(kernel, num_chains=2, num_samples=args.num_samples, num_warmup=args.num_warmup)
    # TH: fixed this
    mcmc.run(rng_key, num_mix_comp=num_mix_comp, n=None, data=data, identifiable=args.identifiable)
    mcmc.print_summary()
    post_samples = mcmc.get_samples()

    def waic(model, posterior, model_args, model_kwargs):
        waic_result = az.waic(get_idata(model, posterior, model_args, model_kwargs))
        elpd_waic = waic_result.elpd_waic
        p_waic = waic_result.p_waic
        return elpd_waic

    def loo(model, posterior, model_args, model_kwargs):
        loo_result = az.loo(get_idata(model, posterior, model_args, model_kwargs))
        elpd_loo = loo_result.elpd_loo
        p_loo = loo_result.p_loo
        return elpd_loo

    def ess(model, posterior, model_args, model_kwargs):
        return az.ess(get_idata(model, posterior, model_args, model_kwargs))

    def rhat(model, posterior, model_args, model_kwargs):
        return az.rhat(get_idata(model, posterior, model_args, model_kwargs))


     # We need to sample the assignments, looks az.from_numpyro doesn't handle discrete sites correctly.
     # **NOTE**: I have not checked whether we should enumerate mix_comp when computing WAIC.
    def get_idata(model, posterior, model_args, model_kwargs):
        ll = log_likelihood(seed(model, rng_seed=0), posterior, *model_args, **model_kwargs)
        ll = {k: v[None] for k, v in ll.items()}
        idata = az.convert_to_inference_data(
            {k: v[None] for k, v in posterior.items() if k not in ll}
        )
        idata.add_groups(log_likelihood=ll)
        return idata

    expected_waic = waic(model, post_samples, (num_mix_comp, data), {})
    expected_loo = loo(model, post_samples, (num_mix_comp, data), {})
    expected_ess = ess(model, post_samples, (num_mix_comp, data), {})
    expected_rhat = rhat(model, post_samples, (num_mix_comp, data), {})

    return post_samples, expected_waic, expected_loo, expected_ess, expected_rhat

In [4]:
def plot_metric_vs_components(metric_values, num_components, metric_name, iters):
    metric_values = metric_values
    print(f"{metric_name} metric_vals : ", metric_values)
    print(f"{metric_name} num_comp : ", num_components)
    plt.figure()
    plt.plot(num_components, metric_values, marker='o')
    plt.xlabel('Number of Components')
    plt.ylabel(metric_name)
    plt.title(f'{metric_name} vs Number of Components')
    plt.grid(True)
    plt.tight_layout()
    plt.show()
    save_path = f'model1_{metric_name}_vs_components_{iters}.png'
    plt.savefig(save_path)

    # Calculate and print the optimal component
    if len(num_components) > 1:
        optimal_component = num_components[np.argmax(metric_values)]
        print(f'Optimal {metric_name} Component: {optimal_component}')

In [5]:
def plot_rama(phi_values, psi_values):
    fig, ax = plt.subplots(figsize=(7, 6))
    ax.set_title("Ramachandran Plot")

    mask = ~np.isnan(phi_values) & ~np.isnan(psi_values)
    phi_values = phi_values[mask]
    psi_values = psi_values[mask]

    # Create hexbin plot 
    hb = ax.hexbin(phi_values, psi_values, bins='log', gridsize=50, cmap='inferno')
   
    # Set labels and colorbar
    ax.set_xlabel('Phi Angles')
    ax.set_ylabel('Psi Angles')
    cb = fig.colorbar(hb, ax=ax, label='Counts')

    plt.show()

In [6]:
def plot_density(phi_values, psi_values, iters):

    # Check if the lengths of phi_values and psi_values match
    if len(phi_values) != len(psi_values):
        raise ValueError("The lengths of phi_values and psi_values must be the same.")

    # Create a DataFrame from the phi_values and psi_values
    data = pd.DataFrame({'Phi': phi_values, 'Psi': psi_values})

    # Create a density plot
    plt.figure(figsize=(10, 8))
    sns.kdeplot(data=data, x='Phi', y='Psi', fill=True, cmap='viridis', cbar=True)

    # Add labels and title
    plt.xlabel("Phi")
    plt.ylabel("Psi")
    plt.title("Phi vs. Psi Angles Density")
    save_path = f'model1_density_{iters}.png'
    plt.savefig(save_path)

    # Show the plot
    plt.show()

In [7]:
def plot_concentrations(post_samples):
    phi_concentration = post_samples['phi_conc']
    psi_concentration = post_samples['psi_conc']

    # Fixing random state for reproducibility
    np.random.seed(19680801)

    x = np.array(phi_concentration)
    y = np.array(psi_concentration)
    xlim = (x.min(), x.max())
    ylim = (y.min(), y.max())

    plt.figure()

    hb = plt.hexbin(x, y, gridsize=50, bins='log', cmap='inferno')
    plt.xlim(xlim)
    plt.ylim(ylim)
    plt.title("Phi vs Psi Concentrations")
    cb = plt.colorbar(hb, label='counts')

    plt.show()

In [8]:
def extract_values(result):
    values = []
    for var in result.data_vars:
        values.extend(result[var].values.flatten())
    return np.array(values)

In [9]:
def main(args):
    start_time = time.time()

    # Load data from the file
    N = 1000  # Limit the data set size for testing
    data = jnp.array(np.load("top500_inliers.npy")[0:N])
    assert data.shape == (N, 2)

    phi = data[:, 0]
    psi = data[:, 1]
    
    # Initialize lists to accumulate the ESS and R-hat values
    accumulated_ess_values = []
    accumulated_rhat_values = []

    # Initialize lists to store the results for other metrics
    all_num_mix_comp = []
    all_expected_waic = []
    all_expected_loo = []
    pred_datas = {}

    # Define the loop parameters
    start = 5
    iters = 30
    step = 5
    rng_key = random.PRNGKey(0)  # Initialize your random key here

    # Loop over the specified range
    for m in range(start, iters, step):
        # Split the random key into multiple keys for different purposes
        rng_key, inf_key, pred_key = random.split(rng_key, 3)

        # Determine the number of mixture components for this iteration
        num_mix_comp = m

        # Time the HMC sampling
        hmc_start = time.time()
        posterior_samples, expected_waic, expected_loo, expected_ess, expected_rhat = run_hmc(inf_key, model, data, num_mix_comp, args)
        hmc_end = time.time()
        print(f"HMC sampling for {num_mix_comp} components took {hmc_end - hmc_start:.2f} seconds")

        # Time the predictive sampling
        predictive_start = time.time()
        predictive = Predictive(model, posterior_samples, parallel=True)
        pred_samples = []
        num_predictions = 50  # Number of different prediction sets to generate
        for _ in range(num_predictions):
            pred_key, new_pred_key = random.split(pred_key)
            pred_samples.append(predictive(new_pred_key, num_mix_comp=num_mix_comp, data=None, n=N, identifiable=args.identifiable)["phi_psi"])
        pred_datas[m] = jnp.concatenate(pred_samples, axis=0)
        predictive_end = time.time()
        print(f"Predictive sampling for {num_mix_comp} components took {predictive_end - predictive_start:.2f} seconds")

        # Extract and accumulate ESS and R-hat values
        accumulated_ess_values.extend(extract_values(expected_ess))
        accumulated_rhat_values.extend(extract_values(expected_rhat))

        # Store the values for plotting or further analysis
        all_num_mix_comp.append(num_mix_comp)
        all_expected_waic.append(expected_waic)
        all_expected_loo.append(expected_loo)


    # Time the plotting
    plotting_start = time.time()
    plot_metric_vs_components(all_expected_waic, all_num_mix_comp, "WAIC", iters)
    plot_metric_vs_components(all_expected_loo, all_num_mix_comp, "LOO", iters)

    for m in range(start, iters, step):
        phi_values = pred_datas[m][:, 0]
        psi_values = pred_datas[m][:, 1]
        plot_rama(phi_values, psi_values)
    
    plot_rama(phi, psi)
    plot_concentrations(posterior_samples)
    plotting_end = time.time()
    print(f"Plotting took {plotting_end - plotting_start:.2f} seconds")
    
    # After the loop, compute the std of ESS and mean of R-hat
    accumulated_ess_values = np.array(accumulated_ess_values)
    accumulated_rhat_values = np.array(accumulated_rhat_values)

    std_ess = np.std(accumulated_ess_values)
    mean_rhat = np.mean(accumulated_rhat_values[~np.isnan(accumulated_rhat_values)])  # Filter out NaN values

    print(f"Overall Standard Deviation of ESS: {std_ess}")
    print(f"Overall Mean of R-hat: {mean_rhat}")

    end_time = time.time()
    print(f"Total execution time: {end_time - start_time:.2f} seconds")

In [None]:
if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Von mises mixture model"
    )
    parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
    parser.add_argument("--num-warmup", nargs="?", default=500, type=int)
    parser.add_argument("--rng_seed", type=int, default=123)
    parser.add_argument("--device", default="gpu", type=str, help='use "cpu" or "gpu".')
    parser.add_argument("--file-path", type=str, default="top500.txt", help="Path to the data file.")
    parser.add_argument("--identifiable", dest="identifiable", default=False, action="store_true", help="Identifiable or not.")

    args, _ = parser.parse_known_args()
    main(args)


  mcmc = MCMC(kernel, num_chains=2, num_samples=args.num_samples, num_warmup=args.num_warmup)
warmup:   5%|█▎                         | 71/1500 [00:18<12:22,  1.93it/s, 1023 steps of size 1.91e-03. acc. prob=0.74]