# Most informative combination of two lines

In [1]:
import os
import yaml
import json
import shutil
import itertools as itt
from typing import List, Optional, Union, Tuple, Dict

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from tqdm import tqdm

from helpers import latex_regime, prob_higher, plot_summary_1d, plot_summary_2d
from store_data_discrete import select_data
from astro30m import latex_line, latex_quantity

plt.rc("text", usetex=True)

figures_dir = os.path.join("figures", "2_lines_discrete")

## Load data files

In [2]:
# Load stored data

filename = os.path.join("data", "discrete", f'nlines-2.json')

if os.path.isfile(filename):
    with open(filename, 'r') as f:
        data = json.load(f)
else:
    raise FileNotFoundError(f'{filename} has not been yet created')

# Load reference data

filename = os.path.join("data", "discrete", f'nlines-1.json')

if os.path.isfile(filename):
    with open(filename, 'r') as f:
        ref_1 = json.load(f)
else:
    raise FileNotFoundError(f'{filename} has not been yet created')


# Load lines informations

with open(os.path.join('data', 'lines.json'), 'r') as f:
    l_lines = json.load(f)

# Load parameters informations

with open(os.path.join('data', 'parameters.json'), 'r') as f:
    l_params = json.load(f)
l_params += list(itt.combinations(l_params, 2))

# Load regimes informations

with open(os.path.join('data', 'regimes.json'), 'r') as f:
    d_regimes = json.load(f)

regimes = {
    'dust-av': ['1', '2', '3', '4'], # Ignore lowest regime
    'dust-g0': ['1', '2', '3', '4'], # Ignore lowest regime
}
for param in ['dust-av', 'dust-g0']:
    del d_regimes[param]['all'] # Ignore regime 'all'
    del d_regimes[param]['0']   # Ignore lowest regime

## Plot figures (one or two parameters)

- Without regime restriction
- With restriction on a single parameter
- With restriction on two parameters

In [3]:
def param_latex(param: Union[str, Tuple[str, str]]):
    if isinstance(param, str):
        return latex_quantity(param)
    return f"({latex_quantity(param[0])}, {latex_quantity(param[1])})"

def param_str(param: Union[str, Tuple[str, str]]):
    if isinstance(param, str):
        return param.replace('dust-', '').capitalize()
    return param[0].replace('dust-', '').capitalize() + '_' + param[1].replace('dust-', '').capitalize()

def regime_latex(param: Union[str, Tuple[str, str]], d_regimes: Dict[str, List[Tuple]], reg: Dict[str, str]):
    if isinstance(param, str):
        return latex_regime(param, d_regimes[param][reg[param]])
    return latex_regime(param[0], d_regimes[param[0]][reg[param[0]]]) + ", "\
        + latex_regime(param[1], d_regimes[param[1]][reg[param[1]]])

def regime_str(param: Union[str, Tuple[str, str]], reg: Dict[str, str]):
    if isinstance(param, str):
        return param_str(param) + "_" + reg[param]
    return param_str(param[0]) + "_" + reg[param[0]] + "_"\
        + param_str(param[1]) + "_" + reg[param[1]]

In [4]:
# Creates figures directory if not exists
if not os.path.isdir("figures"):
    os.mkdir("figures")
if not os.path.isdir(figures_dir):
    os.mkdir(figures_dir)

# Remove folder of parameters if exist and recreate them
for param in l_params:
    path = os.path.join(figures_dir, param_str(param))
    if os.path.isdir(path):
        shutil.rmtree(path)
    os.mkdir(path)

### Helpers

In [5]:
dpi = 200
cmap = 'OrRd' #'veridis'
xscale = 1.0
yscale = 1.0

def plot_condh(
    lines: List[str],
    hs: List[List[float]],
    show_diag: bool=True,
) -> Figure:
    fig, ax = plt.subplots(1, 1, figsize = (xscale*6.4, yscale*4.8), dpi=dpi)

    hs = np.array(hs)
    mask = np.where(
        np.tril(np.ones_like(hs), k=-1 if show_diag else 0),
        float('nan'), 1.
    )
    im = ax.imshow(mask * hs, origin='lower', cmap=cmap+'_r')

    cbar = fig.colorbar(im)
    cbar.set_label('Conditional entropy (bits)', labelpad=30, rotation=270)

    ax.set_xticks(np.arange(hs.shape[0]))
    ax.set_xticklabels([latex_line(l, short=True, equation_mode=True) for l in lines], rotation=45, fontsize=12)
    ax.set_yticks(np.arange(hs.shape[0]))
    ax.set_yticklabels([latex_line(l, short=True, equation_mode=True) for l in lines], rotation=45, fontsize=12);

    return fig

