In [2]:
import altair as alt
import numpy as np
from os import listdir
import pandas as pd
import torch

from blase.emulator import SparseLinearEmulator as SLE
from blase.optimizer import default_clean
from blase.utils import doppler_grid
from collections import defaultdict
from functools import reduce
from gollum.phoenix import PHOENIXSpectrum
from re import split

PyTorch 2.0.1 active


In [3]:
line_stats = defaultdict(list)
for state_file in listdir('emulator_states'):
    state_dict = torch.load(f'emulator_states/{state_file}')
    tokens = split('[TGZ]', state_file[:-3])
    line_stats['teff'].append(int(tokens[1]))
    line_stats['logg'].append(float(tokens[2]))
    line_stats['Z'].append(float(tokens[3]))

    line_stats['center'].append(state_dict['pre_line_centers'].cpu().numpy())
    line_stats['shift_center'].append(state_dict['lam_centers'].cpu().numpy())
    line_stats['amp'].append(state_dict['amplitudes'].cpu().numpy())
    line_stats['sigma'].append(state_dict['sigma_widths'].cpu().numpy())
    line_stats['gamma'].append(state_dict['gamma_widths'].cpu().numpy())

In [8]:
line_set = reduce(np.union1d, line_stats['center'])
df = pd.DataFrame(line_stats)
df
len(line_set)

7499

In [5]:
df = df.explode(['center', 'amp', 'sigma', 'gamma', 'shift_center'])
df

Unnamed: 0,teff,logg,Z,center,shift_center,amp,sigma,gamma
0,2300,2.0,-0.5,10895.12,10895.123052,-3.857617,-3.780783,-3.492216
0,2300,2.0,-0.5,10895.26,10895.262,-4.199675,-3.873594,-3.622967
0,2300,2.0,-0.5,10895.44,10895.441643,-4.678689,-3.604816,-3.3678
0,2300,2.0,-0.5,10895.58,10895.582443,-3.896488,-3.99923,-3.733276
0,2300,2.0,-0.5,10895.66,10895.666839,-3.200734,-3.662038,-3.258528
...,...,...,...,...,...,...,...,...
396,3700,5.0,-0.5,11302.98,11302.976732,-2.967514,-3.238581,-2.935286
396,3700,5.0,-0.5,11303.6,11303.598489,-3.423583,-3.503259,-3.267775
396,3700,5.0,-0.5,11303.72,11303.71992,-5.831678,-4.017862,-3.842542
396,3700,5.0,-0.5,11303.94,11303.939779,-5.195273,-3.623826,-3.442364


In [6]:
df_solar = df.query('Z == 0')
df_solar['jitter'] = abs(df_solar.center - df_solar.shift_center)
counts = df_solar.value_counts('center')
df_solar

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_solar['jitter'] = abs(df_solar.center - df_solar.shift_center)


Unnamed: 0,teff,logg,Z,center,shift_center,amp,sigma,gamma,jitter
1,2300,2.0,0.0,10895.12,10895.12808,-3.393281,-3.694007,-3.246374,0.00808
1,2300,2.0,0.0,10895.26,10895.265955,-3.676929,-3.72983,-3.344152,0.005955
1,2300,2.0,0.0,10895.46,10895.462622,-4.62411,-3.622986,-3.348344,0.002622
1,2300,2.0,0.0,10895.58,10895.584836,-3.673342,-3.966864,-3.617482,0.004836
1,2300,2.0,0.0,10895.66,10895.676867,-2.780189,-3.593445,-2.951067,0.016867
...,...,...,...,...,...,...,...,...,...
394,3700,4.5,0.0,11303.28,11303.279926,-6.776049,-2.971076,-2.795911,0.000074
394,3700,4.5,0.0,11303.6,11303.598641,-3.357307,-3.457116,-3.225527,0.001359
394,3700,4.5,0.0,11303.72,11303.719942,-5.916493,-4.043244,-3.868762,0.000058
394,3700,4.5,0.0,11303.94,11303.939818,-5.186052,-3.615989,-3.435862,0.000182


In [7]:
current_line = counts.index[3]
df_heat = df_solar.query('center == @current_line')
x1 = alt.Chart(df_heat).mark_rect().encode(x='teff:O', y='logg:O', color='amp:Q')\
    .properties(width=600, height=400, title=f'Line: {current_line} Angstroms ({len(df_heat)} points)')
x2 = alt.Chart(df_heat).mark_rect().encode(x='teff:O', y='logg:O', color='sigma:Q')\
    .properties(width=600, height=400, title=f'Line: {current_line} Angstroms ({len(df_heat)} points)')
x3 = alt.Chart(df_heat).mark_rect().encode(x='teff:O', y='logg:O', color='gamma:Q')\
    .properties(width=600, height=400, title=f'Line: {current_line} Angstroms ({len(df_heat)} points)')
x4 = alt.Chart(df_heat).mark_rect().encode(x='teff:O', y='logg:O', color='jitter:Q')\
    .properties(width=600, height=400, title=f'Line: {current_line} Angstroms ({len(df_heat)} points)')
((x1 | x2).resolve_scale(color='independent') & (x3 | x4).resolve_scale(color='independent'))