In [1]:
from sklearn.neighbors import NearestNeighbors
import pandas as pd
import numpy as np
from scipy import spatial
from astropy.table import Table
from astropy.cosmology import LambdaCDM as Cos
from astropy.io import fits
from astropy.coordinates import SkyCoord
from astropy import units as u
import matplotlib.pyplot as plt
from IPython import display 
from scipy import stats
from scipy.interpolate import interp1d
from scipy.stats import norm
import fitsio

import pickle
import dask

In [2]:
from dask.distributed import Client, LocalCluster
cluster = LocalCluster(n_workers = 9)
client = Client(cluster)
client

0,1
Client  Scheduler: tcp://127.0.0.1:42767  Dashboard: http://127.0.0.1:8787/status,Cluster  Workers: 9  Cores: 18  Memory: 67.31 GB


In [3]:
# load table of centers
table_of_centers = pd.read_csv("local_table_of_centers.csv")
table_of_centers_south = table_of_centers[0:437]
table_of_centers_north = table_of_centers[437:]
clusters = []
testing_centers = pd.read_csv("testing_sweeps2.csv")
#testing_centers = testing_centers[0:1]

In [None]:
#np.seterr(divide = 'ignore', invalid = "ignore") 
#np.seterr(divide = 'warn', invalid = 'warn')

In [4]:
#Finding neighboring sweeps
nbrs = NearestNeighbors(n_neighbors=9, algorithm='ball_tree').fit(table_of_centers_south[["mean_RA", "mean_DEC"]])

#Mass fitting parameters and equations
a = 1.3620186928378857  
b = 9.968545069745126
j= 1.04935943 
k = 0.39573094 
l = 0.28347756
def mass_limit(z):
    return np.minimum((a*z + b), 11.2)

def mass_coefficient(z):
    return np.exp(j*z**2 + k*z + l)

#Radii
radius = 1
small_radius = 0.5
mini_radius = 0.1

#Buffer (in degrees, from interpolating maximum cluster radius at z = 0.05)
buffer = 0.285

In [5]:
@dask.delayed
def data_import(maxx, maxy, minx, miny, row2):
    fits_data = fitsio.FITS(row2.patch)
    sweep = fits_data[1].read(columns=['RELEASE','BRICKID','BRICKNAME', 'OBJID', 'TYPE', 'RA', 'DEC', 'FLUX_G', 'FLUX_R', 'FLUX_Z', 'FLUX_W1', 'MASKBITS', 'GAIA_PHOT_G_MEAN_MAG', 'GAIA_ASTROMETRIC_EXCESS_NOISE'])
    with fits.open(row2.photoz) as data:
        pz = pd.DataFrame(data[1].data)
    mass = np.load(row2.masses)
    
    pz['mass'] = mass
    pz['RELEASE']=sweep['RELEASE']
    pz['BRICKID']=sweep['BRICKID']
    pz['BRICKNAME']=sweep['BRICKNAME']
    pz['OBJID']=sweep['OBJID']
    pz['TYPE']=sweep['TYPE']
    pz['RA']=sweep['RA']
    pz['DEC']=sweep['DEC']
    pz['FLUX_G']=sweep['FLUX_G']
    pz['FLUX_R']=sweep['FLUX_R']
    pz['FLUX_Z']=sweep['FLUX_Z']
    pz['FLUX_W1']=sweep['FLUX_W1']
    pz['MASKBITS']=sweep['MASKBITS']
    pz['gaia_phot_g_mean_mag']=sweep['GAIA_PHOT_G_MEAN_MAG']
    pz['gaia_astrometric_excess_noise']=sweep['GAIA_ASTROMETRIC_EXCESS_NOISE']
    pz = pz[np.logical_and.reduce((pz.RA < maxx + buffer/np.cos(pz.DEC*(np.pi/180)), pz.DEC < maxy + buffer, pz.RA > minx - buffer/np.cos(pz.DEC*(np.pi/180)), pz.DEC > miny - buffer))]
    return pz

