# Trial DESI trained neural network
## 1. Extract Data from DESI data release
- refer to `intro_to_DESI_EDR_files.ipynb`

In [1]:
# import some helpful python packages 
import os
import gc
import numpy as np

from astropy.io import fits
from astropy.table import Table
from astropy.convolution import convolve, Gaussian1DKernel

import matplotlib
import matplotlib.pyplot as plt

plt.style.use('./others/desi.mplstyle')

In [2]:
# import DESI related modules - 
from desimodel.footprint import radec2pix      # For getting healpix values
import desispec.io                             # Input/Output functions related to DESI spectra
from desispec import coaddition                # Functions related to coadding the spectra

# DESI targeting masks - 
from desitarget.sv1 import sv1_targetmask    # For SV1
from desitarget.sv2 import sv2_targetmask    # For SV2
from desitarget.sv3 import sv3_targetmask    # For SV3

In [3]:
# Release directory path

specprod = 'iron'    # Internal name for the EDR
specprod_dir = '/global/cfs/cdirs/desi/spectro/redux/' + specprod
print(specprod_dir)

/global/cfs/cdirs/desi/spectro/redux/iron


In [4]:
zpix_cat_is_galaxy = np.load("./saves/zpix_cat_is_galaxy.npy")
zpix_cat = Table.read(f'{specprod_dir}/zcatalog/zall-pix-{specprod}.fits', hdu="ZCATALOG")[zpix_cat_is_galaxy]

In [5]:
print(type(zpix_cat), len(zpix_cat))

<class 'astropy.table.table.Table'> 21696490


In [6]:
zpix_cat[50:55]

