In [1]:
import altair as alt
import numpy as np
import pandas as pd
import torch

from blase.optimizer import default_clean
from collections import defaultdict
from gollum.phoenix import PHOENIXSpectrum
from pathlib import Path
from re import split
from tqdm import tqdm

alt.data_transformers.enable('vegafusion')

PyTorch 2.2.0 active


DataTransformerRegistry.enable('vegafusion')

# [2] **Line Count Heatmap**

In [None]:
line_stats = defaultdict(list)
for file in tqdm(Path('../../experiments/08_blase3D_HPC_test/emulator_states').glob('*'), total=1314):
    state_dict = torch.load(file, 'cuda')
    tokens = split('[TGZ]', file.stem)
    line_stats['teff'].append(int(tokens[1]))
    line_stats['logg'].append(float(tokens[2]))
    line_stats['Z'].append(float(tokens[3]))
    line_stats['center'].append(state_dict['pre_line_centers'].cpu().numpy())
    line_stats['shift_center'].append(state_dict['lam_centers'].cpu().numpy())
    line_stats['amp'].append(state_dict['amplitudes'].cpu().numpy())
    line_stats['sigma'].append(state_dict['sigma_widths'].cpu().numpy())
    line_stats['gamma'].append(state_dict['gamma_widths'].cpu().numpy())
df = pd.DataFrame(line_stats).explode(['center', 'amp', 'sigma', 'gamma', 'shift_center'])
df_gp = pd.DataFrame(line_stats)[['teff', 'logg', 'Z']]
df['jitter'] = df.shift_center - df.center

In [None]:
df.groupby(['teff', 'logg', 'Z']).size().reset_index(name='count').sort_values('count', ascending=False)

In [None]:
c1 = alt.Chart(df.query('Z == 0').groupby(['teff', 'logg']).size().reset_index(name='n_lines'), width=1000, height=180, title='[Fe/H] = 0').mark_rect()\
    .encode(x=alt.X('teff:O', title='', axis=alt.Axis(values=list(np.arange(2400, 12001, 200)), labels=False)), y=alt.Y('logg:O', title='Gravity', axis=alt.Axis(titleAlign='right')), color=alt.Color('n_lines:Q', title='Count', scale=alt.Scale(type='log')))
c2 = alt.Chart(df.query('Z == -0.5').groupby(['teff', 'logg']).size().reset_index(name='n_lines'), width=1000, height=200, title='[Fe/H] = -0.5').mark_rect()\
    .encode(x=alt.X('teff:O', title='Effective Temperature [K]', axis=alt.Axis(values=list(np.arange(2400, 12001, 400)))), y=alt.Y('logg:O', title='Surface', axis=alt.Axis(titleAlign='left')), color=alt.Color('n_lines:Q', title='Count', scale=alt.Scale(type='log')).legend(gradientLength=370))   
alt.vconcat(c1, c2, spacing=3).configure_title(fontSize=27, dy=110).configure_axis(labelFontSize=18, titleFontSize=27).configure_legend(labelFontSize=16, titleFontSize=18).save('figure2.png', scale_factor=4.0)

# [3, 4] **PHOENIX Subset Discrete Manifold**
C I line and He II line

In [None]:
target_line = df.value_counts('center').index[1]
cs = [alt.Chart(df.query('center == @target_line & Z == 0'), width=1000, height=180).mark_rect()\
    .encode(x=alt.X('teff:O', title=('Effective Temperature [K]' if i == 3 else ''), axis=alt.Axis(values=list(np.arange(2400, 12001, 400)), labels=(i == 3))), 
            y=alt.Y('logg:O', title=('Gravity' if i == 1 else ('Surface' if i == 2 else '')), axis=alt.Axis(titleAnchor=('start' if i == 2 else 'end'), titleAlign=(('right' if i == 2 else 'left')))), 
            color=alt.Color(f'{param[0]}:Q', title=param[1]).legend(gradientLength=135)) for i, param in enumerate([('amp', ['Log-', 'Amplitude']), ('sigma', ['Gaussian', 'Width']), ('gamma', ['Lorentzian', 'Width']), ('jitter', ['Line', 'Offset'])])]
alt.vconcat(*cs, spacing=3).properties(title='C I Line at 11617.66 Å').resolve_scale(color='independent').configure_axis(labelFontSize=18, titleFontSize=27).configure_legend(labelFontSize=16, titleFontSize=18).configure_title(fontSize=27, anchor='middle').save('figure3.png', scale_factor=4.0)

