In [None]:
import altair as alt
import numpy as np
from os import listdir
import pandas as pd
import torch

from blase.emulator import SparseLinearEmulator as SLE
from blase.optimizer import default_clean
from blase.utils import doppler_grid
from collections import defaultdict
from functools import reduce
from gollum.phoenix import PHOENIXSpectrum
from re import split
from tqdm import tqdm

alt.data_transformers.enable('vegafusion')

In [None]:
path = "/home/sujay/data/10K_12.5K_clones"

line_stats = defaultdict(list)
for state_file in tqdm(listdir(path)):
    state_dict = torch.load(f'{path}/{state_file}', map_location='cuda:0')
    tokens = split('[TGZ]', state_file[:-3])
    line_stats['teff'].append(int(tokens[1]))
    line_stats['logg'].append(float(tokens[2]))
    line_stats['Z'].append(float(tokens[3]))

    line_stats['center'].append(state_dict['pre_line_centers'].cpu().numpy())
    line_stats['amp'].append(state_dict['amplitudes'].cpu().numpy())
    line_stats['sigma'].append(state_dict['sigma_widths'].cpu().numpy())
    line_stats['gamma'].append(state_dict['gamma_widths'].cpu().numpy())

In [None]:
line_set = reduce(np.union1d, line_stats['center'])
df = pd.DataFrame(line_stats).query('Z == 0').explode(['center', 'amp', 'sigma', 'gamma'])
df

In [None]:
line_counts = df.value_counts('center')
current_line = line_counts.index[0]

In [None]:
df_heat = df.query('center == @current_line')
x1 = alt.Chart(df_heat).mark_rect().encode(x='teff:O', y='logg:O', color='amp:Q')\
    .properties(width=600, height=400, title=f'Line: {current_line} Angstroms ({len(df_heat)} points)')
x2 = alt.Chart(df_heat).mark_rect().encode(x='teff:O', y='logg:O', color='sigma:Q')\
    .properties(width=600, height=400, title=f'Line: {current_line} Angstroms ({len(df_heat)} points)')
x3 = alt.Chart(df_heat).mark_rect().encode(x='teff:O', y='logg:O', color='gamma:Q')\
    .properties(width=600, height=400, title=f'Line: {current_line} Angstroms ({len(df_heat)} points)')

((x1 | x2 | x3).resolve_scale(color='independent'))

In [None]:
df_manifold = pd.read_parquet('../../experiments/09_surface_fitting/surface_visualization_grid_interp.parquet.gz')
df_manifold
current_line = 8185.51

In [None]:
def show_manifolds(data):
    x1 = alt.Chart(data).mark_rect().encode(x='teff:O', y='logg:O', color='amp:Q')\
        .properties(width=1000, height=600, title=f'Line: {current_line} Angstroms ({len(data)} points)')
    x2 = alt.Chart(data).mark_rect().encode(x='teff:O', y='logg:O', color='sigma:Q')\
        .properties(width=1000, height=600, title=f'Line: {current_line} Angstroms ({len(data)} points)')
    x3 = alt.Chart(data).mark_rect().encode(x='teff:O', y='logg:O', color='gamma:Q')\
        .properties(width=1000, height=600, title=f'Line: {current_line} Angstroms ({len(data)} points)')

    return ((x1 | x2 | x3).resolve_scale(color='independent'))

In [None]:
show_manifolds(df_manifold)