In [1]:
#----------------------------------------------------
# Initial notebook for making the Illustris cosmic graph from galaxy catalogues
# Author: Christian Kragh Jespersen
# First created: 11/03/23 @KITP
# note: to make into routine later
#----------------------------------------------------

import h5py
from torch_geometric.data import Data, DataLoader
import scipy.spatial as ss
import os, sys
import os.path as osp
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

Nstar_th = 20   # Minimum number of stellar particles required to consider a galaxy

# Compute KDTree and get edges and edge features
def get_edges(pos, r_link, use_loops):

    # 1. Get edges

    # Create the KDTree and look for pairs within a distance r_link
    # Boxsize normalize to 1
    kd_tree = SS.KDTree(pos, leafsize=16, boxsize=1.0001)
    edge_index = kd_tree.query_pairs(r=r_link, output_type="ndarray")

    # Add reverse pairs
    reversepairs = np.zeros((edge_index.shape[0],2))
    for i, pair in enumerate(edge_index):
        reversepairs[i] = np.array([pair[1], pair[0]])
    edge_index = np.append(edge_index, reversepairs, 0)

    edge_index = edge_index.astype(int)

    # Write in pytorch-geometric format
    edge_index = edge_index.reshape((2,-1))
    num_pairs = edge_index.shape[1]

    # 2. Get edge attributes

    row, col = edge_index
    diff = pos[row]-pos[col]

    # Take into account periodic boundary conditions, correcting the distances
    for i, pos_i in enumerate(diff):
        for j, coord in enumerate(pos_i):
            if coord > r_link:
                diff[i,j] -= 1.  # Boxsize normalize to 1
            elif -coord > r_link:
                diff[i,j] += 1.  # Boxsize normalize to 1

    # Get translational and rotational invariant features
    # Distance
    dist = np.linalg.norm(diff, axis=1)
    # Centroid of galaxy catalogue
    centroid = np.mean(pos,axis=0)
    # Unit vectors of node, neighbor and difference vector
    unitrow = (pos[row]-centroid)/np.linalg.norm((pos[row]-centroid), axis=1).reshape(-1,1)
    unitcol = (pos[col]-centroid)/np.linalg.norm((pos[col]-centroid), axis=1).reshape(-1,1)
    unitdiff = diff/dist.reshape(-1,1)
    # Dot products between unit vectors
    cos1 = np.array([np.dot(unitrow[i,:].T,unitcol[i,:]) for i in range(num_pairs)])
    cos2 = np.array([np.dot(unitrow[i,:].T,unitdiff[i,:]) for i in range(num_pairs)])
    # Normalize distance by linking radius
    dist /= r_link

    # Concatenate to get all edge attributes
    edge_attr = np.concatenate([dist.reshape(-1,1), cos1.reshape(-1,1), cos2.reshape(-1,1)], axis=1)

    # Add loops
    if use_loops:
        loops = np.zeros((2,pos.shape[0]),dtype=int)
        atrloops = np.zeros((pos.shape[0],3))
        for i, posit in enumerate(pos):
            loops[0,i], loops[1,i] = i, i
            atrloops[i,0], atrloops[i,1], atrloops[i,2] = 0., 1., 0.
        edge_index = np.append(edge_index, loops, 1)
        edge_attr = np.append(edge_attr, atrloops, 0)
    edge_index = edge_index.astype(int)

    return edge_index, edge_attr


