In [1]:
import altair as alt
import numpy as np
import pandas as pd
import torch

from blase.optimizer import default_clean
from collections import defaultdict
from gollum.phoenix import PHOENIXSpectrum
from pathlib import Path
from re import split
from tqdm import tqdm

alt.data_transformers.enable('vegafusion')

PyTorch 2.2.0 active


DataTransformerRegistry.enable('vegafusion')

# [2] **Line Count Heatmap**

In [2]:
line_stats = defaultdict(list)
for file in tqdm(Path('../../experiments/08_blase3D_HPC_test/emulator_states').glob('*'), total=1314):
    state_dict = torch.load(file, 'cuda')
    tokens = split('[TGZ]', file.stem)
    line_stats['teff'].append(int(tokens[1]))
    line_stats['logg'].append(float(tokens[2]))
    line_stats['Z'].append(float(tokens[3]))
    line_stats['center'].append(state_dict['pre_line_centers'].cpu().numpy())
    line_stats['shift_center'].append(state_dict['lam_centers'].cpu().numpy())
    line_stats['amp'].append(state_dict['amplitudes'].cpu().numpy())
    line_stats['sigma'].append(state_dict['sigma_widths'].cpu().numpy())
    line_stats['gamma'].append(state_dict['gamma_widths'].cpu().numpy())
df = pd.DataFrame(line_stats).explode(['center', 'amp', 'sigma', 'gamma', 'shift_center'])
df_gp = pd.DataFrame(line_stats)[['teff', 'logg', 'Z']]
df['jitter'] = df.shift_center - df.center

100%|██████████| 1314/1314 [00:03<00:00, 367.08it/s]


In [23]:
df.groupby(['teff', 'logg', 'Z']).size().reset_index(name='count').sort_values('count', ascending=False)

Unnamed: 0,teff,logg,Z,count
5,2300,3.0,0.0,34551
3,2300,2.5,0.0,34538
0,2300,2.0,-0.5,34397
2,2300,2.5,-0.5,34336
7,2300,3.5,0.0,34091
...,...,...,...,...
1276,11600,6.0,-0.5,276
1292,11800,5.5,-0.5,273
1294,11800,6.0,-0.5,270
1310,12000,5.5,-0.5,266


In [16]:
c1 = alt.Chart(df.query('Z == 0').groupby(['teff', 'logg']).size().reset_index(name='n_lines'), width=1000, height=200, title='[Fe/H] = 0 dex').mark_rect()\
    .encode(x=alt.X('teff:O', title='Effective Temperature [K]', axis=alt.Axis(values=list(np.arange(2400, 12001, 200)))), y=alt.Y('logg:O', title='Surface Gravity'), color=alt.Color('n_lines:Q', title='Count', scale=alt.Scale(type='log')))
c2 = alt.Chart(df.query('Z == -0.5').groupby(['teff', 'logg']).size().reset_index(name='n_lines'), width=1000, height=200, title='[Fe/H] = -0.5 dex').mark_rect()\
    .encode(x=alt.X('teff:O', title='Effective Temperature [K]', axis=alt.Axis(values=list(np.arange(2400, 12001, 200)))), y=alt.Y('logg:O', title='Surface Gravity'), color=alt.Color('n_lines:Q', title='Count', scale=alt.Scale(type='log')))   
(c1 & c2).configure_title(fontSize=25).configure_axis(labelFontSize=15, titleFontSize=24).configure_legend(labelFontSize=15, titleFontSize=15).save('figure2.png', scale_factor=4.0)

# [3, 4] **PHOENIX Subset Discrete Manifold**
C I line and He II line

In [18]:
target_line = df.value_counts('center').index[1]
cs = [alt.Chart(df.query('center == @target_line & Z == 0'), width=1000, height=200, title=f'C I Line at {target_line} Å').mark_rect()\
    .encode(x=alt.X('teff:O', title='Effective Temperature [K]', axis=alt.Axis(values=list(np.arange(2400, 12001, 200)))), 
            y=alt.Y('logg:O', title='Surface Gravity'), 
            color=alt.Color(f'{param[0]}:Q', title=param[1])) for param in (('amp', 'Log-Amplitude'), ('sigma', 'Gaussian Width'), ('gamma', 'Lorentzian Width'), ('jitter', 'Line Center Offset'))]
