In [1]:
#imports 
import astropy.units as u
import numpy as np
import popsims #custom libray for plotting aesthetics
import matplotlib.pyplot as plt
#%matplotlib notebook
from tqdm import tqdm

#import HSS
import seaborn as sns
import matplotlib as mpl

import astropy.coordinates as astro_coord
import glob
import pandas as pd
sns.set_style("dark")
mpl.rc('xtick', labelsize=16) 
mpl.rc('ytick', labelsize=16) 
font = {'axes.titlesize'      : 'large',   # fontsize of the axes title
        'axes.labelsize'      : 'large', # fontsize of the x any y labels
        'size'   : 20}


import itertools
from scipy import stats
from shapey import Box
from scipy import stats
from findthegap.gapper import Gapper
import torch
#paths
path_plot = '/users/caganze/research/stellarstreams/figures/paper/'
#path_data = '/users/caganze/research/stellarstreams/data/rotating/'
path_data = '/users/caganze/research/stellarstreams/data/stream/'
isochrone_path='/users/caganze/research/stellarstreams/data/isochrones/'


In [None]:
def get_cutout_m31(rgc, mhalo, mag_limit):
    filename=path_data+'/gaps_at_M31{} mlimit {}Mhalo={:.2e}_cutout'.format(rgc, mag_limit, mhalo) 
    return pd.read_csv(filename).values
                       
    
def get_cutout_distance(mhalo, mag_limit, dmod):
    filename=path_data+'/gaps_at_OTHER{}Mhalo={:.2e}_maglimit{}_cutout.txt'.format(dmod, mhalo, mag_limit)
    return pd.read_csv(filename).values

def detect_gap_by_boostrap(bws, data, xlims, ylims, rescale=False):
    
    """
    Purpose detect gap by bootstrapping and using an median over several bandwidths
    
    """
    min_bw_x= 0.05
    min_bw_y= 0.01
    
    #Boundaries for the Gapper (if none are provided, this is the default mode)
    bounds = np.array([[np.min(data[:,d]),np.max(data[:,d])] for d in range(data.shape[1])])

    gridding_size = [ int((xlims[1]-xlims[0])/min_bw_x), int((ylims[1]-ylims[0])/min_bw_y)]

    grid_linspace = [ np.linspace(bounds[d][0], bounds[d][1], gridding_size[d]) for d in range(2) ]
    #could use a rectangular grid instead


    meshgrid = np.meshgrid(*grid_linspace, indexing='ij')

    meshgrid_ravel = [ xi.ravel().reshape(-1,1) for xi in meshgrid]
    grid_data = np.hstack(meshgrid_ravel)
    
    #rescale grid data using sklearn's minmax scaler 
    
    if rescale:
        minmaxscal = MinMaxScaler().fit(data)
        #data_resc = minmaxscal.transform(data)
        grid_data_resc = minmaxscal.transform(grid_data)
    

    res=dict(zip(bws, [None for x in bws]))

    for bw in bws:

        #run for multiple bandwidths
        gapper_base = Gapper(data, bw, bounds)

        #compute density along the grid 
        grid_density = gapper_base.kde.score_samples(torch.tensor(grid_data))

        #density matrix 
        density_matr = grid_density.reshape((gridding_size[0], gridding_size[1]))

        #compute piHpi matrix by bootstraping
        maxeigval_PiHPi_boots = []
        
        #median PiHPi
        PiHPi_boots=[]
        for i in range(5):

            boot_indx = np.random.choice(np.arange(data.shape[0]), data.shape[0], 
                                         replace=True) ## Sample with replacement:bootstrap

            gapper_ = Gapper(data[boot_indx], bw, bounds)
            PiHPis_grid = []
            eigval_PiHPi = [] 

            for pt in grid_data:
                _pihpi = gapper_.get_PiHPi(pt) 
                _pihpi_eigval, _pihpi_eigvec = np.linalg.eigh(_pihpi)

                PiHPis_grid.append(_pihpi)
                eigval_PiHPi.append(_pihpi_eigval)

            PiHPis_grid, eigval_PiHPi = np.array(PiHPis_grid), np.array(eigval_PiHPi)
            max_eigval_PiHPi_k = np.max(eigval_PiHPi, axis=1)
            maxeigval_PiHPi_boots.append(max_eigval_PiHPi_k)
            PiHPi_boots.append(PiHPis_grid)
            print(f'Run {i} finished')

        #visualize and take the median
        maxeigval_PiHPi_boots = np.array(maxeigval_PiHPi_boots)
        print(maxeigval_PiHPi_boots.shape)

        #median
        med_maxeigval_pihpi = np.median(maxeigval_PiHPi_boots, axis=0)


        med_maxeigval_pihpi_resh = med_maxeigval_pihpi.reshape((gridding_size[0], gridding_size[1]))


        res[bw]= {'density':density_matr, \
                  'max_eigen':med_maxeigval_pihpi_resh,
                  'meshgrid':meshgrid, 
                  'PiHPi': PiHPi_boots}

    return res
    

In [None]:

def run_gap_diagnostics(gap_bw, stream_bw, gap_threshold=99, stream_threshold=1, galaxy='M31', mag_limit=27.15, mhalo=5e6, rgc='50_60', ):
    
    #first read in the data 
    data=get_cutout_m31(rgc, mhalo, mag_limit)
    
    #detect the gap and plot
    res= detect_gap_by_boostrap([gap_bw, stream_bw], data, [data[:,0].min(), \
                                                                data[:,0].max()],\
                                [data[:,1].min(), data[:,1].max()])
    
    #select threshold above 
    gap_mask = res[size_gap]['max_eigen'] > np.percentile(res[size_gap]['max_eigen'], gap_threshold) 

    meshgrid_ravel = [ xi.ravel().reshape(-1,1) for xi in  res[size_gap]['meshgrid']]
    grid_data = np.hstack(meshgrid_ravel)
    
    #compute gap location
    gap_loc=(np.nanmean(grid_data[:,0][gap_mask.flatten().astype(bool)]), \
            np.nanmean(grid_data[:,1][gap_mask.flatten().astype(bool)]))
    gap_size=np.nanstd(data[:,0][gap_mask.flatten().astype(int)])
    
    #select stars in the stream
    stream_mask= (res[size_stream]['max_eigen'] < np.percentile(res[size_stream]['max_eigen'], stream_threshold) ).flatten()
    
    return 