In [2]:
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"

import matplotlib
from matplotlib import rcParams, rc
import matplotlib.pyplot as plt
from matplotlib import colors
import smplotlib
import pandas as pd
import numpy as np
import cmasher as cmr
cmap = cmr.dusk

from tqdm import tqdm
import joblib 

import sys
sys.path.append("/home/jdli/AspGap")

import time
import os
import torch
from torch.utils.data import DataLoader, SubsetRandomSampler
from sklearn.model_selection import KFold
from aspgap.xpformer import MLP, MLP_upsampling
from aspgap.data import XPAP4l, XPAP4l_infer
from aspgap.utils import *
from aspgap.vis import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Read Gaia's catalog

In [25]:
df = pd.read_csv("/nfsdata/share/gaiadr3_spec_v2.psv", sep="|", nrows=20)

print(df.shape, df.columns)

(20, 8) Index(['source_id', 'ra', 'dec', 'bp_coefficients', 'bp_coefficient_errors',
       'rp_coefficients', 'rp_coefficient_errors', 'phot_g_mean_mag'],
      dtype='object')


In [36]:
bp_scaler_name    = "/home/jdli/AspGap/models/scaler_bp_gmagand_0330.gz"
rp_scaler_name    = "/home/jdli/AspGap/models/scaler_rp_gmagand_0330.gz"
label_scaler_name = "/home/jdli/AspGap/models/scaler_labels_0330.gz"

scaler_bp    = joblib.load(bp_scaler_name)
scaler_rp    = joblib.load(rp_scaler_name)
scaler_label = joblib.load(label_scaler_name)

def recover_coef(series):
    return [list(map(float, s[1:-1].split(','))) for s in series]


def recover_coef_2(series):
    coef_recv = []
    
    for s in series:
        
        if type(s) == str:
            coef_recv.append(list(map(float, s[1:-1].split(','))))
            
        elif type(s) == list:
            coef_recv.append(s)
            
        else:
            print(type(s))
            # print(s)
            
    return coef_recv

def l2norm(y):
    return np.sqrt(np.sum(y**2, axis=1))

coef_names = ['bp_coefficients', 'bp_coefficient_errors', 
              'rp_coefficients', 'rp_coefficient_errors']

for col in coef_names:
    df[col] = recover_coef_2(df[col])
    
xp_coeff_array = np.zeros([4, len(df), 55])

for i in range(4):
    xp_coeff_array[i] = np.array([np.array(x) for x in df[coef_names[i]].values])


In [47]:
tr_file = "test100.npy"
device = torch.device('cuda:0')

fold = 0
epoch = 1000

n_lat = 1024
n_xp  =  110
n_labels = 8

model_dir = f"/data/jdli/gaia/model/0418/"
enc_point = model_dir + f"xp2_lat_{fold}_ep{epoch}.pt"
dec_point = model_dir + f"lat2_4lerr_{fold}_ep{epoch}.pt"


decoder = MLP(n_lat, n_labels).to(device)
encoder = MLP_upsampling(n_xp, n_lat, hidden_size=256).to(device)

encoder.load_state_dict(
    remove_prefix(
        torch.load(enc_point)
    )
)
decoder.load_state_dict(
    remove_prefix(
        torch.load(dec_point)
    )
)

def recover_scale_label(y, e_y, scaler_name='../models/scaler_labels_0330.gz'):
    scaler_labels = joblib.load(scaler_name)
    y   = scaler_labels.inverse_transform(y)
    e_y = e_y * scaler_labels.scale_
    return y, e_y