def plot_mi(
    lines: List[str],
    mis: List[List[float]],
    show_diag: bool=True,
) -> Figure:
    fig, ax = plt.subplots(1, 1, figsize = (xscale*6.4, yscale*4.8), dpi=dpi)

    mis = np.array(mis)
    mask = np.where(
        np.tril(np.ones_like(mis), k=-1 if show_diag else 0),
        float('nan'), 1.
    )
    im = ax.imshow(mask * mis, origin='lower', cmap=cmap)

    cbar = fig.colorbar(im)
    cbar.set_label('Mutual information (bits)', labelpad=30, rotation=270)

    ax.set_xticks(np.arange(mis.shape[0]))
    ax.set_xticklabels([latex_line(l, short=True, equation_mode=True) for l in lines], rotation=45, fontsize=12)
    ax.set_yticks(np.arange(mis.shape[0]))
    ax.set_yticklabels([latex_line(l, short=True, equation_mode=True) for l in lines], rotation=45, fontsize=12);

    return fig

In [6]:
min_pixels = 100

### Without regime restriction

In [7]:
for param in l_params:
    print("param:", param)

    reg = {key: 'all' for key in regimes}

    mis = np.zeros((len(l_lines), len(l_lines)))
    sigmas_mi = np.zeros((len(l_lines), len(l_lines)))
    for line1, line2 in list(itt.combinations_with_replacement(l_lines, r=2)):
        if line1 == line2:
            entry = select_data(ref_1, param, line1, reg)
        else:
            entry = select_data(data, param, (line1, line2), reg)
        i1, i2 = l_lines.index(line1), l_lines.index(line2)
        mis[i1, i2] = entry["mi"]
        mis[i2, i1] = entry["mi"]

    #

    path = os.path.join(figures_dir, param_str(param), f"regime_all")
    if not os.path.isdir(path):
        os.mkdir(path)

    fig = plot_mi(l_lines, mis)#, errs=sigmas_mi)
    plt.title(f"Mutual information between ${param_latex(param)}$ and lines intensity")
    plt.savefig(os.path.join(path, f"{param_str(param)}_mi"), bbox_inches="tight")
    
    plt.close('all')

param: dust-av
param: dust-g0
param: ('dust-av', 'dust-g0')


### With restriction on one parameter

In [8]:
for param, param_regime in itt.product(l_params, l_params[:2]):
    print("param:", param)
    print("param regime:", param_regime)

    best_lines = np.zeros(
        len(regimes[param_regime])
    ).tolist()
    confidences = np.zeros(
        len(regimes[param_regime])
    ).tolist()

    for i in tqdm(range(len(regimes[param_regime]))):

        reg = {param_regime: regimes[param_regime][i]}

        mis = np.zeros((len(l_lines), len(l_lines)))
        sigmas_mi = np.zeros((len(l_lines), len(l_lines)))
        for line1, line2 in list(itt.combinations_with_replacement(l_lines, r=2)):
            if line1 == line2:
                entry = select_data(ref_1, param, line1, reg)
            else:
                entry = select_data(data, param, (line1, line2), reg)
            i1, i2 = l_lines.index(line1), l_lines.index(line2)
            mis[i1, i2], mis[i2, i1] = entry["mi"], entry["mi"]

            try:
                sigma = max(
                    select_data(ref_1, param, line1, reg)["sigma-mi"],
                    select_data(ref_1, param, line2, reg)["sigma-mi"],
                )
                sigmas_mi[i1, i2], sigmas_mi[i2, i1] = sigma, sigma
            except TypeError:
                sigmas_mi[i1, i2], sigmas_mi[i2, i1] = sigma, sigma

        tril_x, tril_y = np.tril_indices(len(l_lines), -1) # -1 to ignore the diagonal

        if not any([el is None for el in mis]) and entry["pixels"] > min_pixels:
            probs = prob_higher(mis[tril_x, tril_y], sigmas_mi[tril_x, tril_y], approx=True)
            probs[np.isnan(probs)] = 0.
            order = np.argsort(probs)[::-1]
            order = order[probs[order] > 0.10] # We take the probabilities higher than 10%
            order = order[:min(order.size, 3)] # We take only the 3 first probabilities for display reasons
            
            best_lines[i] = [(l_lines[tril_x[k]], l_lines[tril_y[k]]) for k in order]
            confidences[i] = [probs[k] for k in order]
        else:
            best_lines[i] = None
            confidences[i] = None

        #

        path = os.path.join(figures_dir, param_str(param), f"regime_{param_str(param_regime)}")
        if not os.path.isdir(path):
            os.mkdir(path)

        fig = plot_mi(l_lines, mis)# errs=sigmas_mi)
        plt.title(f"Mutual information between ${param_latex(param)}$ and lines intensity ({regime_latex(param_regime, d_regimes, reg)})")
        plt.savefig(os.path.join(path, f"{param_str(param)}_regime_{regime_str(param_regime, reg)}_mi"), bbox_inches="tight")
        
        plt.close('all')

    fig = plot_summary_1d(param, {param_regime: d_regimes[param_regime]}, best_lines, confidences)
    plt.title(f"Most informative line on ${param_latex(param)}$ for each regime of ${param_latex(param_regime)}$")
    plt.savefig(os.path.join(path, f"{param_str(param)}_regime_{param_str(param_regime)}_summary"), bbox_inches="tight")

    plt.close(fig)

    print()

