In [None]:
# ------------------------------------------------------------------------
#
# TITLE - 2-download_primary_cutouts.ipynb
# AUTHOR - James Lane
# PROJECT - tng-dfs
#
# ------------------------------------------------------------------------
#
# Docstring:
'''Download primary cutouts for merger analogs.
'''

__author__ = "James Lane"

In [None]:
# %load ../../src/nb_modules/nb_imports.txt
### Imports

## Basic
import numpy as np
import sys, os, pdb, glob, dill as pickle, h5py

## Matplotlib
from matplotlib import pyplot as plt

## Astropy
from astropy import units as apu

## Project-specific
src_path = 'src/'
while True:
    if os.path.exists(src_path): break
    if os.path.realpath(src_path).split('/')[-1] in ['tng-dfs','/']:
            raise FileNotFoundError('Failed to find src/ directory.')
    src_path = os.path.join('..',src_path)
sys.path.insert(0,'../../src/')
from tng_dfs import cutout as pcutout
from tng_dfs import tree as ptree
from tng_dfs import util as putil

### Notebook setup

%matplotlib inline
plt.style.use('../../src/mpl/project.mplstyle') # This must be exactly here
%config InlineBackend.figure_format = 'retina'
%load_ext autoreload
%autoreload 2

### Keywords, loading, pathing

In [None]:
# %load ../../src/nb_modules/nb_setup.txt
# Keywords
cdict = putil.load_config_to_dict()
keywords = ['DATA_DIR','MW_ANALOG_DIR','RO','VO','ZO','LITTLE_H',
            'MW_MASS_RANGE']
data_dir,mw_analog_dir,ro,vo,zo,h,mw_mass_range = \
    putil.parse_config_dict(cdict,keywords)

# MW Analog 
mwsubs,mwsubs_vars = putil.prepare_mwsubs(mw_analog_dir,h=h,
    mw_mass_range=mw_mass_range,return_vars=True,force_mwsubs=False)

# # Logging setup
log_dir = './log/'
os.makedirs(log_dir,exist_ok=True)

# # Figure path
fig_dir = './fig/'
# epsen_fig_dir = '/epsen_data/scr/lane/projects/tng-dfs/figs/notebooks/sample/'
# os.makedirs(fig_dir,exist_ok=True)
# os.makedirs(epsen_fig_dir,exist_ok=True)
# show_plots = False

# # Load tree data
# with open('../parse_sublink_trees/data/tree_primaries.pkl','rb') as handle:
#     tree_primaries = pickle.load(handle)
# with open('../parse_sublink_trees/data/tree_major_mergers.pkl','rb') as handle:
#     tree_major_mergers = pickle.load(handle)
# n_mw = len(tree_primaries)

### Get some information about the simulation snapshots

In [None]:
snaps = putil.get( mwsubs_vars['sim']['snapshots'] )

### Get the sublink files

In [None]:
tree_dir = data_dir+'mw_analogs/sublink_trees/full/'
sublink_files = np.sort(glob.glob(tree_dir+'*.hdf5'))
n_mw = len(sublink_files)
# assert n_mw == len(mwsubs), 'Consistency check failed.'

### Download the cutouts for the primary subhalos

In [None]:
force_download_cutouts = False

# Prepare directory structure
for i in range(len(snaps)):
    snap_path = data_dir+'cutouts/snap_'+str(snaps[i]['number'])+'/'
    os.makedirs(snap_path,exist_ok=True)

# Open some log files
log_file = open(os.path.join(log_dir,'download_cutouts.log'),'w')
exceptions = []

