# Matching SpenderQ output to Picca "delta_extraction" format

This notebook takes SpenderQ outputs, and reformats it to match the "delta files" format obtained by Picca via "delta_extraction"; making it compatible with picca to calculate the correlation functions. [2024 October 25]

In [15]:
import pickle
import numpy as np
import glob
import torch
import math

In [16]:
from astropy.io import fits
from astropy.table import Table

In [17]:
from spenderq import load_model
from spenderq import util as U
from spenderq import lyalpha as LyA

In [18]:
import matplotlib as mpl
import matplotlib.pyplot as plt

## Add paths, define variables, load catalogue

In [19]:
#quasar catalogue of the production of interest (fits file)

## EDR
#quasar_catalogue = '/global/cfs/projectdirs/desi/users/sgontcho/lya/spender-lya/QSO_cat_EDR_n_M2_main_dark_healpix_BAL_n_DLA_cuts.fits'

## Y1 LONDON MOCKS
quasar_catalogue = '/global/cfs/cdirs/desicollab/mocks/lya_forest/develop/london/qq_desi/v9.0_Y1/v9.0.9.9.9/desi-4.124-4-prod/zcat.fits'

In [20]:
#point to the SpenderQ output directory and give the files prefix (from the same production as the catalogue!)

path = '/global/cfs/projectdirs/desi/users/chahah/spender_qso'
#outpath = '/global/cfs/projectdirs/desi/users/sgontcho/lya/spender-lya/spenderq-to-deltas'
outpath = '/global/cfs/projectdirs/desi/users/sgontcho/lya/spender-lya/iron_comparison/spenderq_prod_20241126_v0/Delta'

## Y1 LONDON MOCKS
file_prefix = 'spenderq_london_v0/DESIlondon_highz.rebin.iter3'
spender_output_files = glob.glob(path+'/'+file_prefix+'_*.pkl')

## EDR
#file_prefix = 'DESI.edr.qso_highz'
#spender_output_files = glob.glob('/global/cfs/projectdirs/desi/users/chahah/spender_qso/DESIedr.qso_highz_*.pkl')

In [21]:
#LAMBDA range defined in picca
picca_lambda_min = 3600.
picca_lambda_max = 5772.

#FOREST range (default is LyA)
forest_lambda_min = 1040.
forest_lambda_max = 1205.

In [22]:
hdul = fits.open(quasar_catalogue)

In [23]:
catalogue_TID = hdul[1].data['TARGETID']
catalogue_Z = hdul[1].data['Z']
catalogue_RA = hdul[1].data['RA']
catalogue_DEC = hdul[1].data['DEC']

##EDR
#catalogue_RA = hdul[1].data['TARGET_RA']
#catalogue_DEC = hdul[1].data['TARGET_DEC']

In [24]:
LAMBDA = np.arange(picca_lambda_min,picca_lambda_max+0.8,0.8)

## Grab values from SpenderQ files

In [25]:
normalised_flux, pipeline_weight, z, tid, normalisation, zerr = [], [], [], [], [], []

sq_reconstructions = []

for ibatch in range(50):
#for ibatch in range(len(spender_output_files)): 
    
    #load batch
    with open(f'{path}/{file_prefix}_%i.pkl' % ibatch, 'rb') as f:
        _normalised_flux, _pipeline_weight, _z, _tid, _normalisation, _zerr = pickle.load(f)
    normalised_flux.append(np.array(_normalised_flux))
    pipeline_weight.append(np.array(_pipeline_weight))
    z.append(np.array(_z))
    tid.append(np.array(_tid))
    normalisation.append(np.array(_normalisation))
    zerr.append(np.array(_zerr))
    
    #load SpenderQ reconstructions
    _sq_reconstructions = np.load(f'{path}/{file_prefix}_%i.recons.npy' % ibatch)
    sq_reconstructions.append(np.array(_sq_reconstructions))

normalised_flux=np.concatenate(normalised_flux,axis=0)
pipeline_weight=np.concatenate(pipeline_weight,axis=0)
z=np.concatenate(z)
tid=np.concatenate(tid)
normalisation=np.concatenate(normalisation,axis=0)
zerr=np.concatenate(zerr,axis=0)
sq_reconstructions=np.concatenate(sq_reconstructions,axis=0)

## Create variables needed for delta files

In [30]:
#nb_of_quasars = 10
nb_of_quasars = len(z)

_tff = np.full((nb_of_quasars,7781),np.nan)
_weight = np.full((nb_of_quasars,7781),np.nan)
_sq_cont = np.full((nb_of_quasars,7781),np.nan)

wrecon = np.load(f'{path}/{file_prefix}.wave_recon.npy')
desi_wave = np.linspace(3600, 9824, 7781) #obs

_tff_bar_up = np.zeros(7781)
_tff_bar_down = np.zeros(7781)

#for iqso in range(10):
for iqso in range(nb_of_quasars):

    #create wavelength grids
    desi_wave_rest = desi_wave/(1+z[iqso]) #rest
    
    #rebin spender reconstruction
    edges = np.linspace(wrecon[0], wrecon[-1], 7782)
    spenderq_rebin = U.trapz_rebin(wrecon, sq_reconstructions[iqso], edges = edges)
    
    #keep only part of spec that is within the lya range
    mask_desi_rest = (desi_wave_rest >= forest_lambda_min) & (desi_wave_rest <= forest_lambda_max)

    _tff[iqso][mask_desi_rest] = normalised_flux[iqso][mask_desi_rest]/spenderq_rebin[mask_desi_rest]
    _weight[iqso][mask_desi_rest] = pipeline_weight[iqso][mask_desi_rest]
    _sq_cont[iqso][mask_desi_rest] = spenderq_rebin[mask_desi_rest]
    _tff_bar_up += normalised_flux[iqso]/spenderq_rebin * pipeline_weight[iqso]
    _tff_bar_down += pipeline_weight[iqso]


