In [39]:
import altair as alt
import numpy as np
import pandas as pd
import torch

from collections import defaultdict
from pathlib import Path
from re import split
from tqdm import tqdm

# [1] **Line Count Heatmap**

In [40]:
line_stats = defaultdict(list)
for file in tqdm(Path('../../experiments/08_blase3D_HPC_test/emulator_states').glob('*0.0.pt'), total=657):
    state_dict = torch.load(file, 'cuda')
    tokens = split('[TGZ]', file.stem)
    line_stats['teff'].append(int(tokens[1]))
    line_stats['logg'].append(float(tokens[2]))
    line_stats['Z'].append(float(tokens[3]))
    line_stats['center'].append(state_dict['pre_line_centers'].cpu().numpy())
    line_stats['shift_center'].append(state_dict['lam_centers'].cpu().numpy())
    line_stats['amp'].append(state_dict['amplitudes'].cpu().numpy())
    line_stats['sigma'].append(state_dict['sigma_widths'].cpu().numpy())
    line_stats['gamma'].append(state_dict['gamma_widths'].cpu().numpy())
df = pd.DataFrame(line_stats).explode(['center', 'amp', 'sigma', 'gamma', 'shift_center'])
df['jitter'] = df.shift_center - df.center

100%|██████████| 657/657 [00:01<00:00, 392.37it/s]


In [41]:
alt.Chart(df.groupby(['teff', 'logg']).size().reset_index(name='n_lines')).mark_rect()\
    .encode(x=alt.X('teff:O', title='Effective Temperature [K]'), y=alt.Y('logg:O', title='Surface Gravity'), color=alt.Color('n_lines:Q', title='Line Count', scale=alt.Scale(type='log')))\
    .properties(width=1000, height=400)\
    .configure_axis(labelFontSize=15, titleFontSize=24)\
    .configure_legend(labelFontSize=15, titleFontSize=15)

# [2] **PHOENIX Subset Discrete Manifold**

In [45]:
first_line = df.value_counts('center').index[1]
(alt.Chart(df.query('center == @first_line')).mark_rect()\
    .encode(x=alt.X('teff:O', title='Effective Temperature [K]'), y=alt.Y('logg:O', title='Surface Gravity'), color=alt.Color(f'amp:Q', title='Log-Amplitude'))\
    .properties(width=1000, height=400, title=f'Spectral Line at {first_line} Å') &\
alt.Chart(df.query('center == @first_line')).mark_rect()\
    .encode(x=alt.X('teff:O', title='Effective Temperature [K]'), y=alt.Y('logg:O', title='Surface Gravity'), color=alt.Color(f'sigma:Q', title='Gaussian Width'))\
    .properties(width=1000, height=400, title=f'Spectral Line at {first_line} Å'))\
    .configure_axis(labelFontSize=15, titleFontSize=24)\
    .configure_legend(labelFontSize=15, titleFontSize=15)\
    .configure_title(fontSize=25).resolve_scale(color='independent')

In [None]:
for row in df.sort_values(['teff', 'logg', 'Z']).query('center == 11617.66').iterrows():
    print(row)

(107, teff                    2300
logg                     2.0
Z                        0.0
center              11617.66
shift_center    11617.650814
amp                -2.179514
sigma              -3.297426
gamma              -2.201346
jitter             -0.009186
Name: 107, dtype: object)
(526, teff                    2300
logg                     2.5
Z                        0.0
center              11617.66
shift_center    11617.647985
amp                -2.288572
sigma              -3.363384
gamma               -2.33677
jitter             -0.012015
Name: 526, dtype: object)
(156, teff                    2300
logg                     3.0
Z                        0.0
center              11617.66
shift_center    11617.650375
amp                -2.345414
sigma              -3.422233
gamma              -2.472242
jitter             -0.009625
Name: 156, dtype: object)
(426, teff                    2300
logg                     3.5
Z                        0.0
center              11617.66