In [None]:
# ------------------------------------------------------------------------
#
# TITLE - investigate_tng_api.ipynb
# AUTHOR - James Lane
# PROJECT - tng-dfs
#
# ------------------------------------------------------------------------
#
# Docstrings and metadata:
'''Try out the TNG Illustris python API
'''

__author__ = "James Lane"

In [None]:
### Imports

## Basic
import numpy as np
import sys, os
import h5py
import glob
import copy
import dill as pickle

## Matplotlib
from matplotlib import pyplot as plt
import matplotlib.image as mpimg

## Fitting
import scipy.optimize

sys.path.insert(0,'../../src/')
from tng_dfs import util as putil
from tng_dfs.util import get

In [None]:
### Notebook setup
%matplotlib inline
plt.style.use('../../src/mpl/project.mplstyle') # This must be exactly here
%config InlineBackend.figure_format = 'retina'
%load_ext autoreload
%autoreload 2

In [None]:
# Keywords
cdict = putil.load_config_to_dict()
keywords = ['DATA_DIR','RO','VO','ZO','LITTLE_H']
data_dir,ro,vo,zo,h = putil.parse_config_dict(cdict,keywords)

# API Setup

In [None]:
# base URL
baseURL = 'http://www.tng-project.org/api/'

In [None]:
# Get list of simulations
r = get(baseURL)
sim_names = [sim['name'] for sim in r['simulations']]
tng50_indices = [sim_names.index('TNG50-'+str(i+1)) for i in range(4)]
# Choose the lowest resolution tng50 run
tng50_urls = [r['simulations'][i]['url'] for i in tng50_indices]
tng50_url = tng50_urls[0]

# Get Milky Way Analogs from TNG50-1

In [None]:
# Get the simulation
sim = get( tng50_url )
# Shows properties of the simulation
# sim.keys()

# Get a list of snapshots for the simulation
# Each element is dict with some info
snaps = get( sim['snapshots'] )
# snaps[-1].keys()

# Redshifts corresponding to snapshots
snap_zs = [snap['redshift'] for snap in snaps]

# Pick the present-day snapshot
snap0 = get( snaps[-1]['url'] )
# Shows properties of the snapshot
# snap0.keys()

# Query the API for subhalos with stellar mass in a range near that of the 
# Milky Way
mw_mass_range = np.array([5,7])*1e10
mw_mass_range_code = putil.mass_physical_to_code(mw_mass_range,h=h,e10=True)
mw_search_query = '?mass_stars__gt='+str(mw_mass_range_code[0])+\
                       '&mass_stars__lt='+str(mw_mass_range_code[1])+\
                       '&primary_flag__gt=0'
mw_search_results = get( snap0['subhalos']+mw_search_query )['results']
print(str(len(mw_search_results))+' Milky way like galaxies found')
n_mw = len(mw_search_results)

# Get subhalo data
force_mwsubs = False
mwsubs_path = data_dir+'subs/mwsubs.pkl'
if force_mwsubs or os.path.exists(mwsubs_path) == False:
    print('Downloading subhalo data')
    mwsubs = []
    for i in range(len(mw_search_results)):
        mwsubs.append( get( mw_search_results[i]['url'], timeout=None ) )

    # Save subhalo data
    print('Saving subhalo data to '+mwsubs_path)
    with open(mwsubs_path,'wb') as f:
        pickle.dump(mwsubs,f)
else:
    print('Loading subhalo data from '+mwsubs_path)
    with open(mwsubs_path,'rb') as f:
        mwsubs = pickle.load(f)
    print(mwsubs_path+' has '+str(len(mwsubs))+' subhalos')

# Convert to recarray 
mwsubs_dict = copy.deepcopy(mwsubs)
mwsubs = putil.subhalo_list_to_recarray(mwsubs)

# Examine the Milky Way analogs compared to a range of masses

In [None]:
print(mwsubs.dtype.names)

In [None]:
# Other mass ranges to consider
gal_mass_ranges = np.array([[1,3],[3,5],[7,9],[9,11]])*1e10

