In [1]:
import altair as alt
import numpy as np
from os import listdir
import pandas as pd
import torch

from blase.emulator import SparseLinearEmulator as SLE
from blase.optimizer import default_clean
from blase.utils import doppler_grid
from collections import defaultdict
from functools import reduce
from gollum.phoenix import PHOENIXSpectrum
from re import split
from tqdm import tqdm

alt.data_transformers.enable('vegafusion')

PyTorch 2.1.0 active


DataTransformerRegistry.enable('vegafusion')

In [2]:
path = "/home/sujay/data/10K_12.5K_clones"

line_stats = defaultdict(list)
for state_file in tqdm(listdir(path)):
    state_dict = torch.load(f'{path}/{state_file}', map_location='cuda:0')
    tokens = split('[TGZ]', state_file[:-3])
    line_stats['teff'].append(int(tokens[1]))
    line_stats['logg'].append(float(tokens[2]))
    line_stats['Z'].append(float(tokens[3]))

    line_stats['center'].append(state_dict['pre_line_centers'].cpu().numpy())
    line_stats['amp'].append(state_dict['amplitudes'].cpu().numpy())
    line_stats['sigma'].append(state_dict['sigma_widths'].cpu().numpy())
    line_stats['gamma'].append(state_dict['gamma_widths'].cpu().numpy())

100%|██████████| 1964/1964 [00:07<00:00, 273.16it/s]


In [3]:
line_set = reduce(np.union1d, line_stats['center'])
df = pd.DataFrame(line_stats).query('Z == 0').explode(['center', 'amp', 'sigma', 'gamma'])
df

Unnamed: 0,teff,logg,Z,center,amp,sigma,gamma
3,6100,4.0,0.0,8030.15,-3.761484,-3.378658,-3.206863
3,6100,4.0,0.0,8030.52,-2.764099,-2.931001,-2.763708
3,6100,4.0,0.0,8030.99,-6.879189,-3.301914,-3.130147
3,6100,4.0,0.0,8031.42,-5.298686,-3.317237,-3.146086
3,6100,4.0,0.0,8031.97,-5.949881,-3.399614,-3.228281
...,...,...,...,...,...,...,...
1958,9800,2.0,0.0,12849.66,-5.470564,-3.089381,-2.49134
1958,9800,2.0,0.0,12881.92,-7.101554,-3.072379,-2.9503
1958,9800,2.0,0.0,12900.72,-7.203372,-2.28079,-2.2004
1958,9800,2.0,0.0,12953.38,-6.970805,-2.332786,-2.162303


In [9]:
line_counts = df.value_counts('center')
current_line = line_counts.index[0]

In [10]:
df_heat = df.query('center == @current_line')
x1 = alt.Chart(df_heat).mark_rect().encode(x='teff:O', y='logg:O', color='amp:Q')\
    .properties(width=600, height=400, title=f'Line: {current_line} Angstroms ({len(df_heat)} points)')
x2 = alt.Chart(df_heat).mark_rect().encode(x='teff:O', y='logg:O', color='sigma:Q')\
    .properties(width=600, height=400, title=f'Line: {current_line} Angstroms ({len(df_heat)} points)')
x3 = alt.Chart(df_heat).mark_rect().encode(x='teff:O', y='logg:O', color='gamma:Q')\
    .properties(width=600, height=400, title=f'Line: {current_line} Angstroms ({len(df_heat)} points)')

((x1 | x2 | x3).resolve_scale(color='independent'))

In [11]:
df_manifold = pd.read_parquet('../../experiments/09_surface_fitting/surface_visualization_grid.parquet.gz')
df_manifold
current_line = 8185.51

In [20]:
def show_manifolds(data):
    teff = alt.X('teff:O', title='Effective Temperature [K]', axis=alt.Axis(values=[*range(2200, 7000, 200), *range(7000, 12000, 400)]))
    logg = alt.Y('logg:O', title='Surface Gravity [dex]')

    x1 = alt.Chart(data).mark_rect().encode(x=teff, y=logg, color='amp:Q')\
        .properties(width=600, height=400, title=f'Line: {current_line} Angstroms: Amplitude')
    x2 = alt.Chart(data).mark_rect().encode(x=teff, y=logg, color='sigma:Q')\
        .properties(width=600, height=400, title=f'Line: {current_line} Angstroms: Gaussian Shape')
    x3 = alt.Chart(data).mark_rect().encode(x=teff, y=logg, color='gamma:Q')\
        .properties(width=600, height=400, title=f'Line: {current_line} Angstroms: Lorentzian Shape')

    return ((x1 | x2 | x3).resolve_scale(color='independent')).configure_axis(labelFontSize=16, titleFontSize=16).configure_title(fontSize=20)

In [21]:
show_manifolds(df_manifold)