In [None]:
import json
import numpy as np
import matplotlib.pyplot as plt

In [None]:
with open("Data/polarization/mean-field_epsilon2_epsilon3_polarization.json") as file:
    data1 = json.load(file)
with open("Data/polarization/empirical_epsilon2_epsilon3_polarization.json") as file:
    data2 = json.load(file)
with open("Data/polarization/mean-field_beta2_beta3_polarization.json") as file:
    data3 = json.load(file)
with open("Data/polarization/empirical_beta2_beta3_polarization.json") as file:
    data4 = json.load(file)

epsilon2_mf = np.array(data1["epsilon2"], dtype=float)
epsilon3_mf = np.array(data1["epsilon3"], dtype=float)
psi_1 = np.array(data1["psi"], dtype=float)

epsilon2_sim = np.array(data2["epsilon2"], dtype=float)
epsilon3_sim = np.array(data2["epsilon3"], dtype=float)
psi_2 = np.array(data2["psi"], dtype=float)

beta2_mf = np.array(data3["beta2"], dtype=float)
beta3_mf = np.array(data3["beta3"], dtype=float)
psi_3 = np.array(data3["psi"], dtype=float)

beta2_sim = np.array(data4["beta2"], dtype=float)
beta3_sim = np.array(data4["beta3"], dtype=float)
psi_4 = np.array(data4["psi"], dtype=float)

In [None]:
tick_label_fontsize = 12
axis_label_fontsize = 14
colorbar_fontsize = 16

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(7, 6))


axes[0, 0].imshow(
    np.flipud(psi_1),
    extent=[min(epsilon2_mf), max(epsilon2_mf), min(epsilon3_mf), max(epsilon3_mf)],
    vmin=0,
    vmax=1,
    aspect="auto",
    cmap="inferno",
)
axes[0, 0].set_xlabel(r"$\epsilon_2$", fontsize=axis_label_fontsize)
axes[0, 0].set_ylabel(r"$\epsilon_3$", fontsize=axis_label_fontsize)
axes[0, 0].set_xticks([0, 0.5, 1], fontsize=tick_label_fontsize)
axes[0, 0].set_yticks([0.5, 0.75, 1], fontsize=tick_label_fontsize)
axes[0, 0].set_yticklabels(["0.5", "0.75", "1"])
axes[0, 0].set_xticklabels(["0", "0.5", "1"])

axes[0, 1].imshow(
    np.flipud(psi_2.T),
    extent=[min(epsilon2_sim), max(epsilon2_sim), min(epsilon3_sim), max(epsilon3_sim)],
    vmin=0,
    vmax=1,
    aspect="auto",
    cmap="inferno",
)
axes[0, 1].set_xlabel(r"$\epsilon_2$", fontsize=axis_label_fontsize)
axes[0, 1].set_ylabel(r"$\epsilon_3$", fontsize=axis_label_fontsize)
axes[0, 1].set_xticks([0, 0.5, 1], fontsize=tick_label_fontsize)
axes[0, 1].set_yticks([0.5, 0.75, 1], fontsize=tick_label_fontsize)
axes[0, 1].set_yticklabels(["0.5", "0.75", "1"])
axes[0, 1].set_xticklabels(["0", "0.5", "1"])


axes[1, 0].imshow(
    np.flipud(psi_3),
    extent=[min(beta2_mf), max(beta2_mf), min(beta3_mf), max(beta3_mf)],
    vmin=0,
    vmax=1,
    aspect="auto",
    cmap="inferno",
)
axes[1, 0].set_xlabel(r"$\widetilde{\beta}_2$", fontsize=axis_label_fontsize)
axes[1, 0].set_ylabel(r"$\widetilde{\beta}_3$", fontsize=axis_label_fontsize)
axes[1, 0].set_xticks([0, 0.25, 0.5], fontsize=tick_label_fontsize)
axes[1, 0].set_yticks([3, 4, 5, 6], fontsize=tick_label_fontsize)
axes[1, 0].set_xticklabels(["0", "0.25", "0.5"])

im = axes[1, 1].imshow(
    np.flipud(psi_4.T),
    extent=[min(beta2_sim), max(beta2_sim), min(beta3_sim), max(beta3_sim)],
    vmin=0,
    vmax=1,
    aspect="auto",
    cmap="inferno",
)
axes[1, 1].set_xlabel(r"$\widetilde{\beta}_2$", fontsize=axis_label_fontsize)
axes[1, 1].set_ylabel(r"$\widetilde{\beta}_3$", fontsize=axis_label_fontsize)
axes[1, 1].set_xticks([0, 0.25, 0.5], fontsize=tick_label_fontsize)
axes[1, 1].set_yticks([3, 4, 5, 6], fontsize=tick_label_fontsize)
axes[1, 1].set_xticklabels(["0", "0.25", "0.5"])


fig.subplots_adjust(bottom=0.15, top=0.95, left=0.1, right=0.85, wspace=0.4, hspace=0.4)
cbar_ax = fig.add_axes([0.88, 0.15, 0.03, 0.8])
cbar = fig.colorbar(im, cax=cbar_ax)
cbar.set_label(r"$\psi$", fontsize=colorbar_fontsize, rotation=270, labelpad=10)
plt.savefig("Figures/Fig4/balanced_psi.pdf", dpi=1000)
plt.savefig("Figures/Fig4/balanced_psi.png", dpi=1000)
plt.show()