In [None]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import jax
import jax.numpy as jnp
from jax import grad, vmap, jit
import jax.random as random
from functools import partial
from typing import Callable, Union
from collections import namedtuple
import numpy as np


import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.cm as cm
from matplotlib.lines import Line2D
from matplotlib.ticker import FormatStrFormatter
import matplotlib.ticker as ticker
import dill

from lib.kernel import imq_kernel, rbf_kernel
from lib.model import model
from lib.methods import VGD
from lib.experiment import experiment, diagnostic_experiment
from lib.calculate_mmd import calculate_mmd_squared
from lib.plot_functions import plot_predictives

## sin data

In [None]:
def f(theta, x):
  # f(x) = \theta_0 + \theta_1 \sin(x) + \ldots + \theta_p \sin(px)
  theta = jnp.asarray(theta)

  i_vals = jnp.arange(1, theta.shape[0]+1)
  # i_vals[:, None] shape (p, 1)
  # x               shape (n_data,)
  # broadcast product shape (p, n_data)
  products = i_vals[:, None] * x
  
  # vector of sin termes: [sin(1*x), sin(2*x), ..., sin(p*x)]
  sin_terms = jnp.sin(products)
  
  # sum(theta_i * sin(i*x))
  sum_term = jnp.sum(theta[:, None] * sin_terms, axis=0)
  
  return sum_term

def sin_inverse(x):
    return jnp.sin(1/x)

f_jit = jit(f)

a = jnp.tile(jnp.array([-1, 1]), 25)


sigma = 0.2
sine_model_1 = model(sigma, f_jit, theta_dim=1)
sine_model_2 = model(sigma, f_jit, theta_dim=3)
sine_model_5 = model(sigma, f_jit, theta_dim=5)
sine_model_6 = model(sigma, f_jit, theta_dim=7)
sine_model_10 = model(sigma, f_jit, theta_dim=10)
sine_model_11 = model(sigma, f_jit, theta_dim=12)
sine_model_20 = model(sigma, f_jit, theta_dim=20)
sine_model_50 = model(sigma, f_jit, theta_dim=50)


theta_1 = a[0:1]
theta_2 = a[0:3]
theta_5 = a[0:5]
theta_6 = a[0:7]
theta_10 = a[0:10]
theta_11 = a[0:12]
theta_20 = a[0:20]
theta_50 = a[0:50]

n_data = 1000
key = random.PRNGKey(10)
data_1_w = sine_model_1.generate_data(n_data, theta_1, x_max=2.0, x_min=-2.0, key=key)
data_1_m = sine_model_2.generate_data(n_data, theta_2, x_max=2.0, x_min=-2.0, key=key)
data_5_w = sine_model_5.generate_data(n_data, theta_5, x_max=2.0, x_min=-2.0, key=key)
data_5_m = sine_model_6.generate_data(n_data, theta_6, x_max=2.0, x_min=-2.0, key=key)
data_10_w = sine_model_10.generate_data(n_data, theta_10, x_max=2.0, x_min=-2.0, key=key)
data_10_m = sine_model_11.generate_data(n_data, theta_11, x_max=2.0, x_min=-2.0, key=key)
data_20_w = sine_model_20.generate_data(n_data, theta_20, x_max=2.0, x_min=-2.0, key=key)
data_50_w = sine_model_50.generate_data(n_data, theta_50, x_max=2.0, x_min=-2.0, key=key)

data_inverse = (data_1_w[0], sin_inverse(data_1_w[0])+0.2*random.normal(key, shape=data_1_w[1].shape))

In [None]:
plt.scatter(*data_50_w)

## 1 well

In [None]:
experiment_1_w = experiment(sine_model_1, data_1_w, n_particles=20, kernel=imq_kernel, key=random.PRNGKey(11))
experiment_1_w.run(n_steps=10000, step_size=0.0001)

In [None]:
experiment_1_w.plot_KGD()

In [None]:
experiment_1_w.particles_SVGD.mean()

In [None]:
experiment_1_w.plot_KSD()

In [None]:
diagnostic_experiment_1_w = diagnostic_experiment(experiment_1_w)
all_mmd_values_1_w, actual_mmd_1_w = diagnostic_experiment_1_w.plot_diagnostic()

## 1 mis

