## Comparing DR1 and GalaxyZoo datasets

In [None]:
import pandas as pd
import numpy as np
import astropy.units as u
from astropy.coordinates import SkyCoord

import os
import math

output_path = "test.csv"
#os.remove("test.csv")

In [None]:
desi_path = '/share/nas2/walml/galaxy_zoo/decals/dr8/jpg'
desi_cat = "../Data/gz_desi_deep_learning_catalog_friendly.parquet"
gz1_cat = "../Data/GalaxyZoo1_DR_table2.csv"

SUBSET = False

# read DESI catalogue:
desi_data = pd.read_parquet(desi_cat).reset_index(drop=True)
gz1_data = pd.read_csv(gz1_cat).reset_index(drop=True)
print(desi_data.columns)

print(f"Number of galaxies in DESI catalogue: {len(desi_data)}")
print(f"Number of galaxies in GZ1 catalogue: {len(gz1_data)}")

if SUBSET:
    gz1_data = (gz1_data.sample(10000,random_state=1)).reset_index(drop=True)
    desi_data = (desi_data.sample(10000,random_state=3)).reset_index(drop=True)

In [None]:
#convert the data into skycoord objects
ra1 = gz1_data['RA'].to_numpy()
dec1 = gz1_data['DEC'].to_numpy()
zoo_cat = SkyCoord(ra=ra1, dec=dec1, unit=(u.hourangle, u.deg))

ra2 = desi_data['ra'].to_numpy()
dec2 = desi_data['dec'].to_numpy()
desi_cat = SkyCoord(ra=ra2, dec=dec2, unit=u.deg)

#print(zoo_cat)
#print(desi_cat)

In [None]:
idx, d2d, d3d = zoo_cat.match_to_catalog_sky(desi_cat) #idx is index in desi_cat closest to zoo_cat
max_sep = 10 * u.arcsec
sep_constraint = d2d < max_sep
print(str(sep_constraint.sum()) + " matches found")

zoo_match = gz1_data[sep_constraint] #zoo df that has matches 
#desi_match = desi_data[desi_data.index.isin(idx[sep_constraint])] #desi df that has matches
desi_match = desi_data.loc[idx[sep_constraint]]
#get dr8 id from desi stack to zoo

In [None]:
desi_match_sort = desi_match.sort_index()
zoo_match_sort = zoo_match.set_index(idx[sep_constraint]).sort_index()
big_cat = pd.concat([zoo_match_sort, desi_match_sort['dr8_id']], axis=1).reset_index(drop=True)
big_cat.head(5)

## Writing to file

In [None]:
def split_dataframe(data, no_of_batches):
    batch_size = math.ceil(data.shape[0] / no_of_batches)
    batched_df = [data[i:i+batch_size] for i in range(0,data.shape[0], batch_size)]
    return batched_df

In [None]:
batched_df = split_dataframe(big_cat,10)

for batch in batched_df: 
    print('hii')
    batch.to_csv(output_path, mode='a', header=not os.path.exists(output_path))