# Data pipeline
This file contains all the data processing steps used to generate the training, validation and testing datasets


In [1]:
import pandas as pd
import numpy as np
import torch
import shutil
import os
import math
import time
from torch.utils.data import random_split
import astropy.units as u
from astropy.coordinates import SkyCoord
from astroquery.sdss import SDSS

## Options 
Select any part of the data pipeline to redo, otherwise this notebook will run without saving or overwriting any files.

In [2]:
REGENERATE_MATCHED_CATALOG = False #Match GZ1 and DESI objects
REGENERATE_QUERIED_CATALOG = False #Run SDSS query on matched catalog to return data and cut objects - VERY SLOW
REGENERATE_LOCAL_SUBSET_CATALOG = False #Create balanced subset of 1500 best S/Z/El galaxies
CREATE_LOCAL_SUBSET_COPY = False #Copy all galaxies in local subset to local folder
REGENERATE_BEST_SUBSET_CATALOG = False #Create balanced catalog of 15000 best S/Z/El galaxies
REGENERATE_CUT_CATALOG = False #Create catalog of matched galaxies after astro cuts
REGENERATE_TEST_TRAIN_CATALOG = False #Create testing and training datasets from cut catalog
REGENERATE_DOWNSAMPLED_CATALOG = False #Downsample training catalog to improve balance

## Utility functions

In [3]:
def get_metrics(catalog, cat_name):
    total = catalog.shape[0]
    largest_prob_class = catalog[['P_CW','P_ACW','P_OTHER']].idxmax(axis=1)
    CW_galaxies = np.count_nonzero(largest_prob_class=='P_CW')
    ACW_galaxies = np.count_nonzero(largest_prob_class=='P_ACW')
    OTHER_galaxies = np.count_nonzero(largest_prob_class=='P_OTHER')
    print(f"{cat_name} contains {total} galaxies. CW: {CW_galaxies} ({CW_galaxies/total:.1%}), ACW: {ACW_galaxies} ({ACW_galaxies/total:.1%}), Other: {OTHER_galaxies} ({OTHER_galaxies/total:.1%})")

def create_balanced_subset(catalog, threshold, N_CW,N_ACW,N_EL):
    very_CW_galaxies = catalog[catalog['P_CW']>threshold]
    very_ACW_galaxies = catalog[catalog['P_ACW']>threshold]
    very_EL_galaxies = catalog[catalog['P_EL']>threshold]
    print(f"Total Very CW: {very_CW_galaxies.shape[0]}, Very ACW: {very_ACW_galaxies.shape[0]}, Very EL: {very_EL_galaxies.shape[0]}")

    galaxy_subset = pd.concat([very_CW_galaxies[0:N_CW],very_ACW_galaxies[0:N_ACW],very_EL_galaxies[0:N_EL]])
    galaxy_subset = galaxy_subset.reset_index()
    print(f"Number of galaxies in best subset catalog: {galaxy_subset.shape[0]}")
    galaxy_subset.loc[:,['P_OTHER']] = galaxy_subset[['P_EL','P_EDGE','P_DK','P_MG']].sum(axis=1).round(3)
    galaxy_subset.reset_index()
    return galaxy_subset

def get_filepath_by_id(dr8_id,folder_path):
    brick_id = dr8_id.split('_')[0]
    file_loc = f"{folder_path}/{brick_id}/{dr8_id}.jpg"
    return file_loc

def split_dataframe(data, no_of_batches):
    batch_size = math.ceil(data.shape[0] / no_of_batches)
    batched_df = [data[i:i+batch_size] for i in range(0,data.shape[0], batch_size)]
    return batched_df