# First loop over all primaries and download cutouts for each
txt = 'Downloading cutouts for all primaries...\n'+'-'*50
print(txt)
log_file.write(txt+'\n')
for i in range(n_mw):
    
    # Load the tree, get the subfind IDs and snapshot numbers
    tree = ptree.SublinkTree(sublink_files[i])
    sids = tree.get_property('SubfindID')[tree.main_branch_mask]
    snapnums = tree.get_property('SnapNum')[tree.main_branch_mask]
    assert mwsubs[i]['id'] == sids[0], 'Consistency check failed.'
    z0_sid = sids[0]
    txt = 'Getting cutouts for primary with z=0 subfind id: '+str(z0_sid)+'...'
    print(txt)
    log_file.write(txt+'\n')

    # Loop over all snapshots and download the cutouts
    n_snapnums = len(snapnums)
    primary_has_data = np.zeros(n_snapnums,dtype=bool)
    for j in range(n_snapnums):
        sn = snapnums[j]
        sid = sids[j]
        snap_path = data_dir+'mw_analogs/cutouts/snap_'+str(sn)+'/'
        snap_filename = snap_path+'cutout_'+str(sid)+'.hdf5'
        # Check if the cutout already exists
        if os.path.isfile(snap_filename):
            primary_has_data[j] = True
            txt = 'Already have cutout for subhalo '+str(sid)+' of snapshot '+\
                str(sn)
            log_file.write(txt+'\n')
            continue
        # Fetch the subhalo
        try:
            subhalo = putil.get(snaps[sn]['url']+'subhalos/'+str(sid),
                timeout=None)
        except Exception as e:
            exceptions.append((e,sn,sid))
            txt = 'Exception raised for subhalo '+str(sid)+' of snapshot '+\
                str(sn)+' while fetching subhalo'
            print(txt)
            log_file.write(txt+'\n')
            continue
        # Some consistency checks
        assert sn == subhalo['snap']
        assert sid == subhalo['id']
        # Download the cutout
        txt = 'Downloading subhalo '+str(sid)+' of snapshot '+str(sn)
        print(txt)
        log_file.write(txt+'\n')
        try:
            _=putil.get(subhalo['cutouts']['subhalo'],directory=snap_path,
                timeout=None)
        except Exception as e:
            exceptions.append((e,sn,sid))
            txt = 'Exception raised for subhalo '+str(sid)+' of snapshot '+\
                str(sn)+' while downloading cutout'
            print(txt)
            log_file.write(txt+'\n')
    # Communicate if all cutouts already exist
    if np.all(primary_has_data):
        print('Already have all cutouts for primary')

### Create edge-on and face-on views

In [None]:
for i in range(n_mw):
    print('Working on galaxy '+str(i+1)+' of '+str(n_mw))
    # if i > -1: continue

    fig = plt.figure(figsize=(10, 5))
    axs = fig.subplots(nrows=1, ncols=2)

    # Get the primary
    snapnum = mwsubs[i]['snap']
    z0_sid = mwsubs[i]['id']
    primary_filename = putil.get_cutout_filename(mw_analog_dir,
        snapnum, z0_sid)
    co = pcutout.TNGCutout(primary_filename)
    co.center_and_rectify()

    orbs = co.get_orbs(ptype='stars')
    x, y, z = orbs.x().value, orbs.y().value, orbs.z().value

    # Plot face 2d hist and then edge on 2d hist
    N, xedges, yedges = np.histogram2d(x, y, bins=100, range=[[-50,50],[-50,50]])
    axs[0].imshow(np.log10(N.T), origin='lower', 
        cmap=plt.cm.gray_r, vmin=1., vmax=5., 
        extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]])
    
    N, xedges, yedges = np.histogram2d(x, z, bins=100, range=[[-50,50],[-50,50]])
    axs[1].imshow(np.log10(N.T), origin='lower', 
        cmap=plt.cm.gray_r, vmin=1, vmax=5., 
        extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]])

    axs[0].set_xlabel('X [kpc]')
    axs[0].set_ylabel('Y [kpc]')
    axs[1].set_xlabel('X [kpc]')
    axs[1].set_ylabel('Z [kpc]')
    axs[0].annotate(str(z0_sid), xy=(0.05, 0.95), 
        xycoords='axes fraction', fontsize=15)

    fig.tight_layout()
    figname = os.path.join(fig_dir,'analog_sample/',str(z0_sid)+'_stellar_density.png')
    fig.savefig(figname,bbox_inches='tight',dpi=300)
    plt.close(fig)
    # fig.show()