In [None]:
import altair as alt
import numpy as np
from os import listdir
import pandas as pd
import torch

from blase.emulator import SparseLinearEmulator as SLE
from blase.optimizer import default_clean
from blase.utils import doppler_grid
from collections import defaultdict
from functools import reduce
from gollum.phoenix import PHOENIXSpectrum
from re import split

In [None]:
from functools import partial
from pickle import dump, load
import numpy as np

def x(x1, x2):
    return np.array([x1, x2])

z = [partial(x, x1=0.5), partial(x, x1=0.6), partial(x, x1=0.7)]
z1 = load(open('z.pkl', 'rb'))
z1[2](x2=4)

In [None]:
path = "/home/sujay/data/10K_12.5K_clones"

line_stats = defaultdict(list)
for state_file in listdir(path):
    state_dict = torch.load(f'{path}/{state_file}', map_location='cuda:0')
    tokens = split('[TGZ]', state_file[:-3])
    line_stats['teff'].append(int(tokens[1]))
    line_stats['logg'].append(float(tokens[2]))
    line_stats['Z'].append(float(tokens[3]))

    line_stats['center'].append(state_dict['pre_line_centers'].cpu().numpy())
    line_stats['shift_center'].append(state_dict['lam_centers'].cpu().numpy())
    line_stats['amp'].append(state_dict['amplitudes'].cpu().numpy())
    line_stats['sigma'].append(state_dict['sigma_widths'].cpu().numpy())
    line_stats['gamma'].append(state_dict['gamma_widths'].cpu().numpy())

In [None]:
line_set = reduce(np.union1d, line_stats['center'])
df = pd.DataFrame(line_stats)
df = df.query('Z == 0')
df

In [None]:
df = df.explode(['center', 'amp', 'sigma', 'gamma', 'shift_center'])
df

In [None]:
df_solar = df.query('Z == 0')
df_solar['jitter'] = abs(df_solar.center - df_solar.shift_center)
counts = df_solar.value_counts('center')
alt.data_transformers.enable('vegafusion')


In [None]:

df_hist = pd.DataFrame(counts).reset_index()
df_hist['order'] = df_hist.index
#df_hist['count'] =np.log10(df_hist['count'])
alt.Chart(df_hist).mark_bar().encode(x='order:Q', y=alt.Y('count:Q'))

In [None]:
i = 0

In [None]:
current_line = counts.index[i]
i+= 1
df_heat = df_solar.query('center == @current_line')
teff = alt.X('teff:O', title='Effective Temperature [K]', axis=alt.Axis(values=[*range(2200, 7000, 200), *range(7000, 12000, 400)]))
logg = alt.Y('logg:O', title='Surface Gravity [dex]')

x1 = alt.Chart(df_heat).mark_rect().encode(x=teff, y=logg, color='amp:Q')\
    .properties(width=600, height=400, title=f'Line at {current_line} Angstroms: Amplitude')
x2 = alt.Chart(df_heat).mark_rect().encode(x=teff, y=logg, color='sigma:Q')\
    .properties(width=600, height=400, title=f'Line at {current_line} Angstroms: Gaussian Shape')
x3 = alt.Chart(df_heat).mark_rect().encode(x=teff, y=logg, color='gamma:Q')\
    .properties(width=600, height=400, title=f'Line at {current_line} Angstroms: Lorentzian Shape')
x4 = alt.Chart(df_heat).mark_rect().encode(x=teff, y=logg, color='jitter:Q')\
    .properties(width=600, height=400, title=f'Line at {current_line} Angstroms: Jitter')
((x1 | x2).resolve_scale(color='independent') & (x3 | x4).resolve_scale(color='independent')).configure_axis(labelFontSize=16, titleFontSize=16).configure_title(fontSize=20)

In [None]:
df_n_lines = df_solar.groupby(['teff', 'logg']).size().reset_index(name='n_lines')
alt.Chart(df_n_lines).mark_rect().encode(x=alt.X('teff:O', axis=alt.Axis(values=[*range(2200, 7000, 200), *range(7000, 12000, 400)]), title='Effective Temperature [K]'), y=alt.Y('logg:O', title='Surface Gravity [dex]'), color=alt.Color('n_lines:Q', scale=alt.Scale(type='log'), title='# Lines'))\
    .properties(width=800, height=400, title=f'Number of Lines: {len(df_solar)} Total Line Detections').configure_axis(labelFontSize=20, titleFontSize=20).configure_title(fontSize=25)