mass_range_filename = './data/massrange.pkl'
force_mass_range = False
if force_mass_range or not os.path.exists(mass_range_filename):
    massrangesubs = []
    for i in range(len(gal_mass_ranges)):
        gal_mass_range_code = putil.mass_physical_to_code(gal_mass_ranges[i],h=h,
                                                          e10=True)
        gal_search_query = '?mass_stars__gt='+str(gal_mass_range_code[0])+\
                           '&mass_stars__lt='+str(gal_mass_range_code[1])+\
                           '&primary_flag__gt=0'
        gal_search_results = get( snap0['subhalos']+gal_search_query )['results']
        print(str(len(gal_search_results))+' Galaxies found in mass range '+str(gal_mass_ranges[i])+'x1e10')
        print('Downloading subhalo data')
        galsubs = []
        for j in range(len(gal_search_results)):
            galsubs.append( get( gal_search_results[j]['url'], timeout=None ) )
        massrangesubs.append(putil.subhalo_list_to_recarray(galsubs))
    with open(mass_range_filename,'wb') as f:
        pickle.dump([massrangesubs,gal_mass_ranges],f)
else:
    with open(mass_range_filename,'rb') as f:
        massrangesubs,gal_mass_ranges = pickle.load(f)

In [None]:
fig = plt.figure(figsize=(15,10))
axs = fig.subplots(nrows=2,ncols=3)

colors = ['DarkGreen','Red','DarkOrange','DodgerBlue','Navy']
labels = ['MW','1-3','3-5','7-9','9-11']

for i in range(5):
    if i == 0:
        thissubs = copy.deepcopy(mwsubs)
    else:
        thissubs = massrangesubs[i-1]
    
    u_mag = thissubs['stellarphotometrics_u']
    g_mag = thissubs['stellarphotometrics_g']
    r_mag = thissubs['stellarphotometrics_r']
    spin_x = thissubs['spin_x']
    spin_y = thissubs['spin_y']
    spin_z = thissubs['spin_z']
    spin_tot = np.sqrt( spin_x**2 + spin_y**2 + spin_z**2 )
    hmr_stars = thissubs['halfmassrad_stars']
    mass_stars = thissubs['mass_stars']
    metals_halfrad = thissubs['starmetallicityhalfrad']
    sfr_halfrad = thissubs['sfrinhalfrad']
    
    axs[0,0].scatter(g_mag, g_mag-r_mag, c=colors[i], edgecolor='Black', label=labels[i])
    axs[0,1].scatter(u_mag, u_mag-r_mag, c=colors[i], edgecolor='Black')
    axs[0,2].scatter( mass_stars, sfr_halfrad, c=colors[i], edgecolor='Black')

    axs[0,0].set_xlabel(r'$G$')
    axs[0,0].set_ylabel(r'$G-R$')
    axs[0,1].set_xlabel(r'$U$')
    axs[0,1].set_ylabel(r'$U-R$')
    axs[0,2].set_xlabel('Stellar mass')
    axs[0,2].set_ylabel('SFR')
    axs[0,0].invert_xaxis()
    axs[0,1].invert_xaxis()
    
    axs[1,0].scatter( mass_stars, spin_tot, c=colors[i], edgecolor='Black')
    axs[1,1].scatter( hmr_stars, spin_tot, c=colors[i], edgecolor='Black')
    axs[1,2].scatter( mass_stars, metals_halfrad, c=colors[i], edgecolor='Black')
    axs[1,0].set_xlabel('Stellar mass')
    axs[1,0].set_ylabel('Total spin')
    axs[1,1].set_xlabel('Stellar half mass radius')
    axs[1,1].set_ylabel('Total spin')
    axs[1,2].set_xlabel('Stellar half mass radius')
    axs[1,2].set_ylabel('Metallicity')

axs[0,0].legend(fontsize=9)
fig.savefig('./fig/analog_properties.png',dpi=300)
fig.show()

# Compare snapshot vs. redshift

In [None]:
snap_nums = np.array([snap['number'] for snap in snaps])
snap_zs = np.array([snap['redshift'] for snap in snaps])

fig = plt.figure(figsize=(8,8))
ax1 = fig.add_subplot(211)
ax2 = fig.add_subplot(212)

ax1.scatter(snap_nums[20:], snap_zs[20:], facecolor='Black', edgecolor=None, 
    s=24, alpha=0.25)
ax1.set_xlabel(r'Snapshot number')
ax1.set_ylabel(r'$z$')

ax2.scatter(snap_nums[20:], np.log10(snap_zs[20:]+1), facecolor='Black', edgecolor=None, 
    s=24, alpha=0.25)
ax2.set_xlabel(r'Snapshot number')
ax2.set_ylabel(r'$\log(1+z)$')

# def power_law(x, A, n, x0, D):
#     return A*(x/x0)**n+D

# def exponential(x, A, n, x0, D):
#     return A*np.exp(n*(x-x0))+D