def get_SDSS_info_batch(catalog,save_path,radius = "1 arcsec",overwrite=True,batch_start=0):
    batched_df = split_dataframe(catalog,200) #30s per batch, more than this seems to fail

    if os.path.exists(save_path) and overwrite:
        os.remove(save_path)

    for i in range(batch_start, len(batched_df)):    
        batch = batched_df[i]        
        coords = SkyCoord(batch["RA"],batch["DEC"],unit=(u.hourangle, u.deg))
        results = pd.DataFrame(SDSS.query_region(coords,data_release=7,radius=radius,photoobj_fields=["objID","ra","dec","err_r","petroR50_r","petroR50Err_r"]).to_pandas())
    
        #Clean up OBJID fields
        batch.loc[:,'OBJID'] = batch['OBJID'].astype(str).str.strip()
        results.loc[:,'objID'] = results['objID'].astype(str).str.strip()
        
        k=0
        j=0
        rows_list = []
        while k < len(batch)-1: #Run through each item in batch
            batch_row = batch.iloc[k]
            results_row = results.iloc[j]
            
            if batch_row['OBJID'] == results_row['objID']: #If OBJIDs match
                #print(f"Match at row {k}")
                if batch.iloc[k+1]['OBJID'] == results.iloc[j+1]['objID']: #If next object OBJIDs match
                    #print(f"Adding row {k} as next row matches")
                    batch_dict = batch_row.to_dict()
                    results_dict = results_row.to_dict()
                    batch_dict.update(results_dict)# Add matching rows from batch and results
                    rows_list.append(batch_dict)
                else:
                    #print(f"Skipping row {k} as next row does not match")
                    while batch.iloc[k+1]['OBJID'] != results.iloc[j+1]['objID']:
                        j += 1 # Move through results until match found
            else:
                #print(f"Skipping row {k} as next row does not match")
                while batch.iloc[k+1]['OBJID'] != results.iloc[j+1]['objID']:
                    j += 1 # Move through results until match found
            k += 1 #Move on to next i
            j += 1 # Move on to next j

        final_columns = batch.columns.to_list()+results.columns.to_list()
        final = pd.DataFrame(rows_list,columns= final_columns)
        reduced = final.drop(["Unnamed: 0","objID","ra","dec"],axis=1)
        
        time.sleep(1)
        reduced.to_csv(save_path, mode='a', header=not os.path.exists(save_path),index=False)
        print(f"Processing batch {i} ({len(batch)} items, {len(results)} results found, cut to {len(reduced)}))")

## Load in SDSS and GZ1 catalogs

In [4]:
DESI_CATALOG_PATH = '../../Data/DESI/gz_desi_deep_learning_catalog_friendly.parquet' #Available from https://doi.org/10.5281/zenodo.8360385
GZ_CATALOG_PATH = '../../Data/GalaxyZoo1_DR_table2.csv' # Available from ui.adsabs.harvard.edu/abs/2011MNRAS.410..166L/abstract or https://data.galaxyzoo.org/

gz_catalog = pd.read_csv(GZ_CATALOG_PATH)
gz_catalog.loc[:,['P_OTHER']] = gz_catalog[['P_EL','P_EDGE','P_DK','P_MG']].sum(axis=1).round(3)

print(f"DESI catalog contains 8.7 million objects.")
get_metrics(gz_catalog,"GZ1 catalog")

DESI catalog contains 8.7 million objects.
GZ1 catalog contains 667944 galaxies. CW: 32102 (4.8%), ACW: 33795 (5.1%), Other: 602047 (90.1%)


## Step 1: Cross-matching DESI DR8 & GZ1 (SDSS DR7)
DESI images are organised by dr8_id, whereas GZ1 uses SDSS OBJID. Use astropy to match objects across both catalogs & add a 'dr8_id' column to the GZ1 catalog.

In [5]:
MATCHED_CATALOG = '../../Data/gz1_desi_cross_cat.csv'

if REGENERATE_MATCHED_CATALOG:
    desi_data = pd.read_parquet(DESI_CATALOG_PATH).reset_index(drop=True)
    gz1_data = pd.read_csv(GZ_CATALOG_PATH).reset_index(drop=True)
    ra1 = gz1_data['RA'].to_numpy() #Convert to skycoords
    dec1 = gz1_data['DEC'].to_numpy()
    zoo_cat = SkyCoord(ra=ra1, dec=dec1, unit=(u.hourangle, u.deg))

    ra2 = desi_data['ra'].to_numpy()
    dec2 = desi_data['dec'].to_numpy()
    desi_cat = SkyCoord(ra=ra2, dec=dec2, unit=u.deg)

    idx, d2d, d3d = zoo_cat.match_to_catalog_sky(desi_cat) #idx is index in desi_cat closest to zoo_cat
    max_sep = 10 * u.arcsec
    sep_constraint = d2d < max_sep
    print(str(sep_constraint.sum()) + " matches found")

    zoo_match = gz1_data[sep_constraint] #zoo df that has matches 
    desi_match = desi_data.loc[idx[sep_constraint]]
    #get dr8 id from desi stack to zoo
    desi_match_sort = desi_match.sort_index()
    zoo_match_sort = zoo_match.set_index(idx[sep_constraint]).sort_index()
    matched_catalog = pd.concat([zoo_match_sort, desi_match_sort['dr8_id']], axis=1).reset_index(drop=True)
    matched_catalog.to_csv(MATCHED_CATALOG)

