# Most informative lines figures

In [None]:
import os
import yaml
import json
import shutil
import itertools as itt
from typing import List, Optional, Union, Tuple, Dict

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from tqdm import tqdm

from helpers import latex_regime, prob_higher, plot_summary_1d, plot_summary_2d
from store_data_discrete import select_data
from astro30m import latex_line, latex_quantity

plt.rc("text", usetex=True)

figures_dir = os.path.join("figures", "1_line_discrete")

## Load data files

In [None]:
# Load stored data

filename = os.path.join("data", "discrete", f'nlines-1.json')

if os.path.isfile(filename):
    with open(filename, 'r') as f:
        data = json.load(f)
else:
    raise FileNotFoundError(f'{filename} has not been yet created')

# Load reference data

filename = os.path.join("data", "discrete", f'nlines-0.json')

if os.path.isfile(filename):
    with open(filename, 'r') as f:
        ref = json.load(f)
else:
    raise FileNotFoundError(f'{filename} has not been yet created')

# Load lines informations

with open(os.path.join('data', 'lines.yaml'), 'r') as f:
    l_lines = yaml.safe_load(f)

# Load parameters informations

with open(os.path.join('data', 'parameters.yaml'), 'r') as f:
    l_params = yaml.safe_load(f)
l_params += list(itt.combinations(l_params, 2))

# Load regimes informations

with open(os.path.join('data', 'regimes.yaml'), 'r') as f:
    d_regimes = yaml.safe_load(f)

regimes = {
    'dust-av': ['1', '2', '3', '4'], # Ignore lowest regime
    'dust-g0': ['1', '2', '3', '4'], # Ignore lowest regime
}
for param in ['dust-av', 'dust-g0']:
    del d_regimes[param]['all'] # Ignore regime 'all'
    del d_regimes[param]['0']   # Ignore lowest regime

## Plot figures (one or two parameters)

- Without regime restriction
- With restriction on a single parameter
- With restriction on two parameters

In [None]:
def param_latex(param: Union[str, Tuple[str, str]]):
    if isinstance(param, str):
        return latex_quantity(param)
    return f"({latex_quantity(param[0])}, {latex_quantity(param[1])})"

def param_str(param: Union[str, Tuple[str, str]]):
    if isinstance(param, str):
        return param.replace('dust-', '').capitalize()
    return param[0].replace('dust-', '').capitalize() + '_' + param[1].replace('dust-', '').capitalize()

def regime_latex(param: Union[str, Tuple[str, str]], d_regimes: Dict[str, List[Tuple]], reg: Dict[str, str]):
    if isinstance(param, str):
        return latex_regime(param, d_regimes[param][reg[param]])
    return latex_regime(param[0], d_regimes[param[0]][reg[param[0]]]) + ", "\
        + latex_regime(param[1], d_regimes[param[1]][reg[param[1]]])

def regime_str(param: Union[str, Tuple[str, str]], reg: Dict[str, str]):
    if isinstance(param, str):
        return param_str(param) + "_" + reg[param]
    return param_str(param[0]) + "_" + reg[param[0]] + "_"\
        + param_str(param[1]) + "_" + reg[param[1]]

In [None]:
# Creates figures directory if not exists
if not os.path.isdir("figures"):
    os.mkdir("figures")
if not os.path.isdir(figures_dir):
    os.mkdir(figures_dir)

# Remove folder of parameters if exist and recreate them
for param in l_params:
    path = os.path.join(figures_dir, param_str(param))
    if os.path.isdir(path):
        shutil.rmtree(path)
    os.mkdir(path)

### Helpers

In [None]:
dpi = 200
width = 0.6
xscale = 1.2
yscale = 1.0
capsize = 6

