In [1]:
import multiprocessing
from pathlib import Path
from copy import deepcopy
from collections import Counter
import functools

import numpy as np

from astropy import units as u
from astropy import constants as cnst
from astropy.io import fits
from astropy import modeling

from specutils import Spectrum1D, manipulation

from astropy.visualization import quantity_support
quantity_support()

from matplotlib import pyplot as plt

from tqdm.auto import tqdm

In [2]:
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cpu device


In [3]:
datapath = Path('../fullgrid/')

In [4]:
wlspath = datapath / 'WAVE_PHOENIX-ACES-AGSS-COND-2011.fits'

wlsraw = fits.getdata(wlspath)
model_wls = wlsraw << u.Unit(fits.getheader(wlspath)['UNIT'])

model_wls, len(model_wls)

(<Quantity [  500.  ,   500.1 ,   500.2 , ..., 54999.25, 54999.5 , 54999.75] Angstrom>,
 1569128)

In [5]:
allspecpaths = list(datapath.glob('lte*HiRes.fits'))
len(allspecpaths)

7559

In [6]:
units = []
for path in allspecpaths:
    try:
        unit = fits.getval(path, 'BUNIT')
    except KeyError:
        unit = None
    units.append(unit)

Counter(units)

Counter({'erg/s/cm^2/cm': 7508, None: 51})

In [7]:
p1 = p2 = True
for path,unit in zip(allspecpaths, units):
    if p1 and unit is None:
        p1 = False
        print(repr(fits.getheader(path, 0)))
        print('')
    if p2 and unit is not None:
        p2 = False
        print(repr(fits.getheader(path, 0)))
        print('')

SIMPLE  =                    T / conforms to FITS standard                      
BITPIX  =                  -32 / array data type                                
NAXIS   =                    1 / number of array dimensions                     
NAXIS1  =              1569128                                                  
EXTEND  =                    T                                                  
EXTNAME = 'PRIMARY '                                                            
WAVE    = '../../WAVE_PHOENIX-ACES-AGSS-COND-2011.fits' / Wavelength array      
PHXTEFF =               2300.0 / [K] effective temperature                      
PHXLOGG =                  0.0 / [cm/s^2] log (surface gravity)                 
PHXM_H  =                  0.5 / [M/H] metallicity (rel. sol. - Asplund &a 2009)
PHXALPHA=                  0.0 / [a/M] alpha element enhancement                
PHXDUST =                    F / Dust in atmosphere                             
PHXEOS  = 'ACES    '        

Some have no unit, which seems to be interpolated? But most have something that make sense so we will assume that:

In [8]:
unitss = set(units)
unitss.remove(None)
assert len(unitss) == 1
specunit = u.Unit(unitss.pop())



This wavelength is roughly covering the NIRSpec data at ~ 2x oversampling

In [9]:
data_wls = np.linspace(9000, 19500, 16384) *u.angstrom

# Experiment with saving a single spectrum

In [10]:
resampler = manipulation.FluxConservingResampler()

In [11]:
%%time

spec = Spectrum1D(spectral_axis=model_wls, flux=fits.getdata(allspecpaths[0], ext=0) << specunit)
header = dict(fits.getheader(allspecpaths[0], ext=0))
newspec = resampler(spec, data_wls)

CPU times: user 467 ms, sys: 14.6 ms, total: 482 ms
Wall time: 483 ms


In [12]:
%%time 
totensor = torch.tensor(newspec.flux.value)

CPU times: user 250 μs, sys: 892 μs, total: 1.14 ms
Wall time: 654 μs


In [13]:
%%time
tosingle = totensor.to(torch.float32)

CPU times: user 117 μs, sys: 17 μs, total: 134 μs
Wall time: 138 μs


In [14]:
%%time

pth = Path('test.pt')
data = {'header':header, 'wl':torch.tensor(data_wls), 'flux': totensor}
torch.save(data, pth)

CPU times: user 0 ns, sys: 1.11 ms, total: 1.11 ms
Wall time: 876 μs


In [15]:
pth.stat().st_size * len(allspecpaths) *  2**-20  #MB for the full set of spectra

1903.6341953277588

In [16]:
%%time

torch.load('test.pt', weights_only=True)

