In [1]:
import os
import pandas as pd
import numpy as np
from glob import glob
from tqdm.notebook import tqdm
import ast
from astropy import units as u
from astropy.coordinates import SkyCoord
from scipy.spatial import cKDTree
from astropy.io import fits
from astropy.wcs import WCS
from sklearn.neighbors import NearestNeighbors

# Paths

In [2]:
data_dir = '/arc/projects/unions/ssl/data/raw/tiles/dwarforge'
table_dir = '/arc/home/heestersnick/dwarforge/tables'
dwarf_cat_file = 'all_known_dwarfs_v2_processed.csv'
dwarf_cat = pd.read_csv(os.path.join(table_dir, dwarf_cat_file))

# Functions

In [5]:
def zfill_tile(tile):
    return f'{str(tile[0]).zfill(3)}_{str(tile[1]).zfill(3)}'

def labels_to_df(parent_folder, tile_list, dwarf_df):
    k = 0
    unmatched_dwarf_counter = 0
    unmatched_tile_counter = 0
    additional_dwarf_counter = 0
    for tile in tqdm(tile_list):
        # Convert tile tuple to folder name format
        folder_name = zfill_tile(tile)
        
        # Construct the full path to the parquet file
        tile_nums_zfill = folder_name.split('_')
        file_path = os.path.join(parent_folder, folder_name, "cfis_lsb-r", f"CFIS_LSB.{tile_nums_zfill[0]}.{tile_nums_zfill[1]}.r_rebin_det_params.parquet")
        fits_name = f'CFIS_LSB.{tile_nums_zfill[0]}.{tile_nums_zfill[1]}.r_rebin_seg.fits'
        fits_path = os.path.join(parent_folder, folder_name, "cfis_lsb-r", fits_name)
        
        if os.path.exists(file_path):
            try:
                # Attempt to read the parquet file
                det_df = pd.read_parquet(file_path)
                det_df_updated = det_df.copy()
                dwarfs_in_tile = dwarf_df[dwarf_df['tile'] == str(tile)].reset_index(drop=True)
                _, header = open_fits(fits_path, fits_ext=0)
                additional_dwarfs = check_objects_in_neighboring_tiles(str(tile), dwarf_df, header)

                if not additional_dwarfs.empty:
                    dwarfs_in_tile = pd.concat([dwarfs_in_tile, additional_dwarfs]).reset_index(drop=True)

                det_idx_lsb, lsb_matches, lsb_unmatches, _ = match_cats(det_df_updated, dwarfs_in_tile, tile, header, max_sep=15.0)

                # add lsb labels to detections dataframe
                det_df_updated['lsb'] = np.nan
                det_df_updated['ID_known'] = np.nan

                if len(det_idx_lsb) > 0:
                    print(f'Found {len(det_idx_lsb)} lsb detections for tile {tile}.')
                    det_df_updated.loc[det_idx_lsb, 'lsb'] = 1
                    # Initialize the column to accept strings
                    det_df_updated['ID_known'] = det_df_updated['ID_known'].astype(object)
                    det_df_updated.loc[det_idx_lsb, 'ID_known'] = lsb_matches['ID'].values
                    # print(
                    #     f'Added {np.count_nonzero(~np.isnan(det_df_updated["lsb"]))} LSB labels to the detection dataframe for tile {tile}.'
                    # )
                    k += 1
                    additional_dwarf_counter += len(additional_dwarfs)

                if len(lsb_unmatches) > 0:
                    print(f'Found {len(lsb_unmatches)} unmatched dwarf for tile: {tile}.')
                    unmatched_tile_counter += 1
                    unmatched_dwarf_counter += len(lsb_unmatches)

                # Save updated dataframe
                det_df_updated.to_parquet(file_path, index=False)
            except Exception as e:
                print(f'Something went wrong for tile {tile}: {e}')
    print(f'Was able to match {k}/{len(tile_list)} tiles.')
    print(f'There were {unmatched_dwarf_counter} unmatched dwarfs in {unmatched_tile_counter} tiles.')
    print(f'{additional_dwarf_counter} dwarfs are in multiple tiles.')