In [6]:
#@dask.delayed
def cluster_finder(massive_sample, indexable, maxRA, maxDEC, minRA, minDEC):
    #Cosmological Parameters: radius and cylinder length
    cos = Cos(H0 = 70, Om0 = .286, Ode0 = .714)
    z_array = np.linspace(1e-2, indexable[:, 0].max(), 500)
    sparse_radius = (1+z_array)/(cos.comoving_distance(z_array))
    radius_threshold = interp1d(z_array, sparse_radius, kind = "linear", fill_value = "extrapolate")
    median = stats.binned_statistic(indexable[:, 0].astype(float), indexable[:, -1].astype(float), "median", bins = np.linspace(0.05, indexable[:, 0].max(), 100))
    bins = np.linspace(0.05, indexable[:, 0].max(), 99)
    z_threshold = interp1d(bins, np.minimum(median[0], np.ones(len(median[0]))*0.1), kind = "linear", fill_value = "extrapolate")
    
    #Tree Algorithm
    iterrator = massive_sample.copy()
    tree = spatial.cKDTree(indexable[:, 1:3].astype(float), copy_data = True)
    for i, row in iterrator.iterrows():
        neighbors = tree.query_ball_point([row.x, row.y], radius_threshold(row.z_phot_median))
        if len(neighbors) > 0:
            local_data = indexable[neighbors]
            
            z_c = z_threshold(row.z_phot_median)
            cylinder = np.abs(np.vstack(local_data[:, 4]) - row.z_phot_median)
            weight_array = cylinder < 2*z_c
            weights = weight_array.sum(axis = 1)/oversample
            
            approx_cluster = np.append(local_data, np.reshape(weights, newshape = (len(weights), 1)), axis = 1)
            cluster = approx_cluster[approx_cluster[:, -1] > 0]
            
            r_smaller = radius_threshold(row.z_phot_median)
            small_cluster = cluster[np.sqrt(np.array((cluster[:, 1] - row.x)**2 + (cluster[:, 2] - row.y)**2).astype(float)) < 0.5*r_smaller]
            mini_cluster = cluster[np.sqrt(np.array((cluster[:, 1] - row.x)**2 + (cluster[:, 2] - row.y)**2).astype(float)) < 0.1*r_smaller]
            
            massive_sample.at[i, "z_average_no_wt"] = np.mean(cluster[:, 0])
            massive_sample.at[i, "z_average_prob"] = np.average(cluster[:, 0], weights = cluster[:, -1])
            massive_sample.at[i, "z_average_mass_prob"] = np.average(cluster[:, 0], weights = cluster[:, -1]*cluster[:, 3])
            
            massive_sample.at[i, "z_std_no_wt"] = np.std(cluster[:, 0])
            massive_sample.at[i, "z_std_prob"] = np.sqrt(np.cov(cluster[:, 0].astype(float), aweights = cluster[:, -1].astype(float)))
            massive_sample.at[i, "z_std_mass_prob"] = np.sqrt(np.cov(cluster[:, 0].astype(float), aweights = cluster[:, -1]*cluster[:, 3].astype(float)))
            
            massive_sample.at[i, "neighbors"] = np.sum(cluster[:, -1])
            massive_sample.at[i, "local_neighbors"] = np.sum(small_cluster[:, -1])
            massive_sample.at[i, "ultra_local_neighbors"] = np.sum(mini_cluster[:, -1])
            
            mass_co = mass_coefficient(row.z_phot_median)
            massive_sample.at[i, "correction_factor"] = mass_co
            c_mask = cluster[:, 3]>mass_limit(row.z_phot_median)
            cluster_limited = cluster[c_mask.astype("bool"), :]
            massive_sample.at[i, "neighbor_mass"] = np.log10(np.sum(np.append(((10**cluster_limited[:, 3]))*cluster_limited[:, -1], [10**row.mass]))*mass_co)
            massive_sample.at[i, "local_neighbor_mass"] = np.log10(np.sum(np.append((10**small_cluster[:, 3])*small_cluster[:, -1], [10**row.mass])))
            massive_sample.at[i, "ultra_local_neighbor_mass"] = np.log10(np.sum(np.append((10**mini_cluster[:, 3])*mini_cluster[:, -1], [10**row.mass])))
            massive_sample.at[i, "corr_local_neighbor_mass"] = np.log10(np.sum(np.append((10**small_cluster[:, 3])*small_cluster[:, -1], [10**row.mass]))*mass_co)
            massive_sample.at[i, "corr_ultra_local_neighbor_mass"] = np.log10(np.sum(np.append((10**mini_cluster[:, 3])*mini_cluster[:, -1], [10**row.mass]))*mass_co)
            
            clusterid = np.ones((1, len(local_data)))*row.gid
            clusterz = np.ones((1, len(local_data)))*row.z_phot_median
            membership = np.concatenate((local_data[:, -2].reshape((1, len(local_data))), clusterid, local_data[:, 0].reshape((1, len(local_data))), clusterz, local_data[:, -1].reshape((1, len(local_data)))), axis = 0).T
            massive_sample.at[i, "neighbor_gids"] = membership
            
            
    
    #Thresholding
    bins = np.arange(0.05, massive_sample.z_phot_median.max(), 0.01)
    binned = [massive_sample[np.logical_and(massive_sample.z_phot_median>=i-.025, massive_sample.z_phot_median<=i+.025)].copy() for i in bins]
    clusters = pd.DataFrame()
    threshold1 = np.empty(len(binned))
    threshold2 = np.empty(len(binned))
    for i in range(len(binned)):
        threshold1[i] = np.mean(binned[i].neighbors) + 1.8*np.sqrt(np.mean(binned[i].neighbors))
        threshold2[i] = np.mean(binned[i].local_neighbors) + 1.2*np.sqrt(np.mean(binned[i].local_neighbors))
    thresh1 = interp1d(bins, threshold1, kind = "linear", fill_value = "extrapolate")
    thresh2 = interp1d(bins, threshold2, kind = "linear", fill_value = "extrapolate")
    clusters = massive_sample[np.logical_and(massive_sample.neighbors >= thresh1(massive_sample.z_phot_median), massive_sample.local_neighbors >= thresh2(massive_sample.z_phot_median))].copy()
    clusters.sort_values("local_neighbor_mass", inplace = True, ascending = False)
    clusters.reset_index(inplace= True, drop = True)
    
    #Aggregation
    tree = spatial.cKDTree(clusters[["x", "y"]], copy_data = True)
    clusters["ncluster"] = np.zeros(len(clusters))
    clusternum = 1
    iterrator2 = clusters.copy()
    for i, row in iterrator2.iterrows():
        if clusters.iloc[i].ncluster == 0:
            clusters.at[i, "ncluster"] = clusternum
            neighbors = tree.query_ball_point([row.x, row.y], 1.5*radius_threshold(row.z_phot_median))
            for index in neighbors:
                if np.logical_and(clusters.at[index, "ncluster"] == 0, np.abs(clusters.at[index, "z_phot_median"] - row.z_phot_median) < 2*z_threshold(row.z_phot_median)):
                    clusters.at[index, "ncluster"] = clusternum
                    clusters.at[i, "neighbor_gids"] = np.concatenate((clusters.at[i, "neighbor_gids"], clusters.at[index, "neighbor_gids"]), axis = 0) 
            clusternum += 1
    
    #Results
    cluster_center = clusters.sort_values(by = ['ncluster','ultra_local_neighbor_mass'], ascending = [True, False]).groupby('ncluster').head(1).copy()
    cluster_center_selected = cluster_center[np.logical_and.reduce((cluster_center.RA < maxRA, cluster_center.RA > minRA, cluster_center.DEC < maxDEC, cluster_center.DEC > minDEC))].copy()
    
    #Membership
    membership = pd.DataFrame(cluster_center_selected.neighbor_gids.values)
    membership_data = np.zeros((1, 5))
    for i in range(0, len(membership)):
        temp = np.stack(membership.values[i])[0]
        membership_data = np.concatenate([membership_data, temp], axis = 0)
    membershippd = pd.DataFrame(membership_data[1:], columns = ["galaxy", "cluster", "galaxy_z", "cluster_z", "galaxy_z_std"], dtype = float)
    membershippd["z_dist"] = np.abs(membershippd.galaxy_z - membershippd.cluster_z)
    membershippd.sort_values("z_dist", ascending = True, inplace = True)
    membershippd.drop_duplicates(subset = "galaxy", inplace = True)
    membershippd.reset_index(inplace = True, drop = True)
    
    iterrator3 = membershippd.copy()
    membershippd["prob"] = np.zeros(len(membershippd))
    for i, row in iterrator3.iterrows():
        x = row.galaxy_z
        mu = row.cluster_z
        std = row.galaxy_z_std
        membershippd.at[i, "prob"] = 2*(1-norm.cdf(x = np.abs(x-mu), loc = 0, scale = std))
    memberspd = membershippd[membershippd.prob > 0.0027].astype({"galaxy": "int64", "cluster": "int64"}).drop(columns = {"z_dist"})
    
    #Cleaning things up
    cluster_center_selected["BRICKNAME"] = cluster_center_selected["BRICKNAME"].astype('|S80')
    clusters_final = cluster_center_selected[["RA", "DEC", "z_phot_median", "z_average_no_wt", "z_average_prob", "z_average_mass_prob", "z_phot_std", "z_std_no_wt", "z_std_prob", "z_std_mass_prob", "RELEASE", "BRICKID", "OBJID", "MASKBITS", "gid", "mass", "neighbor_mass", "corr_local_neighbor_mass", "corr_ultra_local_neighbor_mass", "correction_factor", "neighbors", "local_neighbors", "ultra_local_neighbors"]].copy()
    clusters_final.columns = ["RA_central", "DEC_central", "z_median_central", "z_average_no_wt", "z_average_prob", "z_average_mass_prob", "z_std_central", "z_std_no_wt", "z_std_prob", "z_std_mass_prob", "RELEASE", "BRICKID", "OBJID", "MASKBITS", "gid", "mass_central", "neighbor_mass", "local_neighbor_mass", "ultra_local_neighbor_mass", "correction_factor", "neighbors", "local_neighbors", "ultra_local_neighbors"]
    
    return clusters_final, memberspd