# Routine to create a cosmic graph from a galaxy catalogue
# simnumber: number of simulation
# param_file: file with the value of the cosmological + astrophysical parameters
# hparams: hyperparameters class
def sim_graph(simnumber, param_file, hparams):

    # Get some hyperparameters
    simsuite,simset,r_link,only_positions,outmode,pred_params = hparams.simsuite,hparams.simset,hparams.r_link,hparams.only_positions,hparams.outmode,hparams.pred_params

    # Name of the galaxy catalogue
    simpath = simpathroot + simsuite + "/"+simset+"_"
    catalogue = simpath + str(simnumber)+"/fof_subhalo_tab_0"+hparams.snap+".hdf5"

    # Read the catalogue
    f     = h5py.File(catalogue, 'r')
    pos   = f['/Subhalo/SubhaloPos'][:]/boxsize
    Mstar = f['/Subhalo/SubhaloMassType'][:,4] #Msun/h
    Rstar = f["Subhalo/SubhaloHalfmassRadType"][:,4]
    Metal = f["Subhalo/SubhaloStarMetallicity"][:]
    Vmax = f["Subhalo/SubhaloVmax"][:]
    Nstar = f['/Subhalo/SubhaloLenType'][:,4]       #number of stars
    f.close()

    # Some simulations are slightly outside the box, correct it
    pos[np.where(pos<0.0)]+=1.0
    pos[np.where(pos>1.0)]-=1.0

    # Select only galaxies with more than Nstar_th star particles
    indexes = np.where(Nstar>Nstar_th)[0]
    pos     = pos[indexes]
    Mstar   = Mstar[indexes]
    Rstar   = Rstar[indexes]
    Metal   = Metal[indexes]
    Vmax   = Vmax[indexes]

    # Get the output to be predicted by the GNN, either the cosmo parameters or the power spectrum
    if outmode=="cosmo":
        # Read the value of the cosmological & astrophysical parameters
        paramsfile = np.loadtxt(param_file, dtype=str)
        params = np.array(paramsfile[simnumber,1:-1],dtype=np.float32)
        params = normalize_params(params)
        params = params[:pred_params]   # Consider only the first parameters, up to pred_params
        y = np.reshape(params, (1,params.shape[0]))

    # Read the power spectra
    elif outmode=="ps":

        ps = np.load(param_file)
        ps = ps[simnumber]
        ps = np.log10(ps)
        #ps = normalize_ps(ps)
        y = np.reshape(ps, (1,ps_size))

    # Number of galaxies as global feature
    u = np.log10(pos.shape[0]).reshape(1,1)

    Mstar = np.log10(1.+ Mstar)
    Rstar = np.log10(1.+ Rstar)
    Metal = np.log10(1.+ Metal)
    Vmax = np.log10(1. + Vmax)

    # Node features
    tab = np.column_stack((Mstar, Rstar, Metal, Vmax))
    #tab = Vmax.reshape(-1,1)       # For using only Vmax
    x = torch.tensor(tab, dtype=torch.float32)

    # Use loops if node features are considered only
    if only_positions:
        tab = np.zeros_like(pos[:,:1])   # Node features not really used
        use_loops = False
    else:
        use_loops = True

    # Get edges and edge features
    edge_index, edge_attr = get_edges(pos, r_link, use_loops)

    # Construct the graph
    graph = Data(x=x,
                 y=torch.tensor(y, dtype=torch.float32),
                 u=torch.tensor(u, dtype=torch.float32),
                 edge_index=torch.tensor(edge_index, dtype=torch.long),
                 edge_attr=torch.tensor(edge_attr, dtype=torch.float32))

    return graph


# Split training and validation sets
def split_datasets(dataset):

    random.shuffle(dataset)

    num_train = len(dataset)
    split_valid = int(np.floor(valid_size * num_train))
    split_test = split_valid + int(np.floor(test_size * num_train))

    train_dataset = dataset[split_test:]
    valid_dataset = dataset[:split_valid]
    test_dataset = dataset[split_valid:split_test]

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

    return train_loader, valid_loader, test_loader

######################################################################################

# Main routine to load data and create the dataset
def create_dataset(hparams):

    # Target file depending on the task: inferring cosmo parameters or predicting power spectrum
    if hparams.outmode == "cosmo":
        param_file = "/projects/QUIJOTE/CAMELS/Sims/CosmoAstroSeed_params_"+hparams.simsuite+".txt"
    elif hparams.outmode == "ps":
        param_file = "PS_files/Pk_galaxies_"+hparams.simsuite+"_LH_"+hparams.snap+"_kmax=20.0.npy"

    dataset = []

    for simnumber in range(hparams.n_sims):
        dataset.append(sim_graph(simnumber,param_file,hparams))

    # Add the other suite for predicting the power spectrum
    if hparams.outmode == "ps":
        hparams.simsuite = hparams.flip_suite()
        param_file = "PS_files/Pk_galaxies_"+hparams.simsuite+"_LH_"+hparams.snap+"_kmax=20.0.npy"

        for simnumber in range(hparams.n_sims):
            dataset.append(sim_graph(simnumber,param_file,hparams))

        # Add other snapshots from other redshifts
        # Snapshot redshift
        # 004: z=3, 010: z=2, 014: z=1.5, 018: z=1, 024: z=0.5, 033: z=0
        #for snap in [24,18,14,10]:
        for snap in [18,10]:

            hparams.snap = str(snap)

            param_file = "PS_files/Pk_galaxies_"+hparams.simsuite+"_LH_"+hparams.snap+"_kmax=20.0.npy"

            for simnumber in range(hparams.n_sims):
                dataset.append(sim_graph(simnumber,param_file,hparams))

            hparams.simsuite = hparams.flip_suite()
            param_file = "PS_files/Pk_galaxies_"+hparams.simsuite+"_LH_"+hparams.snap+"_kmax=20.0.npy"

            for simnumber in range(hparams.n_sims):
                dataset.append(sim_graph(simnumber,param_file,hparams))

    gals = np.array([graph.x.shape[0] for graph in dataset])
    print("Total of galaxies", gals.sum(0), "Mean of", gals.mean(0),"per simulation, Std of", gals.std(0))

    return 