In [None]:
experiment_1_m = experiment(sine_model_1, data_inverse, n_particles=20, kernel=imq_kernel, key=random.PRNGKey(49))
experiment_1_m.run(n_steps=10000, step_size=0.0001)

In [None]:
experiment_1_m.plot_KGD()

In [None]:
experiment_1_m.plot_KSD()

In [None]:
experiment_1_m.particles_VGD.mean(axis=0)

In [None]:
diagnostic_experiment_1_m = diagnostic_experiment(experiment_1_m)
all_mmd_values_1_m, actual_mmd_1_m = diagnostic_experiment_1_m.plot_diagnostic()

In [None]:
plot_predictives(
    experiment_1_w,
    experiment_1_m, 
    intervals=[50, 80, 90]
)

## 5 well

In [None]:
experiment_5_w = experiment(sine_model_5, data_5_w, n_particles=20, kernel=imq_kernel, key=random.PRNGKey(49))
experiment_5_w.run(n_steps=15000, step_size=0.0001)

In [None]:
experiment_5_w.plot_KGD()

In [None]:
experiment_5_w.plot_KSD()

In [None]:
experiment_5_w.particles_SVGD.mean(axis=0)

In [None]:
experiment_5_w.particles_VGD.mean(axis=0)

In [None]:
diagnostic_experiment_5_w = diagnostic_experiment(experiment_5_w)
all_mmd_values_5_w, actual_mmd_5_w = diagnostic_experiment_5_w.plot_diagnostic(parallel=False, trajectory=False)

## 5 mis

In [None]:
experiment_5_m = experiment(sine_model_5, data_inverse, n_particles=20, kernel=imq_kernel, key=random.PRNGKey(49))
experiment_5_m.run(n_steps=15000, step_size=0.0001)

In [None]:
experiment_5_m.plot_KGD()

In [None]:
experiment_5_m.plot_KSD()

In [None]:
diagnostic_experiment_5_m = diagnostic_experiment(experiment_5_m)
all_mmd_values_5_m, actual_mmd_5_m = diagnostic_experiment_5_m.plot_diagnostic(parallel=False, trajectory=False)

In [None]:
plot_predictives(
    experiment_5_w,
    experiment_5_m, 
    intervals=[50, 80, 90]
)

## 10 well

In [None]:
experiment_10_w = experiment(sine_model_10, data_10_w, n_particles=20, kernel=imq_kernel, key=random.PRNGKey(49))
experiment_10_w.run(n_steps=20000, step_size=0.0001)

In [None]:
experiment_10_w.particles_SVGD.mean(axis=0)

In [None]:
experiment_10_w.particles_VGD.mean(axis=0)

In [None]:
experiment_10_w.plot_KGD()

In [None]:
experiment_10_w.plot_KSD()

In [None]:
diagnostic_experiment_10_w = diagnostic_experiment(experiment_10_w)
all_mmd_values_10_w, actual_mmd_10_w = diagnostic_experiment_10_w.plot_diagnostic()

## 10 mis

In [None]:
experiment_10_m = experiment(sine_model_10, data_inverse, n_particles=20, kernel=imq_kernel, key=random.PRNGKey(49))
experiment_10_m.run(n_steps=20000, step_size=1e-4)

In [None]:
experiment_10_m.plot_KGD()

In [None]:
experiment_10_m.plot_KSD()

In [None]:
diagnostic_experiment_10_m = diagnostic_experiment(experiment_10_m)
all_mmd_values_10_m, actual_mmd_10_m = diagnostic_experiment_10_m.plot_diagnostic()

In [None]:
plot_predictives(
    experiment_10_w,
    experiment_10_m,
    intervals=[50, 80, 90]
)

## 20 well

In [None]:
experiment_20_w = experiment(sine_model_20, data_20_w, n_particles=20, kernel=imq_kernel, key=random.PRNGKey(49))
experiment_20_w.run(n_steps=60000, step_size=0.0001)

In [None]:
experiment_20_w.plot_KGD()

In [None]:
experiment_20_w.plot_KSD()

In [None]:
experiment_20_w.particles_VGD.mean(axis=0)

In [None]:
diagnostic_experiment_20_w = diagnostic_experiment(experiment_20_w)
all_mmd_values_20_w, actual_mmd_20_w = diagnostic_experiment_20_w.plot_diagnostic(parallel=False, trajectory=False)