matched_catalog = pd.read_csv(MATCHED_CATALOG)
matched_catalog.loc[:,['P_OTHER']] = matched_catalog[['P_EL','P_EDGE','P_DK','P_MG']].sum(axis=1).round(3)
print(f"Number of galaxies in matched catalog: {matched_catalog.shape[0]}, removed {gz_catalog.shape[0]-matched_catalog.shape[0]}")
get_metrics(gz_catalog,"Matched catalog")

Number of galaxies in matched catalog: 647837, removed 20107
Matched catalog contains 667944 galaxies. CW: 32102 (4.8%), ACW: 33795 (5.1%), Other: 602047 (90.1%)


## Step 1a: Creating balanced local subset of 1500 most S, Z and El images for testing

In [6]:
LOCAL_SUBSET_CATALOG_PATH = '../../Data/gz1_desi_cross_cat_local_subset.csv'
DESI_DATA_PATH = '/share/nas2/walml/galaxy_zoo/decals/dr8/jpg'
SUBSET_DATA_PATH = '../../Data/Subset'

if REGENERATE_LOCAL_SUBSET_CATALOG:
    #local_subset_catalog = local_subset_catalog = local_subset_catalog.loc[:, ~local_subset_catalog.columns.str.contains('^Unnamed')]
    local_subset_catalog = create_balanced_subset(matched_catalog, threshold=0.8, N_CW=500,N_ACW=500,N_EL=500)
    local_subset_catalog.loc[:,['P_OTHER']] = matched_catalog[['P_EL','P_EDGE','P_DK','P_MG']].sum(axis=1).round(3)
    local_subset_catalog.to_csv(LOCAL_SUBSET_CATALOG_PATH,index=False)

local_subset_catalog = pd.read_csv(LOCAL_SUBSET_CATALOG_PATH)
local_subset_catalog.loc[:,['P_OTHER']] = matched_catalog[['P_EL','P_EDGE','P_DK','P_MG']].sum(axis=1).round(3)

if CREATE_LOCAL_SUBSET_COPY:
    for index, galaxy in local_subset_catalog.iterrows():
        g_dr8_id = galaxy['dr8_id']
        galaxy_path = get_filepath_by_id(g_dr8_id,DESI_DATA_PATH) #ORIGINAL
        new_path = get_filepath_by_id(g_dr8_id,SUBSET_DATA_PATH) #NEW

        #MAKE FOLDER IN NEW
        os.makedirs(os.path.dirname(new_path), exist_ok=True)
        shutil.copy(galaxy_path, new_path) #COPY
        
get_metrics(local_subset_catalog,"Local subset catalog")

Local subset catalog contains 1500 galaxies. CW: 241 (16.1%), ACW: 251 (16.7%), Other: 1008 (67.2%)


## Step 1b: Create balanced subset catalog of 15000 most S, Z & El galaxies for testing

In [7]:
BEST_SUBSET_CATALOG_PATH = '../../Data/gz1_desi_cross_cat_best_subset.csv'

if REGENERATE_BEST_SUBSET_CATALOG:
    best_subset_catalog = create_balanced_subset(matched_catalog, threshold=0.8, N_CW=5000,N_ACW=5000,N_EL=5000)
    best_subset_catalog.loc[:,['P_OTHER']] = matched_catalog[['P_EL','P_EDGE','P_DK','P_MG']].sum(axis=1).round(3)
    best_subset_catalog.to_csv(BEST_SUBSET_CATALOG_PATH,index=False)

best_subset_catalog = pd.read_csv(BEST_SUBSET_CATALOG_PATH)
best_subset_catalog.loc[:,['P_OTHER']] = matched_catalog[['P_EL','P_EDGE','P_DK','P_MG']].sum(axis=1).round(3)
get_metrics(best_subset_catalog,"Best subset catalog")

Best subset catalog contains 15000 galaxies. CW: 2418 (16.1%), ACW: 2479 (16.5%), Other: 10103 (67.4%)


## Step 2: Cut objects that have another object within 1 arcsec
Query SDSS via astroquery to get r-band values, and cut objects that have another object within 1 arcsec

In [8]:
QUERIED_CATALOG_PATH = '../../Data/gz1_desi_cross_cat_queried.csv'

if REGENERATE_QUERIED_CATALOG:
    get_SDSS_info_batch(matched_catalog,QUERIED_CATALOG_PATH)