TARGETID,SURVEY,PROGRAM,HEALPIX,SPGRPVAL,Z,ZERR,ZWARN,CHI2,COEFF,NPIXELS,SPECTYPE,SUBTYPE,NCOEFF,DELTACHI2,COADD_FIBERSTATUS,TARGET_RA,TARGET_DEC,PMRA,PMDEC,REF_EPOCH,FA_TARGET,FA_TYPE,OBJTYPE,SUBPRIORITY,OBSCONDITIONS,RELEASE,BRICKNAME,BRICKID,BRICK_OBJID,MORPHTYPE,EBV,FLUX_G,FLUX_R,FLUX_Z,FLUX_W1,FLUX_W2,FLUX_IVAR_G,FLUX_IVAR_R,FLUX_IVAR_Z,FLUX_IVAR_W1,FLUX_IVAR_W2,FIBERFLUX_G,FIBERFLUX_R,FIBERFLUX_Z,FIBERTOTFLUX_G,FIBERTOTFLUX_R,FIBERTOTFLUX_Z,MASKBITS,SERSIC,SHAPE_R,SHAPE_E1,SHAPE_E2,REF_ID,REF_CAT,GAIA_PHOT_G_MEAN_MAG,GAIA_PHOT_BP_MEAN_MAG,GAIA_PHOT_RP_MEAN_MAG,PARALLAX,PHOTSYS,PRIORITY_INIT,NUMOBS_INIT,CMX_TARGET,DESI_TARGET,BGS_TARGET,MWS_TARGET,SCND_TARGET,SV1_DESI_TARGET,SV1_BGS_TARGET,SV1_MWS_TARGET,SV1_SCND_TARGET,SV2_DESI_TARGET,SV2_BGS_TARGET,SV2_MWS_TARGET,SV2_SCND_TARGET,SV3_DESI_TARGET,SV3_BGS_TARGET,SV3_MWS_TARGET,SV3_SCND_TARGET,PLATE_RA,PLATE_DEC,COADD_NUMEXP,COADD_EXPTIME,COADD_NUMNIGHT,COADD_NUMTILE,MEAN_DELTA_X,RMS_DELTA_X,MEAN_DELTA_Y,RMS_DELTA_Y,MEAN_FIBER_RA,STD_FIBER_RA,MEAN_FIBER_DEC,STD_FIBER_DEC,MEAN_PSF_TO_FIBER_SPECFLUX,TSNR2_GPBDARK_B,TSNR2_ELG_B,TSNR2_GPBBRIGHT_B,TSNR2_LYA_B,TSNR2_BGS_B,TSNR2_GPBBACKUP_B,TSNR2_QSO_B,TSNR2_LRG_B,TSNR2_GPBDARK_R,TSNR2_ELG_R,TSNR2_GPBBRIGHT_R,TSNR2_LYA_R,TSNR2_BGS_R,TSNR2_GPBBACKUP_R,TSNR2_QSO_R,TSNR2_LRG_R,TSNR2_GPBDARK_Z,TSNR2_ELG_Z,TSNR2_GPBBRIGHT_Z,TSNR2_LYA_Z,TSNR2_BGS_Z,TSNR2_GPBBACKUP_Z,TSNR2_QSO_Z,TSNR2_LRG_Z,TSNR2_GPBDARK,TSNR2_ELG,TSNR2_GPBBRIGHT,TSNR2_LYA,TSNR2_BGS,TSNR2_GPBBACKUP,TSNR2_QSO,TSNR2_LRG,MAIN_NSPEC,MAIN_PRIMARY,SV_NSPEC,SV_PRIMARY,ZCAT_NSPEC,ZCAT_PRIMARY
int64,bytes7,bytes6,int32,int32,float64,float64,int64,float64,float64[10],int64,bytes6,bytes20,int64,float64,int32,float64,float64,float32,float32,float32,int64,uint8,bytes3,float64,int32,int16,bytes8,int32,int32,bytes4,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,int16,float32,float32,float32,float32,int64,bytes2,float32,float32,float32,float32,bytes1,int64,int64,int64,int64,int64,int64,int64,int64,int64,int64,int64,int64,int64,int64,int64,int64,int64,int64,int64,float64,float64,int16,float32,int16,int16,float32,float32,float32,float32,float64,float32,float64,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,int32,bool,int32,bool,int64,bool
39628483705438705,cmx,other,2153,2153,0.9549765465799412,6.337646627836233e-05,0,8277.82517927885,25.182798895163756 .. 2.449810721298287,7928,GALAXY,--,10,14.312310665845873,0,24.58350306600249,30.216567171317408,0.0,0.0,2020.9597,2048,1,TGT,0.9827835723915774,3,9010,0247p302,497017,497,REX,0.046512615,0.2737542,0.3430835,0.2783253,1.9005281,1.0421883,2061.1145,471.4325,62.777065,3.3555176,0.73318565,0.14113112,0.1768731,0.14348768,0.14117967,0.17693171,0.14355366,0,1.0,0.49542147,0.0,0.0,0,--,0.0,0.0,0.0,0.0,S,3000,1,2048,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,24.58350306600249,30.216567171317408,4,3600.0,1,1,0.00525,0.008760707,0.00075,0.004769696,24.5834789553704,0.13638209,30.21656981469769,0.06616776,0.76944447,542.6087,0.4709213,100.765854,439.72195,2308.856,681.3403,12.753676,3.945053,37538.395,81.67198,6661.2754,0.18977702,7619.6934,40402.59,24.900114,117.4269,5.9244776e-05,303.3242,1.1048722e-05,0.0,12617.437,7.5980206e-05,61.369797,134.76157,38081.004,385.4671,6762.041,439.91174,22545.984,41083.93,99.02359,256.1335,0,False,0,False,1,True
39628483701249266,cmx,other,2153,2153,0.9122369740101364,0.000212347373126,0,8758.555644992739,54.49493300752991 .. 3.040363332377466,7928,GALAXY,--,10,166.5843929760158,0,24.53251052236921,30.197231711858777,0.0,0.0,2020.9597,2048,1,TGT,0.9040460585375006,3,9010,0244p302,497016,5362,REX,0.047126617,0.4060862,0.8378453,2.7442741,9.807552,6.666676,621.8876,199.89471,20.723118,2.8979256,0.68580663,0.1537954,0.31731382,1.039328,0.15379566,0.31731454,1.0393325,0,1.0,0.74291134,0.0,0.0,0,--,0.0,0.0,0.0,0.0,S,3000,1,2048,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,24.53251052236921,30.197231711858777,4,3600.0,1,1,0.00375,0.007053368,-0.0025,0.005,24.532493007316297,0.116333865,30.197221971875877,0.060646076,0.75882584,489.52963,0.42781433,91.140045,400.69818,2103.175,623.69525,11.615481,3.5744662,35055.68,76.47339,6230.1064,0.17641726,7105.2236,38093.586,23.315447,109.80748,5.611017e-05,286.96964,1.0477452e-05,0.0,11933.254,7.262718e-05,58.150017,127.36008,35545.21,363.87085,6321.2466,400.8746,21141.652,38717.28,93.08095,240.74202,0,False,0,False,1,True
39628483705440516,cmx,other,2153,2153,0.5523515369528929,1.1002414956197951e-06,0,15212.64501953125,246.23372812182905 .. -14.177052410218682,7928,GALAXY,--,10,120530.73322296144,0,24.6817564539086,30.35761664452527,0.0,0.0,2020.9597,36028797018968064,1,TGT,0.259630378567685,7,9010,0247p302,497017,2308,PSF,0.046624493,1.7112678,2.5731533,2.3794425,9.441193,20.3394,942.6277,260.7799,40.30578,3.188242,0.6917823,1.3327067,2.0039287,1.8530699,1.3327067,2.0039287,1.8530699,0,0.0,0.0,0.0,0.0,0,--,0.0,0.0,0.0,0.0,S,3400,1,36028797018968064,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,24.6817564539086,30.35761664452527,4,3600.0,1,1,0.0035,0.0055677644,-0.003,0.0035355338,24.68173997589153,0.082922995,30.357604812822352,0.025805095,0.789,520.5887,0.37932727,96.81095,289.0088,1991.287,671.9463,10.197584,3.812738,38155.734,84.7787,6776.477,0.1922295,7662.279,42038.83,25.607819,120.53768,6.368398e-05,311.54083,1.1846729e-05,0.0,13103.11,8.2662045e-05,63.787136,138.97253,38676.324,396.69885,6873.288,289.20102,22756.676,42710.773,99.592545,263.32294,0,False,0,False,2,True
616089241234964843,cmx,other,2153,2153,0.932752477574551,2.232146096148817e-46,519,9.000000000000002e+99,0.0 .. 0.0,0,GALAXY,--,10,0.0,512,24.774719360827007,30.451198618850484,0.0,0.0,0.0,4294967296,4,SKY,0.9985601934567682,63,9010,0247p305,498263,363,--,0.05223905,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-0.009525567,-0.005927755,0.03759135,0.0,0.0,0.0,0,0.0,0.0,0.0,0.0,0,--,0.0,0.0,0.0,0.0,--,-1,-1,4294967296,4294967296,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,24.774719360827007,30.451198618850484,0,0.0,0,0,0.0,0.0,0.0,0.0,24.774719360827,0.0,30.451198618850498,0.0,0.79036385,403.01196,0.2936096,74.652664,226.46931,1543.0046,548.83136,8.054763,2.9468231,30647.582,68.139366,5425.91,0.15740857,6120.605,35578.652,21.012077,96.16447,5.1967127e-05,254.73347,9.607358e-06,0.0,10713.21,7.052646e-05,53.38695,113.1652,31050.594,323.16644,5500.563,226.62672,18376.82,36127.484,82.4538,212.27649,0,False,0,False,1,True
39628483705440434,cmx,other,2153,2153,0.2888847661877839,7.83513818147459e-05,4,8683.654407080263,7.595722946895678 .. -0.9218795682464365,7928,GALAXY,--,10,5.059906933456659,0,24.67746783368929,30.335151532387563,0.0,0.0,2020.9597,2048,1,TGT,0.8709354893059359,3,9010,0247p302,497017,2226,PSF,0.045590498,0.27525312,0.26038677,0.5035696,0.09540175,5.1039367,2439.9775,622.9312,64.016464,3.3915884,0.7444912,0.21399611,0.20243824,0.39150125,0.21405712,0.20250997,0.39161864,0,0.0,0.0,0.0,0.0,0,--,0.0,0.0,0.0,0.0,S,3000,1,2048,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,24.67746783368929,30.335151532387563,4,3600.0,1,1,0.00125,0.0040926766,-0.00325,0.0042130747,24.67746172120114,0.07472323,30.335138852024823,0.037340038,0.789,497.3344,0.35752708,92.708084,269.42096,1890.4734,647.3471,9.604626,3.614011,36530.05,81.09669,6503.726,0.18437743,7370.127,40553.75,24.519749,115.338776,6.131522e-05,303.15295,1.1424892e-05,0.0,12747.478,8.0033875e-05,62.03047,135.12991,37027.387,384.60718,6596.434,269.60535,22008.078,41201.098,96.15485,254.0827,0,False,0,False,1,True


