In [None]:
from astropy.table import Column, Table
import numpy as np
import os

# Sometimes different catalogs have different names for their RA and DEC columns
# This function has a database of catalog names (the names of the files of the catalogs without the suffix)
# It checks which catalog it is and fetches the names of its respective RA and DEC columns
def get_ra_dec_names(cat_name):
    
    # The variable used to alternate between the different possible names for the right ascenscion and declination columns
    check = 0
    
    # Initializes the RESULT variable which will harbour the names for the right ascenscion and declination columns, respectively
    result = "", ""
    
    # List of catalogs with "ra" as the right ascenscion and "dec" as declination columns
    ra_dec = ["2MASS", "AKARI", "Gaia", "IRAS", "IRS", "PACS", "Planck", "SEIP", "SPIRE", "USNO", "VSA", "WISE"]
    
    # List of catalogs with "RAJ2000" as the right ascenscion and "DEJ2000" as declination columns
    RAJ2000_DEJ2000 = ["DENIS"]
    
    RA_DEC_ICRS = ["DANCE", "Gaia_DR3"]
    
    # Checks if the catalog belongs to the RA_DEC list of names
    for name in ra_dec:
        
        if name in cat_name:
            
            check = 1
            
            result = "ra", "dec"
            
            break
    
    # Checks if the catalog belongs to the RAJ2000_DEJ2000 list of names
    if check == 0:
        
        for name in RAJ2000_DEJ2000:
            
            if name in cat_name:
                
                check = 1
                
                result = "RAJ2000", "DEJ2000"
                
                break
    
    # Checks if the catalog belongs to the RA_DEC_ICRS list of names
    if check == 0:
        
        for name in RA_DEC_ICRS:
            
            if name in cat_name:
                
                check = 1
                
                result = "RA_ICRS", "DE_ICRS"
    
    # If the input catalog is not found in the database
    if check == 0:
        
        return("Catalog was not found in the database!")
    
    # If the input catalog is found in the database
    elif check == 1:
        
        return result

# After performing the cross-match function on a catalog, its column names are changed if they are repeated in the catalog it is being matched with
# This function resets the column names of the first catalog back to what they were
def reset_main_catalog_columns(main_catalog, final_catalog):
    
    # Initializes the Table object to be returned from the function
    final_cat = Table()
    
    # Creates a list of the column names present in the MAIN_CATALOG
    main_cat_col_names = main_catalog.colnames
    
    # Adds an underscore next to each catalog column name, since this underscore is the character added during the HSTACK function
    main_cat_col_names[:] = [names + "_" for names in main_cat_col_names]
    
    # Iterates over each column in the FINAL_CATALOG - the catalog making up all catalogs aggregated, so far
    for col in final_catalog.itercols():
        
        # If the current FINAL_CATALOG column is the same as the respective MAIN_CATALOG column with an underscore added to its end
        if col.name in main_cat_col_names:
            
            # Subtracted the added underscore from the catalog column name
            new_col_name = col.name[:-1]
            #print(col.name, new_col_name)
            # Adds the current column from the FINAL_CATALOG to the FINAL_CAT Table
            final_cat.add_column(col)
            
            # Renames the newly added column to match the original name from the MAIN_CATALOG
            final_cat.rename_column(col.name, new_col_name)
        
        else:
            
            try:
                
                # Adds the current column from the FINAL_CATALOG to the FINAL_CAT Table
                final_cat.add_column(col)
            
            except ValueError:
                
                temp_col = Column.copy(col)
                
                temp_col.name = col.name + "_2_2"
                
                final_cat.add_column(temp_col)
    
    return final_cat

# When stacking two columns with the same name but different types, this function is called
# cat_1 && cat_2 are the two catalogs you are vertically stacking
# cat_name is the name of the column you are stacking
def same_name_diff_types(cat_1, cat_2, cat_2_name):
    
    # Creates a list of the column names present in the CAT_1
    cat_1_col_names = cat_1.colnames
    
    # Creates a list of the column names present in the CAT_2
    cat_2_col_names = cat_2.colnames
    
    # Creates a list of the common column names between the CAT_1 and CAT_2 Table objects
    common = [i for i in cat_1_col_names if i in cat_2_col_names]
    
    # Cycles through each common column name to both catalogs
    for name in common:
        
        # If the column in CAT_1 has a different .dtype than its CAT_2 counterpart with the same name
        if cat_1[name].dtype != cat_2[name].dtype:
            
            # Renames the column in CAT_2 to carry the name of its catalog so as to differentiate from both going forward
            cat_2.rename_column(name, name + "_" + cat_2_name)
    
    # Returns CAT_2 with its column names changed for all columns with differing .dtypes than the CAT_1 columns with the same name
    return cat_2

# Rebuilds a catalog with only the columns specified in a list
def make_catalog(input_catalog, columns):
    
    catalog = Table()
    
    for column in columns:
        
        catalog[column] = Column.copy(input_catalog[column])
    
    return catalog

# After the cross-match function and resetting the catalog column names, removes the columns of the second catalog which it cross-matched with
def remove_new_columns(cat_1, cat_2):
    
    # Creates a list of the column names present in the CAT_1
    cat_1_col_names = cat_1.colnames
    
    # Creates a list of the column names present in the CAT_2
    cat_2_col_names = cat_2.colnames
    
    new_columns = [i for i in cat_2_col_names if i not in cat_1_col_names]
    
    for column in new_columns:
        
        del cat_2[column]
    
    return cat_2

# Line equation of the form ax + by + c = 0 through (x0, y0) and (x1, y1);
# ax + by + c < 0 for points left of the line
def get_line_eq(x0, x1, y0, y1):
    
    return y0 - y1, x1 - x0, x0 * y1 - x1 * y0

# Retrieves the sources in a catalog which are inside a given user-defined region
def sources_in_region(ra_min, ra_max, dec_min, dec_max, catalog, ra, dec):
    
    #catalog.remove_rows([catalog[np.logical_and(np.logical_and(catalog[ra] < ra_min, catalog[ra] > ra_max), np.logical_and(catalog[dec] < dec_min, catalog[dec] > dec_max))]])
    
    #catalog.remove_rows([np.logical_and(catalog[ra] < ra_min, catalog[ra] > ra_max)])
    
    #catalog.remove_rows([np.logical_and(catalog[dec] < dec_min, catalog[dec] > dec_max)])
    
    sources = Table.copy(catalog)
    
    sources.remove_rows([sources[ra] < ra_min])
    
    sources.remove_rows([sources[ra] > ra_max])
    
    sources.remove_rows([sources[dec] < dec_min])
    
    sources.remove_rows([sources[dec] > dec_max])
    
    return sources