In [None]:
target_line = 10123.66
white_bg = alt.Chart(df_gp.query('teff >= 3400 & teff <= 6800'), width=500, height=180).mark_rect()\
    .encode(x=alt.X('teff:O', axis=alt.Axis(values=list(np.arange(3400, 6801, 200)))), 
            y=alt.Y('logg:O', axis=alt.Axis(values=list(np.arange(2, 6, 0.5)))), 
            color=alt.value('white'))
cs = [white_bg + alt.Chart(df.query('center == @target_line & Z == 0'), width=500, height=180).mark_rect()\
    .encode(x=alt.X('teff:O', title=('Effective Temperature [K]' if i == 3 else ''), axis=alt.Axis(values=list(np.arange(3400, 6801, 200)), labels=(i == 3))), 
            y=alt.Y('logg:O', title=('Gravity' if i == 1 else ('Surface' if i == 2 else '')), axis=alt.Axis(titleAnchor=('start' if i == 2 else 'end'), titleAlign=('right' if i == 2 else 'left'))), 
            color=alt.Color(f'{param[0]}:Q', title=param[1]).legend(gradientLength=135)) for i, param in enumerate([('amp', ['Log-', 'Amplitude']), ('sigma', ['Gaussian', 'Width']), ('gamma', ['Lorentzian', 'Width']), ('jitter', ['Line', 'Offset'])])]
alt.vconcat(*cs, spacing=3).properties(title='He II Line at 10123.66 Å').resolve_scale(color='independent').configure_axis(labelFontSize=18, titleFontSize=27).configure_legend(labelFontSize=16, titleFontSize=18).configure_title(fontSize=27, anchor='middle').save('figure4.png', scale_factor=4.0)

# [5, 6] **Spectral Reconstruction Demo**

In [None]:
lo = 10123.275
hi = 10123.5
s1 = default_clean(PHOENIXSpectrum(teff=5500, logg=4.5))
s2 = default_clean(PHOENIXSpectrum(teff=5900, logg=4.5))
s3 = pd.read_parquet('../../experiments/10_end_to_end/reconstruction_demo.parquet.gz').query('wavelength <= @hi and wavelength >= @lo')
s3['Stellar Parameters'] = 'T: 5772 K, log(g): 4.4374, [Fe/H]: 0'
s3['Type'] = 'Line Interpolation'
c1 = alt.Chart(pd.DataFrame({'wavelength': s1.wavelength.value, 'flux': s1.flux.value, 'Stellar Parameters': 'T: 5500 K, log(g): 4.5, [Fe/H]: 0'}).query('wavelength <= @hi and wavelength >= @lo'), width=400, height=400, title='He II Line at 10123.66 Å').mark_line(strokeWidth=8).encode(x=alt.X('wavelength', title='Wavelength [Å]'), y=alt.Y('flux', title='Normalized Flux').scale(zero=False), color='Stellar Parameters')
c2 = alt.Chart(pd.DataFrame({'wavelength': s2.wavelength.value, 'flux': s2.flux.value, 'Stellar Parameters': 'T: 5900 K, log(g): 4.5, [Fe/H]: 0'}).query('wavelength <= @hi and wavelength >= @lo')).mark_line(strokeWidth=8).encode(x='wavelength', y='flux', color=alt.Color('Stellar Parameters', scale=alt.Scale(scheme='inferno')))
c3 = alt.Chart(s3).mark_line(strokeWidth=5, strokeDash=[10, 5]).encode(x='wavelength', y='flux', color='Stellar Parameters')
(c1 + c2 + c3).configure_axis(labelFontSize=15, titleFontSize=24).configure_title(fontSize=25).configure_legend(labelFontSize=16, titleFontSize=16, offset=5, fillColor='white', orient='top-right', labelLimit=0, symbolStrokeWidth=6).save('figure5.png', scale_factor=4.0)


