In [None]:
import os
import sys
import itertools as itt
from typing import List

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

from infovar import ContinuousHandler

sys.path.insert(0, os.path.join("..", ".."))
sys.path.insert(1, os.path.join(".."))

from infobs.plots import Plotter
from orion_util import latex_line, latex_param 

plt.rc("text", usetex=True)
data_dir = os.path.join("..", "data", "continuous")
figures_dir = os.path.join("continuous_images_comparison_comb")

## Configuration of Handler

In [None]:
handler = ContinuousHandler()

handler.set_paths(
    save_path=data_dir
)

## Configuration of Plotter

In [None]:
plotter = Plotter(
    line_formatter=latex_line,
    param_formatter=latex_param
)

latex_comb_lines = lambda ls: plotter.lines_comb_formatter(ls, short=True)
latex_comb_params = lambda ps: plotter.params_comb_formatter(ps)

## Settings

In [None]:
lims = {
    'av': [1e+0, 8e+1],
    'g0': [None, None]
}

In [None]:
params_target = "g0"            # Can be "av", "g0" or ["av", "g0"]
params_regime = ["av", "g0"]    # Cannot be modified
combs_list = handler.get_available_variables(params_target)

stats = ["mi", "linearinfo", "linearinfogauss"]

## Formatting

In [None]:
if isinstance(params_target, str):
    params_target = [params_target]
assert isinstance(params_target, List)

for i, l in enumerate(combs_list):
    if isinstance(l, str):
        combs_list[i] = [l]
    assert isinstance(combs_list[i], List)

## Comparison between informativity of marginal lines and combinations

In [None]:
lines_list = []
for c in combs_list:
    lines_list.extend(c)
lines_list = list(set(lines_list))

lines_list.sort()

In [None]:
def plot_comp(
    xticks, yticks,
    mat1, mat2, mat12,
    vmax,
    paramx, paramy,
    line1, line2,
    stat
):
    # Settings

    ticksfontsize = 16
    labelfontsize = 18
    titlefontsize = 18
    suptitlefontsize = 20

    fig, axs = plt.subplots(1, 3, figsize=(2.3*6.4, 4.8), dpi=150, constrained_layout=True)
    ax1, ax2, ax12 = axs.flatten().tolist()

    # Coordinates

    X, Y = np.meshgrid(xticks, yticks)

    # First marginal

    ax1.pcolor(X, Y, mat1, cmap='jet', vmin=0, vmax=vmax)
    ax1.set_xlim(lims[params_regime[0]])
    ax1.set_ylim(lims[params_regime[1]])

    ax1.set_xscale('log')
    ax1.set_yscale('log')

    ax1.set_xlabel(f"${plotter.param_formatter(paramx)}$", fontsize=labelfontsize)
    ax1.set_ylabel(f"${plotter.param_formatter(paramy)}$", fontsize=labelfontsize)
    ax1.set_title(f"${plotter.lines_comb_formatter(line1)}$", fontsize=titlefontsize, pad=5)

    ax1.tick_params(axis='both', labelsize=ticksfontsize)

    # Second marginal

    ax2.pcolor(X, Y, mat2, cmap='jet', vmin=0, vmax=vmax)
    ax2.set_xlim(lims[params_regime[0]])
    ax2.set_ylim(lims[params_regime[1]])

    ax2.set_xscale('log')
    ax2.set_yscale('log')

    ax2.set_yticks([])
    ax2.set_xlabel(f"${plotter.param_formatter(paramx)}$", fontsize=labelfontsize)
    ax2.set_title(f"${plotter.lines_comb_formatter(line2)}$", fontsize=titlefontsize, pad=5)

    ax2.tick_params(axis='both', labelsize=ticksfontsize)
    
    # Combination

    im = ax12.pcolor(X, Y, mat12, cmap='jet', vmin=0, vmax=vmax)
    ax12.set_xlim(lims[params_regime[0]])
    ax12.set_ylim(lims[params_regime[1]])

    # cbar = fig.colorbar(im, ax=[ax1, ax2, ax12])
    cbar = fig.colorbar(im, ax=[ax12])
    cbar.set_label("Amount of information (bits)", labelpad=15, fontsize=labelfontsize)
    cbar.ax.tick_params(labelsize=ticksfontsize)

    ax12.set_xscale('log')
    ax12.set_yscale('log')

    ax12.set_yticks([])
    ax12.set_xlabel(f"${plotter.param_formatter(paramx)}$", fontsize=labelfontsize)
    ax12.set_title(f"${plotter.lines_comb_formatter([line1, line2])}$", fontsize=titlefontsize, pad=5)

    ax12.tick_params(axis='both', labelsize=ticksfontsize)

    # Suptitle
    
    _d = {
        'mi': "Estimation of mutual information",
        'linearinfo': "Estimation of mutual information under multivariate Gaussian assumption",
        'linearinfogauss': "Estimation of mutual information under multivariate Gaussian assumption"
    }
    plt.suptitle(f"{_d[stat].capitalize()} between ${plotter.params_comb_formatter(params_target)}$ and observables", fontsize=suptitlefontsize)

    return fig, (ax1, ax2, ax12)