alt.vconcat(*cs).resolve_scale(color='independent').configure_axis(labelFontSize=15, titleFontSize=24).configure_axisX(title='Eff').configure_legend(labelFontSize=15, titleFontSize=15).configure_title(fontSize=25).save('figure3.png', scale_factor=4.0)

In [26]:
target_line = df.value_counts('center').index[3840]
white_bg = alt.Chart(df_gp, width=1000, height=200, title=f'He II Line at {target_line} Å').mark_rect()\
    .encode(x=alt.X('teff:O', title='Effective Temperature [K]', axis=alt.Axis(values=list(np.arange(2400, 12001, 200)))), 
            y=alt.Y('logg:O', title='Surface Gravity', axis=alt.Axis(values=list(np.arange(2, 6, 0.5)))), 
            color=alt.value('white'))
cs = [white_bg + alt.Chart(df.query('center == @target_line & Z == 0'), width=1000, height=200, title=f'C I Line at {target_line} Å').mark_rect()\
    .encode(x=alt.X('teff:O', title='Effective Temperature [K]', axis=alt.Axis(values=list(np.arange(2400, 12001, 200)))), 
            y=alt.Y('logg:O', title='Surface Gravity'), 
            color=alt.Color(f'{param[0]}:Q', title=param[1])) for param in (('amp', 'Log-Amplitude'), ('sigma', 'Gaussian Width'), ('gamma', 'Lorentzian Width'), ('jitter', 'Line Center Offset'))]
alt.vconcat(*cs).resolve_scale(color='independent').configure_axis(labelFontSize=15, titleFontSize=24).configure_axisX(title='Eff').configure_legend(labelFontSize=15, titleFontSize=15).configure_title(fontSize=25).save('figure4.png', scale_factor=4.0)

In [7]:
target_line = df.value_counts('center').index[3840]
white_bg = alt.Chart(df_gp, width=1000, height=400, title=f'He II Line at {target_line} Å').mark_rect()\
    .encode(x=alt.X('teff:O', title='Effective Temperature [K]', axis=alt.Axis(values=list(np.arange(2400, 12001, 200)))), 
            y=alt.Y('logg:O', title='Surface Gravity', axis=alt.Axis(values=list(np.arange(2, 6, 0.5)))), 
            color=alt.value('white'))
amp_chart = alt.Chart(df.query('center == @target_line')).mark_rect()\
    .encode(x=alt.X('teff:O', title='Effective Temperature [K]', axis=alt.Axis(values=list(np.arange(2400, 12001, 200)))), 
            y=alt.Y('logg:O', title='Surface Gravity'), 
            color=alt.Color(f'amp:Q', title='Log-Amplitude'))
sigma_chart = alt.Chart(df.query('center == @target_line')).mark_rect()\
    .encode(x=alt.X('teff:O', title='Effective Temperature [K]', axis=alt.Axis(values=list(np.arange(2400, 12001, 200)))), 
            y=alt.Y('logg:O', title='Surface Gravity'), 
            color=alt.Color(f'sigma:Q', title='Gaussian Width'))
(white_bg + amp_chart & white_bg + sigma_chart) \
    .configure_axis(labelFontSize=15, titleFontSize=24)\
    .configure_legend(labelFontSize=15, titleFontSize=15)\
    .configure_title(fontSize=25).resolve_scale(color='independent').save('images/figure4.png', scale_factor=4.0)

# [5] **Spectral Reconstruction Demo**

