In [None]:
import altair as alt
import numpy as np
from os import listdir
import pandas as pd
import torch

from blase.emulator import SparseLinearEmulator as SLE
from blase.optimizer import default_clean
from blase.utils import doppler_grid
from collections import defaultdict
from functools import reduce
from gollum.phoenix import PHOENIXSpectrum
from re import split

In [None]:
line_stats = defaultdict(list)
for state_file in listdir('../../experiments/08_blase3D_HPC_test/emulator_states'):
    state_dict = torch.load(f'emulator_states/{state_file}')
    tokens = split('[TGZ]', state_file[:-3])
    line_stats['teff'].append(int(tokens[1]))
    line_stats['logg'].append(float(tokens[2]))
    line_stats['Z'].append(float(tokens[3]))

    line_stats['center'].append(state_dict['pre_line_centers'].cpu().numpy())
    line_stats['shift_center'].append(state_dict['lam_centers'].cpu().numpy())
    line_stats['amp'].append(state_dict['amplitudes'].cpu().numpy())
    line_stats['sigma'].append(state_dict['sigma_widths'].cpu().numpy())
    line_stats['gamma'].append(state_dict['gamma_widths'].cpu().numpy())

In [None]:
line_set = reduce(np.union1d, line_stats['center'])
df = pd.DataFrame(line_stats)
df

In [None]:
df = df.explode(['center', 'amp', 'sigma', 'gamma', 'shift_center'])
df

In [None]:
df_solar = df.query('Z == 0')
df_solar['jitter'] = abs(df_solar.center - df_solar.shift_center)
df_solar

In [None]:
counts = df_solar.value_counts('center')

In [None]:
current_line = counts.index[20]
df_heat = df_solar.query('center == @current_line')
x1 = alt.Chart(df_heat).mark_rect().encode(x='teff:O', y='logg:O', color='amp:Q')\
    .properties(width=600, height=400, title=f'Line: {current_line} Angstroms ({len(df_heat)} points)')
x2 = alt.Chart(df_heat).mark_rect().encode(x='teff:O', y='logg:O', color='sigma:Q')\
    .properties(width=600, height=400, title=f'Line: {current_line} Angstroms ({len(df_heat)} points)')
x3 = alt.Chart(df_heat).mark_rect().encode(x='teff:O', y='logg:O', color='gamma:Q')\
    .properties(width=600, height=400, title=f'Line: {current_line} Angstroms ({len(df_heat)} points)')
x4 = alt.Chart(df_heat).mark_rect().encode(x='teff:O', y='logg:O', color='jitter:Q')\
    .properties(width=600, height=400, title=f'Line: {current_line} Angstroms ({len(df_heat)} points)')
((x1 | x2).resolve_scale(color='independent') & (x3 | x4).resolve_scale(color='independent'))

In [None]:
df_n_lines = df_solar.groupby(['teff', 'logg']).size().reset_index(name='n_lines')
alt.Chart(df_n_lines).mark_rect().encode(x='teff:O', y='logg:O', color='n_lines:Q')\
    .properties(width=800, height=400, title=f'Number of lines ({len(df_solar)} points)')

In [None]:
lo, hi = 11000-60, 11180+60
line_spread = 7
wl_grid = doppler_grid(lo, hi)
spec1 = default_clean(PHOENIXSpectrum(teff=3700, logg=5, wl_lo=lo, wl_hi=hi))
df_spec1 = pd.DataFrame({'x': spec1.wavelength.value, 'y': spec1.flux.value}).query('abs(@current_line - x) <= @line_spread')
spec2 = default_clean(PHOENIXSpectrum(teff=6200, logg=6, wl_lo=lo, wl_hi=hi))
df_spec2 = pd.DataFrame({'x': spec2.wavelength.value, 'y': spec2.flux.value}).query('abs(@current_line - x) <= @line_spread')
spec3 = default_clean(PHOENIXSpectrum(teff=8600, logg=2.5, wl_lo=lo, wl_hi=hi))
df_spec3 = pd.DataFrame({'x': spec3.wavelength.value, 'y': spec3.flux.value}).query('abs(@current_line - x) <= @line_spread')

opt1 = SLE(wl_native=wl_grid, wing_cut_pixels=6000, init_state_dict=torch.load('emulator_states/T3700G5.0Z0.0.pt'), device='cpu')
df_opt1 = pd.DataFrame({'x': opt1.wl_native, 'y': opt1.forward().detach().cpu().numpy()}).query('abs(@current_line - x) <= @line_spread')
opt2 = SLE(wl_native=wl_grid, wing_cut_pixels=6000, init_state_dict=torch.load('emulator_states/T6200G6.0Z0.0.pt'), device='cpu')
df_opt2 = pd.DataFrame({'x': opt2.wl_native, 'y': opt2.forward().detach().cpu().numpy()}).query('abs(@current_line - x) <= @line_spread')
opt3 = SLE(wl_native=wl_grid, wing_cut_pixels=6000, init_state_dict=torch.load('emulator_states/T8600G2.5Z0.0.pt'), device='cpu')
df_opt3 = pd.DataFrame({'x': opt3.wl_native, 'y': opt3.forward().detach().cpu().numpy()}).query('abs(@current_line - x) <= @line_spread')

g1 = alt.Chart(df_heat).mark_rect().encode(x='teff:O', y='logg:O', color='amp:Q')\
    .properties(width=600, height=400, title=f'Line: {current_line} Angstroms ({len(df_heat)} points)')
c1 = alt.Chart(df_spec1).mark_line().encode(x='x:Q', y=alt.Y('y:Q', scale=alt.Scale(domain=(0.8, 1.025)))).properties(title='T3700:G5') +\
alt.Chart(df_opt1).mark_line(color='orange').encode(x='x:Q', y='y:Q').properties(title='T3700:G5')

c2 = alt.Chart(df_spec2).mark_line().encode(x='x:Q', y=alt.Y('y:Q', scale=alt.Scale(domain=(0.8, 1.025)))).properties(title='T6200:G6') +\
alt.Chart(df_opt2).mark_line(color='orange').encode(x='x:Q', y='y:Q').properties(title='T6200:G6')

c3 = alt.Chart(df_spec3).mark_line().encode(x='x:Q', y=alt.Y('y:Q', scale=alt.Scale(domain=(0.8, 1.025)))).properties(title='T8600:G2.5') +\
alt.Chart(df_opt3).mark_line(color='orange').encode(x='x:Q', y='y:Q').properties(title='T8600:G2.5')

g1 & (c1 | c2 | c3)