In [7]:
# Selecting All galaxy targets
is_galaxy = zpix_cat["SPECTYPE"] == "GALAXY"
# np.save("./saves/zpix_cat_is_galaxy", is_galaxy)

In [8]:
print(type(is_galaxy), len(is_galaxy), len(zpix_cat))

<class 'numpy.ndarray'> 21696490 21696490


In [9]:
gc.collect()

0

In [10]:
# flag of confidence -> Δχ^2>40 or not
is_confident = zpix_cat["DELTACHI2"] > 40
# not_confident = zpix_cat["DELTACHI2"] < 40
# not_confident = ~is_confident

In [11]:
print(zpix_cat["SHAPE_R"].data, type(is_confident.astype(int)))
# zpix_cat["SHAPE_R"].data.shape


[0.         0.96060556 0.         ... 0.25930887 0.34113547 0.        ] <class 'numpy.ndarray'>


## 2. Transfer data to DataLoader for training

In [12]:
from torch.utils.data import Dataset, TensorDataset, DataLoader, random_split
import torch
from torch import nn

In [39]:
np_in_data = np.stack((zpix_cat["SHAPE_R"].data.astype(np.float32), zpix_cat["GAIA_PHOT_G_MEAN_MAG"].data.astype(np.float32)), axis=-1)
np_out_data = is_confident.astype(np.intp)
dataset = TensorDataset(
    torch.as_tensor(np_in_data),
    torch.as_tensor(np_out_data),
)
# np_in_data.shape, np_out_data.shape

