In [None]:
%config InlineBackend.figure_format = 'retina'

import os
import numpy as np
import pandas as pd
import dask
import dask.dataframe as dd
from dask import delayed

import itertools as it
from functools import reduce
import treecorr

import matplotlib
import matplotlib.pyplot as plt

In [None]:
matplotlib.rcParams['savefig.dpi'] = 120
matplotlib.rcParams['figure.dpi'] = 120

In [None]:
from dask.distributed import Client, LocalCluster

#cluster = LocalCluster(n_workers=8, 
#                       threads_per_worker=1,
#                       memory_limit='6Gb')
#client = Client(cluster)

dask.config.config["distributed"]["dashboard"]["link"] = "{JUPYTERHUB_SERVICE_PREFIX}proxy/{host}:{port}/status"
client = Client(scheduler_file='/global/cscratch1/sd/cwalter/scheduler.json')

client

In [None]:
from scipy.special import comb

scratch= os.environ["SCRATCH"]

file_path = '/global/cscratch1/sd/cwalter/parquet-with-healpixels/'
selected = ['galaxy_id', 'mag_i', 'redshift_true', 'ra', 'dec', 'shear_1', 'shear_2']
rename_map = {'galaxy_id':'id', 'redshift_true':'z', 'shear_1':'g1', 'shear_2':'g2'}

df = dd.read_parquet(file_path+'skysim-*.parquet', columns=selected)
df = df.rename(columns=rename_map)
#df = df.sample(frac=.000001)
df = df.sample(frac=.1)

#selected = ['id', 'mag_r', 'z', 'ra', 'dec', 'g1', 'g2']
#df = dd.read_parquet(f'{scratch}/parquet-skysim/*.parquet', columns=selected, engine='pyarrow')

#df = df.persist()

number_in_df = df.index.size.compute()
print('Columns:', df.columns.values, '#Rows:', number_in_df/1e9)
print(f'There are {number_in_df:,d} elements in the area with {comb(number_in_df, 2, exact=True):,d} total combinations')

In [None]:
from healpy.pixelfunc import ang2pix
from healpy.pixelfunc import pix2ang
from healpy.rotator import angdist

NSIDE = 32

def add_healpixels(dataframe):
    return ang2pix(NSIDE, dataframe.ra.to_numpy(), dataframe.dec.to_numpy(), lonlat=True)

def angular_distance(pairs):
    
    pixel1 = pix2ang(NSIDE, pairs[:,0])
    pixel2 = pix2ang(NSIDE, pairs[:,1])
    
    seperation = angdist(pixel1, pixel2)*180/np.pi*60 # in arcmin
    
    return seperation

@delayed
def cross(dataframe1, dataframe2, pixel1, pixel2):
    
    gg = treecorr.GGCorrelation(min_sep=1., max_sep=200., nbins=20, num_threads=1, sep_units='arcmin')

    if dataframe1 is dataframe2:   
        #print(pixel1, pixel2, "same!")
        cat1 = treecorr.Catalog(ra=dataframe1.ra, dec=dataframe1.dec, g1=dataframe1.g1, g2=dataframe1.g2, flip_g2=False, ra_units='deg', dec_units='deg')
        gg.process_auto(cat1)
        
        del cat1
    else:
        #print(pixel1, pixel2, "different!")
        cat1 = treecorr.Catalog(ra=dataframe1.ra, dec=dataframe1.dec, g1=dataframe1.g1, g2=dataframe1.g2, flip_g2=False, ra_units='deg', dec_units='deg')
        cat2 = treecorr.Catalog(ra=dataframe2.ra, dec=dataframe2.dec, g1=dataframe2.g1, g2=dataframe2.g2, flip_g2=False, ra_units='deg', dec_units='deg')
        gg.process_cross(cat1, cat2)
        
        del cat1
        del cat2
             
    del dataframe1 
    del dataframe2
    
    return gg

@delayed
def size_test(dataframe1, dataframe2):
    
    gg = 1
    
    del dataframe1 
    del dataframe2

    return gg

def calculateVariance(dataframe):
    cat = treecorr.Catalog(ra=dataframe.ra, dec=dataframe.dec, g1=dataframe.g1, g2=dataframe.g2, flip_g2=False, ra_units='deg', dec_units='deg')
    return pd.DataFrame([[cat.varg*cat.sumw, cat.sumw]])

In [None]:
pixel_list = list(df.index.unique())

In [None]:
delayed_list = df.to_delayed()
partition_map = {j:i for i,j in enumerate(pixel_list)}

In [None]:
elements = df.map_partitions(calculateVariance).sum().compute()
varg = elements[0]/elements[1]

In [None]:
pairs = np.array( [x for x in it.combinations_with_replacement(pixel_list, 2)] )
selected_pairs = pairs[angular_distance(pairs) < 600]

In [None]:
%%time
a = [cross( delayed_list[partition_map[i[0]]], delayed_list[partition_map[i[1]]], i[0], i[1] ) for i in selected_pairs]

In [None]:
#a = [size_test(delayed_list[partition_map[i[0]]], delayed_list[partition_map[i[1]]]) for i in selected_pairs]
#a = [size_test(delayed_list[partition_map[i[0]]], delayed_list[partition_map[i[1]]]) for i in selected_pairs]

In [None]:
%%time
gg_list = dask.compute(*a[10000:15000])

In [None]:
%%time
gg = reduce(treecorr.GGCorrelation.__iadd__, gg_list)

In [None]:
#del gg_list
client.cancel(gg_list)

In [None]:
%%time
gg.finalize(varg, varg)

In [None]:
plt.errorbar(gg.meanr, gg.xip, yerr=np.sqrt(gg.varxip), marker='.', markersize=9, label=r'$\xi_{+}$', ls='none')
plt.errorbar(gg.meanr, gg.xim, yerr=np.sqrt(gg.varxim), marker='.', markersize=9, label=r'$\xi_{-}$', ls='none')

plt.title('$\gamma \gamma$ Correlation')
plt.xscale('log')
#plt.yscale('log')
plt.legend()
plt.ylabel(r'$\xi$')
plt.xlabel(r'$\theta$ (arcmin)')

In [None]:
import ctypes

def trim_memory() -> int:
    libc = ctypes.CDLL("libc.so.6")
    return libc.malloc_trim(0)

client.run(trim_memory)