In [None]:
import altair as alt
import numpy as np
import pandas as pd
import torch

from collections import defaultdict
from pathlib import Path
from re import split
from tqdm import tqdm

alt.data_transformers.enable('vegafusion')

# [1] **Line Count Heatmap**

In [None]:
line_stats = defaultdict(list)
for file in tqdm(Path('../../experiments/08_blase3D_HPC_test/emulator_states').glob('*0.0.pt'), total=657):
    state_dict = torch.load(file, 'cuda')
    tokens = split('[TGZ]', file.stem)
    line_stats['teff'].append(int(tokens[1]))
    line_stats['logg'].append(float(tokens[2]))
    line_stats['Z'].append(float(tokens[3]))
    line_stats['center'].append(state_dict['pre_line_centers'].cpu().numpy())
    line_stats['shift_center'].append(state_dict['lam_centers'].cpu().numpy())
    line_stats['amp'].append(state_dict['amplitudes'].cpu().numpy())
    line_stats['sigma'].append(state_dict['sigma_widths'].cpu().numpy())
    line_stats['gamma'].append(state_dict['gamma_widths'].cpu().numpy())
df = pd.DataFrame(line_stats).explode(['center', 'amp', 'sigma', 'gamma', 'shift_center'])
df['jitter'] = df.shift_center - df.center

In [None]:
alt.Chart(df.groupby(['teff', 'logg']).size().reset_index(name='n_lines')).mark_rect()\
    .encode(x=alt.X('teff:O', title='Effective Temperature [K]'), y=alt.Y('logg:O', title='Surface Gravity'), color=alt.Color('n_lines:Q', title='Count', scale=alt.Scale(type='log')))\
    .properties(width=1000, height=400)\
    .configure_axis(labelFontSize=15, titleFontSize=24)\
    .configure_legend(labelFontSize=15, titleFontSize=15).save('images/figure1.png', scale_factor=4.0)

# [2] **PHOENIX Subset Discrete Manifold**

In [None]:
first_line = df.value_counts('center').index[1]
(alt.Chart(df.query('center == @first_line')).mark_rect()\
    .encode(x=alt.X('teff:O', title='Effective Temperature [K]', axis=alt.Axis(values=list(np.arange(2400, 12001, 200)))), 
            y=alt.Y('logg:O', title='Surface Gravity'), 
            color=alt.Color(f'amp:Q', title='Log-Amplitude'))\
    .properties(width=1000, height=400, title=f'Spectral Line at {first_line} Å') &\
alt.Chart(df.query('center == @first_line')).mark_rect()\
    .encode(x=alt.X('teff:O', title='Effective Temperature [K]', axis=alt.Axis(values=list(np.arange(2400, 12001, 200)))), 
            y=alt.Y('logg:O', title='Surface Gravity'), 
            color=alt.Color(f'sigma:Q', title='Gaussian Width'))\
    .properties(width=1000, height=400, title=f'Spectral Line at {first_line} Å'))\
    .configure_axis(labelFontSize=15, titleFontSize=24)\
    .configure_legend(labelFontSize=15, titleFontSize=15)\
    .configure_title(fontSize=25).resolve_scale(color='independent').save('images/figure2.png', scale_factor=4.0)

# [3] **PHOENIX Generator Continuous Manifold**

In [None]:
df = pd.read_parquet('../../experiments/10_end_to_end/interpolated_line.parquet.gz')
df[df.amp < -100] = np.nan
df.dropna()
df.logg = df.logg.round(1)
(alt.Chart(df).mark_rect()\
    .encode(x=alt.X('teff:O', title='Effective Temperature [K]', axis=alt.Axis(values=list(np.arange(2400, 12001, 200)))), 
            y=alt.Y('logg:O', title='Surface Gravity', axis=alt.Axis(values=list(np.arange(2, 6, 0.5)))), 
            color=alt.Color('amp:Q', title='Log-Amplitude'))\
    .properties(width=1000, height=400, title=f'Spectral Line at 11617.66 Å') &\
alt.Chart(df).mark_rect()\
    .encode(x=alt.X('teff:O', title='Effective Temperature [K]', axis=alt.Axis(values=list(np.arange(2400, 12001, 200)))), 
            y=alt.Y('logg:O', title='Surface Gravity', axis=alt.Axis(values=list(np.arange(2, 6, 0.5)))), 
            color=alt.Color('sigma:Q', title='Gaussian Width'))\
    .properties(width=1000, height=400, title=f'Spectral Line at 11617.66 Å'))\
    .configure_axis(labelFontSize=15, titleFontSize=24)\
    .configure_legend(labelFontSize=15, titleFontSize=15)\
    .configure_title(fontSize=25).resolve_scale(color='independent').save('images/figure3.png', scale_factor=4.0)