def open_fits(file_path, fits_ext):
    """
    Open fits file and return data and header.

    Args:
        file_path (str): name of the fits file
        fits_ext (int): extension of the fits file

    Returns:
        data (numpy.ndarray): image data
        header (fits header): header of the fits file
    """
    # logger.debug(f'Opening fits file {os.path.basename(file_path)}..')
    with fits.open(file_path, memmap=True) as hdul:
        data = hdul[fits_ext].data.astype(np.float32)  # type: ignore
        header = hdul[fits_ext].header  # type: ignore
    # logger.debug(f'Fits file {os.path.basename(file_path)} opened.')
    return data, header
                
def check_objects_in_neighboring_tiles(tile, dwarfs_df, header):
    wcs = WCS(header)
    # Get neighboring tile numbers
    neighboring_tiles = get_neighboring_tile_numbers(tile)

    # Filter dwarfs in neighboring tiles
    neighboring_dwarfs = dwarfs_df[dwarfs_df['tile'].isin(neighboring_tiles)]

    # Check which of these dwarfs are actually within the current tile's boundaries
    dwarfs_in_current_tile = neighboring_dwarfs[
        neighboring_dwarfs.apply(
            lambda row: wcs.footprint_contains(
                SkyCoord(row['ra'], row['dec'], unit='deg', frame='icrs')
            ),
            axis=1,
        )
    ]

    return dwarfs_in_current_tile

def get_neighboring_tile_numbers(tile):
    tile = ast.literal_eval(tile)
    x, y = map(int, tile)
    neighbors = [
        (x - 1, y - 1),
        (x - 1, y),
        (x - 1, y + 1),
        (x, y - 1),
        (x, y + 1),
        (x + 1, y - 1),
        (x + 1, y),
        (x + 1, y + 1),
    ]
    return [f'({nx:03d}, {ny:03d})' for nx, ny in neighbors if 0 <= nx < 1000 and 0 <= ny < 1000]

def dwarfs_to_df(parent_folder):
    # Pattern to match all relevant parquet files
    pattern = os.path.join(parent_folder, "*_*", "cfis_lsb-r", "CFIS_LSB.*.r_rebin_det_params.parquet")
    
    # List to store filtered dataframes
    filtered_dfs = []
    
    # Iterate through all matching files
    for file in tqdm(glob(pattern)):
        try:
            # Attempt to read the parquet file
            df = pd.read_parquet(file)
            # Check if 'label' column exists
            if 'lsb' in df.columns:
                # Filter rows where label is 1
                df_filtered = df[df['lsb'] == 1]
                
                if not df_filtered.empty:
                    filtered_dfs.append(df_filtered)
            # If 'label' column doesn't exist, we skip this file
        
        except Exception as e:
            print(f"Error processing file {file}: {str(e)}")
            continue
        
        # The file is automatically closed after reading
    
    # Combine all dataframes
    if filtered_dfs:
        final_df = pd.concat(filtered_dfs, ignore_index=True)
        return final_df
    else:
        return pd.DataFrame()  # Return an empty dataframe if no data found

