In [1]:
import os
import sys
import itertools as itt
from typing import List, Dict, Tuple

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from matplotlib.colors import Normalize, LogNorm
from matplotlib.patches import Rectangle

import infovar

sys.path.insert(0, os.path.join("..", ".."))
from infobs.plots import Plotter

sys.path.insert(1, os.path.join(".."))
from pdr_util import get_physical_env, simulate, latex_line, latex_param, Settings

results_path = os.path.join("..", "data", "continuous", "results")

plt.rc("text", usetex=True)
Settings.only_rotational = True

In [2]:
handler = infovar.ContinuousHandler()

handler.set_paths(
    save_path=results_path,
)

In [3]:
plotter = Plotter(
    line_formatter=latex_line,
    param_formatter=latex_param
)

latex_comb_lines = lambda ls: plotter.lines_comb_formatter(ls, short=True)
latex_comb_params = lambda ps: plotter.params_comb_formatter(ps)

## Combinations

In [4]:
ticksfontsize = 18
labelfontsize = 22
titlefontsize = 22

In [5]:
param = "Avmax"
wins_features = ["radm", "Avmax"]
stat = "mi"

lines_list = [
    ["co_v0_j1__v0_j0", "co_v0_j2__v0_j1", ["co_v0_j1__v0_j0", "co_v0_j2__v0_j1"]],
    ["13c_o_j1__j0", "13c_o_j2__j1", ["13c_o_j1__j0", "13c_o_j2__j1"]],
    ["c_18o_j1__j0", "c_18o_j2__j1", ["c_18o_j1__j0", "c_18o_j2__j1"]]
]

fignames = ["12CO", "13CO", "C18O"]

In [13]:
for lines, name in zip(lines_list, fignames):

    # Look for vmax
    vmax = 0
    for ls in lines:

        d = handler.read(
            ls, param, wins_features
        )

        data = d["stats"][stat]["data"]
        vmax = max(vmax, np.max(data))

        wins_features = d["features"]

    #
    
    fig, axs = plt.subplots(1, 4, figsize=(2*6.4, 4.8), width_ratios=[1, 1, 1, 0.1], dpi=150)

    for i, ls in enumerate(lines):
    
        ax = axs[i]

        d = handler.read(
            ls, param, wins_features
        )

        data = d["stats"][stat]["data"]
        samples = d["stats"][stat]["samples"]
        xticks, yticks = d["stats"][stat]["coords"]
        wins_features = d["features"]

        paramy, paramx = wins_features

        X, Y = np.meshgrid(yticks, xticks)

        im = ax.pcolor(X, Y, data, cmap='jet', vmin=0, vmax=vmax)
        # ax.set_xlim(lims[params_regime[0]])
        # ax.set_ylim(lims[params_regime[1]])

        ax.set_xscale('log')
        ax.set_yscale('log')

        ax.xaxis.set_tick_params(labelsize=ticksfontsize)
        ax.set_xlabel(f"${plotter.param_formatter(paramx)}$", fontsize=labelfontsize, labelpad=10)
        if i == 0:
            ax.yaxis.set_tick_params(labelsize=ticksfontsize)
            ax.set_ylabel(f"${plotter.param_formatter(paramy)}$", fontsize=labelfontsize, labelpad=10)
        else:
            ax.yaxis.set_ticks([])

        # ax.set_box_aspect(1)
        ax.set_title(f"${plotter.lines_comb_formatter(ls)}$", fontsize=titlefontsize, pad=15)

        #

        if name == "12CO" and i == 0:
            av_center, g0_center = (1e1*2e1)**0.5, 2e1
            av_width, g0_width = 2, 5.2
            
            av_anchor = av_center/av_width**0.5
            g0_anchor = g0_center/g0_width**0.5

            rect = Rectangle(
                (av_anchor, g0_anchor), av_anchor*(av_width-1), g0_anchor*(g0_width-1), linewidth=1.5,
                edgecolor='tab:red', facecolor='none'
            )
            ax.scatter([av_center], [g0_center], color="tab:red", s=12)
            ax.add_patch(rect)

    cbar = fig.colorbar(im, cax=axs[-1])
    cbar.set_label("Mutual information (bits)", labelpad=10, fontsize=labelfontsize)
    cbar.ax.tick_params(labelsize=ticksfontsize)

    plt.tight_layout()
    plt.savefig(f"{param}_{name}.png", bbox_inches="tight")
    plt.close(fig)

## Grid

In [7]:
ticksfontsize = 28
labelfontsize = 28
titlefontsize = 32

