In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from thesis_utils.plotting import save_figure, set_plotting, get_default_figsize
from thesis_utils.random import seed_everything
from thesis_utils import colours as thesis_colours

set_plotting()
seed_everything()

In [None]:
x = np.concatenate([
    np.random.normal(1, 0.5, 5000),
    np.random.normal(-1, 0.5, 5000)
])
y = np.concatenate([
    np.random.normal(-1, 0.7, 5000),
    np.random.normal(1, 0.7, 5000)
])

In [None]:
xx = np.linspace(0, 1, 100)
ymax = 1.5 ** 5
yy = -(xx + 0.5)**5 + ymax

In [None]:
X = np.array([0, 0.3, 0.5, 0.7, 0.9])[::-1]
L = -(X + 0.5)**5 + ymax

In [None]:
figsize = get_default_figsize()
figsize[1] /= 1.5
fig, axes = plt.subplots(1, 2, figsize=figsize)
sns.kdeplot(
    x=x, y=y, levels=len(X), fill=True, gridsize=100, ax=axes[0], color=thesis_colours.teal,
)
axes[0].set_xlim([-2.5, 2.5])
axes[0].set_ylim([-3.5, 3.5])
axes[0].set_xticks([])
axes[0].set_yticks([])
axes[0].set_xlabel('Parameter space')

colours = [axes[0].get_children()[i]._facecolors[0] for i in range(4)]
# 
axes[1].plot(xx, yy, c=thesis_colours.pillarbox)
axes[1].set_xlim([0, 1])
axes[1].set_ylim([0, ymax + 1])
axes[1].set_yticks([])
axes[1].set_xticks([1] + X.tolist())
axes[1].set_xticklabels(['1', r'$X_1$', r'$X_2$', r'$X_3$', r'$X_4$', '0'])
#axes[1].set_xticks([0, 1])
#axes[1].set_xlabel(r'$X$')
axes[1].set_ylabel(r'$\mathcal{L}$', rotation=0, labelpad=15)

for i in range(X.size - 1):
    a = np.linspace(X[i], X[i+1])
    b = -(a + 0.5)**5 + ymax
    axes[1].fill_between(a, b, color=colours[i],)
    axes[1].text(X[i], L[i] + 0.2, r"$\mathcal{L}" + f"_{i}$")


plt.tight_layout()
save_figure(fig, "nest_plot")