def inference_from_model(data, source_id, encoder, decoder, 
                         transcale_method=recover_scale_label, 
                         device=torch.device('cuda:0'), 
                         pred_names   = ['teff_xp', 'logg_xp',  'moh_xp', 'aom_xp'],
                         e_pred_names = ['e_teff_xp', 'e_logg_xp',  'e_moh_xp', 'e_aom_xp']):
    
    with torch.no_grad():
        
        z = encoder(torch.tensor(data.astype(np.float32), device=device))
        output = decoder(z).cpu().numpy()
    
    y_pred, e_pred = transcale_method(output[:,:4], output[:,4:])
    
    rdf = pd.DataFrame(np.c_[y_pred, np.abs(e_pred), source_id], 
                       columns=pred_names+e_pred_names+['source_id'])
    
    rdf.source_id = rdf.source_id.astype(np.int64)
    rdf = rdf.reset_index(drop=True)
    return rdf

In [9]:
# Set the chunk size
chunksize = 2**20  # read 524,288 rows at a time

psv_file_path = "/nfsdata/share/gaiadr3_spec_v2.psv"
# total: 219197643 rows
# total: 209 * 2**20 rows
chunks = []

chunks = pd.read_csv(psv_file_path, chunksize=chunksize, sep="|")

# print(len(df_raw_tot))

# chunksize = 2**20  # read 1048576 rows at a time

# len_chunk = int(len(df_raw_tot)/chunksize)
# print(len_chunk)

# 219197643
# 209

In [48]:
def infer_pipeline(chunk, j,
                   coef_names = ['bp_coefficients', 'bp_coefficient_errors', 
                                 'rp_coefficients', 'rp_coefficient_errors']):
    
    # 1. begin 
    start = time.time()
    
    # 2. drop null value
    chunk = chunk.dropna()
    chunk.reset_index(inplace=True)
    
    # 3. data normalization
    for col in coef_names:
        chunk[col] = recover_coef_2(chunk[col])
        
    xp_coeff_array = np.zeros([4, len(chunk), 55])

    for i in range(4):
        xp_coeff_array[i] = np.array([np.array(x) for x in chunk[coef_names[i]].values])
        
    gmag_norm = 10**((15.-chunk['phot_g_mean_mag'].values)*0.4)

    norm_bp = xp_coeff_array[0,:,:]/gmag_norm[:,None]
    norm_rp = xp_coeff_array[2,:,:]/gmag_norm[:,None]

    norm_bp = scaler_bp.transform(norm_bp)
    norm_rp = scaler_rp.transform(norm_rp)

    snr_rp_global = l2norm(xp_coeff_array[2,:,:])/l2norm(xp_coeff_array[3,:,:])
    
    xp_dict = {
        "xp":np.c_[norm_bp, norm_rp],
        "source_id":chunk['source_id'].values.astype(np.int64),
        "snr_rp":snr_rp_global
    }

    data_dir = "/nfsdata/users/jdli_ny/xp_chunks_v2_3/"
    fname    = f"gxp_chunk_1m_{j}.npy"
    
    np.save(data_dir+fname, xp_dict)
    
    # 4. inference labels
    
    rdf = inference_from_model(
        xp_dict['xp'], xp_dict['source_id'].astype(np.int64), 
        encoder, decoder, device=device, 
        transcale_method=recover_scale_label
    )
    
    rdf = rdf.assign(
        ra=chunk['ra'], dec=chunk['dec'],
        snr_rp=snr_rp_global
    )
    
    save_csv_file = f"gxp_chunk_1m_{j}.csv"

    rdf.to_csv(f"/nfsdata/users/jdli_ny/xp_infers_v2_3/"+save_csv_file, index=False)
    
    print(f"{save_csv_file}")
    print("inference %.4f s"%(time.time()-start))
    return rdf

In [49]:
# test

infer_pipeline(df, 0)

gxp_chunk_1m_0.csv
inference 0.0540 s