def plot_info_gain(
    xticks, yticks,
    mat1, mat2, mat12,
    vmax, vmax_gain,
    paramx, paramy,
    line1, line2,
):
    # Settings

    ticksfontsize = 16
    labelfontsize = 18
    titlefontsize = 18
    suptitlefontsize = 20

    fig, axs = plt.subplots(1, 3, figsize=(2.3*6.4, 4.8), dpi=150, constrained_layout=True)
    axmax, axcomb, axgain = axs.flatten().tolist()

    # Coordinates

    X, Y = np.meshgrid(xticks, yticks)

    # Maximum

    matmax = np.maximum(mat1, mat2)

    axmax.pcolor(X, Y, matmax, cmap='jet', vmin=0, vmax=vmax)
    axmax.set_xlim(lims[params_regime[0]])
    axmax.set_ylim(lims[params_regime[1]])

    axmax.set_xscale('log')
    axmax.set_yscale('log')

    axmax.set_xlabel(f"${plotter.param_formatter(paramx)}$", fontsize=labelfontsize)
    axmax.set_ylabel(f"${plotter.param_formatter(paramy)}$", fontsize=labelfontsize)
    axmax.set_title(f"Maximum marginal mutual information", fontsize=titlefontsize, pad=5)

    axmax.tick_params(axis='both', labelsize=ticksfontsize)
    
    # Combination

    im = axcomb.pcolor(X, Y, mat12, cmap='jet', vmin=0, vmax=vmax)
    axcomb.set_xlim(lims[params_regime[0]])
    axcomb.set_ylim(lims[params_regime[1]])

    # cbar = fig.colorbar(im, ax=[axmax, axcomb])
    # cbar.set_label("Amount of information (bits)", labelpad=15, fontsize=labelfontsize)
    # cbar.ax.tick_params(labelsize=ticksfontsize)

    axcomb.set_xscale('log')
    axcomb.set_yscale('log')

    axcomb.set_yticks([])
    axcomb.set_xlabel(f"${plotter.param_formatter(paramx)}$", fontsize=labelfontsize)
    axcomb.set_title(f"${plotter.lines_comb_formatter([line1, line2])}$", fontsize=titlefontsize, pad=5)

    axcomb.tick_params(axis='both', labelsize=ticksfontsize)

    # Information gain

    im = axgain.pcolor(X, Y, mat12-matmax, cmap='magma', vmin=0, vmax=vmax_gain)
    axgain.set_xlim(lims[params_regime[0]])
    axgain.set_ylim(lims[params_regime[1]])

    cbar = fig.colorbar(im, ax=[axgain])
    cbar.set_label("Amount of information (bits)", labelpad=15, fontsize=labelfontsize)
    cbar.ax.tick_params(labelsize=ticksfontsize)

    axgain.set_xscale('log')
    axgain.set_yscale('log')

    axgain.set_yticks([])
    axgain.set_xlabel(f"${plotter.param_formatter(paramx)}$", fontsize=labelfontsize)
    axgain.set_title(f"Information gain", fontsize=titlefontsize, pad=5)

    axgain.tick_params(axis='both', labelsize=ticksfontsize)

    # Suptitle
    
    plt.suptitle(f"Information gain on ${plotter.params_comb_formatter(params_target)}$ when knowing the combination of observables", fontsize=suptitlefontsize)

    return fig, (axmax, axcomb, axgain)

In [None]:
vmax = 0
vmax_gain = 0
for lines in combs_list:
    try:
        d1 = handler.read(
            params_target, lines[0], params_regime
        )
        d2 = handler.read(
            params_target, lines[1], params_regime
        )
        d12 = handler.read(
            params_target, lines, params_regime
        )
    except:
        continue
    vmax = max(vmax, np.nanmax(d12["mi"]['data']))
    vmax = max(vmax, np.nanmax(d12["linearinfo"]['data']))
    vmax = max(vmax, np.nanmax(d12["linearinfogauss"]['data']))

    infogain = lambda st: d12[st]['data'] - np.maximum(d1[st]['data'], d2[st]['data'])

    vmax_gain = max(vmax_gain, np.nanmax(infogain("mi")))
    vmax_gain = max(vmax_gain, np.nanmax(infogain("linearinfo")))
    vmax_gain = max(vmax_gain, np.nanmax(infogain("linearinfogauss")))


for stat in stats:
    if not os.path.isdir(os.path.join(figures_dir, stat)):
        os.mkdir(os.path.join(figures_dir, stat))

    for l1, l2 in tqdm(list(itt.combinations(lines_list, 2)), desc=stat):
        try:
            d1 = handler.read(
                params_target, l1, params_regime
            )
            xticks, yticks = d1[stat]['coords']
            mat1 = d1[stat]['data'].T

            d2 = handler.read(
                params_target, l2, params_regime
            )
            mat2 = d2[stat]['data'].T

            d12 = handler.read(
                params_target, [l1, l2], params_regime
            )
            mat12 = d12[stat]['data'].T
        except:
            continue

        #
        
        filename = f"{'_'.join(params_target)}__{'_'.join([l1, l2])}_{stat}.png"
        
        fig, _ = plot_comp(
            xticks, yticks,
            mat1, mat2, mat12, vmax,
            params_regime[0], params_regime[1],
            l1, l2,
            stat
        )

        plt.savefig(os.path.join(figures_dir, stat, filename), bbox_inches="tight")
        plt.close(fig)

        #

        filename = f"{'_'.join(params_target)}__{'_'.join([l1, l2])}_{stat}_gain.png"
        
        fig, _ = plot_info_gain(
            xticks, yticks,
            mat1, mat2, mat12, vmax, vmax_gain,
            params_regime[0], params_regime[1],
            l1, l2,
        )

        plt.savefig(os.path.join(figures_dir, stat, filename), bbox_inches="tight")
        plt.close(fig)