In [1]:
import argparse
import math
from math import pi

import matplotlib.colors
import matplotlib.pyplot as plt
import numpy as np
from sklearn.cluster import KMeans

from jax import numpy as jnp, random

import numpyro
from numpyro.distributions import (
    Beta,
    Categorical,
    Dirichlet,
    Gamma,
    Normal,
    SineSkewed,
    Uniform,
    VonMises,
    HalfNormal
)
from numpyro.infer import MCMC, NUTS, Predictive, init_to_value
import arviz as az
import seaborn as sns
import pandas as pd
from numpyro.infer.util import log_likelihood
from numpyro.handlers import seed

from directional import SineBivariateVonMises

In [None]:
def model(data, num_data, num_mix_comp=2):
    # Sampling mixture weights from a Dirichlet distribution (latent variable)
    mix_weights = numpyro.sample("mix_weights", Dirichlet(jnp.ones((num_mix_comp,))))
    
    # Using plate to indicate that the following variables are conditionally independent
    with numpyro.plate("mixture", 2):
        # Sampling mixture component locations and concentrations (latent variables)
        phi_loc = numpyro.sample("phi_loc", Uniform(-1.0*jnp.pi, 1.0*jnp.pi)) #Uniform 0 2pi
        psi_loc = numpyro.sample("psi_loc", Uniform(-1.0*jnp.pi, 1.0*jnp.pi))
        phi_conc = numpyro.sample("phi_conc", Uniform(1, 1000))
        psi_conc = numpyro.sample("psi_conc", Uniform(1, 1000))
        corr_scale = numpyro.sample("corr_scale", Beta(2.0, 10.0))
    
    # Using plate for the observed data
    with numpyro.plate("obs_plate", 3000, dim=-1):
        # Sampling the mixture component assignment (latent variable)
        assign = numpyro.sample("mix_comp", Categorical(mix_weights), infer={"enumerate": "parallel"})
        
        # Define the likelihood using the SineBivariateVonMises distribution
        sine = SineBivariateVonMises( #NEW CODE
            phi_loc=phi_loc[assign],
            psi_loc=psi_loc[assign],
            phi_concentration=phi_conc[assign],# rama # c1 against c2 ## density correlation, density plot.
            psi_concentration=psi_conc[assign],
            weighted_correlation=corr_scale[assign]  
        )
        
        # Sampling the observed data (not latent because obs is set)
        numpyro.sample("phi_psi", sine, obs=data )


In [None]:
def model2(data, num_data, num_mix_comp=2):
    # Sampling mixture weights from a Dirichlet distribution (latent variable)
    mix_weights = numpyro.sample("mix_weights", Dirichlet(jnp.ones((num_mix_comp,))))
    
    
    # Using plate to indicate that the following variables are conditionally independent
    with numpyro.plate("mixture", num_mix_comp):
        # Sampling mixture component locations and concentrations (latent variables)
        phi_loc = numpyro.sample("phi_loc", Uniform(-1.0*jnp.pi, 1.0*jnp.pi))
        psi_loc = numpyro.sample("psi_loc", Uniform(-1.0*jnp.pi, 1.0*jnp.pi))
        # Sampling s_phi and transforming to rho_phi (latent variable)
        s_phi = numpyro.sample("s_phi", HalfNormal(1.0))
        phi_conc = 1.0 / (s_phi + 0.001)  # Derived, not directly sampled
        # Sampling s_psi and transforming to rho_psi (latent variable)
        s_psi = numpyro.sample("s_psi", HalfNormal(1.0))
        psi_conc = 1.0 / (s_psi + 0.001)  # Derived, not directly sampled
        corr_scale = numpyro.sample("corr_scale", Beta(2.0, 10.0))
    
    # Using plate for the observed data
    with numpyro.plate("obs_plate", 100, dim=-1):
        # Sampling the mixture component assignment (latent variable)
        assign = numpyro.sample("mix_comp", Categorical(mix_weights), infer={"enumerate": "parallel"})
        
        # Define the likelihood using the SineBivariateVonMises distribution
        sine = SineBivariateVonMises(
            phi_loc=phi_loc[assign],
            psi_loc=psi_loc[assign],
            phi_concentration=phi_conc[assign],
            psi_concentration=psi_conc[assign],
            weighted_correlation=corr_scale[assign]
        )
        
        # Sampling the observed data (not latent because obs is set)
        numpyro.sample("phi_psi", sine, obs=data )