In [41]:
lo = 8561.9
hi = 8562.3
s1 = default_clean(PHOENIXSpectrum(teff=5700, logg=4.5))
s2 = default_clean(PHOENIXSpectrum(teff=5800, logg=4.5))
s3 = pd.read_parquet('../../experiments/10_end_to_end/reconstruction_demo.parquet.gz').query('wavelength <= @hi and wavelength >= @lo')
s3['Stellar Parameters'] = 'T: 5772 K, log(g): 4.4374, [Fe/H]: 0.0'
c1 = alt.Chart(pd.DataFrame({'wavelength': s1.wavelength.value, 'flux': s1.flux.value, 'Stellar Parameters': 'T: 5700 K, log(g): 4.5, [Fe/H]: 0.0'}).query('wavelength <= @hi and wavelength >= @lo'), width=1000, height=400, title='Fe I Line at 8562.09 Å').mark_line(strokeWidth=8).encode(x=alt.X('wavelength', title='Wavelength [Å]'), y=alt.Y('flux', title='Normalized Flux').scale(zero=False), color='Stellar Parameters')
c2 = alt.Chart(pd.DataFrame({'wavelength': s2.wavelength.value, 'flux': s2.flux.value, 'Stellar Parameters': 'T: 5800 K, log(g): 4.5, [Fe/H]: 0.0'}).query('wavelength <= @hi and wavelength >= @lo')).mark_line(strokeWidth=8).encode(x='wavelength', y='flux', color=alt.Color('Stellar Parameters', scale=alt.Scale(scheme='inferno')))
c3 = alt.Chart(s3).mark_line(strokeWidth=8, strokeDash=[10, 4]).encode(x='wavelength', y='flux', color='Stellar Parameters')
(c1 + c2 + c3).configure_axis(labelFontSize=15, titleFontSize=24).configure_title(fontSize=25).configure_legend(labelFontSize=16, titleFontSize=16, padding=5, fillColor='white', orient='bottom-right', labelLimit=0, symbolStrokeWidth=8).save('figure5.png', scale_factor=4.0)

# [6] **PHOENIX Generator Reconstruction Performance**

In [30]:
df = pd.DataFrame({'N': [2, 5, 10, 25, 50, 75, 100], 
                   'Pseudo-Parallel': [16.572938919067383, 18.667826108634472, 21.52281529456377, 28.28887704014778, 39.83675589412451, 53.722645066678524, 62.9894732311368],
                   'Serial': [29.681056290864944, 80.29932584613562, 156.88104890286922, 381.16286554932594, 747.9954723417759, 1159.3977541476488, 1518.2086955830455]})
df['R'] = df['Serial'] / df['Pseudo-Parallel']
df2 = df.copy()
df2['Pseudo-Parallel'] /= df2.N 
df2['Serial'] /= df2.N
df = df.melt(id_vars='N', var_name='Set', value_name='Time')
df2 = df2.melt(id_vars='N', var_name='Set', value_name='Time')
c1 = alt.Chart(df.query('Set != "R"'), width=400, height=400).mark_line().encode(x=alt.X('N', title='', axis=alt.Axis(labels=False)), y=alt.Y('Time', title='Time [s]', scale=alt.Scale(domain=[16, 1600], type='log')), color=alt.Color('Set', title=' Reconstruction Type'))
c2 = alt.Chart(df.query('Set == "R"'), width=400, height=100).mark_line().encode(x=alt.X('N', title='Number of Reconstructions'), y=alt.Y('Time', title='Speedup'), color=alt.value('purple'))
c3 = alt.Chart(df2.query('Set != "R"'), width=400, height=400).mark_line().encode(x=alt.X('N', title='', axis=alt.Axis(labels=False)), y=alt.Y('Time', title='Time per Spectrum [s]'), color=alt.Color('Set', title=' Reconstruction Type'))
alt.vconcat(c1, c3, c2, spacing=5).configure_axis(labelFontSize=15, titleFontSize=24).configure_legend(labelFontSize=16, titleFontSize=16, offset=-400, padding=5, fillColor='white').configure_title(fontSize=25).save('figure6.png', scale_factor=4.0)