queried_catalog = pd.read_csv(QUERIED_CATALOG_PATH)
queried_catalog.loc[:,['P_OTHER']] = queried_catalog[['P_EL','P_EDGE','P_DK','P_MG']].sum(axis=1).round(3)
print(f"Number of galaxies with no objects within 1 arcsec: {queried_catalog.shape[0]}, removed {matched_catalog.shape[0]-queried_catalog.shape[0]}")
get_metrics(queried_catalog,"Queried catalog")

Number of galaxies with no objects within 1 arcsec: 213744, removed 434093
Queried catalog contains 213744 galaxies. CW: 8688 (4.1%), ACW: 9190 (4.3%), Other: 195866 (91.6%)


## Step 3: Cut objects using magnitude/r-band

Apply the following cuts
-  r-band magnitude error >0 & <1
- r-band half-light radius r50 >1 arcsec
- relative r-band half-light radius error >0 & <0.25

In [9]:
CUT_CATALOG_PATH = '../../Data/gz1_desi_cross_cat_cut.csv'
Jia_final = 173097

reduced = queried_catalog[np.logical_and(queried_catalog["err_r"]>0,queried_catalog["err_r"]<1)]
print(f"Number of galaxies with r-band magnitude error >0 & <1: {reduced.shape[0]}, removed {queried_catalog.shape[0]-reduced.shape[0]}")

reduced2 = reduced[reduced["petroR50_r"]>1]
print(f"Number of galaxies with r-band half-light radius r50 >1 arcsec: {reduced2.shape[0]}, removed {reduced.shape[0]-reduced2.shape[0]}")

r_band_err = reduced2["petroR50Err_r"]/reduced2["petroR50_r"]
cut_catalog = reduced2[np.logical_and(r_band_err>0,r_band_err<0.25)]
print(f"Number of galaxies with relative r-band half-light radius error >0 & <0.25: {cut_catalog.shape[0]}, removed {reduced2.shape[0]-cut_catalog.shape[0]}")

print(f"Jia et al (2023) Final Number: {Jia_final}. Difference: {Jia_final-cut_catalog.shape[0]}")

if REGENERATE_CUT_CATALOG:
    cut_catalog.to_csv(CUT_CATALOG_PATH)

get_metrics(cut_catalog,"Cut catalog")

Number of galaxies with r-band magnitude error >0 & <1: 213744, removed 0
Number of galaxies with r-band half-light radius r50 >1 arcsec: 213369, removed 375
Number of galaxies with relative r-band half-light radius error >0 & <0.25: 208682, removed 4687
Jia et al (2023) Final Number: 173097. Difference: -35585
Cut catalog contains 208682 galaxies. CW: 8520 (4.1%), ACW: 9023 (4.3%), Other: 191139 (91.6%)


## Step 4: Select testing dataset

Select 15% of the cut data for a reserved test set, and, with a set seed

In [10]:
TESTING_CATALOG_PATH = '../../Data/gz1_desi_cross_cat_testing.csv'
TRAIN_VAL_CATALOG_PATH = '../../Data/gz1_desi_cross_cat_train_val.csv'
get_metrics(cut_catalog,"Cut catalog")
generator1 = torch.Generator().manual_seed(42)
testing_catalog, train_val_catalog = random_split(cut_catalog, [0.20,0.80], generator=generator1)
#Convert from subsets
testing_catalog = testing_catalog.dataset.iloc[testing_catalog.indices]
train_val_catalog = train_val_catalog.dataset.iloc[train_val_catalog.indices]
get_metrics(testing_catalog,"Testing catalog")
get_metrics(train_val_catalog,"Training & validation catalog")

if REGENERATE_TEST_TRAIN_CATALOG:
    #Probably can drop unneeded columns
    testing_catalog.to_csv(TESTING_CATALOG_PATH,index=False)
    train_val_catalog.to_csv(TRAIN_VAL_CATALOG_PATH,index=False)

Cut catalog contains 208682 galaxies. CW: 8520 (4.1%), ACW: 9023 (4.3%), Other: 191139 (91.6%)
Testing catalog contains 41737 galaxies. CW: 1682 (4.0%), ACW: 1840 (4.4%), Other: 38215 (91.6%)
Training & validation catalog contains 166945 galaxies. CW: 6838 (4.1%), ACW: 7183 (4.3%), Other: 152924 (91.6%)


## Step 5: Downsampling

From the training and validation catalog, keep 
- 1 in 20 galaxies with 0 < max(P_CW, P_ACW) <= 0.1
- 1 in 5 galaxies with 0.1 < max(P_CW, P_ACW) <= 0.2 
- 1 in 2 galaxies with 0.2 < max(P_CW, P_ACW) <= 0.3