CPU times: user 1.13 ms, sys: 22 μs, total: 1.16 ms
Wall time: 977 μs


{'header': {'SIMPLE': True,
  'BITPIX': -32,
  'NAXIS': 1,
  'NAXIS1': 1569128,
  'EXTEND': True,
  'EXTNAME': 'PRIMARY',
  'WAVE': '../../WAVE_PHOENIX-ACES-AGSS-COND-2011.fits',
  'PHXTEFF': 2300.0,
  'PHXLOGG': 0.0,
  'PHXM_H': 0.5,
  'PHXALPHA': 0.0,
  'PHXDUST': False,
  'PHXEOS': 'ACES',
  'PHXBUILD': '02/Aug/2010',
  'PHXVER': '16.01.00B',
  'DATE': '2012-03-16 15:16:21',
  'PHXXI_L': 0.0,
  'PHXXI_M': 0.0,
  'PHXXI_N': 0.0,
  'PHXMASS': 1.5802e+33,
  'PHXREFF': 10266000000000.0,
  'PHXLUM': 2.1025e+36,
  'PHXMXLEN': 1.0,
  'PHXCONV': False,
  'BUNIT': 'erg/s/cm^2/cm'},
 'wl': tensor([ 9000.0000,  9000.6409,  9001.2818,  ..., 19498.7182, 19499.3591,
         19500.0000], dtype=torch.float64),
 'flux': tensor([3.5468e+12, 4.1099e+12, 2.4918e+12,  ..., 5.1086e+12, 2.5230e+12,
         2.4912e+12], dtype=torch.float64)}

In [17]:
allspecpaths[0].name.split('.PHOENIX')[0]

'lte02300-0.00+0.5'

# Writing out the tensors

In [15]:
resampler = manipulation.FluxConservingResampler()
data_wls = np.linspace(9000, 19500, 16384) *u.angstrom
outdir = Path('resampled_tensors')
outdir.mkdir(exist_ok=True)

def interpolate_and_save_spectrum(specpath, data_wls=data_wls, resampler=resampler, outdir=outdir, suffix='_16k'):
    header = dict(fits.getheader(specpath, ext=0))
    spec = Spectrum1D(spectral_axis=model_wls, flux=fits.getdata(specpath, ext=0) << u.Unit(header['BUNIT']))
    newspec = resampler(spec, data_wls)

    fluxtensor = torch.tensor(newspec.flux.value)
    wltensor = torch.tensor(data_wls.value)

    basename = specpath.name.split('.PHOENIX')[0]

    tosave = {'header':header, 'wl':torch.tensor(data_wls), 'flux': fluxtensor}
    pthout = outdir / (basename + '.pt')
    if suffix:
        pthout = pthout.with_stem(pthout.stem + suffix)
    torch.save(tosave, pthout)

    return pthout

In [15]:
interpolate_and_save_spectrum(allspecpaths[0])

PosixPath('resampled_tensors/lte02300-0.00+0.5.pt')

In [16]:
tointerpolate = [p for p in allspecpaths if 'BUNIT' in fits.getheader(p)]
len(tointerpolate), len(allspecpaths)

(7508, 7559)

In [17]:
with multiprocessing.Pool(16) as pool:
      results = list(tqdm(pool.imap(interpolate_and_save_spectrum, tointerpolate), total=len(tointerpolate)))
len(results)

  0%|          | 0/7508 [00:00<?, ?it/s]

7508

And a 2k version for quicker experiments

In [24]:
data_wls_2k = np.linspace(9000, 19500, 2048) *u.angstrom

interpolate_and_save_spectrum(allspecpaths[0], data_wls=data_wls_2k, suffix='_2k')



PosixPath('resampled_tensors/lte02300-0.00+0.5_2k.pt')

In [36]:
interpolate_and_save_spectrum_2k = functools.partial(interpolate_and_save_spectrum, data_wls=data_wls_2k, suffix='_2k')

with multiprocessing.Pool(16) as pool:
      results = list(tqdm(pool.imap(interpolate_and_save_spectrum_2k, tointerpolate), total=len(tointerpolate)))
len(results)

  0%|          | 0/7508 [00:00<?, ?it/s]













7508