# popt, pcov = scipy.optimize.curve_fit(exponential, snap_nums, snap_zs, 
#     p0=[10, -1, 1, 1])

# ax1.plot(np.linspace(0,100,100), exponential(np.linspace(0,100,100), *popt),
#     c='Red', lw=2, ls='--')
# ax2.plot(np.linspace(0,100,100), 
#     np.log10(exponential(np.linspace(0,100,100), *popt)+1), c='Red', lw=2, 
#     ls='--')

fig.savefig('./fig/redshift_snapshot_relation.png',dpi=300)
fig.show()

# Download Sublink Trees

In [None]:
download_trees = False
tree_dir = data_dir+'sublink_trees/full/'
if download_trees:
    for i in range(n_mw):
        get( mwsubs[i]['trees']['sublink'], directory=tree_dir, timeout=None )
        print('Done '+str(i))

# Glob the files
sublink_files = glob.glob(data_dir+'sublink_trees/full/*.hdf5')

# Load an example tree

In [None]:
subhalo_n = sublink_files[0].split('/full/sublink_')[-1].strip('.hdf5')
f = h5py.File(sublink_files[0],'r')
for name in f:
    print(name)

In [None]:
# # first lets get the 'main branch' i.e. all the galaxies with 
# # SubhaloID_mostmassive < ID < MainLeafProgenitorID
# endpoints = np.unique(f['LastProgenitorID']) # this gets all of the branch 'end-points'
# #SubhaloID between the main progenitor and its MainLeafProgenitorID are the main branch
# mainbranchIDs = np.arange(f['SubhaloID'][0], f['MainLeafProgenitorID'][0]+1) 
# inmain = np.in1d(f['SubhaloID'], mainbranchIDs)
# sort = np.argsort(f['SnapNum'][inmain])

# #find all the galaxies that merge to the main branch (i.e. DescendantID is in the main branch IDs)
# merges_to_main = np.in1d(f['DescendantID'][~inmain], mainbranchIDs)
# merger_mass = f['SubhaloMassType'][:,4][~inmain][merges_to_main]*1e10
# subhalo_arr = np.ones_like(merger_mass)*int(subhalo_n)

# Crawl the trees looking for galaxies higher than a certain merger ratio

In [None]:
# First get the Main branch and tree endpoints

# First get the endpoints of the tree
tree_endpoints = np.unique(f['LastProgenitorID'])
# Now get the main branch of the tree. MainLeafProgenitorID defines the 
# end of the main branch, so all subhalos between z=0 and there are 
# contained on the main branch (depth first counting). Should be 100 of them
# one for each snapshot available in the simulation
tree_mbIDs = np.arange(f['SubhaloID'][0], f['MainLeafProgenitorID'][0]+1)
tree_ismb = np.isin(f['SubhaloID'], tree_mbIDs)
assert len(tree_mbIDs) == 100 # Sanity
tree_mbmass = f['Mass'][tree_ismb] # DM only
tree_mbsnap = f['SnapNum'][tree_ismb]

In [None]:
# tree_mbsnap_indx = np.argsort(tree_mbsnap)
# tree_mbsnap_sorted = tree_mbsnap[tree_mbsnap_indx]
# tree_mbsnap_sorted_indx = np.searchsorted(tree_mbsnap_sorted,tree_sbsnap)
# tree_sbsnap_mbmap = np.take(tree_mbsnap_indx, tree_mbsnap_sorted_indx, mode='clip')

In [None]:
# Now parse for branches where mass ratio w.r.t. main at that time is > 
# some threshold. Use DM mass for now, should be fine

# Sub-branch properties
tree_sbmass = f['Mass'][~tree_ismb] # DM only
tree_sbsnap = f['SnapNum'][~tree_ismb]

# Map the sub-branch elements to the corresponding main branch element according
# to snapshot number. 
# So f['Property'][~tree_ismb][i] corresponds to f['Property'][tree_ismb][tree_sbsnap_mbmap[i]]
tree_mbsnap_indx = np.argsort(tree_mbsnap)
tree_mbsnap_sorted = tree_mbsnap[tree_mbsnap_indx]
tree_mbsnap_sorted_indx = np.searchsorted(tree_mbsnap_sorted,tree_sbsnap)
tree_sbsnap_mbmap = np.take(tree_mbsnap_indx, tree_mbsnap_sorted_indx, mode='clip')