def plot_condh(
    lines: List[str],
    hs: List[float],
    ref_h: Optional[float]=None,
    sorted:bool=False,
    errs: Optional[List[Union[float, Tuple[float, float]]]]=None,
    ref_err: Optional[Union[float, Tuple[float]]]=None
) -> Figure:
    fig, ax = plt.subplots(1, 1, figsize = (xscale*6.4, yscale*4.8), dpi=dpi)

    if sorted:
        indices = np.array(hs).argsort()
        hs = [hs[i] for i in indices]
        errs = [errs[i] for i in indices]
        lines = [lines[i] for i in indices]

    ax.bar(np.arange(len(hs)), hs, width=width, color='tab:blue')
    ax.errorbar(np.arange(len(hs)), hs, yerr=errs, fmt='none', capsize=capsize, color='tab:red')

    if ref_h is not None:
        ax.bar(len(hs), ref_h, color='tab:orange', width=width)
        ax.errorbar(len(hs), ref_h, yerr=ref_err, fmt='none', capsize=capsize, color='tab:red')
        ax.set_xticks(np.arange(len(hs)+1))
        ax.set_xticklabels([latex_line(l, short=True, equation_mode=True) for l in lines] + ['REF'], rotation = 45, fontsize = 12)
    else:
        ax.set_xticks(np.arange(len(hs)))
        ax.set_xticklabels([latex_line(l, short=True, equation_mode=True) for l in lines], rotation = 45, fontsize = 12)

    ax.set_xlabel('Integrated molecular lines', labelpad = 20)
    ax.set_ylabel('Conditional entropy (bits)', labelpad = 20)

    return fig

def plot_mi(
    lines: List[str],
    mis: List[float],
    sorted:bool=False,
    errs: Optional[List[Union[float, Tuple[float, float]]]]=None,
) -> Figure:
    fig, ax = plt.subplots(1, 1, figsize = (xscale*6.4, yscale*4.8), dpi=dpi)

    if sorted:
        indices = np.array(mis).argsort()[::-1]
        mis = [mis[i] for i in indices]
        lines = [lines[i] for i in indices]

    ax.bar(np.arange(len(mis)), mis, width=width, color='tab:blue')
    ax.errorbar(np.arange(len(mis)), mis, yerr=errs, fmt='none', capsize=capsize, color='tab:red')

    ax.set_xticks(np.arange(len(mis)))
    ax.set_xticklabels([latex_line(l, short=True, equation_mode=True) for l in lines], rotation=45, fontsize = 12)

    ax.set_xlabel('Integrated molecular lines', labelpad = 20)
    ax.set_ylabel('Mutual information (bits)', labelpad = 20)

    return fig   

In [None]:
min_pixels = 100

### Without regime restriction

In [None]:
for param in l_params:
    print("param:", param)

    reg = {key: 'all' for key in regimes}

    entries = [select_data(data, param, line, reg) for line in l_lines]

    mis = np.array([entry["mi"] for entry in entries])
    sigmas_mi = np.array([entry["sigma-mi"] for entry in entries])

    #

    path = os.path.join(figures_dir, param_str(param), f"regime_all")
    if not os.path.isdir(path):
        os.mkdir(path)

    fig = plot_mi(l_lines, mis, sorted=False, errs=sigmas_mi)
    plt.title(f"Mutual information between ${param_latex(param)}$ and lines intensity")
    plt.savefig(os.path.join(path, f"{param_str(param)}_mi"), bbox_inches="tight")

    fig = plot_mi(l_lines, mis, sorted=True, errs=sigmas_mi)
    plt.title(f"Mutual information between ${param_latex(param)}$ and lines intensity")
    plt.savefig(os.path.join(path, f"{param_str(param)}_mi_sorted"), bbox_inches="tight")
    
    plt.close('all')

### With restriction on one parameter