In [40]:
gc.collect()

837

In [41]:
# Use random_split to create the splits
train_dataset, val_dataset, test_dataset = random_split(dataset, [0.8, 0.1, 0.1])

In [42]:
batch_size = 6400
# construct instances for class `DataLoader`, with parameters 
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

# `DataLoader` instances are iterable, but not subscpriptable
print(type(train_dataloader), len(train_dataloader), len(train_dataloader)*64, len(zpix_cat)*0.8)

for X, y in train_dataloader:
    print(f"info of input data: {type(X)}, {X.shape}, {X.dtype}")
    print(f"info of output flag: {type(y)}, {y.shape} {y.dtype}")
    break  # only show index 0

<class 'torch.utils.data.dataloader.DataLoader'> 2713 173632 17357192.0
info of input data: <class 'torch.Tensor'>, torch.Size([6400, 2]), torch.float32
info of output flag: <class 'torch.Tensor'>, torch.Size([6400]) torch.int64


## 3. Define neural network and training

In [43]:
# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.model = nn.Sequential(
            nn.Linear(in_features=2, out_features=16),
            nn.LeakyReLU(0.1),
            nn.Linear(in_features=16, out_features=32),
            nn.LeakyReLU(0.1),
            nn.Linear(in_features=32, out_features=1),
            nn.Sigmoid(),
        )
        
    def forward(self, x):
        return self.model(x)

model = NeuralNetwork().to(device)
print(model)

Using cuda device
NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (model): Sequential(
    (0): Linear(in_features=2, out_features=16, bias=True)
    (1): LeakyReLU(negative_slope=0.1)
    (2): Linear(in_features=16, out_features=32, bias=True)
    (3): LeakyReLU(negative_slope=0.1)
    (4): Linear(in_features=32, out_features=1, bias=True)
    (5): Sigmoid()
  )
)


In [44]:
# To train a model, also need a loss function and an optimizer (besides dataset and neural network)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)

In [45]:
# In a single training loop, the model makes predictions on the training dataset (fed to it in batches),
# and backpropagates the prediction error to adjust the model’s parameters.
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    
    model.train() # default is true, means training, false means evaluation mode
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # clear gradients for the optimizer, for next time optimization
        optimizer.zero_grad()
        
        # forward pass to predict the output with current network
        pred = model(X)
        
        # Compute the loss, output is torch.Tensor
        loss = loss_fn(pred.squeeze(), y.squeeze())

        # Backpropagation
        loss.backward()
        
        # update the weights of the network
        optimizer.step()

        if batch % 1000 == 0:
            print(f"{type(loss):}")
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


# check the model’s performance against the test dataset to ensure it is learning.
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [46]:
epochs = 1
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------


../aten/src/ATen/native/cuda/Loss.cu:250: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [1,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/Loss.cu:250: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [2,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/Loss.cu:250: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [3,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/Loss.cu:250: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [7,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/Loss.cu:250: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [8,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/Loss.cu:250: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [9,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/Loss.cu:250: nll_loss_f

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