# Get the mass of the primary at the time of each snap in the rest of the tree
threshold_mratio = 0.05 # 1:20 merger ratio
tree_sbmratio = f['Mass'][~tree_ismb] / f['Mass'][tree_ismb][tree_sbsnap_mbmap]
tree_sbismajor_all = np.where(tree_sbmratio > threshold_mratio)[0]
tree_sbismajor_range = np.where(np.logical_and(tree_sbmratio > threshold_mratio,
                                         tree_sbmratio < 1))[0]
tree_sbismajor_plus = np.where(tree_sbmratio > 1)[0]

In [None]:
# Plot the mass ratios
fig = plt.figure()
ax = fig.add_subplot(111)

ax.hist(np.log10(tree_sbmratio), bins=50, histtype='step')
ax.set_yscale('log')
ax.axvspan(np.log10(threshold_mratio), 0, alpha=0.2, color='Grey')
ax.set_xlabel('Mass ratio')

fig.savefig('./fig/analog_merger_mass_ratios.png',dpi=300)
fig.show()

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111)

redshift_sbismajor = putil.snapshot_to_redshift(tree_sbsnap[tree_sbismajor_all])
ax.hist(redshift_sbismajor, bins=20, histtype='step')
# ax.set_yscale('log')
ax.set_xlabel('Redshift')

fig.savefig('./fig/analog_merger_redshifts.png',dpi=300)
fig.show()

# Make some plots

In [None]:
make_plots = True

for i in range(n_mw):
    
    subhalo_n = sublink_files[i].split('sublink_')[-1].split('.')[0]
    f = h5py.File(sublink_files[i],'r')
    
    #first lets get the 'main branch' i.e. all the galaxies with 
    # SubhaloID_mostmassive < ID < MainLeafProgenitorID
    endpoints = np.unique(f['LastProgenitorID']) # this gets all of the branch 'end-points'
    #SubhaloID between the main progenitor and its MainLeafProgenitorID are the main branch
    mainbranchIDs = np.arange(f['SubhaloID'][0], f['MainLeafProgenitorID'][0]+1) 
    inmain = np.in1d(f['SubhaloID'], mainbranchIDs)
    sort = np.argsort(f['SnapNum'][inmain])
    
    #find all the galaxies that merge to the main branch (i.e. DescendantID is in the main branch IDs)
    merges_to_main = np.in1d(f['DescendantID'][~inmain], mainbranchIDs)
    merger_mass = f['SubhaloMassType'][:,4][~inmain][merges_to_main]*1e10
    subhalo_arr = np.ones_like(merger_mass)*int(subhalo_n)
    
    if i == 0:
        merger_mass_inds = np.array([ merger_mass,subhalo_arr ]).T
    else:
        merger_mass_inds = np.concatenate( (merger_mass_inds,
                                           np.array([merger_mass,subhalo_arr]).T),
                                           axis=0)

    if make_plots:
        # Plot the mass accumulation
        fig = plt.figure()
        ax = fig.add_subplot(111)
        ax.plot(f['SnapNum'][inmain][sort],np.log10(f['Mass'][inmain][sort]*1e10))
        ax.set_ylabel(r'$\log_{10}(M_{\mathrm{tot}})\ \mathrm{[M_{\odot}]}$')
        ax.set_xlabel(r'snapshot number')
        fig.savefig('fig/subhalo_'+subhalo_n+'_mass_accumulation.png')
        plt.close(fig)

        #can plot SFR as well
        fig = plt.figure()
        ax = fig.add_subplot(111)
        ax.plot(f['SnapNum'][inmain][sort],np.log10(f['SubhaloSFR'][inmain][sort]))
        ax.set_ylabel(r'$\log_{10}(SFR)\ \mathrm{[M_{\odot}\ yr^{-1}]}$')
        ax.set_xlabel(r'snapshot number')
        fig.savefig('fig/subhalo_'+subhalo_n+'_sfr.png')
        plt.close(fig)

        # Plot the stellar mass of the merger remnants
        fig = plt.figure()
        ax = fig.add_subplot(111)
        ax.hist(np.log10(f['Mass'][~inmain][merges_to_main]*1e10), log=True)
        ax.axvline(6.5, linestyle='dashed', color='Black') #rough resolution limit
        ax.set_xlabel(r'$\log_{10}(M_{*})\ \mathrm{[M_{\odot}]}$')
        ax.set_ylabel(r'$N$')
        fig.savefig('fig/subhalo_'+subhalo_n+'_mass_stars_mergers.png')
        plt.close(fig)
    
    print('Done '+str(i)+', subhalo: '+str(subhalo_n))
    