In [2]:
## get illustris positions
import h5py
import illustris_python as il
tng_base_path = osp.expanduser("~/../../scratch/gpfs/cj1223/TNG50")
snapshot = 99
subhalo_fields = ["SubhaloPos", "SubhaloMassType"]

In [3]:
subhalos = il.groupcat.loadSubhalos(tng_base_path, snapshot, fields=subhalo_fields) 

pos = subhalos["SubhaloPos"][:,:3]
pos

array([[ 7307.2407, 24550.361 , 21302.58  ],
       [ 6811.0596, 24897.572 , 21219.457 ],
       [ 6774.9155, 23864.58  , 21089.37  ],
       ...,
       [30997.785 , 10243.298 , 32577.902 ],
       [33109.76  , 24575.389 ,  6742.7163],
       [31731.814 , 23896.273 ,  6157.3115]], dtype=float32)

In [4]:
min_box, max_box = np.rint(np.min(pos)), np.rint(np.max(pos))

h =0.7

box_size = max_box/(h**1e3) #/(h*1000), pos units are in kpc

In [5]:
# normalization_params = dict(
#     minimum_n_star_particles=10., # min star particles to be considered a galaxy
#     norm_half_mass_radius=8., 
#     norm_velocity=100., # note: use value of 1 if `use_central_galaxy_frame=True`
# )

In [6]:
subhalo_fields = [
        "SubhaloPos", "SubhaloMassType", "SubhaloLenType", "SubhaloHalfmassRadType", 
        "SubhaloVel", "SubhaloVmax", "SubhaloGrNr", "SubhaloFlag"
    ]
subhalos = il.groupcat.loadSubhalos(tng_base_path, snapshot, fields=subhalo_fields) 

halo_fields = ["Group_M_Crit200", "GroupFirstSub", "GroupPos", "GroupVel"]
halos = il.groupcat.loadHalos(tng_base_path, snapshot, fields=halo_fields)

subhalo_pos = subhalos["SubhaloPos"][:] / (h*1e3) #/(h*1000), pos units are in comoving kpc, so now in Mpc
subhalo_stellarmass = subhalos["SubhaloMassType"][:,4]
subhalo_halomass = subhalos["SubhaloMassType"][:,1]
subhalo_n_stellar_particles = subhalos["SubhaloLenType"][:,4]
subhalo_stellarhalfmassradius = subhalos["SubhaloHalfmassRadType"][:,4] #normalize?
subhalo_vel = subhalos["SubhaloVel"][:] #normalize?
subhalo_vmax = subhalos["SubhaloVmax"][:] #normalize?
subhalo_flag = subhalos["SubhaloFlag"][:]
halo_id = subhalos["SubhaloGrNr"][:]

halo_mass = halos["Group_M_Crit200"][:]
halo_primarysubhalo = halos["GroupFirstSub"][:]  # currently not used but might be good for magnitude gap
group_pos = halos["GroupPos"][:] / (h*1e3)
group_vel = halos["GroupVel"][:]  #normalize?

# get subhalos/galaxies      
subhalos = pd.DataFrame(
    np.column_stack([halo_id, subhalo_flag, np.arange(len(subhalo_stellarmass)), subhalo_pos, subhalo_vel, subhalo_n_stellar_particles, subhalo_stellarmass, subhalo_halomass, subhalo_stellarhalfmassradius, subhalo_vmax]), 
    columns=['halo_id', 'subhalo_flag', 'subhalo_id', 'subhalo_x', 'subhalo_y', 'subhalo_z', 'subhalo_vx', 'subhalo_vy', 'subhalo_vz', 'subhalo_n_stellar_particles', 'subhalo_stellarmass', 'subhalo_halomass', 'subhalo_stellarhalfmassradius', 'subhalo_vmax'],
)
subhalos = subhalos[subhalos["subhalo_flag"] != 0].copy()
subhalos['halo_id'] = subhalos['halo_id'].astype(int)
subhalos['subhalo_id'] = subhalos['subhalo_id'].astype(int)

subhalos.drop("subhalo_flag", axis=1, inplace=True)

In [7]:
cuts = {"minimum_log_stellar_mass": 9,
        "minimum_log_halo_mass": 8,
       "minimum_n_star_particles": 100}