In [None]:
for param, param_regime in itt.product(l_params, l_params[:2]):
    print("param:", param)
    print("param regime:", param_regime)

    best_lines = np.zeros(
        len(regimes[param_regime])
    ).tolist()
    confidences = np.zeros(
        len(regimes[param_regime])
    ).tolist()

    for i in tqdm(range(len(regimes[param_regime]))):

        reg = {param_regime: regimes[param_regime][i]}

        entries = [select_data(data, param, line, reg) for line in l_lines]

        mis = np.array([entry["mi"] for entry in entries])
        sigmas_mi = np.array([entry["sigma-mi"] for entry in entries])

        if not any([el is None for el in mis]) and entries[0]["pixels"] > min_pixels:
            probs = prob_higher(mis, sigmas_mi, approx=True)
            probs[np.isnan(probs)] = 0.
            order = np.argsort(probs)[::-1]
            order = order[probs[order] > 0.10] # We take the probabilities higher than 10%
            order = order[:min(order.size, 3)] # We take only the 3 first probabilities for display reasons
            best_lines[i] = [l_lines[i] for i in order]
            confidences[i] = [probs[i] for i in order]
        else:
            best_lines[i] = None
            confidences[i] = None

        #

        path = os.path.join(figures_dir, param_str(param), f"regime_{param_str(param_regime)}")
        if not os.path.isdir(path):
            os.mkdir(path)

        fig = plot_mi(l_lines, mis, sorted=False, errs=sigmas_mi)
        plt.title(f"Mutual information between ${param_latex(param)}$ and lines intensity ({regime_latex(param_regime, d_regimes, reg)})")
        plt.savefig(os.path.join(path, f"{param_str(param)}_regime_{regime_str(param_regime, reg)}_mi"), bbox_inches="tight")

        fig = plot_mi(l_lines, mis, sorted=True, errs=sigmas_mi)
        plt.title(f"Mutual information between ${param_latex(param)}$ and lines intensity ({regime_latex(param_regime, d_regimes, reg)})")
        plt.savefig(os.path.join(path, f"{param_str(param)}_regime_{regime_str(param_regime, reg)}_mi_sorted"), bbox_inches="tight")
        
        plt.close('all')

    fig = plot_summary_1d(param, {param_regime: d_regimes[param_regime]}, best_lines, confidences)
    plt.title(f"Most informative line on ${param_latex(param)}$ for each regime of ${param_latex(param_regime)}$")
    plt.savefig(os.path.join(path, f"{param_str(param)}_regime_{param_str(param_regime)}_summary"), bbox_inches="tight")

    plt.close(fig)

    print()

### With restriction on two parameters

In [None]:
for param, params_regime in itt.product(l_params, l_params[2:]):
    print("param:", param)
    print("params regime:", params_regime)

    best_lines = np.zeros(
        (len(regimes[params_regime[0]]), len(regimes[params_regime[1]]))
    ).tolist()
    confidences = np.zeros(
        (len(regimes[params_regime[0]]), len(regimes[params_regime[1]]))
    ).tolist()

    for i, j in tqdm(list(itt.product(*[range(len(regimes[key])) for key in params_regime]))):

        reg = {
            params_regime[0]: regimes[params_regime[0]][i],
            params_regime[1]: regimes[params_regime[1]][j]
        }

        entries = [select_data(data, param, line, reg) for line in l_lines]

        mis = np.array([entry["mi"] for entry in entries])
        sigmas_mi = np.array([entry["sigma-mi"] for entry in entries])

        if not any([el is None for el in mis]) and entries[0]["pixels"] > min_pixels:
            probs = prob_higher(mis, sigmas_mi, approx=True)
            probs[np.isnan(probs)] = 0.
            order = np.argsort(probs)[::-1]
            order = order[probs[order] > 0.10] # We take the probabilities higher than 10%
            order = order[:min(order.size, 4)] # We take only the 3 first probabilities for display reasons
            best_lines[i][j] = [l_lines[i] for i in order]
            confidences[i][j] = [probs[i] for i in order]
        else:
            best_lines[i][j] = None
            confidences[i][j] = None

        #

        path = os.path.join(figures_dir, param_str(param), f"regime_{param_str(params_regime)}")
        if not os.path.isdir(path):
            os.mkdir(path)

        try:
            fig = plot_mi(l_lines, mis, sorted=False, errs=sigmas_mi)
            plt.title(f"Mutual information between ${param_latex(param)}$ and lines intensity ({regime_latex(params_regime, d_regimes, reg)})")
            plt.savefig(os.path.join(path, f"{param_str(param)}_regime_{regime_str(params_regime, reg)}_mi"), bbox_inches="tight")

            fig = plot_mi(l_lines, mis, sorted=True, errs=sigmas_mi)
            plt.title(f"Mutual information between ${param_latex(param)}$ and lines intensity ({regime_latex(params_regime, d_regimes, reg)})")
            plt.savefig(os.path.join(path, f"{param_str(param)}_regime_{regime_str(params_regime, reg)}_mi_sorted"), bbox_inches="tight")
        except:
            pass

        plt.close('all')

    fig = plot_summary_2d(param, {key: d_regimes[key] for key in params_regime}, best_lines, confidences)
    plt.title(f"Most informative line on ${param_latex(param)}$ for different regimes of ${param_latex(params_regime[0])}$ and ${param_latex(params_regime[1])}$")
    plt.savefig(os.path.join(path, f"{param_str(param)}_regime_{param_str(params_regime)}_summary"), bbox_inches="tight")
    
    plt.close(fig)

    print()