In [None]:
# ------------------------------------------------------------------------
#
# TITLE - 2-download_secondary_cutouts.ipynb
# AUTHOR - James Lane
# PROJECT - tng-dfs
#
# ------------------------------------------------------------------------
#
# Docstring:
'''Download primary cutouts for merger analogs.
'''

__author__ = "James Lane"

In [None]:
# %load ../../src/nb_modules/nb_imports.txt
### Imports

## Basic
import numpy as np
import sys, os, pdb, glob, dill as pickle, h5py

## Matplotlib
from matplotlib import pyplot as plt

## Astropy
from astropy import units as apu

## Project-specific
src_path = 'src/'
while True:
    if os.path.exists(src_path): break
    if os.path.realpath(src_path).split('/')[-1] in ['tng-dfs','/']:
            raise FileNotFoundError('Failed to find src/ directory.')
    src_path = os.path.join('..',src_path)
sys.path.insert(0,'../../src/')
from tng_dfs import cutout as pcutout
from tng_dfs import tree as ptree
from tng_dfs import util as putil

### Notebook setup

%matplotlib inline
plt.style.use('../../src/mpl/project.mplstyle') # This must be exactly here
%config InlineBackend.figure_format = 'retina'
%load_ext autoreload
%autoreload 2

### Keywords, loading, pathing

In [None]:
# %load ../../src/nb_modules/nb_setup.txt
# Keywords
cdict = putil.load_config_to_dict()
keywords = ['DATA_DIR','MW_ANALOG_DIR','RO','VO','ZO','LITTLE_H',
            'MW_MASS_RANGE']
data_dir,mw_analog_dir,ro,vo,zo,h,mw_mass_range = \
    putil.parse_config_dict(cdict,keywords)

# MW Analog 
mwsubs,mwsubs_vars = putil.prepare_mwsubs(mw_analog_dir,h=h,
    mw_mass_range=mw_mass_range,return_vars=True,force_mwsubs=False,
    bulge_disk_fraction_cuts=True)

# Figure path
fig_dir = './fig/sample/'
log_dir = './log/'
epsen_fig_dir = '/epsen_data/scr/lane/projects/tng-dfs/figs/notebooks/sample/'
os.makedirs(fig_dir,exist_ok=True)
os.makedirs(epsen_fig_dir,exist_ok=True)
show_plots = False

# Load tree data
tree_primary_filename = os.path.join(mw_analog_dir,
    'major_mergers/tree_primaries.pkl')
with open(tree_primary_filename,'rb') as handle: 
    tree_primaries = pickle.load(handle)
tree_major_mergers_filename = os.path.join(mw_analog_dir,
    'major_mergers/tree_major_mergers.pkl')
with open(tree_major_mergers_filename,'rb') as handle:
    tree_major_mergers = pickle.load(handle)
n_mw = len(tree_primaries)

### Get some information about the simulation snapshots

In [None]:
snaps = putil.get( mwsubs_vars['sim']['snapshots'] )

### Download the cutouts for the main branch of each secondary

In [None]:
force_download_cutouts = False

# Prepare directory structure
for i in range(len(snaps)):
    snap_path = data_dir+'cutouts/snap_'+str(snaps[i]['number'])+'/'
    os.makedirs(snap_path,exist_ok=True)

# Open some log files
log_file = open(os.path.join(log_dir,
    'download_cutouts_secondary_main_branches.log'),'w')
exceptions = []

# First loop over all primaries and download cutouts for each
txt = 'Downloading cutouts for all secondary main branches...\n'+'-'*50
print(txt)
log_file.write(txt+'\n')
for i in range(n_mw):
    
    # if i > 2: continue

    # Load the tree, get the subfind IDs and snapshot numbers
    primary = tree_primaries[i]
    tree = ptree.SublinkTree(primary.tree_filename)
    sfids = tree.get_property('SubfindID')
    snapnums = tree.get_property('SnapNum')
    mlpids = tree.get_property('MainLeafProgenitorID')
    assert mwsubs[i]['id'] == sfids[0], 'Consistency check failed.'
    z0_sid = sfids[0]

    txt = 'Getting cutouts for primary with z=0 subfind id: '+str(z0_sid)+'...'
    print(txt)
    log_file.write(txt+'\n')

    # Loop over all secondaries and download their main branch cutouts
    n_major_mergers = primary.n_major_mergers
    for j in range(n_major_mergers):
        secondary = primary.tree_major_mergers[j]
        secondary_mask = (mlpids == secondary.secondary_mlpid)
        secondary_snapnums = snapnums[secondary_mask]
        secondary_sfids = sfids[secondary_mask]
        secondary_has_data = np.zeros(len(secondary_snapnums),dtype=bool)

        # Loop over all secondary snapshots
        for k in range(len(secondary_snapnums)):
            ssn = secondary_snapnums[k]
            sfid = secondary_sfids[k]
            snap_path = data_dir+'mw_analogs/cutouts/snap_'+str(ssn)+'/'
            snap_filename = snap_path+'cutout_'+str(sfid)+'.hdf5'

            # Check if the cutout already exists
            if os.path.isfile(snap_filename):
                secondary_has_data[k] = True
                txt = 'Already have cutout for subhalo '+str(sfid)+' of '+\
                    'snapshot '+str(ssn)
                log_file.write(txt+'\n')
                continue

            # Fetch the subhalo
            try:
                subhalo = putil.get(snaps[ssn]['url']+'subhalos/'+str(sfid),
                    timeout=None)
            except Exception as e:
                exceptions.append((e,ssn,sfid))
                txt = 'Exception raised for subhalo '+str(sfid)+' of '+\
                    'snapshot '+str(ssn)+' while fetching subhalo information'
                print(txt)
                log_file.write(txt+'\n')
                continue

            # Some consistency checks
            assert ssn == subhalo['snap']
            assert sfid == subhalo['id']
            # Download the cutout
            txt = 'Downloading subhalo '+str(sfid)+' of snapshot '+str(ssn)
            print(txt)
            log_file.write(txt+'\n')
            try:
                _=putil.get(subhalo['cutouts']['subhalo'],directory=snap_path,
                    timeout=None)
            except Exception as e:
                exceptions.append((e,ssn,sfid))
                txt = 'Exception raised for subhalo '+str(sfid)+' of '+\
                    'snapshot '+str(ssn)+' while downloading cutout'
                print(txt)
                log_file.write(txt+'\n')
            
        # Communicate if all cutouts already exist
        if np.all(secondary_has_data):
            print('Already have all cutouts for secondary')

