# Most informative lines figures

In [13]:
import os
import sys
import yaml
import json
import shutil
import itertools as itt
from typing import List, Dict, Union, Tuple

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
from tqdm import tqdm

sys.path.insert(0, os.path.join("..", ".."))
sys.path.insert(1, os.path.join(".."))

from infovar import DiscreteHandler
from infovar.stats.ranking import prob_higher
from infovar.stats.statistics import MI
from infovar.stats.resampling import Subsampling

from infobs.plots import Plotter
from orion_util import get_pixels_of_region, latex_line, latex_param 

plt.rc("text", usetex=True)

env_name = "pdr3"

data_dir = os.path.join("..", "data", "raw", "emir_simulations")

if not os.path.exists(env_name):
    os.mkdir(env_name)

## Configuration of Plotter

In [14]:
plotter = Plotter(
    line_formatter=latex_line,
    param_formatter=latex_param,
)

latex_comb_lines = lambda ls: plotter.lines_comb_formatter(ls, short=False)
latex_comb_params = lambda ps: plotter.params_comb_formatter(ps)

## Environment

In [15]:
with open('envs.yaml', 'r') as f:
    envs_dict = yaml.safe_load(f)

## Simulate environment

In [16]:
n_samples = 100_000
# low, upp = get_physical_env(env_name)
# low = envs_dict[env_name]["lower_bounds_lin"]
# upp = envs_dict[env_name]["upper_bounds_lin"]  

with open("envs.yaml", "r") as f:
    d = yaml.safe_load(f)
    low, upp = d[env_name]["lower_bounds_lin"], d[env_name]["upper_bounds_lin"]

df = simulate(
    n_samples,
    lower_bounds_lin = low,
    upper_bounds_lin = upp,
    obs_time = 0.75
)

In [17]:
params = df.columns.to_list()[:3]
lines = df.columns.to_list()[3:]

print("Number of parameters:", len(params))
print("Number of lines:", len(lines))

Number of parameters: 3
Number of lines: 38


In [18]:
def param_str(params: Union[str, List[str]]):
    if isinstance(params, str):
        return params
    return "_".join(params)

def regime_str(params: Union[str, Tuple[str, str]], reg: Dict[str, str]):
    if isinstance(params, str):
        return param_str(params) + "_" + reg[params]
    return "_".join([param_str(param) + "_" + reg[param] for param in params])

## Plot without regime restriction

In [19]:
# Samples restriction
_df = df.loc[:10_000]

In [20]:
mi = MI()
subsampling = Subsampling()

subsampling_params = {
    "stat": mi,
    "n": 5,
    "min_samples": 20,
    "min_subsets": 5,
    "decades": 2,
}

In [21]:
ref_lines = {}
ref_mis = {}
ref_sigmas = {}
for param in params:
    print("param:", param)
    y = _df[param].to_numpy()
    y = y.reshape(-1, 1)
    
    mis = []
    sigmas = []
    for line in tqdm(lines):
        x = _df[line].to_numpy()
        x = x.reshape(-1, 1)
        
        mis.append(
            mi(x, y)
        )

        sigmas.append(
            subsampling.compute_sigma(
                x, y, **subsampling_params
            )
        )
    
    #
    
    n_lines = len(lines)

    idx = np.argsort(mis)[::-1][:n_lines]
    _lines = [lines[i] for i in idx]
    _mis = [mis[i] for i in idx]
    _sigmas = [sigmas[i] for i in idx]
    
    #

    plt.figure(figsize=(n_lines/10 * 6.4, 4.8), dpi=150)
    _ = plotter.plot_mi_bar(_lines, _mis, sorted=False, errs=_sigmas, short_names=False)
    plt.title(f"Mutual information between ${latex_comb_params(param)}$ and lines intensity")
    plt.savefig(os.path.join(env_name, f"{param_str(param)}_mi"), bbox_inches="tight")
    
    plt.close('all')

    #

    ref_lines[param] = _lines
    ref_mis[param] = _mis
    ref_sigmas[param] = _sigmas

param: P


100%|██████████| 38/38 [00:22<00:00,  1.66it/s]


param: radm


100%|██████████| 38/38 [00:22<00:00,  1.67it/s]


param: Avmax


100%|██████████| 38/38 [00:22<00:00,  1.69it/s]


In [22]:
ref_lines