In [31]:
#get the weighted average transmitted flux fraction 
picca_range = (desi_wave <= LAMBDA[-1])
tff_bar = np.divide(_tff_bar_up[picca_range], _tff_bar_down[picca_range], out=np.zeros_like(_tff_bar_up[picca_range]), where=_tff_bar_down[picca_range]!=0)

In [32]:
#picca structure: nan filled arrays
delta = np.full((nb_of_quasars,sum(picca_range)),np.nan)
weight = np.full((nb_of_quasars,sum(picca_range)),np.nan)
sq_cont = np.full((nb_of_quasars,sum(picca_range)),np.nan)

meta_los_id, meta_ra, meta_dec, meta_z, meta_meansnr, meta_targetid, meta_night, meta_petal, meta_tile = [], [], [], [], [], [], [], [], []

#for iqso in range(5):
for iqso in range(nb_of_quasars):
    
    delta[iqso] = (_tff[iqso][picca_range]/tff_bar) - 1.
    weight[iqso] = _weight[iqso][picca_range]
    sq_cont[iqso] = _sq_cont[iqso][picca_range]
    
    los_id_idx = int(np.where(catalogue_TID == int(tid[iqso]))[0][0])
    
    meta_los_id.append(int(catalogue_TID[los_id_idx]))
    meta_ra.append(float(math.radians(catalogue_RA[los_id_idx])))
    meta_dec.append(float(math.radians(catalogue_DEC[los_id_idx])))
    meta_z.append(float(catalogue_Z[los_id_idx]))
    meta_meansnr.append(np.nan)
    meta_targetid.append(int(catalogue_TID[los_id_idx]))
    meta_night.append('')
    meta_petal.append('')
    meta_tile.append('')
    

## Create the delta files

In [36]:
# Create primary HDU 
primary_hdu = fits.PrimaryHDU(data=None)

# Create OBSERVED WAVELENGTH HDU
hdu_wave = fits.ImageHDU(LAMBDA, name=f'LAMBDA')
hdu_wave.header['HIERARCH WAVE_SOLUTION'] = 'lin'
hdu_wave.header['HIERARCH DELTA_LAMBDA'] = 0.8

# Set chunk size, and how to chunk
chunk_size = 1024

def chunks(lst, n):
    return [lst[i:i + n] for i in range(0, len(lst), n)]

_delta_chunks = chunks(delta, chunk_size)
_weight_chunks = chunks(weight, chunk_size)
_cont_chunks = chunks(sq_cont, chunk_size)

nb_of_chunks = len(_delta_chunks)
print(nb_of_chunks)

for ichunk in range(nb_of_chunks):
    i=0    
    c1 = fits.Column(name='LOS_ID', format='K', array=np.array(meta_los_id[i*chunk_size:(i+1)*chunk_size]))
    c2 = fits.Column(name='RA', format='D', array=np.array(meta_ra[i*chunk_size:(i+1)*chunk_size]))
    c3 = fits.Column(name='DEC', format='D', array=np.array(meta_dec[i*chunk_size:(i+1)*chunk_size]))
    c4 = fits.Column(name='Z', format='D', array=np.array(meta_z[i*chunk_size:(i+1)*chunk_size]))
    c5 = fits.Column(name='MEANSNR', format='D', array=meta_meansnr[i*chunk_size:(i+1)*chunk_size])
    c6 = fits.Column(name='TARGETID', format='K', array=np.array(meta_targetid[i*chunk_size:(i+1)*chunk_size]))
    c7 = fits.Column(name='NIGHT', format='12A', array=meta_night[i*chunk_size:(i+1)*chunk_size])
    c8 = fits.Column(name='PETAL', format='12A', array=meta_petal[i*chunk_size:(i+1)*chunk_size])
    c9 = fits.Column(name='TILE', format='12A', array=meta_tile[i*chunk_size:(i+1)*chunk_size])
    hdu_meta = fits.BinTableHDU.from_columns([c1, c2, c3, c4, c5, c6, c7, c8, c9], name='METADATA')
    hdu_meta.header['BLINDING'] = 'none'

    hdu_delta = fits.ImageHDU(_delta_chunks[ichunk], name=f'DELTA')
    hdu_weight = fits.ImageHDU(_weight_chunks[ichunk], name=f'WEIGHT')
    hdu_cont = fits.ImageHDU(_cont_chunks[ichunk], name=f'CONT')
    hdu_tff_bar = fits.ImageHDU(tff_bar, name=f'FBAR')

    # Combine all HDUs into an HDUList
    hdul = fits.HDUList([primary_hdu, hdu_wave, hdu_meta, hdu_delta, hdu_weight, hdu_cont, hdu_tff_bar])

    # Write the HDUList to a new FITS file
    hdul.writeto(f'{outpath}/delta-%i.fits' % ichunk, overwrite=True)
    i+=1

50