param: dust-av
param regime: dust-av


100%|██████████| 4/4 [00:07<00:00,  1.96s/it]



param: dust-av
param regime: dust-g0


100%|██████████| 4/4 [00:14<00:00,  3.54s/it]



param: dust-g0
param regime: dust-av


100%|██████████| 4/4 [00:05<00:00,  1.31s/it]



param: dust-g0
param regime: dust-g0


100%|██████████| 4/4 [00:42<00:00, 10.73s/it]



param: ('dust-av', 'dust-g0')
param regime: dust-av


100%|██████████| 4/4 [00:05<00:00,  1.39s/it]



param: ('dust-av', 'dust-g0')
param regime: dust-g0


100%|██████████| 4/4 [00:26<00:00,  6.69s/it]





### With restriction on two parameters

In [9]:
for param, params_regime in itt.product(l_params, l_params[2:]):
    print("param:", param)
    print("params regime:", params_regime)

    best_lines = np.zeros(
        (len(regimes[params_regime[0]]), len(regimes[params_regime[1]]))
    ).tolist()
    confidences = np.zeros(
        (len(regimes[params_regime[0]]), len(regimes[params_regime[1]]))
    ).tolist()

    for i, j in tqdm(list(itt.product(*[range(len(regimes[key])) for key in params_regime]))):

        reg = {
            params_regime[0]: regimes[params_regime[0]][i],
            params_regime[1]: regimes[params_regime[1]][j]
        }

        mis = np.zeros((len(l_lines), len(l_lines)))
        sigmas_mi = np.zeros((len(l_lines), len(l_lines)))
        for line1, line2 in list(itt.combinations_with_replacement(l_lines, r=2)):
            if line1 == line2:
                entry = select_data(ref_1, param, line1, reg)
            else:
                entry = select_data(data, param, (line1, line2), reg)
            i1, i2 = l_lines.index(line1), l_lines.index(line2)
            mis[i1, i2], mis[i2, i1] = entry["mi"], entry["mi"]

            try:
                sigma = max(
                    select_data(ref_1, param, line1, reg)["sigma-mi"],
                    select_data(ref_1, param, line2, reg)["sigma-mi"],
                )
                sigmas_mi[i1, i2], sigmas_mi[i2, i1] = sigma, sigma
            except TypeError:
                sigmas_mi[i1, i2], sigmas_mi[i2, i1] = None, None

        tril_x, tril_y = np.tril_indices(len(l_lines), -1)  # -1 to ignore the diagonal

        if not any([el is None for el in mis]) and entry["pixels"] > min_pixels:
            probs = prob_higher(mis[tril_x, tril_y], sigmas_mi[tril_x, tril_y], approx=True)
            probs[np.isnan(probs)] = 0.
            order = np.argsort(probs)[::-1]
            order = order[probs[order] > 0.10] # We take the probabilities higher than 10%
            order = order[:min(order.size, 3)] # We take only the 3 first probabilities for display reasons
            
            best_lines[i][j] = [(l_lines[tril_x[k]], l_lines[tril_y[k]]) for k in order]
            confidences[i][j] = [probs[k] for k in order]
        else:
            best_lines[i][j] = None
            confidences[i][j] = None

        #

        path = os.path.join(figures_dir, param_str(param), f"regime_{param_str(params_regime)}")
        if not os.path.isdir(path):
            os.mkdir(path)

        try:
            fig = plot_mi(l_lines, mis)#, errs=sigmas_mi)
            plt.title(f"Mutual information between ${param_latex(param)}$ and lines intensity ({regime_latex(params_regime, d_regimes, reg)})")
            plt.savefig(os.path.join(path, f"{param_str(param)}_regime_{regime_str(params_regime, reg)}_mi"), bbox_inches="tight")
        except:
            pass

        plt.close('all')

    fig = plot_summary_2d(param, {key: d_regimes[key] for key in params_regime}, best_lines, confidences)
    plt.title(f"Most informative line on ${param_latex(param)}$ for different regimes of ${param_latex(params_regime[0])}$ and ${param_latex(params_regime[1])}$")
    plt.savefig(os.path.join(path, f"{param_str(param)}_regime_{param_str(params_regime)}_summary"), bbox_inches="tight")
    
    plt.close(fig)

    print()

param: dust-av
params regime: ('dust-av', 'dust-g0')


100%|██████████| 16/16 [00:41<00:00,  2.62s/it]



param: dust-g0
params regime: ('dust-av', 'dust-g0')


100%|██████████| 16/16 [01:25<00:00,  5.33s/it]



param: ('dust-av', 'dust-g0')
params regime: ('dust-av', 'dust-g0')


100%|██████████| 16/16 [01:40<00:00,  6.31s/it]