## 20 mis

In [None]:
experiment_20_m = experiment(sine_model_20, data_inverse, n_particles=20, kernel=imq_kernel, key=random.PRNGKey(49))
experiment_20_m.run(n_steps=50000, step_size=0.0001)

In [None]:
experiment_20_m.plot_KGD()

In [None]:
experiment_20_m.plot_KSD()

In [None]:
diagnostic_experiment_20_m = diagnostic_experiment(experiment_20_m)
all_mmd_values_20_m, actual_mmd_20_m = diagnostic_experiment_20_m.plot_diagnostic(parallel=False, trajectory=False)

In [None]:
plot_predictives(
    experiment_20_w,
    experiment_20_m, 
    intervals=[50, 80, 90]
)

## 50 well

In [None]:
experiment_50_w = experiment(sine_model_50, data_50_w, n_particles=20, kernel=imq_kernel, key=random.PRNGKey(49))
experiment_50_w.run(n_steps=100000, step_size=0.0001)

In [None]:
experiment_50_w.plot_KGD()

In [None]:
experiment_50_w.plot_KSD()

In [None]:
diagnostic_experiment_50_w = diagnostic_experiment(experiment_50_w)
all_mmd_values_50_w, actual_mmd_50_w = diagnostic_experiment_50_w.plot_diagnostic(parallel=False, trajectory=False)

## 50 mis

In [None]:
experiment_50_m = experiment(sine_model_50, data_inverse, n_particles=20, kernel=imq_kernel, key=random.PRNGKey(49))
experiment_50_m.run(n_steps=100000, step_size=0.0001)

In [None]:
experiment_50_m.plot_KGD()

In [None]:
experiment_50_m.plot_KSD()

In [None]:
diagnostic_experiment_50_m = diagnostic_experiment(experiment_50_m)
all_mmd_values_50_m, actual_mmd_50_m = diagnostic_experiment_50_m.plot_diagnostic(parallel=False, trajectory=False)

In [None]:
plot_predictives(
    experiment_50_w,
    experiment_50_m, 
    intervals=[50, 80, 90]
)

## save data

In [None]:
import dill
data_to_save = {
    'experiment_20_w': experiment_20_w,
    'all_mmd_values_20_w': all_mmd_values_20_w,
    'actual_mmd_20_w': actual_mmd_20_w,
    'experiment_20_m': experiment_20_m,
    'all_mmd_values_20_m': all_mmd_values_20_m,
    'actual_mmd_20_m': actual_mmd_20_m,
    'experiment_50_w': experiment_50_w,
    'all_mmd_values_50_w': all_mmd_values_50_w,
    'actual_mmd_50_w': actual_mmd_50_w,
    'experiment_50_m': experiment_50_m,
    'all_mmd_values_50_m': all_mmd_values_50_m,
    'actual_mmd_50_m': actual_mmd_50_m
}
with open('sine_different_dimension_experiments.pkl', 'wb') as f:
    dill.dump(data_to_save, f)

In [None]:
import dill
try:
    with open('sine_different_dimension_experiments.dill', 'rb') as f:
        data_to_save = dill.load(f)
    print("加载成功：旧的字典已读入内存。")
except FileNotFoundError:
    print("警告：未找到旧文件。将创建一个新的字典。")
    data_to_save = {}

data_to_save['experiment_20_m'] = experiment_20_m
data_to_save['all_mmd_values_20_m'] = all_mmd_values_20_m
data_to_save['actual_mmd_20_m'] = actual_mmd_20_m

with open('sine_different_dimension_experiments.dill', 'wb') as f:
    dill.dump(data_to_save, f)

In [None]:
with open('sine_different_dimension_experiments.dill', 'rb') as f:
    data = dill.load(f)

In [None]:
data_to_save.keys()

## plot

In [None]:
from plot_functions import predictive_posterior_distribution_k, plot_shaded_region_predictive, plot_diagnostic, plot_diagnostic_manual_broken

