# Histograms

In [None]:
import pickle

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import dipsy

In [None]:
from dipsy.data import T17_Fmm_corr, T17_Reff_corr
import dipsy.data_frame_functions as ddff

Define constants

In [None]:
year = dipsy.cgs_constants.year

# Define some more functions

Test the correlations

In [None]:
t = dipsy.data.Tripathi2017()

# Load the data

the old way as reference
```python
with open('df_dict.pickle', 'rb') as fid:
    df = pickle.load(fid)

d = df['smooth_ricci']
```

In [None]:
%%time 
d = pd.DataFrame(dipsy.utils.read_from_hdf5('dustlines_analysis_lam870_q3.5_f68.hdf5'))

In [None]:
d.head()

# Process parameters

define which keys are the parameters

In [None]:
param_names = ['v_frag', 'alpha', 'Mdisk', 'r_c', 'M_star']

We also define the "nice names" of each parameter for the labels

In [None]:
param_label = {
    'alpha': r'$\alpha$',
    'v_frag': r'$v_\mathsf{frag}$',
    'Mdisk': r'$M_\mathsf{disk}$',
    'M_star': r'$M_\star$',
    'r_c': r'$r_\mathsf{c}$'
}

In [None]:
param_values = ddff.get_param_values(d, param_names)

In [None]:
print('we have the following parameters:')
for key, value in param_values.items():
    print((key + f'({len(value)}):').ljust(15), end='')
    print(', '.join([str(v) for v in value]))

#### define the time array

In [None]:
time = d.iloc[0].time
i0 = time.searchsorted(3e5 * year)
i1 = time.searchsorted(3e6 * year)

for the filtering below: how many sigma around the correlation do we allow?

In [None]:
n_sig = 1

# Check the filter visually

In [None]:
i = 38399 # which simulation to pick

row = d.iloc[i]
print(row)
ddff.filter_function(d.iloc[i], i0=i0, i1=i1, alpha=[0.0001, 0.001], M_star=2, Mdisk=param_values['Mdisk'][-2], corr=n_sig)

In [None]:
f, ax = t.plot_rosotti()
ref = np.logspace(1, 2.5, 50)
ax.plot(np.log10(ref), np.log10(T17_Fmm_corr(ref, sigma=n_sig)), 'k--')
ax.plot(np.log10(ref), np.log10(T17_Fmm_corr(ref, sigma=-n_sig)), 'k--')
ax.plot(np.log10(row['rf_t'][i0:i1]), np.log10(row['flux_t'][i0:i1]));

## Now apply it

We filter all simulations where the snapshots between `i0` and `i1` are within `nsig` sigma.

In [None]:
f = lambda row: ddff.filter_function(row, i0=i0, i1=i1, corr=n_sig)
res = d[d.apply(f, axis=1)]
print(f'found {len(res)} matching simulations ({len(res) / len(d):.1%})')

Plot 5 randomly picked tracks

In [None]:
f, ax = t.plot_rosotti()
x = np.logspace(1, 2.5, 50)
ax.plot(np.log10(x), np.log10(T17_Fmm_corr(x, sigma=n_sig)), 'k--')
ax.plot(np.log10(x), np.log10(T17_Fmm_corr(x, sigma=-n_sig)), 'k--')

for i in np.random.choice(np.arange(len(res)), 5):
    row = res.iloc[i]
    ax.plot(np.log10(row['rf_t'][i0:i1]), np.log10(row['flux_t'][i0:i1]))

## 2D Histogram
First, we plot a single 2D histogram

In [None]:
f = ddff.histogram2D(res, 'v_frag', 'Mdisk', param_values, param_label=param_label)
f.savefig('histogram2D.pdf', transparent=True, bbox_inches='tight')

## 1D histogram

Next, let's collapse it in one dimension. Here we normalize to the number of simulations, so for every given value, we count how many total simulations there are (without applying the correlation-filter). This takes a bit of time, but should just return 10 000 if we have 100 000 simulations and 10 values for each parameter.

In [None]:
x_name = 'v_frag'

In [None]:
n_sims = []
for value in param_values[x_name]:
    
    f = lambda row: ddff.filter_function(row, **{x_name:value})
    n_sims += [len(d[d.apply(f, axis=1)])]

In [None]:
f = ddff.histogram1d_normalized(res, x_name, param_values, param_label=param_label, n_sims=n_sims)
f.savefig('histogram1D.pdf', transparent=True, bbox_inches='tight')

# Corner plot of histograms

In [None]:
f = ddff.histogram_corner(res, param_values, param_label=param_label)
f.savefig('histograms_corner.pdf', transparent=True, bbox_inches='tight')

# Heatmap

In [None]:
for _t in [3e5, 1e6, 3e6]:
    i_snap = time.searchsorted(_t * year)
    fig, ax = ddff.heatmap(d, i_snap, correlation=True, observations=False, n_sig=1, cmap='cividis', rasterized=True, vmin=0, vmax=300)
    pos = ax.get_position()
    cax = fig.add_axes([pos.x1, pos.y0, pos.width/20, pos.height])
    plt.colorbar(ax.collections[0], cax=cax)
    fig.savefig(f'heatmap_{time[i_snap] / 1e6 / year:.1f}Myr.pdf', transparent=True)