In [11]:
TRAIN_VAL_DOWNSAMPLE_CATALOG_PATH = '../../Data/gz1_desi_cross_cat_train_val_downsample.csv'

get_metrics(train_val_catalog,"Training & validation catalog")

def cut_by_factor(cat,factor):
    generator1 = torch.Generator().manual_seed(42)
    kept_downsample, _ = random_split(cat, [1/factor,1-(1/factor)], generator=generator1)
    return kept_downsample.dataset.iloc[kept_downsample.indices]

sample_mask_1 = np.logical_and(train_val_catalog[['P_CW',"P_ACW"]].max(axis=1) >= 0, train_val_catalog[['P_CW',"P_ACW"]].max(axis=1) <= 0.1)
sample_mask_2 = np.logical_and(train_val_catalog[['P_CW',"P_ACW"]].max(axis=1) > 0.1, train_val_catalog[['P_CW',"P_ACW"]].max(axis=1) <= 0.2)
sample_mask_3 = np.logical_and(train_val_catalog[['P_CW',"P_ACW"]].max(axis=1) > 0.2, train_val_catalog[['P_CW',"P_ACW"]].max(axis=1) <= 0.3)
keep_mask = np.logical_and(train_val_catalog[['P_CW',"P_ACW"]].max(axis=1) > 0.3, train_val_catalog[['P_CW',"P_ACW"]].max(axis=1) <= 1)

kept_galaxies = train_val_catalog[keep_mask]
downsample_set_1 = cut_by_factor(train_val_catalog[sample_mask_1],20)
downsample_set_2 = cut_by_factor(train_val_catalog[sample_mask_2],5)
downsample_set_3 = cut_by_factor(train_val_catalog[sample_mask_3],2)

train_val_downsample_catalog = pd.concat([kept_galaxies,downsample_set_1,downsample_set_2,downsample_set_3])

if REGENERATE_DOWNSAMPLED_CATALOG:
    train_val_downsample_catalog.to_csv(TRAIN_VAL_DOWNSAMPLE_CATALOG_PATH)

get_metrics(train_val_downsample_catalog,"Downsampled catalog")


Training & validation catalog contains 166945 galaxies. CW: 6838 (4.1%), ACW: 7183 (4.3%), Other: 152924 (91.6%)
Downsampled catalog contains 35988 galaxies. CW: 6838 (19.0%), ACW: 7183 (20.0%), Other: 21967 (61.0%)


## Summary of all steps

In [12]:
print("Load initial GZ1 catalog")
get_metrics(gz_catalog,"GZ1 catalog")
print("\nStep 1: Cross-match with DESI image catalog")
get_metrics(matched_catalog,"Matched catalog")
print("\nStep 2: Cut objects that have another object within 1 arcsec by querying SDSS")
get_metrics(queried_catalog,"Queried catalog")
print("\nStep 3: Cut objects using magnitude/r-band")
get_metrics(cut_catalog,"Cut catalog")
print("\nStep 4: Select testing dataset")
get_metrics(testing_catalog,"Testing catalog")
get_metrics(train_val_catalog,"Train/val catalog")
print("\nStep 5: Downsampling")
get_metrics(train_val_downsample_catalog,"Downsampled train/val Catalog")

Load initial GZ1 catalog
GZ1 catalog contains 667944 galaxies. CW: 32102 (4.8%), ACW: 33795 (5.1%), Other: 602047 (90.1%)

Step 1: Cross-match with DESI image catalog
Matched catalog contains 647837 galaxies. CW: 31594 (4.9%), ACW: 33241 (5.1%), Other: 583002 (90.0%)

Step 2: Cut objects that have another object within 1 arcsec by querying SDSS
Queried catalog contains 213744 galaxies. CW: 8688 (4.1%), ACW: 9190 (4.3%), Other: 195866 (91.6%)

Step 3: Cut objects using magnitude/r-band
Cut catalog contains 208682 galaxies. CW: 8520 (4.1%), ACW: 9023 (4.3%), Other: 191139 (91.6%)

Step 4: Select testing dataset
Testing catalog contains 41737 galaxies. CW: 1682 (4.0%), ACW: 1840 (4.4%), Other: 38215 (91.6%)
Train/val catalog contains 166945 galaxies. CW: 6838 (4.1%), ACW: 7183 (4.3%), Other: 152924 (91.6%)

Step 5: Downsampling
Downsampled train/val Catalog contains 35988 galaxies. CW: 6838 (19.0%), ACW: 7183 (20.0%), Other: 21967 (61.0%)