In [None]:
# Run Hamiltonian Monte Carlo
def run_hmc(rng_key, model, data, num_mix_comp, args):
    kernel = NUTS(model)
    mcmc = MCMC(kernel, num_samples=args.num_samples, num_warmup=args.num_warmup)
    mcmc.run(rng_key, data, len(data), num_mix_comp)
    mcmc.print_summary()
    post_samples = mcmc.get_samples()
    
    def waic(model, posterior, model_args, model_kwargs):
        waic_result = az.waic(get_idata(model, posterior, model_args, model_kwargs))
        elpd_waic = waic_result.elpd_waic
        p_waic = waic_result.p_waic
        return elpd_waic
    
    def loo(model, posterior, model_args, model_kwargs):
        loo_result = az.loo(get_idata(model, posterior, model_args, model_kwargs))
        elpd_loo = loo_result.elpd_loo
        p_loo = loo_result.p_loo
        return elpd_loo
    
    def ess(model, posterior, model_args, model_kwargs):
        return az.ess(get_idata(model, posterior, model_args, model_kwargs))
    
    def rhat(model, posterior, model_args, model_kwargs):
        return az.rhat(get_idata(model, posterior, model_args, model_kwargs))


     # We need to sample the assignments, looks az.from_numpyro doesn't handle discrete sites correctly.
     # **NOTE**: I have not checked whether we should enumerate mix_comp when computing WAIC.
    def get_idata(model, posterior, model_args, model_kwargs):
        ll = log_likelihood(seed(model, rng_seed=0), posterior, *model_args, **model_kwargs)
        ll = {k: v[None] for k, v in ll.items()}
        idata = az.convert_to_inference_data(
            {k: v[None] for k, v in posterior.items() if k not in ll}
        )
        idata.add_groups(log_likelihood=ll)
        return idata
   
    expected_waic = waic(model, post_samples, (data, len(data)), {})
    expected_loo = loo(model, post_samples, (data, len(data)), {})
    expected_ess = ess(model, post_samples, (data, len(data)), {})
    expected_rhat = rhat(model, post_samples, (data, len(data)), {})
        
    return post_samples, expected_waic, expected_loo, expected_ess, expected_rhat

In [None]:
def load_data(file_path):
    phi = []
    psi = []
    with open(file_path, 'r') as file:
        for line in file:
            if line.startswith("#"):
                continue  # Skip comments
            parts = line.split()
            if len(parts) >= 4:
                if 'NT' in parts[-2:] or 'CT' in parts[-2:]:
                    continue  # Skip lines with 'NT' or 'CT'
                phi_values = float(parts[-2])
                psi_values = float(parts[-1])
                phi.append(phi_values)
                psi.append(psi_values)
    phi_array = jnp.array(phi)
    psi_array = jnp.array(psi)
    return [phi_array, psi_array]

In [None]:
def plot_metric_vs_components(metric_values, num_components, metric_name, iters):
    metric_values = metric_values
    print(f"{metric_name} metric_vals : ", metric_values)
    print(f"{metric_name} num_comp : ", num_components)
    plt.figure()
    plt.plot(num_components, metric_values, marker='o')
    plt.xlabel('Number of Components')
    plt.ylabel(metric_name)
    plt.title(f'{metric_name} vs Number of Components')
    plt.grid(True)
    plt.tight_layout()
    plt.show()
    save_path = f'model1_{metric_name}_vs_components_{iters}.png'
    plt.savefig(save_path)
    
    # Calculate and print the optimal component
    if len(num_components) > 1:
        optimal_component = num_components[np.argmax(metric_values)]
        print(f'Optimal {metric_name} Component: {optimal_component}')


In [None]:
def plot_rama(phi_values, psi_values):
    fig, ax = plt.subplots(figsize=(7, 6))
    ax.set_title("Ramachandran Plot")
    
    # Create hexbin plot with specified norm and bins
    hb = ax.hexbin(phi_values, psi_values, gridsize=50, bins="log", cmap='inferno')
    
    # Set labels and colorbar
    ax.set_xlabel('Phi Angles')
    ax.set_ylabel('Psi Angles')
    cb = fig.colorbar(hb, ax=ax, label='Counts')
    
    plt.show()

In [None]:
def plot_density(phi_values, psi_values, iters):

    # Check if the lengths of phi_values and psi_values match
    if len(phi_values) != len(psi_values):
        raise ValueError("The lengths of phi_values and psi_values must be the same.")
    
    # Create a DataFrame from the phi_values and psi_values
    data = pd.DataFrame({'Phi': phi_values, 'Psi': psi_values})
    
    # Create a density plot
    plt.figure(figsize=(10, 8))
    sns.kdeplot(data=data, x='Phi', y='Psi', fill=True, cmap='viridis', cbar=True)
    
    # Add labels and title
    plt.xlabel("Phi")
    plt.ylabel("Psi")
    plt.title("Phi vs. Psi Angles Density")
    save_path = f'model1_density_{iters}.png'
    plt.savefig(save_path)
    
    # Show the plot
    plt.show()