def gather_training_data(parent_folder, band='cfis_lsb-r', n_neighbors=1):
    pattern = os.path.join(parent_folder, "*_*", band, "CFIS_LSB.*.r_rebin_det_params.parquet")
    
    all_examples = []
    
    for file in tqdm(glob(pattern)):
        try:
            filename = os.path.basename(file)
            tile_numbers = filename.split('.')[1:3]
            tile_id = f"{tile_numbers[0]}.{tile_numbers[1]}"
            
            df = pd.read_parquet(file)
            
            if 'lsb' in df.columns:
                positive_examples = df[df['lsb'] == 1].copy()
                potential_negatives = df[df['lsb'].isna()].copy()
                
                if not positive_examples.empty and not potential_negatives.empty:
                    nn = NearestNeighbors(n_neighbors=len(potential_negatives), metric='euclidean')
                    nn.fit(potential_negatives[['ra', 'dec']])
                    
                    used_negatives = set()  # Set to keep track of used negative examples in this field
                    all_file_examples = []
                    
                    for idx, lsb_obj in positive_examples.iterrows():
                        lsb_df = pd.DataFrame({'ra': [lsb_obj['ra']], 'dec': [lsb_obj['dec']]})
                        
                        distances, indices = nn.kneighbors(lsb_df)
                        
                        # Find n_neighbors unique negative examples within this field
                        unique_negatives = []
                        for index in indices[0]:
                            if index not in used_negatives:
                                unique_negatives.append(index)
                                used_negatives.add(index)
                                if len(unique_negatives) == n_neighbors:
                                    break
                        
                        # If we couldn't find enough unique negatives, continue to the next positive example
                        if len(unique_negatives) < n_neighbors:
                            continue
                        
                        nearest_neighbors = potential_negatives.iloc[unique_negatives].copy()
                        
                        lsb_obj['example_id'] = f"{tile_id}.{lsb_obj['ID']}"
                        
                        nearest_neighbors['example_id'] = nearest_neighbors['ID'].apply(lambda x: f"{tile_id}.{x}")
                        nearest_neighbors['lsb'] = 0  # Set to 0 for negative examples
                        nearest_neighbors['associated_lsb_ra'] = lsb_obj['ra']
                        nearest_neighbors['associated_lsb_dec'] = lsb_obj['dec']
                        
                        all_file_examples.append(pd.concat([lsb_obj.to_frame().T, nearest_neighbors]))
                    
                    if all_file_examples:
                        all_examples.append(pd.concat(all_file_examples))
            
        except Exception as e:
            print(f"Error processing file {file}: {str(e)}")
            continue
    
    if all_examples:
        final_df = pd.concat(all_examples, ignore_index=True)
        return final_df
    else:
        return pd.DataFrame()

def create_cartesian_kdtree(ra, dec):
    """
    Create a KD-Tree using Cartesian coordinates converted from RA and Dec.
    
    :param ra: Right Ascension in degrees
    :param dec: Declination in degrees
    :return: cKDTree object and the corresponding SkyCoord object
    """
    coords = SkyCoord(ra, dec, unit='deg', frame='icrs')
    xyz = coords.cartesian.xyz.value.T
    tree = cKDTree(xyz)
    return tree, coords

def match_cats(df_det, df_label, tile, pixel_scale, max_sep=15.0, re_multiplier=4.0):
    tree, _ = create_cartesian_kdtree(df_det['ra'].values, df_det['dec'].values)
    matches = []
    potential_matches_df = pd.DataFrame()
    for idx, known in df_label.iterrows():
        known_coords = SkyCoord(known['ra'], known['dec'], unit='deg')
        known_coords_xyz = known_coords.cartesian.xyz.value
        
        # Calculate base search radius in degrees
        base_search_radius = max_sep / 3600  # Convert arcseconds to degrees
        
        # Adaptive search radius (if 're' is available)
        if 're' in known and known['re'] is not None and not np.isnan(known['re']) and known['re'] > 0:
            adaptive_radius = known['re'] * re_multiplier / 3600  # Convert to degrees
            search_radius = max(base_search_radius, adaptive_radius)
        else:
            search_radius = base_search_radius
            
        search_radius_chord = 2 * np.sin(np.deg2rad(search_radius) / 2)
        
        potential_match_indices = tree.query_ball_point(known_coords_xyz, search_radius_chord)
        potential_matches = df_det.iloc[potential_match_indices]
        
        print(f'potential matches for {known["ID"]}: {len(potential_matches)}')
        
        potential_matches_df = pd.concat([potential_matches_df, potential_matches])
        if len(potential_matches) > 0:
            potential_matches_coords = SkyCoord(potential_matches['ra'], potential_matches['dec'], unit='deg')
            distances = known_coords.separation(potential_matches_coords).arcsec
            max_n_pix = potential_matches['n_pix'].max()
            max_mu = potential_matches['mu'].max()
            scores = []
            for i, det in potential_matches.iterrows():
                size_score = np.log1p(det['n_pix']) / np.log1p(max_n_pix)
                lsb_score = det['mu'] / max_mu
                distance = distances[potential_matches.index.get_loc(i)]
                distance_score = 1 - (distance / (3600 * search_radius))  # Normalized distance score
                score = (
                    lsb_score * 0.2
                    + size_score * 0.4
                    + distance_score * 0.4
                )
                print(f'object: {det["ID"]}; lsb score: {lsb_score:.4f}, size score: {size_score:.4f}, distance score: {distance_score:.4f}')
                print(f'object: {det["ID"]}; total score: {score:.4f}; distance: {distance:.2f} arcsec')
                scores.append((i, score, distance))
            best_match = max(scores, key=lambda x: x[1])
            matches.append((idx, best_match[0], best_match[2]))
    
    if matches:
        label_match_idx, det_match_idx, match_distances = zip(*matches)
    else:
        label_match_idx, det_match_idx, match_distances = [], [], []
    label_matches = df_label.loc[list(label_match_idx)].reset_index(drop=True)
    label_unmatches = df_label.drop(list(label_match_idx)).reset_index(drop=True)
    det_matches = df_det.loc[list(det_match_idx)].reset_index(drop=True)
    det_matches['match_distance'] = match_distances
    return list(det_match_idx), label_matches, label_unmatches, det_matches