In [15]:
testing_centers

Unnamed: 0.2,Unnamed: 0,Unnamed: 0.1,Unnamed: 0.1.1,patch,mean_RA,mean_DEC,photoz,masses
0,0,42.0,42.0,/data/mjb299/sweep/sweep-190p030-200p035.fits,195.206692,31.979525,/data/mjb299/pz/sweep-190p030-200p035-pz.fits,/data/mjb299/pz/sweep-190p030-200p035_stellar_...
1,1,92.0,92.0,/data/mjb299/sweep/sweep-170p025-180p030.fits,175.036076,27.486449,/data/mjb299/pz/sweep-170p025-180p030-pz.fits,/data/mjb299/pz/sweep-170p025-180p030_stellar_...
2,2,216.0,216.0,/data/mjb299/sweep/sweep-180p030-190p035.fits,184.901037,32.099699,/data/mjb299/pz/sweep-180p030-190p035-pz.fits,/data/mjb299/pz/sweep-180p030-190p035_stellar_...
3,3,300.0,300.0,/data/mjb299/sweep/sweep-180p025-190p030.fits,184.882132,27.488065,/data/mjb299/pz/sweep-180p025-190p030-pz.fits,/data/mjb299/pz/sweep-180p025-190p030_stellar_...
4,4,316.0,316.0,/data/mjb299/sweep/sweep-170p030-180p035.fits,175.040271,32.140876,/data/mjb299/pz/sweep-170p030-180p035-pz.fits,/data/mjb299/pz/sweep-170p030-180p035_stellar_...
5,5,394.0,394.0,/data/mjb299/sweep/sweep-190p025-200p030.fits,194.986337,27.503652,/data/mjb299/pz/sweep-190p025-200p030-pz.fits,/data/mjb299/pz/sweep-190p025-200p030_stellar_...
6,6,467.0,467.0,/data/mjb299/sweep-north/sweep-190p030-200p035...,194.931985,33.023176,/data/mjb299/pz-north/sweep-190p030-200p035-pz...,/data/mjb299/pz-north/sweep-190p030-200p035_st...
7,7,496.0,496.0,/data/mjb299/sweep-north/sweep-170p025-180p030...,174.629494,29.651472,/data/mjb299/pz-north/sweep-170p025-180p030-pz...,/data/mjb299/pz-north/sweep-170p025-180p030_st...
8,8,573.0,573.0,/data/mjb299/sweep-north/sweep-180p030-190p035...,184.932171,32.894182,/data/mjb299/pz-north/sweep-180p030-190p035-pz...,/data/mjb299/pz-north/sweep-180p030-190p035_st...
9,9,631.0,631.0,/data/mjb299/sweep-north/sweep-180p025-190p030...,184.702577,29.64062,/data/mjb299/pz-north/sweep-180p025-190p030-pz...,/data/mjb299/pz-north/sweep-180p025-190p030_st...