In [None]:
def plot_concentrations(post_samples):
    phi_concentration = post_samples['phi_conc']
    psi_concentration = post_samples['psi_conc']
    
    # Fixing random state for reproducibility
    np.random.seed(19680801)
    
    x = np.array(phi_concentration)
    y = np.array(psi_concentration)
    xlim = (x.min(), x.max())
    ylim = (y.min(), y.max())
    
    plt.figure()
    
    hb = plt.hexbin(x, y, gridsize=50, bins='log', cmap='inferno')
    plt.xlim(xlim)
    plt.ylim(ylim)
    plt.title("Phi vs Psi Concentrations")
    cb = plt.colorbar(hb, label='counts')
    
    plt.show()

In [None]:
def plot_corr_scale_density(post_samples):
    corr_scale = post_samples['corr_scale']
    
    # Create an Arviz InferenceData object
    posterior_data = az.from_dict(
        posterior={
            "corr_scale": corr_scale
        }
    )
    
    # Plot the density
    az.style.use("arviz-doc")
    axes = az.plot_density(
        posterior_data,
        data_labels=["Posterior"],
        var_names=["corr_scale"],
        shade=0.2,
    )

    fig = axes.flatten()[0].get_figure()
    fig.suptitle("Density Plot for Correlation Scale")
    plt.show()

In [None]:
def main(args):
    N=1000 # Limit the data set size for testing
    data = jnp.array(np.transpose(load_data(args.file_path)))[0:N]
    
    # Prepare to store predicted data
    pred_datas = {}
    
    # Lists to store data for plotting
    all_num_mix_comp = []
    all_expected_waic = []
    all_expected_loo = []
    
    # Set up random number generation
    rng_key = random.PRNGKey(args.rng_seed)
    
    iters = 50 

    for m in range(5, iters, 5):
        # Split the random key into multiple keys for different purposes
        rng_key, inf_key, pred_key = random.split(rng_key, 3)

        # Determine the number of mixture components for this amino acid
        num_mix_comp = m 
        
        # Run Hamiltonian Monte Carlo to obtain posterior samples #, expected_loo, expected_ess, expected_rhat
        posterior_samples, expected_waic, expected_loo, expected_ess, expected_rhat = run_hmc(inf_key, model, data, num_mix_comp, args)
           
        # Use the posterior samples to make predictions
        predictive = Predictive(model, posterior_samples, parallel=True)
        pred_datas[m] = predictive(pred_key, None, 1, num_mix_comp)["phi_psi"].reshape(-1, 2)
        
        # Store values for plotting
        all_num_mix_comp.append(num_mix_comp)
        all_expected_waic.append(expected_waic)
        all_expected_loo.append(expected_loo)
        
        # Sort num_comp and metric_values based on num_comp
        sorted_data = sorted(zip(all_num_mix_comp, all_expected_waic, all_expected_loo))
        sorted_num_mix_comp, sorted_expected_waic, sorted_expected_loo = zip(*sorted_data)
        
        # Plot the Ramachandran plot
        phi_values = pred_datas[m][:, 0]
        psi_values = pred_datas[m][:, 1]

    plot_metric_vs_components(sorted_expected_waic, sorted_num_mix_comp, "WAIC", iters)
    plot_metric_vs_components(sorted_expected_loo, sorted_num_mix_comp, "LOO", iters)
    
    plot_rama(phi_values, psi_values)
    plot_density(phi_values, psi_values, iters)
    plot_concentrations(posterior_samples)
    plot_corr_scale_density(posterior_samples)
    
    print('expected_ess_std: ', np.std(expected_ess))
    print('expected_rhat_mean: ', np.mean(expected_rhat))


In [None]:
if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Sine-skewed sine (bivariate von mises) mixture model"
    )
    parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
    parser.add_argument("--num-warmup", nargs="?", default=500, type=int)
    parser.add_argument("--rng_seed", type=int, default=123)
    parser.add_argument("--device", default="gpu", type=str, help='use "cpu" or "gpu".')
    parser.add_argument("--file-path", type=str, default="top500.txt", help="Path to the data file.")

    args, _ = parser.parse_known_args()
    main(args)