# Close the log file
log_file.close()

### Download the cutouts for the entire secondary branch rooted at the point where the secondary merges with the primary

In [None]:
force_download_cutouts = False

# Prepare directory structure
for i in range(len(snaps)):
    snap_path = data_dir+'cutouts/snap_'+str(snaps[i]['number'])+'/'
    os.makedirs(snap_path,exist_ok=True)

# Open some log files
log_file = open(os.path.join(log_dir,
    'download_cutouts_secondary_all.log'),'w')
exceptions = []

# First loop over all primaries and download cutouts for each
txt = 'Downloading cutouts for all secondary branches...\n'+'-'*50
print(txt)
log_file.write(txt+'\n')
for i in range(n_mw):
    
    if i > 10: continue

    # Load the tree, get the subfind IDs and snapshot numbers
    primary = tree_primaries[i]
    tree = ptree.SublinkTree(primary.tree_filename)
    sids = tree.get_property('SubhaloID')
    sfids = tree.get_property('SubfindID')
    snapnums = tree.get_property('SnapNum')
    mlpids = tree.get_property('MainLeafProgenitorID')
    lpids = tree.get_property('LastProgenitorID')
    assert mwsubs[i]['id'] == sfids[0], 'Consistency check failed.'
    z0_sid = sfids[0]

    txt = 'Getting cutouts for primary with z=0 subfind id: '+str(z0_sid)+'...'
    print(txt)
    log_file.write(txt+'\n')

    # Loop over all secondaries and download their main branch cutouts
    n_major_mergers = primary.n_major_mergers
    for j in range(n_major_mergers):
        secondary = primary.tree_major_mergers[j]
        secondary_main_mask = (mlpids == secondary.secondary_mlpid)
        secondary_sid = sids[secondary_main_mask][0]
        secondary_lpid = lpids[secondary_main_mask][0]
        secondary_branch_sids = np.arange(0,secondary_lpid-secondary_sid+1) + \
            secondary_sid
        secondary_branch_mask = np.isin(sids,secondary_branch_sids)
        secondary_branch_snapnums = snapnums[secondary_branch_mask]
        secondary_branch_sfids = sfids[secondary_branch_mask]
        secondary_has_data = np.zeros(len(secondary_branch_sids),dtype=bool)

        txt = 'Getting '+str(len(secondary_branch_sids))+' cutouts for '+\
            'secondary with MLPID '+str(secondary.secondary_mlpid)+'...'
        print(txt)
        log_file.write(txt+'\n')

        # Loop over all secondary snapshots
        for k in range(len(secondary_branch_sids)):
            ssn = secondary_branch_snapnums[k]
            sfid = secondary_branch_sfids[k]
            snap_path = data_dir+'mw_analogs/cutouts/snap_'+str(ssn)+'/'
            snap_filename = snap_path+'cutout_'+str(sfid)+'.hdf5'

            # Check if the cutout already exists
            if os.path.isfile(snap_filename):
                secondary_has_data[k] = True
                txt = 'Already have cutout for subhalo '+str(sfid)+' of '+\
                    'snapshot '+str(ssn)
                log_file.write(txt+'\n')
                continue

            # Fetch the subhalo
            try:
                subhalo = putil.get(snaps[ssn]['url']+'subhalos/'+str(sfid),
                    timeout=None)
            except Exception as e:
                exceptions.append((e,ssn,sfid))
                txt = 'Exception raised for subhalo '+str(sfid)+' of '+\
                    'snapshot '+str(ssn)+' while fetching subhalo information'
                print(txt)
                log_file.write(txt+'\n')
                continue

            # Some consistency checks
            assert ssn == subhalo['snap']
            assert sfid == subhalo['id']
            # Download the cutout
            txt = 'Downloading subhalo '+str(sfid)+' of snapshot '+str(ssn)
            print(txt)
            log_file.write(txt+'\n')
            try:
                _=putil.get(subhalo['cutouts']['subhalo'],directory=snap_path,
                    timeout=None)
            except Exception as e:
                exceptions.append((e,ssn,sfid))
                txt = 'Exception raised for subhalo '+str(sfid)+' of '+\
                    'snapshot '+str(ssn)+' while downloading cutout'
                print(txt)
                log_file.write(txt+'\n')
            
        # Communicate if all cutouts already exist
        if np.all(secondary_has_data):
            print('Already have all cutouts for secondary')

# Close the log file
log_file.close()