In [8]:
# impose stellar mass and particle cuts
subhalos = subhalos[subhalos["subhalo_n_stellar_particles"] > cuts["minimum_n_star_particles"]].copy()
subhalos["subhalo_logstellarmass"] = np.log10(subhalos["subhalo_stellarmass"])+10

subhalos["subhalo_loghalomass"] = np.log10(subhalos["subhalo_halomass"])+10
subhalos["subhalo_logvmax"] = np.log10(subhalos["subhalo_vmax"])
subhalos["subhalo_logstellarhalfmassradius"] = np.log10(subhalos["subhalo_stellarhalfmassradius"])

subhalos = subhalos[subhalos["subhalo_loghalomass"] > cuts["minimum_log_halo_mass"]].copy()

subhalos = subhalos[subhalos["subhalo_logstellarmass"] > cuts["minimum_log_stellar_mass"]].copy()

subhalos.reset_index(drop = True)

df = subhalos

# remove extraneous columns
df.drop(["subhalo_n_stellar_particles", "subhalo_stellarmass", "subhalo_halomass"], axis=1, inplace=True)


  result = getattr(ufunc, method)(*inputs, **kwargs)


In [9]:
pos = np.vstack(df[['subhalo_x', 'subhalo_y', 'subhalo_z']].to_numpy())

In [10]:
pos, len(pos)

(array([[10.43891525, 35.07194519, 30.4322567 ],
        [ 9.73008537, 35.56795883, 30.31351089],
        [ 9.67845058, 34.09225845, 30.12767029],
        ...,
        [11.85899734, 34.01246262, 28.89342117],
        [10.23868847, 28.60673904, 30.6593132 ],
        [12.11908245, 35.60639191, 31.51264191]]),
 2497)

In [40]:
r_link = 4

kd_tree = ss.KDTree(pos, leafsize=16, boxsize=max_box)
edge_index = kd_tree.query_pairs(r=r_link, output_type="ndarray")

In [41]:
len(edge_index)

32553

In [42]:
undirected = True
periodic = False
use_loops = False

if undirected:
# Add reverse pairs
    reversepairs = np.zeros((edge_index.shape[0],2))
    for i, pair in enumerate(edge_index):
        reversepairs[i] = np.array([pair[1], pair[0]])
    edge_index = np.append(edge_index, reversepairs, 0)

    edge_index = edge_index.astype(int)

    # Write in pytorch-geometric format
    edge_index = edge_index.reshape((2,-1))
    num_pairs = edge_index.shape[1]

In [43]:
row, col = edge_index

diff = pos[row]-pos[col]
dist = np.linalg.norm(diff, axis=1)

use_gal = True

if periodic:
    # Take into account periodic boundary conditions, correcting the distances
    for i, pos_i in enumerate(diff):
        for j, coord in enumerate(pos_i):
            if coord > r_link:
                diff[i,j] -= box_size  # Boxsize normalize to 1
            elif -coord > r_link:
                diff[i,j] += box_size  # Boxsize normalize to 1


In [44]:
centroid = np.mean(pos,axis=0) # define arbitrary coordinate, invarinat to translation/rotation shifts, but not stretches
# centroid+=1.2

unitrow = (pos[row]-centroid)/np.linalg.norm((pos[row]-centroid), axis=1).reshape(-1,1)
unitcol = (pos[col]-centroid)/np.linalg.norm((pos[col]-centroid), axis=1).reshape(-1,1)
unitdiff = diff/dist.reshape(-1,1)
# Dot products between unit vectors
cos1 = np.array([np.dot(unitrow[i,:].T,unitcol[i,:]) for i in range(num_pairs)])
cos2 = np.array([np.dot(unitrow[i,:].T,unitdiff[i,:]) for i in range(num_pairs)])

In [45]:
edge_attr = np.concatenate([dist.reshape(-1,1), cos1.reshape(-1,1), cos2.reshape(-1,1)], axis=1)

In [46]:
if use_loops:
    loops = np.zeros((2,pos.shape[0]),dtype=int)
    atrloops = np.zeros((pos.shape[0],3))
    for i, posit in enumerate(pos):
        loops[0,i], loops[1,i] = i, i
        atrloops[i,0], atrloops[i,1], atrloops[i,2] = 0., 1., 0.
    edge_index = np.append(edge_index, loops, 1)
    edge_attr = np.append(edge_attr, atrloops, 0)
edge_index = edge_index.astype(int)

In [47]:
from torch_geometric.data import Data
import torch