In [7]:
pbar = display.ProgressBar(len(testing_centers))
pbar.display()

delayed_results = []

#Remember to change testing_centers to table_of_centers
for index, row in testing_centers[0:1].iterrows():
    fits_data = fitsio.FITS(row.patch)
    sweep = fits_data[1].read(columns=['RA', 'DEC'])
    print(len(sweep))
    maxx = max(sweep['RA'])
    maxy = max(sweep['DEC'])
    minx = min(sweep['RA'])
    miny = min(sweep['DEC'])
    
    maxRA = max(sweep['RA'])
    maxDEC = max(sweep['DEC'])
    minRA = min(sweep['RA'])
    minDEC = min(sweep['DEC'])
    
    list_of_imports = []
    #Neighbors:
    distances, indices = nbrs.kneighbors([row[["mean_RA", "mean_DEC"]]])
    patches = table_of_centers_south.iloc[indices[0]]
    print(patches)
    for index2, row2 in patches.iterrows():
        delayed_import = data_import(maxx, maxy, minx, miny, row2)
        list_of_imports.append(delayed_import)
    
    imports = dask.compute(*list_of_imports)
    ra_dec = pd.concat(imports)
    print(len(ra_dec))
    #Initial sample cuts
    zmag=np.array(22.5-2.5*np.log10(ra_dec.FLUX_Z))
    zmag[np.where(~np.isfinite(zmag))]=99.
    #whgood=np.where(np.logical_and(zmag < 21,ra_dec.mass > 0 ))
    isgood=np.logical_and(zmag < 21,ra_dec.mass > 0 )
    ra_dec = ra_dec[isgood]
    print(len(ra_dec))
    
    #Further sample cuts
    ra_dec = ra_dec[np.logical_or(ra_dec.MASKBITS == 0, ra_dec.MASKBITS == 4096)]
    ra_dec = ra_dec[np.logical_or(np.logical_or(ra_dec.gaia_phot_g_mean_mag > 19, ra_dec.gaia_astrometric_excess_noise > 10**.5), ra_dec.gaia_astrometric_excess_noise==0)]
    ra_dec["magR"] = 22.5-2.5*np.log10(ra_dec.FLUX_R)
    ra_dec["magZ"] = 22.5-2.5*np.log10(ra_dec.FLUX_Z)
    ra_dec["magW1"] = 22.5-2.5*np.log10(ra_dec.FLUX_W1)
    l_mask = (ra_dec.magR - ra_dec.magW1) > 1.8*(ra_dec.magR-ra_dec.magZ)-0.6
    l_mask[~np.isfinite(l_mask)] = False
    ra_dec = ra_dec[np.logical_and(22.5 - 2.5*np.log10(ra_dec.FLUX_Z)<21, ra_dec.z_phot_median>0.01)]
    
    #Coordinates
    ra_dec["RA_r"] = (np.pi/180)*ra_dec["RA"]
    ra_dec["DEC_r"] = (np.pi/180)*ra_dec["DEC"]
    ra_dec["gid"] = np.round(ra_dec.RA, 6)*10**16 + np.round(ra_dec.DEC + 90, 6)*10**6
    
    #Oversampling
    ra_dec.reset_index(inplace = True, drop = True)
    oversample = 30
    over = np.array([ra_dec.z_phot_median.values]).T*np.ones((len(ra_dec), oversample))
    sigma = np.array([ra_dec.z_phot_std.values]).T*np.ones((len(ra_dec), oversample))
    random = np.random.normal(loc = 0, scale = 1, size = (len(ra_dec), oversample))
    gauss = over + sigma*random
    ra_dec["gauss_z"] = pd.Series(list(gauss))
    
    #Coordinate transform to prevent zeros
    ra_dec["y"] = ra_dec["DEC_r"] - np.mean(ra_dec["DEC_r"]) + 50
    ra_dec["x"] = (ra_dec["RA_r"] - np.mean(ra_dec["RA_r"]))*np.cos(ra_dec["DEC_r"]) + 50
    #ra_dec = ra_dec[ra_dec.z_phot_median < 1.5].copy()
    #Creating array for indexing
    indexable = ra_dec[["z_phot_median", "x", "y", "mass", "gauss_z", "gid", "z_phot_std"]].values.copy()
    
    #Creating massive sample
    massive_sample = ra_dec[ra_dec.mass > 11.2].copy()
    massive_sample["neighbor_mass"] = np.zeros(len(massive_sample))
    massive_sample["local_neighbor_mass"] = np.zeros(len(massive_sample))
    massive_sample["ultra_local_neighbor_mass"] = np.zeros(len(massive_sample))
    massive_sample["corrected_neighbor_mass"] = np.zeros(len(massive_sample))
    massive_sample["neighbors"] = np.zeros(len(massive_sample))
    massive_sample["local_neighbors"] = np.zeros(len(massive_sample))
    massive_sample["ultra_local_neighbors"] = np.zeros(len(massive_sample))
    massive_sample["neighbor_gids"] = np.empty((len(massive_sample)), dtype = "object")
    massive_sample["z_average_no_wt"] = np.zeros(len(massive_sample))
    massive_sample["z_average_prob"] = np.zeros(len(massive_sample))
    massive_sample["z_average_mass_prob"] = np.zeros(len(massive_sample))
    massive_sample["z_std_no_wt"] = np.zeros(len(massive_sample))
    massive_sample["z_std_prob"] = np.zeros(len(massive_sample))
    massive_sample["z_std_mass_prob"] = np.zeros(len(massive_sample))
    massive_sample.reset_index(inplace=True, drop = True)
    
    delayed_result, _ = cluster_finder(massive_sample, indexable, maxRA, maxDEC, minRA, minDEC)
    print(len(delayed_result))
    """delayed_results.append(delayed_result)
    
    if (index)%3 == 0:
        results = dask.compute(*delayed_results)
        if len(results)>1:
            cluster_centrals = pd.concat([select[0] for select in results])
            cluster_members = pd.concat([select[1] for select in results])
        else:
            cluster_centrals = results[0][0]
            cluster_members = results[0][1]
        print(len(cluster_centrals), len(cluster_members))
        cc = Table.from_pandas(cluster_centrals)
        cm = Table.from_pandas(cluster_members)
        cc.write(f'test_table_c{index}.fits', format = 'fits')
        cm.write(f'test_table_m{index}.fits', format = 'fits')
        delayed_results = []
    
    if index+1 == len(testing_centers):
        results = dask.compute(*delayed_results)
        if len(results)>1:
            cluster_centrals = pd.concat([select[0] for select in results])
            cluster_members = pd.concat([select[1] for select in results])
        else:
            cluster_centrals = results[0][0]
            cluster_members = results[0][1]
        print(len(cluster_centrals), len(cluster_members))
        cc = Table.from_pandas(cluster_centrals)
        cm = Table.from_pandas(cluster_members)
        cc.write(f'test_table_c{index}.fits', format = 'fits')
        cm.write(f'test_table_m{index}.fits', format = 'fits')
    
    pbar.progress = index + 1"""