def plot_main_figure(file_name='main_fig.dill'):
    sns.set_theme(
    style="white",
    rc={
        "font.family": "serif",
        "font.serif": ["Computer Modern Roman", "CMU Serif", "Times New Roman"],
        "mathtext.fontset": "cm",
        "legend.frameon": True
    }
    )
    # -----------------------------------------------------------

    jax.config.update("jax_enable_x64", True)
    try:
        with open(file_name, 'rb') as f:
            data = dill.load(f)
        print(f"成功加载文件，包含 {len(data.keys())} 个数组: {list(data.keys())}")
    except FileNotFoundError:
        print(fr"错误：'{file_name}' 文件未找到。")
        exit()

    col_widths = [1, 1, 1, 1, 1, 0.75 + 0.25] 
    row_heights = [1, 1, 1]
    nested_col_widths = [0.75, 0.25]

    fig, axes = plt.subplots(nrows=3, ncols=6, figsize=(24, 12.5))
    main_gs = fig.add_gridspec(
        nrows=3, 
        ncols=6, 
        height_ratios=row_heights,
        width_ratios=col_widths
    )
        
    data_prefixes = ['5', '20', '50']
    # row_titles = ['Quadratic', 'Sigmoid', '2D Quadratic']
    row_titles = [r'$\mathrm{p=5}$', r'$\mathrm{p=20}$', r'$\mathrm{p=50}$']
    VGD_color = '#ff7f0e'
    SVGD_color = '#1f77b4'

    for row_idx, prefix in enumerate(data_prefixes):
        experiment_w = data[f'experiment_{prefix}_w']

        all_mmd_values_w = data[f'all_mmd_values_{prefix}_w']
        actual_mmd_w = data[f'actual_mmd_{prefix}_w'].item()

        plot_shaded_region_predictive(axes[row_idx, 0], experiment_w, experiment_w.particles_SVGD, SVGD_color)
        plot_shaded_region_predictive(axes[row_idx, 1], experiment_w, experiment_w.particles_VGD, VGD_color)
        axes[row_idx, 1].yaxis.set_ticks_position('none')
        plt.setp(axes[row_idx, 1].get_yticklabels(), visible=False)

        ax = axes[row_idx, 2]
        plot_diagnostic(ax, all_mmd_values=all_mmd_values_w, actual_mmd=actual_mmd_w)

        experiment_m = data[f'experiment_{prefix}_m']

        all_mmd_values_m = data[f'all_mmd_values_{prefix}_m']
        actual_mmd_m = data[f'actual_mmd_{prefix}_m'].item()

        plot_shaded_region_predictive(axes[row_idx, 3], experiment_m, experiment_m.particles_SVGD, SVGD_color)
        plot_shaded_region_predictive(axes[row_idx, 4], experiment_m, experiment_m.particles_VGD, VGD_color)
        axes[row_idx, 4].yaxis.set_ticks_position('none')
        plt.setp(axes[row_idx, 4].get_yticklabels(), visible=False)

        ax = axes[row_idx, 5]
        plot_diagnostic(ax, all_mmd_values=all_mmd_values_m, actual_mmd=actual_mmd_m)


        y_lim_w_svgd = axes[row_idx, 0].get_ylim()
        axes[row_idx, 1].set_ylim(y_lim_w_svgd)
        y_lim_m_svgd = axes[row_idx, 3].get_ylim()
        axes[row_idx, 4].set_ylim(y_lim_m_svgd)
        
    plt.tight_layout(rect=[0.03, 0, 1, 0.94])

    col_titles = [r'$P_\mathrm{Bayes}$', r'$P_\mathrm{PrO}$', r'$\mathrm{MMD}$', r'$P_\mathrm{Bayes}$', r'$P_\mathrm{PrO}$', r'$\mathrm{MMD}$']
    for i, title in enumerate(col_titles):
        if title:
            x_coord = (axes[0, i].get_position().x0 + axes[0, i].get_position().x1) / 2
            fig.text(x_coord, 0.96, title, ha='center', va='top', fontsize=28)

    for i, title in enumerate(row_titles):
        y_coord = (axes[i, 0].get_position().y0 + axes[i, 0].get_position().y1) / 2
        fig.text(0.02, y_coord, title, ha='left', va='center', fontsize=28, rotation=90)


    fig.subplots_adjust(
        wspace=0.16,   
        hspace=0.12    
    )
    plt.show()

In [None]:
print(sns.__version__)

In [None]:
plot_main_figure('sine_different_dimension_experiments.dill')