In [48]:
if use_gal:
    use_cols = ['subhalo_x', 'subhalo_y', 'subhalo_z', 'subhalo_vx', 'subhalo_vy', 'subhalo_vz','subhalo_logstellarmass', 'subhalo_stellarhalfmassradius']
    y_cols = ['subhalo_loghalomass', 'subhalo_logvmax'] 
if not use_gal:
    use_cols = ['subhalo_x', 'subhalo_y', 'subhalo_z', 'subhalo_vx', 'subhalo_vy', 'subhalo_vz', 'subhalo_loghalomass', 'subhalo_logvmax'] 
    y_cols = ['subhalo_logstellarmass', 'subhalo_stellarhalfmassradius']
x =  torch.tensor(np.vstack(df[use_cols].to_numpy()), dtype=torch.float)
y =  torch.tensor(np.vstack(df[y_cols].to_numpy()), dtype=torch.float)
edge_index = torch.tensor(edge_index, dtype=torch.long)
edge_attr=torch.tensor(edge_attr, dtype=torch.float)

In [49]:
data = Data(x = x, y= y, edge_index = edge_index, edge_attr = edge_attr)

In [50]:
data

Data(x=[2497, 8], edge_index=[2, 65106], edge_attr=[65106, 3], y=[2497, 2])

In [51]:
import torch_geometric.utils as tg_utils
nx = tg_utils.to_networkx(data)

In [52]:
def visualize_graph(data, draw_edges = True, sizes=0.1, projection="3d", edge_index=None, fontsize = 11):

    fig = plt.figure(figsize=(4, 4))

    if projection=="3d":
        ax = fig.add_subplot(projection ="3d")
        pos = data.x[:,:3]
    elif projection=="2d":
        ax = fig.add_subplot()
        pos = data.x[:,:2]

    # Draw lines for each edge
    if data.edge_index is not None and draw_edges:
        for (src, dst) in data.edge_index.t().tolist():

            src = pos[src].tolist()
            dst = pos[dst].tolist()
    #         print(src, dst
            if projection=="3d":
                ax.plot([src[0], dst[0]], [src[1], dst[1]], zs=[src[2], dst[2]], linewidth=sizes/r_link, color='black')
            elif projection=="2d":
                ax.plot([src[0], dst[0]], [src[1], dst[1]], linewidth=sizes/r_link, color='black')

    # Plot nodes
    if projection=="3d":
        ax.scatter(pos[:, 0], pos[:, 1], pos[:, 2], s=sizes, zorder=1000, alpha=0.5)
    elif projection=="2d":
        ax.scatter(pos[:, 0], pos[:, 1], s=sizes, zorder=1000, alpha=0.5)

    ax.xaxis.set_tick_params(labelsize=fontsize)
    ax.yaxis.set_tick_params(labelsize=fontsize)
    ax.zaxis.set_tick_params(labelsize=fontsize)

    fig.savefig("figs/graph_for_John.png", bbox_inches='tight', dpi=300)
    plt.close(fig)


In [53]:
visualize_graph(data)

In [25]:
data.edge_index.t().tolist()

[[799, 800],
 [800, 799],
 [800, 803],
 [803, 800],
 [800, 802],
 [802, 800],
 [800, 801],
 [801, 800],
 [799, 803],
 [803, 799],
 [799, 802],
 [802, 799],
 [799, 801],
 [801, 799],
 [799, 2349],
 [2349, 799],
 [802, 803],
 [803, 802],
 [801, 803],
 [803, 801],
 [803, 2349],
 [2349, 803],
 [411, 2415],
 [2415, 411],
 [408, 2415],
 [2415, 408],
 [406, 2415],
 [2415, 406],
 [801, 802],
 [802, 801],
 [408, 411],
 [411, 408],
 [406, 411],
 [411, 406],
 [801, 2349],
 [2349, 801],
 [406, 408],
 [408, 406],
 [800, 2113],
 [2113, 800],
 [799, 2113],
 [2113, 799],
 [803, 2113],
 [2113, 803],
 [407, 2415],
 [2415, 407],
 [405, 2415],
 [2415, 405],
 [802, 2113],
 [2113, 802],
 [407, 411],
 [411, 407],
 [405, 411],
 [411, 405],
 [801, 2113],
 [2113, 801],
 [407, 408],
 [408, 407],
 [405, 408],
 [408, 405],
 [406, 407],
 [407, 406],
 [405, 406],
 [406, 405],
 [729, 730],
 [730, 729],
 [1117, 1400],
 [1400, 1117],
 [1003, 2397],
 [2397, 1003],
 [405, 407],
 [407, 405],
 [823, 824],
 [824, 823],
 [82