5522115
     Unnamed: 0                                          patch    mean_RA  \
22           22  /data/mjb299/sweep/sweep-000m005-010p000.fits   5.027669   
362         362  /data/mjb299/sweep/sweep-000m010-010m005.fits   4.998628   
126         126  /data/mjb299/sweep/sweep-000p000-010p005.fits   4.974314   
32           32  /data/mjb299/sweep/sweep-000m015-010m010.fits   5.126173   
379         379  /data/mjb299/sweep/sweep-000p005-010p010.fits   5.039875   
314         314  /data/mjb299/sweep/sweep-010m005-020p000.fits  14.970043   
382         382  /data/mjb299/sweep/sweep-010m010-020m005.fits  15.011904   
238         238  /data/mjb299/sweep/sweep-010p000-020p005.fits  14.979433   
100         100  /data/mjb299/sweep/sweep-010p005-020p010.fits  15.009611   

      mean_DEC                                         photoz  \
22   -2.508936  /data/mjb299/pz/sweep-000m005-010p000-pz.fits   
362  -7.250919  /data/mjb299/pz/sweep-000m010-010m005-pz.fits   
126   2.421131  /data/mjb2

KeyboardInterrupt: 



In [None]:
# @@ Cell 1
from sklearn.neighbors import NearestNeighbors
import pandas as pd
import numpy as np
from scipy import spatial
from astropy.table import Table
from astropy.cosmology import LambdaCDM as Cos
from astropy.io import fits
from astropy.coordinates import SkyCoord
from astropy import units as u
import matplotlib.pyplot as plt
from IPython import display 
from scipy import stats
from scipy.interpolate import interp1d
from scipy.stats import norm
import fitsio

import pickle
import dask

# @@ Cell 2
from dask.distributed import Client, LocalCluster
cluster = LocalCluster(n_workers = 9)
client = Client(cluster)
client

# @@ Cell 3
# load table of centers
table_of_centers = pd.read_csv("local_table_of_centers.csv")
table_of_centers_south = table_of_centers[0:437]
table_of_centers_north = table_of_centers[437:]
clusters = []
testing_centers = pd.read_csv("testing_sweeps2.csv")
#testing_centers = testing_centers[0:1]

# @@ Cell 4
len(testing_centers)

# @@ Cell 5
#np.seterr(divide = 'ignore', invalid = "ignore") 
#np.seterr(divide = 'warn', invalid = 'warn')

# @@ Cell 6
nbrs = NearestNeighbors(n_neighbors=9, algorithm='ball_tree').fit(table_of_centers_south[["mean_RA", "mean_DEC"]])

# @@ Cell 7
#Mass fitting parameters and equations
a = 1.3620186928378857  
b = 9.968545069745126
j= 1.04935943 
k = 0.39573094 
l = 0.28347756
def mass_limit(z):
    return np.minimum((a*z + b), 11.2)

def mass_coefficient(z):
    return np.exp(j*z**2 + k*z + l)

# @@ Cell 8
#Radii
radius = 1
small_radius = 0.5
mini_radius = 0.1



