In [None]:
import altair as alt
import numpy as np
from os import listdir
import pandas as pd
import torch

from collections import defaultdict
from functools import reduce
from re import split

In [None]:
line_stats = defaultdict(list)
for state_file in listdir('emulator_states'):
    state_dict = torch.load(f'emulator_states/{state_file}')
    tokens = split('[TGZ]', state_file[:-3])
    line_stats['teff'].append(int(tokens[1]))
    line_stats['logg'].append(float(tokens[2]))
    line_stats['Z'].append(float(tokens[3]))

    line_stats['center'].append(state_dict['pre_line_centers'].cpu().numpy())
    line_stats['shift_center'].append(state_dict['lam_centers'].cpu().numpy())
    line_stats['amp'].append(-state_dict['amplitudes'].cpu().numpy())
    line_stats['sigma'].append(state_dict['sigma_widths'].cpu().numpy())
    line_stats['gamma'].append(state_dict['gamma_widths'].cpu().numpy())

In [None]:
line_set = reduce(np.union1d, line_stats['center'])
df = pd.DataFrame(line_stats)
df

In [None]:
df = df.explode(['center', 'amp', 'sigma', 'gamma', 'shift_center'])
df

In [None]:
df_solar = df.query('Z == 0')
df_solar['jitter'] = df_solar.center - df_solar.shift_center
df_solar

In [None]:
# [2, 5, 7, 11, 15, 18, 23, 33, 35, 39, 40]
# [17, 20, 24, 31]

#most_points = np.argmax([len(df_solar.query('center == @line')) for line in line_set])
#print(most_points)
counts = df_solar.value_counts('center')

In [None]:
current_line = counts.index[78]
df_heat = df_solar.query('center == @current_line')
x1 = alt.Chart(df_heat).mark_rect().encode(x='teff:O', y='logg:O', color='amp:Q')\
    .properties(width=400, height=400, title=f'Line: {current_line} Angstroms ({len(df_heat)} points)')
x2 = alt.Chart(df_heat).mark_rect().encode(x='teff:O', y='logg:O', color='sigma:Q')\
    .properties(width=400, height=400, title=f'Line: {current_line} Angstroms ({len(df_heat)} points)')
x3 = alt.Chart(df_heat).mark_rect().encode(x='teff:O', y='logg:O', color='gamma:Q')\
    .properties(width=400, height=400, title=f'Line: {current_line} Angstroms ({len(df_heat)} points)')
x4 = alt.Chart(df_heat).mark_rect().encode(x='teff:O', y='logg:O', color='jitter:Q')\
    .properties(width=400, height=400, title=f'Line: {current_line} Angstroms ({len(df_heat)} points)')
((x1 | x2).resolve_scale(color='independent') & (x3 | x4).resolve_scale(color='independent'))

In [None]:
df_n_lines = df_solar.groupby(['teff', 'logg']).size().reset_index(name='n_lines')
alt.Chart(df_n_lines).mark_rect().encode(x='teff:O', y='logg:O', color='n_lines:Q')\
    .properties(width=400, height=400, title=f'Number of lines ({len(df_solar)} points)')