In [8]:
# ["co_v0_j1__v0_j0"],
# ["co_v0_j2__v0_j1"],
# ["co_v0_j3__v0_j2"],
# ["13c_o_j1__j0"],
# ["13c_o_j2__j1"],
# ["13c_o_j3__j2"],
# ["c_18o_j1__j0"],
# ["c_18o_j2__j1"],
# ["c_18o_j3__j2"],
# ["hcop_j1__j0"],
# ["hcop_j2__j1"],
# ["hcop_j3__j2"],
# ["hcop_j4__j3"],
# ["hnc_j1__j0"],
# ["hnc_j3__j2"],
# ["hcn_j1_f2__j0_f1"],
# ["hcn_j2_f3__j1_f2"],
# ["hcn_j3_f3__j2_f3"],
# ["cs_j2__j1"],
# ["cs_j3__j2"],
# ["cs_j5__j4"],
# ["cs_j6__j5"],
# ["cs_j7__j6"],
# # CN lines
# ["cn_n1_j0d5__n0_j0d5"],
# ["cn_n1_j1d5__n0_j0d5"],
# ["cn_n2_j1d5__n1_j0d5"],
# ["cn_n2_j2d5__n1_j1d5"],
# ["cn_n3_j3d5__n2_j2d5"],
# # C2H lines
# ["c2h_n1d0_j1d5_f2d0__n0d0_j0d5_f1d0"],
# ["c2h_n2d0_j2d5_f3d0__n1d0_j1d5_f2d0"],
# ["c2h_n3d0_j3d5_f4d0__n2d0_j2d5_f3d0"],
# ["c2h_n3d0_j2d5_f3d0__n2d0_j1d5_f2d0"],   
# ["c2h_n4d0_j4d5_f5d0__n3d0_j3d5_f4d0"],
# # Carbon lines
# ["c_el3p_j1__el3p_j0"],
# ["c_el3p_j2__el3p_j1"],
# ["cp_el2p_j3_2__el2p_j1_2"]

In [9]:
param = "Avmax"
wins_features = ["radm", "Avmax"]
stat = "mi"
logscale = True

grid_lines = [
    ["co_v0_j1__v0_j0", "13c_o_j1__j0", "c_18o_j1__j0", "hcop_j1__j0"],
    ["co_v0_j2__v0_j1", "13c_o_j2__j1", "c_18o_j2__j1", "hcop_j2__j1"],
    ["hcn_j1_f2__j0_f1", "hcn_j2_f3__j1_f2", "hnc_j1__j0", "hnc_j3__j2"],
    ["cs_j2__j1", "cs_j3__j2", "cn_n1_j0d5__n0_j0d5", "cn_n1_j1d5__n0_j0d5"],
    ["c2h_n1d0_j1d5_f2d0__n0d0_j0d5_f1d0", "c_el3p_j1__el3p_j0", "c_el3p_j2__el3p_j1", "cp_el2p_j3_2__el2p_j1_2"]
]

rows, cols = 5, 4

In [10]:
# Look for vmax
vmax = 0
vmin = 5e-2 if logscale else 0
for ls in list(itt.chain.from_iterable(grid_lines)):

    d = handler.read(
        ls, param, wins_features
    )

    data = d["stats"][stat]["data"]
    vmax = max(vmax, np.max(data))

    wins_features = d["features"]

#
    
fig, axs = plt.subplots(rows, cols+1, figsize=(0.8*cols*6.4, 1.1*rows*4.8), width_ratios=cols*[1]+[0.15], dpi=300)

for i, j in itt.product(range(rows), range(cols)):

    ls = grid_lines[i][j]
    ax = axs[i, j]
    
    d = handler.read(
        ls, param, wins_features
    )

    data = d["stats"][stat]["data"]
    samples = d["stats"][stat]["samples"]
    xticks, yticks = d["stats"][stat]["coords"]
    wins_features = d["features"]

    paramy, paramx = wins_features

    X, Y = np.meshgrid(yticks, xticks)

    im = ax.pcolor(X, Y, data.clip(vmin, vmax), cmap='jet', norm=LogNorm(vmin, vmax) if logscale else Normalize(vmin, vmax))
    # ax.set_xlim(lims[params_regime[0]])
    # ax.set_ylim(lims[params_regime[1]])

    ax.set_xscale('log')
    ax.set_yscale('log')

    if i == rows-1:
        ax.xaxis.set_tick_params(labelsize=ticksfontsize)
        ax.set_xlabel(f"${plotter.param_formatter(paramx)}$", fontsize=labelfontsize, labelpad=10)
    else:
        ax.xaxis.set_ticks([])
    if j == 0:
        ax.yaxis.set_tick_params(labelsize=ticksfontsize)
        ax.set_ylabel(f"${plotter.param_formatter(paramy)}$", fontsize=labelfontsize, labelpad=10)
    else:
        ax.yaxis.set_ticks([])

    ax.set_box_aspect(1)
    ax.set_title(f"${plotter.lines_comb_formatter(ls)}$", fontsize=titlefontsize, pad=15)

gs = axs[0, -1].get_gridspec()
# Remove the underlying axes
for ax in axs:
    ax[-1].remove()
axbig = fig.add_subplot(gs[:, -1])

cbar = fig.colorbar(im, cax=axbig)
cbar.set_label("Mutual information (bits)", labelpad=10, fontsize=labelfontsize)
cbar.ax.tick_params(labelsize=ticksfontsize)

# plt.tight_layout()
plt.savefig(f"grid{'_log' if logscale else ''}.png", bbox_inches="tight")
plt.close(fig)