In [None]:
import os
import numpy as np
import pandas as pd
import dask
import dask.dataframe as dd
from dask import delayed

import itertools as it
from functools import reduce
import treecorr

import matplotlib
import matplotlib.pyplot as plt

In [None]:
from dask.distributed import Client, LocalCluster

#cluster = LocalCluster(n_workers=8, 
#                       threads_per_worker=1,
#                       memory_limit='6Gb')
#client = Client(cluster)

dask.config.config["distributed"]["dashboard"]["link"] = "{JUPYTERHUB_SERVICE_PREFIX}proxy/{host}:{port}/status"
client = Client(scheduler_file='/global/cscratch1/sd/cwalter/scheduler.json')

client

In [None]:
from scipy.special import comb

scratch= os.environ["SCRATCH"]

file_path = '/global/cscratch1/sd/cwalter/parquet-with-healpixels/'
selected = ['galaxy_id', 'mag_i', 'redshift_true', 'ra', 'dec', 'shear_1', 'shear_2']
rename_map = {'galaxy_id':'id', 'redshift_true':'z', 'shear_1':'g1', 'shear_2':'g2'}

df = dd.read_parquet(file_path+'skysim-*.parquet', columns=selected)
df = df.rename(columns=rename_map)
#df = df.sample(frac=.000001)
df = df.sample(frac=.1)

#df = df.persist()

number_in_df = df.index.size.compute()
print('Columns:', df.columns.values, '#Rows:', number_in_df/1e9)
print(f'There are {number_in_df:,d} elements in the area with {comb(number_in_df, 2, exact=True):,d} total combinations')

In [None]:
pixel_list = list(df.index.unique())

In [None]:
from healpy.pixelfunc import pix2ang
from healpy.rotator import angdist

NSIDE = 32

def angular_distance(pairs):
    
    pixel1 = pix2ang(NSIDE, pairs[:,0])
    pixel2 = pix2ang(NSIDE, pairs[:,1])
    
    seperation = angdist(pixel1, pixel2)*180/np.pi*60 # in arcmin
    
    return seperation

@delayed
def size_test(cat1, cat2):
    gg = 1
    return gg


def make_catalog(dataframe):
    cat = treecorr.Catalog(ra=dataframe.ra, dec=dataframe.dec, g1=dataframe.g1, g2=dataframe.g2, flip_g2=False, ra_units='deg', dec_units='deg')
    #print('size', cat.ntot)
    return cat


In [None]:
catalogs = df.map_partitions(make_catalog, meta=(None, 'O'))
catalogs = catalogs.persist()

In [None]:
delayed_list = catalogs.to_delayed()
partition_map = {j:i for i,j in enumerate(pixel_list)}

In [None]:
pairs = np.array( [x for x in it.combinations_with_replacement(pixel_list, 2)] )
selected_pairs = pairs[angular_distance(pairs) < 600]

In [None]:
a = [size_test(delayed_list[partition_map[i[0]]], delayed_list[partition_map[i[1]]]) for i in selected_pairs]
#a = [size_test(catalogs.get_partition(partition_map[i[0]]), catalogs.get_partition(partition_map[i[1]])) for i in selected_pairs]

In [None]:
%%time
gg_list = dask.compute(*a)

In [None]:
#catalogs.get_partition(0).compute()