def get_tile_list(dwarf_cat):
    tiles = dwarf_cat['tile'].values
    non_nan_tiles = [x for x in tiles if x is not np.nan]
    str_to_tuple = [ast.literal_eval(item) for item in non_nan_tiles]
    unique_tiles = list(set(str_to_tuple))
    return unique_tiles

def check_bands(bands_str, to_check):
    if isinstance(bands_str, str):
        try:
            bands_list = ast.literal_eval(bands_str)
            return all(band in bands_list for band in to_check)
        except:
            return False
    return False  # Return False for NaN values

def check_availability(dwarf_cat, check_for_bands):
    df_select = dwarf_cat.loc[(~dwarf_cat['tile'].isna()) & (dwarf_cat['bands'].apply(lambda x: check_bands(x, check_for_bands)))].reset_index(drop=True)
    return df_select, len(df_select)

# Analysis

In [6]:
tile_list = get_tile_list(dwarf_cat)

In [21]:
dwarf_cat['bands'][0] in 'r'

False

In [27]:
check_availability(dwarf_cat, ['r'])

(Empty DataFrame
 Columns: [host, ID, ra, dec, morph, re, zspec, tile, x, y, bands, n_bands, cutout]
 Index: [],
 0)

In [28]:
dwarf_cat

Unnamed: 0,host,ID,ra,dec,morph,re,zspec,tile,x,y,bands,n_bands,cutout
0,NGC5457,[MVA2014] DF 4,211.89083,54.71089,UDG,28.000000,,"(246, 289)",4120.276396,9087.340522,ugriz,5,0
1,NGC5457,[MVA2014] DF 5,211.11710,55.61670,UDG,38.000000,,"(239, 291)",3492.105870,7263.024500,ugriz,5,0
2,NGC5457,[MVA2014] DF 6,212.07792,55.19183,UDG,22.000000,,"(243, 290)",2244.034195,8722.400179,ugriz,5,0
3,NGC5457,[MVA2014] DF 7,211.44792,55.13258,UDG,20.000000,,"(243, 290)",9219.451934,7580.774126,ugriz,5,0
4,NGC253,NGC247,11.78340,-20.75700,,247.882787,,,,,,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...
14095,NGC7541,LS-357669-3767,348.87450,4.61310,,,,"(696, 189)",8893.878219,7192.314564,u,1,0
14096,NGC7541,LS-357668-2728,348.62140,4.50730,,,,"(695, 189)",4094.690157,5141.495530,u,1,0
14097,NGC7541,LS-360540-737,348.55460,4.91510,,,,"(694, 190)",576.261089,3355.482606,u,1,0
14098,NGC7716,NSA-31702,354.35080,0.39100,,,,"(709, 181)",8152.874666,2887.720956,u,1,0


In [5]:
training_data = gather_training_data(data_dir, band='cfis_lsb-r', n_neighbors=10)
training_data.to_csv(os.path.join(table_dir, 'training_data_10x_rf.csv'), index=False)

  0%|          | 0/20789 [00:00<?, ?it/s]

In [6]:
len(training_data)

11363