# @@ Cell 10
@dask.delayed
def cluster_finder(massive_sample, indexable, maxRA, maxDEC, minRA, minDEC):
    
    #Cosmological Parameters: radius and cylinder length
    cos = Cos(H0 = 70, Om0 = .286, Ode0 = .714)
    z_array = np.linspace(1e-2, indexable[:, 0].max(), 500)
    sparse_radius = (1+z_array)/(cos.comoving_distance(z_array))
    radius_threshold = interp1d(z_array, sparse_radius, kind = "linear", fill_value = "extrapolate")
    median = stats.binned_statistic(indexable[:, 0].astype(float), indexable[:, -1].astype(float), "median", bins = np.linspace(0.05, indexable[:, 0].max(), 100))
    bins = np.linspace(0.05, indexable[:, 0].max(), 99)
    z_threshold = interp1d(bins, np.minimum(median[0], np.ones(len(median[0]))*0.1), kind = "linear", fill_value = "extrapolate")
    
    #Tree Algorithm
    iterrator = massive_sample.copy()
    tree = spatial.cKDTree(indexable[:, 1:3].astype(float), copy_data = True)
    for i, row in iterrator.iterrows():
        neighbors = tree.query_ball_point([row.x, row.y], radius_threshold(row.z_phot_median))
        if len(neighbors) > 0:
            local_data = indexable[neighbors]
            
            z_c = z_threshold(row.z_phot_median)
            cylinder = np.abs(np.vstack(local_data[:, 4]) - row.z_phot_median)
            weight_array = cylinder < 2*z_c
            weights = weight_array.sum(axis = 1)/oversample
            
            approx_cluster = np.append(local_data, np.reshape(weights, newshape = (len(weights), 1)), axis = 1)
            cluster = approx_cluster[approx_cluster[:, -1] > 0]
            
            r_smaller = radius_threshold(row.z_phot_median)
            small_cluster = cluster[np.sqrt(np.array((cluster[:, 1] - row.x)**2 + (cluster[:, 2] - row.y)**2).astype(float)) < 0.5*r_smaller]
            mini_cluster = cluster[np.sqrt(np.array((cluster[:, 1] - row.x)**2 + (cluster[:, 2] - row.y)**2).astype(float)) < 0.1*r_smaller]
            
            massive_sample.at[i, "z_average_no_wt"] = np.mean(cluster[:, 0])
            massive_sample.at[i, "z_average_prob"] = np.average(cluster[:, 0], weights = cluster[:, -1])
            massive_sample.at[i, "z_average_mass_prob"] = np.average(cluster[:, 0], weights = cluster[:, -1]*cluster[:, 3])
            
            massive_sample.at[i, "z_std_no_wt"] = np.std(cluster[:, 0])
            massive_sample.at[i, "z_std_prob"] = np.sqrt(np.cov(cluster[:, 0].astype(float), aweights = cluster[:, -1].astype(float)))
            massive_sample.at[i, "z_std_mass_prob"] = np.sqrt(np.cov(cluster[:, 0].astype(float), aweights = cluster[:, -1]*cluster[:, 3].astype(float)))
            
            massive_sample.at[i, "neighbors"] = np.sum(cluster[:, -1])
            massive_sample.at[i, "local_neighbors"] = np.sum(small_cluster[:, -1])
            massive_sample.at[i, "ultra_local_neighbors"] = np.sum(mini_cluster[:, -1])
            
            mass_co = mass_coefficient(row.z_phot_median)
            massive_sample.at[i, "correction_factor"] = mass_co
            c_mask = cluster[:, 3]>mass_limit(row.z_phot_median)
            cluster_limited = cluster[c_mask.astype("bool"), :]
            massive_sample.at[i, "neighbor_mass"] = np.log10(np.sum(np.append(((10**cluster_limited[:, 3]))*cluster_limited[:, -1], [10**row.mass]))*mass_co)
            massive_sample.at[i, "local_neighbor_mass"] = np.log10(np.sum(np.append((10**small_cluster[:, 3])*small_cluster[:, -1], [10**row.mass])))
            massive_sample.at[i, "ultra_local_neighbor_mass"] = np.log10(np.sum(np.append((10**mini_cluster[:, 3])*mini_cluster[:, -1], [10**row.mass])))
            massive_sample.at[i, "corr_local_neighbor_mass"] = np.log10(np.sum(np.append((10**small_cluster[:, 3])*small_cluster[:, -1], [10**row.mass]))*mass_co)
            massive_sample.at[i, "corr_ultra_local_neighbor_mass"] = np.log10(np.sum(np.append((10**mini_cluster[:, 3])*mini_cluster[:, -1], [10**row.mass]))*mass_co)
            
            clusterid = np.ones((1, len(local_data)))*row.gid
            clusterz = np.ones((1, len(local_data)))*row.z_phot_median
            membership = np.concatenate((local_data[:, -2].reshape((1, len(local_data))), clusterid, local_data[:, 0].reshape((1, len(local_data))), clusterz, local_data[:, -1].reshape((1, len(local_data)))), axis = 0).T
            massive_sample.at[i, "neighbor_gids"] = membership
            
            
    
    #Thresholding
    bins = np.arange(0.05, massive_sample.z_phot_median.max(), 0.01)
    binned = [massive_sample[np.logical_and(massive_sample.z_phot_median>=i-.025, massive_sample.z_phot_median<=i+.025)].copy() for i in bins]
    clusters = pd.DataFrame()
    threshold1 = np.empty(len(binned))
    threshold2 = np.empty(len(binned))
    for i in range(len(binned)):
        threshold1[i] = np.mean(binned[i].neighbors) + 1.8*np.sqrt(np.mean(binned[i].neighbors))
        threshold2[i] = np.mean(binned[i].local_neighbors) + 1.2*np.sqrt(np.mean(binned[i].local_neighbors))
    thresh1 = interp1d(bins, threshold1, kind = "linear", fill_value = "extrapolate")
    thresh2 = interp1d(bins, threshold2, kind = "linear", fill_value = "extrapolate")
    clusters = massive_sample[np.logical_and(massive_sample.neighbors >= thresh1(massive_sample.z_phot_median), massive_sample.local_neighbors >= thresh2(massive_sample.z_phot_median))].copy()
    clusters.sort_values("local_neighbor_mass", inplace = True, ascending = False)
    clusters.reset_index(inplace= True, drop = True)
    
    #Aggregation
    tree = spatial.cKDTree(clusters[["x", "y"]], copy_data = True)
    clusters["ncluster"] = np.zeros(len(clusters))
    clusternum = 1
    iterrator2 = clusters.copy()
    for i, row in iterrator2.iterrows():
        if clusters.iloc[i].ncluster == 0:
            clusters.at[i, "ncluster"] = clusternum
            neighbors = tree.query_ball_point([row.x, row.y], 1.5*radius_threshold(row.z_phot_median))
            for index in neighbors:
                if np.logical_and(clusters.at[index, "ncluster"] == 0, np.abs(clusters.at[index, "z_phot_median"] - row.z_phot_median) < 2*z_threshold(row.z_phot_median)):
                    clusters.at[index, "ncluster"] = clusternum
                    clusters.at[i, "neighbor_gids"] = np.concatenate((clusters.at[i, "neighbor_gids"], clusters.at[index, "neighbor_gids"]), axis = 0) 
            clusternum += 1
    
    #Results
    cluster_center = clusters.sort_values(by = ['ncluster','ultra_local_neighbor_mass'], ascending = [True, False]).groupby('ncluster').head(1).copy()
    cluster_center_selected = cluster_center[np.logical_and.reduce((cluster_center.RA < maxRA, cluster_center.RA > minRA, cluster_center.DEC < maxDEC, cluster_center.DEC > minDEC))].copy()
    
    #Membership
    membership = pd.DataFrame(cluster_center_selected.neighbor_gids.values)
    membership_data = np.zeros((1, 5))
    for i in range(0, len(membership)):
        temp = np.stack(membership.values[i])[0]
        membership_data = np.concatenate([membership_data, temp], axis = 0)
    membershippd = pd.DataFrame(membership_data[1:], columns = ["galaxy", "cluster", "galaxy_z", "cluster_z", "galaxy_z_std"], dtype = float)
    membershippd["z_dist"] = np.abs(membershippd.galaxy_z - membershippd.cluster_z)
    membershippd.sort_values("z_dist", ascending = True, inplace = True)
    membershippd.drop_duplicates(subset = "galaxy", inplace = True)
    membershippd.reset_index(inplace = True, drop = True)
    
    iterrator3 = membershippd.copy()
    membershippd["prob"] = np.zeros(len(membershippd))
    for i, row in iterrator3.iterrows():
        x = row.galaxy_z
        mu = row.cluster_z
        std = row.galaxy_z_std
        membershippd.at[i, "prob"] = 2*(1-norm.cdf(x = np.abs(x-mu), loc = 0, scale = std))
    memberspd = membershippd[membershippd.prob > 0.0027].astype({"galaxy": "int64", "cluster": "int64"}).drop(columns = {"z_dist"})
    
    #Cleaning things up
    cluster_center_selected["BRICKNAME"] = cluster_center_selected["BRICKNAME"].astype('|S80')
    clusters_final = cluster_center_selected[["RA", "DEC", "z_phot_median", "z_phot_std", "mass", "RELEASE", "BRICKID", "OBJID", "MASKBITS", "gid", "neighbor_mass", "corr_local_neighbor_mass", "corr_ultra_local_neighbor_mass", "correction_factor", "neighbors", "local_neighbors", "ultra_local_neighbors"]].copy()
    clusters_final.columns = ["RA_central", "DEC_central", "z_median_central", "z_std_central", "mass_central", "RELEASE", "BRICKID", "OBJID", "MASKBITS", "gid", "neighbor_mass", "local_neighbor_mass", "ultra_local_neighbor_mass", "correction_factor", "neighbors", "local_neighbors", "ultra_local_neighbors"]
    
    return clusters_final, memberspd