In [None]:
s4 = default_clean(PHOENIXSpectrum(teff=5700, logg=4.5))
s5 = default_clean(PHOENIXSpectrum(teff=5800, logg=4.5))
f_int = s4.flux.value * 0.28 + s5.flux.value * 0.72
c4 = alt.Chart(pd.DataFrame({'wavelength': s4.wavelength.value, 'flux': f_int, 'Type': 'Pixel Interpolation'}).query('wavelength <= @hi and wavelength >= @lo'), width=400, height=400, title='He II Line at 10123.66 Å').mark_line(strokeWidth=5).encode(x=alt.X('wavelength', title='Wavelength [Å]'), y=alt.Y('flux', title='Normalized Flux', scale=alt.Scale(zero=False)), color=alt.Color('Type', scale=alt.Scale(scheme='inferno')))
c5 = alt.Chart(s3).mark_line(strokeWidth=5).encode(x='wavelength', y='flux', color='Type')
(c4 + c5).configure_axis(labelFontSize=15, titleFontSize=24).configure_title(fontSize=25).configure_legend(labelFontSize=16, titleFontSize=16, offset=5, fillColor='white', orient='top-right', labelLimit=0, symbolStrokeWidth=6).save('figure6.png', scale_factor=4.0)

# [7] **PHOENIX Generator Reconstruction Performance**

In [None]:
df = pd.DataFrame({'N': [2, 5, 10, 25, 50, 75, 100], 
                   'Pseudo-Parallel': [16.572938919067383, 18.667826108634472, 21.52281529456377, 28.28887704014778, 39.83675589412451, 53.722645066678524, 62.9894732311368],
                   'Serial': [29.681056290864944, 80.29932584613562, 156.88104890286922, 381.16286554932594, 747.9954723417759, 1159.3977541476488, 1518.2086955830455]})
df['R'] = df['Serial'] / df['Pseudo-Parallel']
df2 = df.copy()
df2['Pseudo-Parallel'] /= df2.N 
df2['Serial'] /= df2.N
df = df.melt(id_vars='N', var_name='Set', value_name='Time')
df2 = df2.melt(id_vars='N', var_name='Set', value_name='Time')
c1 = alt.Chart(df.query('Set != "R"'), width=400, height=400).mark_line(strokeWidth=6).encode(x=alt.X('N', title='', axis=alt.Axis(labels=False)), y=alt.Y('Time', title='Time [s]', scale=alt.Scale(domain=[16, 1600], type='log')), color=alt.Color('Set', title=' Reconstruction Type', scale=alt.Scale(scheme='inferno')))
c2 = alt.Chart(df.query('Set == "R"'), width=400, height=100).mark_line(strokeWidth=6).encode(x=alt.X('N', title='Number of Reconstructions'), y=alt.Y('Time', title='Speedup'), color=alt.value('crimson'))
c3 = alt.Chart(df2.query('Set != "R"'), width=400, height=400).mark_line(strokeWidth=6).encode(x=alt.X('N', title='', axis=alt.Axis(labels=False)), y=alt.Y('Time', title='Time per Spectrum [s]'), color=alt.Color('Set', title=' Reconstruction Type', scale=alt.Scale(scheme='inferno')))
alt.vconcat(c1, c3, c2, spacing=3).configure_axis(labelFontSize=15, titleFontSize=24).configure_legend(labelFontSize=16, titleFontSize=16, offset=-400, padding=5, fillColor='white', symbolStrokeWidth=8).configure_title(fontSize=25).save('figure7.png', scale_factor=4.0)

# [8] **Inference Residuals**

In [151]:
df = pd.read_csv('../../out.txt', header=None, names=['t', 'g', 'z', 'ti', 'gi', 'zi', 'time'])
df['tr'] = df.ti - df.t
df['gr'] = df.gi - df.g
df['zr'] = df.zi - df.z
df.time /= 60
df.time.mean()
df2 = df

In [152]:
cs = [alt.Chart(df2).mark_errorbar(color='purple', extent='stdev', thickness=3).encode(x='t', y=alt.Y(q[0], title='')) \
    + alt.Chart(df2, width=400, height=400, title=q[1]).mark_line(color='purple', strokeWidth=5).encode(
    x=alt.X('t', title=('Effective Temperature [K]' if i == 1 else ''), scale=alt.Scale(zero=False)), y=alt.Y(f'mean({q[0]})', scale=alt.Scale(zero=False), title=('Residuals' if i == 0 else ''), axis=alt.Axis(titleAnchor='middle'))
) for i, q in enumerate([('tr', 'Effective Temperature [K]'), ('gr', 'Surface Gravity'), ('zr', '[Fe/H]')])]
alt.hconcat(*cs, spacing=3).configure_axis(labelFontSize=20, titleFontSize=30).configure_title(fontSize=30).save('figure8.png', scale_factor=4.0)