Unnamed: 0,teff_xp,logg_xp,moh_xp,aom_xp,e_teff_xp,e_logg_xp,e_moh_xp,e_aom_xp,source_id,ra,dec,snr_rp
0,3975.310059,4.532993,-0.054076,-0.056719,49.376586,0.059108,0.065527,0.01101,1362344320414889472,264.126679,45.218317,170.569658
1,4389.575684,4.56296,-0.070604,-0.019476,32.910546,0.045011,0.052882,0.011156,1362344389134369792,264.001796,45.154085,305.196531
2,4088.169678,2.620805,0.023279,0.043482,221.858306,0.39751,0.175058,0.043264,1362344423494108416,264.013233,45.166039,89.323299
3,5194.047852,4.418731,-0.046342,0.024083,137.892325,0.155291,0.085996,0.036174,1362344492213586944,264.005863,45.179811,154.505595
4,4155.721191,4.564548,-0.089939,-0.036373,73.653778,0.075586,0.075012,0.016573,1362344556637634560,264.051443,45.192208,92.723461
5,5187.661621,4.470252,-0.15646,0.002572,110.897945,0.094193,0.070313,0.022892,1362344560933061376,264.046145,45.189548,367.083265
6,5365.573242,4.36902,0.096144,-0.009051,153.396724,0.217769,0.087009,0.022718,1362344728436327680,263.986907,45.191721,60.76759
7,5448.05957,4.274558,-0.568827,0.109509,98.697246,0.089359,0.084114,0.045187,1362344831515494400,264.048456,45.217448,277.728444
8,3613.481934,4.142218,-0.430462,0.004934,170.547059,0.262585,0.164563,0.038408,1362345076327526912,264.115274,45.225474,103.386644
9,4788.350098,3.491538,-0.371178,0.167414,68.630397,0.166428,0.078717,0.047747,1362345076329134080,264.129124,45.231002,299.2065


In [None]:
# Iterate over the CSV file in chunks and append each chunk to the list


for j,df in enumerate(chunks):
    
    
    rdf = infer_pipeline(df, j)

gxp_chunk_1m_0.csv
inference 260.9092 s
gxp_chunk_1m_1.csv
inference 264.5487 s
gxp_chunk_1m_2.csv
inference 249.6582 s
gxp_chunk_1m_3.csv
inference 249.9922 s
gxp_chunk_1m_4.csv
inference 243.0332 s
gxp_chunk_1m_5.csv
inference 234.6982 s
gxp_chunk_1m_6.csv
inference 243.7788 s
gxp_chunk_1m_7.csv
inference 241.0805 s
gxp_chunk_1m_8.csv
inference 240.5681 s
gxp_chunk_1m_9.csv
inference 225.6564 s
gxp_chunk_1m_10.csv
inference 231.3998 s
gxp_chunk_1m_11.csv
inference 231.6539 s
gxp_chunk_1m_12.csv
inference 230.1653 s
gxp_chunk_1m_13.csv
inference 228.8529 s
gxp_chunk_1m_14.csv
inference 227.9130 s
gxp_chunk_1m_15.csv
inference 237.4127 s
gxp_chunk_1m_16.csv
inference 225.4224 s
gxp_chunk_1m_17.csv
inference 235.8121 s
gxp_chunk_1m_18.csv
inference 233.4452 s
gxp_chunk_1m_19.csv
inference 227.0047 s
gxp_chunk_1m_20.csv
inference 229.1247 s
gxp_chunk_1m_21.csv
inference 224.2300 s
gxp_chunk_1m_22.csv
inference 237.0925 s
gxp_chunk_1m_23.csv
inference 222.3349 s
gxp_chunk_1m_24.csv
infere

In [None]:
!ls /nfsdata/users/jdli_ny/

In [7]:
! mkdir /nfsdata/users/jdli_ny/xp_infers_v2_3
# ! mkdir /nfsdata/users/jdli_ny/xp_chunks_v2

In [63]:
! ls /nfsdata/users/jdli_ny/xp_infers_v2_2

gxp_chunk_1m_0.csv  gxp_chunk_1m_1.csv


In [50]:
! ls /nfsdata/users/jdli_ny/xp_chunks_v2_3


gxp_chunk_1m_0.npy