# @@ Cell 11
pbar = display.ProgressBar(len(testing_centers))
pbar.display()

delayed_results = []

#Remember to change testing_centers to table_of_centers
for index, row in testing_centers[-1:].iterrows():
    fits_data = fitsio.FITS(row.patch)
    sweep = fits_data[1].read(columns=['RA', 'DEC'])
    
    maxx = max(sweep['RA'])
    maxy = max(sweep['DEC'])
    minx = min(sweep['RA'])
    miny = min(sweep['DEC'])
    
    maxRA = max(sweep['RA'])
    maxDEC = max(sweep['DEC'])
    minRA = min(sweep['RA'])
    minDEC = min(sweep['DEC'])
    
    list_of_imports = []
    buffer = 0.285
    #Neighbors:
    distances, indices = nbrs.kneighbors([row[["mean_RA", "mean_DEC"]]])
    patches = table_of_centers_south.iloc[indices[0]]
    for index2, row2 in patches.iterrows():
        delayed_import = data_import(index2, row2)
        list_of_imports.append(delayed_import)
    
    imports = dask.compute(*list_of_imports)
    ra_dec = pd.concat(imports)
    ra_dec = ra_dec[np.logical_and.reduce((ra_dec.RA < maxRA, ra_dec.RA > minRA, ra_dec.DEC < maxDEC, ra_dec.DEC > minDEC))].copy()
    #Initial sample cuts
    zmag=np.array(22.5-2.5*np.log10(ra_dec.FLUX_Z))
    zmag[np.where(~np.isfinite(zmag))]=99.
    #whgood=np.where(np.logical_and(zmag < 21,ra_dec.mass > 0 ))
    isgood=np.logical_and(zmag < 21,ra_dec.mass > 0 )
    ra_dec = ra_dec[isgood]
    
    #Further sample cuts
    ra_dec = ra_dec[np.logical_or(ra_dec.MASKBITS == 0, ra_dec.MASKBITS == 4096)]
    ra_dec = ra_dec[np.logical_or(np.logical_or(ra_dec.gaia_phot_g_mean_mag > 19, ra_dec.gaia_astrometric_excess_noise > 10**.5), ra_dec.gaia_astrometric_excess_noise==0)]
    ra_dec["magR"] = 22.5-2.5*np.log10(ra_dec.FLUX_R)
    ra_dec["magZ"] = 22.5-2.5*np.log10(ra_dec.FLUX_Z)
    ra_dec["magW1"] = 22.5-2.5*np.log10(ra_dec.FLUX_W1)
    l_mask = (ra_dec.magR - ra_dec.magW1) > 1.8*(ra_dec.magR-ra_dec.magZ)-0.6
    l_mask[~np.isfinite(l_mask)] = False
    ra_dec = ra_dec[np.logical_and(22.5 - 2.5*np.log10(ra_dec.FLUX_Z)<21, ra_dec.z_phot_median>0.01)]
    
    #Coordinates
    ra_dec["RA_r"] = (np.pi/180)*ra_dec["RA"]
    ra_dec["DEC_r"] = (np.pi/180)*ra_dec["DEC"]
    ra_dec["gid"] = np.round(ra_dec.RA, 6)*10**16 + np.round(ra_dec.DEC + 90, 6)*10**6
    print(len(ra_dec))
    
    #Oversampling
    ra_dec.reset_index(inplace = True, drop = True)
    oversample = 30
    over = np.array([ra_dec.z_phot_median.values]).T*np.ones((len(ra_dec), oversample))
    sigma = np.array([ra_dec.z_phot_std.values]).T*np.ones((len(ra_dec), oversample))
    random = np.random.normal(loc = 0, scale = 1, size = (len(ra_dec), oversample))
    gauss = over + sigma*random
    ra_dec["gauss_z"] = pd.Series(list(gauss))
    
    #Coordinate transform to prevent zeros
    ra_dec["y"] = ra_dec["DEC_r"] - np.mean(ra_dec["DEC_r"]) + 50
    ra_dec["x"] = (ra_dec["RA_r"] - np.mean(ra_dec["RA_r"]))*np.cos(ra_dec["DEC_r"]) + 50
    #ra_dec = ra_dec[ra_dec.z_phot_median < 1.5].copy()
    #Creating array for indexing
    indexable = ra_dec[["z_phot_median", "x", "y", "mass", "gauss_z", "gid", "z_phot_std"]].values.copy()
    #dindexable = dask.delayed([indexable])
    
    #Creating massive sample
    massive_sample = ra_dec[ra_dec.mass > 11.2].copy()
    massive_sample["neighbor_mass"] = np.zeros(len(massive_sample))
    massive_sample["local_neighbor_mass"] = np.zeros(len(massive_sample))
    massive_sample["ultra_local_neighbor_mass"] = np.zeros(len(massive_sample))
    massive_sample["corrected_neighbor_mass"] = np.zeros(len(massive_sample))
    massive_sample["neighbors"] = np.zeros(len(massive_sample))
    massive_sample["local_neighbors"] = np.zeros(len(massive_sample))
    massive_sample["ultra_local_neighbors"] = np.zeros(len(massive_sample))
    massive_sample["neighbor_gids"] = np.empty((len(massive_sample)), dtype = "object")
    massive_sample["z_average_no_wt"] = np.zeros(len(massive_sample))
    massive_sample["z_average_prob"] = np.zeros(len(massive_sample))
    massive_sample["z_average_mass_prob"] = np.zeros(len(massive_sample))
    massive_sample["z_std_no_wt"] = np.zeros(len(massive_sample))
    massive_sample["z_std_prob"] = np.zeros(len(massive_sample))
    massive_sample["z_std_mass_prob"] = np.zeros(len(massive_sample))
    massive_sample.reset_index(inplace=True, drop = True)
    #dmassive_sample = dask.delayed([massive_sample])
    
    cluster_centrals, cluster_members = cluster_finder(massive_sample, indexable, maxRA, maxDEC, minRA, minDEC)
    """delayed_results.append(delayed_result)
    
    if (index)%3 == 0:
        results = dask.compute(*delayed_results)
        if len(results)>1:
            cluster_centrals = pd.concat([select[0] for select in results])
            cluster_members = pd.concat([select[1] for select in results])
        else:
            cluster_centrals = results[0][0]
            cluster_members = results[0][1]
        print(len(cluster_centrals), len(cluster_members))
        cc = Table.from_pandas(cluster_centrals)
        cm = Table.from_pandas(cluster_members)
        cc.write(f'test_table_c{index}.fits', format = 'fits')
        cm.write(f'test_table_m{index}.fits', format = 'fits')
        delayed_results = []
    
    if index+1 == len(testing_centers):
        results = dask.compute(*delayed_results)
        if len(results)>1:
            cluster_centrals = pd.concat([select[0] for select in results])
            cluster_members = pd.concat([select[1] for select in results])
        else:
            cluster_centrals = results[0][0]
            cluster_members = results[0][1]
        print(len(cluster_centrals), len(cluster_members))
        cc = Table.from_pandas(cluster_centrals)
        cm = Table.from_pandas(cluster_members)
        cc.write(f'test_table_c{index}.fits', format = 'fits')
        cm.write(f'test_table_m{index}.fits', format = 'fits')"""
    
    pbar.progress = index + 1