{'P': ['cn_n1_j1d5__n0_j0d5',
  'cn_n1_j0d5__n0_j0d5',
  'cs_j3__j2',
  'co_v0_j3__v0_j2',
  'cs_j2__j1',
  'cs_j5__j4',
  '13c_o_j3__j2',
  'hcop_j3__j2',
  'co_v0_j2__v0_j1',
  'cp_el2p_j3_2__el2p_j1_2',
  'c_18o_j2__j1',
  'hcop_j2__j1',
  '13c_o_j2__j1',
  'co_v0_j1__v0_j0',
  'c_el3p_j1__el3p_j0',
  'hnc_j1__j0',
  'hcop_j1__j0',
  'cn_n2_j2d5__n1_j1d5',
  'c_el3p_j2__el3p_j1',
  'c_18o_j1__j0',
  'hcn_j1_f2__j0_f1',
  'c2h_n3d0_j3d5_f4d0__n2d0_j2d5_f3d0',
  'c_18o_j3__j2',
  '13c_o_j1__j0',
  'hcn_j3_f3__j2_f3',
  'cn_n3_j2d5__n2_j1d5',
  'cs_j6__j5',
  'c2h_n4d0_j3d5_f4d0__n3d0_j2d5_f3d0',
  'cs_j7__j6',
  'c2h_n3d0_j2d5_f3d0__n2d0_j1d5_f2d0',
  'cn_n3_j3d5__n2_j2d5',
  'hnc_j3__j2',
  'cn_n2_j1d5__n1_j0d5',
  'c2h_n2d0_j2d5_f3d0__n1d0_j1d5_f2d0',
  'hcop_j4__j3',
  'hcn_j4_f4__j3_f3',
  'c2h_n4d0_j4d5_f5d0__n3d0_j3d5_f4d0',
  'hcn_j2_f3__j1_f2'],
 'radm': ['cp_el2p_j3_2__el2p_j1_2',
  'co_v0_j1__v0_j0',
  'co_v0_j2__v0_j1',
  'co_v0_j3__v0_j2',
  'c_el3p_j2__el3p_j1',
  'c_el3p

### 2 lines combination

In [23]:
# Samples restriction
_df = df.loc[:1_000] # TODO 10_000

In [24]:
for param in params:
    print("param:", param)
    y = _df[param].to_numpy()
    y = y.reshape(-1, 1)
    
    mis = np.zeros((len(lines), len(lines)), dtype=float)
    sigmas = np.zeros_like(mis)
    for line1, line2 in tqdm(list(itt.combinations_with_replacement(lines, r=2))):
        i1, i2 = lines.index(line1), lines.index(line2)

        x1, x2 = _df[line1].to_numpy(), _df[line2].to_numpy()
        if line1 != line2:
            x = np.column_stack((x1, x2))
        else:
            x = x1.reshape(-1, 1)

        mis[i1, i2] = mi(x, y)
        mis[i2, i1] = mis[i1, i2]

        sigmas[i1, i2] = subsampling.compute_sigma(
            x, y, **subsampling_params
        )
        sigmas[i2, i1] = sigmas[i1, i2]
        
    #

    n_combs = 20

    tril_x, tril_y = np.tril_indices(len(lines), -1) # -1 to ignore the diagonal

    _mis = mis[tril_x, tril_y]
    _sigmas = sigmas[tril_x, tril_y]
    probs = prob_higher(_mis, _sigmas, approx=True, pbar=True)

    order = np.argsort(_mis)[::-1]
    order = order[:n_combs]
    # order = np.argsort(probs)[::-1] TODO
    # order = order[probs[order] > 0.10] # We take the probabilities higher than 10%
    # order = order[:min(order.size, 3)] # We take only the 3 first probabilities for display reasons
    
    best_lines = [(lines[tril_x[k]], lines[tril_y[k]]) for k in order]
    best_lines_mis = [_mis[k] for k in order]
    best_lines_sigmas = [_sigmas[k] for k in order]
    best_lines_probs = [probs[k] for k in order]

    # Reordering based on individual results

    for i in range(n_combs):
        idx = np.argsort([ref_lines[param].index(l) for l in best_lines[i]])
        best_lines[i] = [best_lines[i][k] for k in idx]

    #

    plt.figure(figsize=(n_lines/10 * 6.4, 4.8), dpi=150)

    _ = plotter.plot_prob_bar(best_lines, best_lines_probs, short_names=False)
    plt.title(f"Probabilities of being the most informative observables on ${latex_comb_params(param)}$")
    plt.savefig(os.path.join(env_name, f"{param_str(param)}_prob_comb"), bbox_inches="tight")

    plt.close('all')

    #
    
    plt.figure(figsize=(n_lines/10 * 6.4, 4.8), dpi=150)

    _ = plotter.plot_mi_bar(best_lines, best_lines_mis, sorted=False, errs=best_lines_sigmas, short_names=False)
    plt.title(f"Highest mutual informations between between ${latex_comb_params(param)}$ and lines intensity")
    plt.savefig(os.path.join(env_name, f"{param_str(param)}_mi_comb_bar"), bbox_inches="tight")

    plt.close('all')

    #

    n_lines_mat = 10

    list_best_lines = []
    i = 0
    while len(list_best_lines) < n_lines_mat:
        for l in best_lines[i]:
            if l not in list_best_lines:
                list_best_lines.append(l)
        i += 1

    idx = [lines.index(l) for l in list_best_lines]
    _mis = mis[idx][:, idx]
    _sigmas = sigmas[idx][:, idx]

    #

    plt.figure(figsize=(n_lines/10 * 6.4, 4.8), dpi=150)

    _ = plotter.plot_mi_matrix(list_best_lines, _mis, short_names=False)
    plt.title(f"Mutual information between ${latex_comb_params(param)}$ and lines intensity")
    plt.savefig(os.path.join(env_name, f"{param_str(param)}_mi_comb_mat"), bbox_inches="tight")

    plt.close('all')

param: P


100%|██████████| 741/741 [00:57<00:00, 12.86it/s]
100%|██████████| 703/703 [00:09<00:00, 75.19it/s]


param: radm


 64%|██████▎   | 472/741 [00:37<00:20, 12.90it/s]