In [None]:
import glob
import os
from functools import partial
from multiprocessing import Manager, Pool

import astropy.units as u
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from astropy.coordinates import SkyCoord, match_coordinates_sky
from scipy.spatial import cKDTree
from tqdm import tqdm

# Paths

In [None]:
data_dir = '/arc/projects/unions/ssl/data/raw/tiles/dwarforge'
table_dir = '/arc/home/heestersnick/dwarforge/tables'
train_df = pd.read_csv(os.path.join(table_dir, 'train_df.csv'))
class_df = pd.read_csv(os.path.join(table_dir, 'class_df_v2.csv'))
all_dwarfs = pd.read_csv(os.path.join(table_dir, 'all_known_dwarfs_v3_processed.csv'))

In [None]:
def match_coordinates(detected_ra, detected_dec, train_ra, train_dec, max_separation=10.0):
    """
    Match coordinates within max_separation arcseconds
    Returns boolean array of matches and indices of matches
    """
    if len(train_ra) == 0:
        return np.zeros(len(detected_ra), dtype=bool), np.array([])

    detected_coords = SkyCoord(ra=detected_ra * u.degree, dec=detected_dec * u.degree)
    train_coords = SkyCoord(ra=train_ra * u.degree, dec=train_dec * u.degree)

    idx, sep, _ = detected_coords.match_to_catalog_sky(train_coords)
    matches = sep.arcsec <= max_separation

    return matches, idx


def process_single_file(file_path, train_df, mismatches_list):
    """
    Read and process a single parquet file, adding training flag and tracking matches
    Returns: tuple of (processed_df, match_stats)
    """
    try:
        df = pd.read_parquet(file_path)

        # Initialize match statistics for this tile
        match_stats = {
            'dwarfs_found': 0,  # objects with real IDs
            'non_dwarfs_found': 0,  # objects with non_dwarf IDs
        }

        # Extract tile from file path and add as column
        tile = os.path.basename(os.path.dirname(os.path.dirname(file_path)))
        df['tile'] = tile

        # Initialize training flag column
        df['in_training_data'] = 0
        # Initialize expert classification
        df['visual_class'] = np.nan

        # Check if this tile exists in training data
        if tile in train_df['tile'].unique():
            # Get training data for this tile
            tile_train_data = train_df[train_df['tile'] == tile]

            # Split training data into dwarfs and non-dwarfs
            dwarfs = tile_train_data[
                ~tile_train_data['known_id'].str.startswith('non_dwarf', na=False)
            ]
            non_dwarfs = tile_train_data[
                tile_train_data['known_id'].str.startswith('non_dwarf', na=False)
            ]

            # Process dwarf objects
            if len(dwarfs) > 0:
                dwarf_matches, dwarf_idx = match_coordinates(
                    df['ra'].values,
                    df['dec'].values,
                    dwarfs['ra'].values,
                    dwarfs['dec'].values,
                )
                df.loc[dwarf_matches, 'in_training_data'] = 1
                # Add visual classification for matched objects
                df.loc[dwarf_matches, 'visual_class'] = dwarfs.iloc[dwarf_idx[dwarf_matches]][
                    'label'
                ].values
                match_stats['dwarfs_found'] = np.count_nonzero(dwarf_matches)

                # Check ID matches for dwarfs
                coord_matched_indices = np.where(dwarf_matches)[0]
                for _, train_obj in dwarfs.iterrows():
                    id_matches = df['ID_known'] == train_obj['known_id']
                    id_matched_indices = np.where(id_matches)[0]

                    if len(id_matched_indices) > 1:
                        print(
                            f"\nFound multiple matches for ID {train_obj['known_id']} in tile {tile}"
                        )
                        print(f'Found in rows: {id_matched_indices}')

                    matched_by_coords = any(i in coord_matched_indices for i in id_matched_indices)

                    if not matched_by_coords or (len(id_matched_indices) == 0):
                        mismatches_list.append(
                            {
                                'train_id': train_obj['known_id'],
                                'train_ra': train_obj['ra'],
                                'train_dec': train_obj['dec'],
                                'train_tile': tile,
                                'coord_matched': matched_by_coords,
                                'id_matched': len(id_matched_indices) > 0,
                                'n_id_matches': len(id_matched_indices),
                            }
                        )

            # Process non-dwarf objects
            if len(non_dwarfs) > 0:
                non_dwarf_matches, non_dwarf_idx = match_coordinates(
                    df['ra'].values,
                    df['dec'].values,
                    non_dwarfs['ra'].values,
                    non_dwarfs['dec'].values,
                )
                df.loc[non_dwarf_matches, 'in_training_data'] = 1
                # df.loc[non_dwarf_matches, 'ID_known'] = non_dwarfs.iloc[non_dwarf_idx[non_dwarf_matches]]['known_id'].values
                # Add visual classification for matched objects
                df.loc[non_dwarf_matches, 'visual_class'] = non_dwarfs.iloc[
                    non_dwarf_idx[non_dwarf_matches]
                ]['label'].values
                match_stats['non_dwarfs_found'] = np.count_nonzero(non_dwarf_matches)

        return df, match_stats

    except Exception as e:
        print(f'Error processing {file_path}: {e}')
        return pd.DataFrame(), {'dwarfs_found': 0, 'non_dwarfs_found': 0}


def gather_data(base_dir, train_df, num_processes=4):
    """
    Gather data from all parquet files and track mismatches
    """
    # Create the pattern to match the parquet files
    pattern = os.path.join(base_dir, '*_*', 'gri', '*_matched_detections.parquet')

    # Get list of all matching files
    parquet_files = glob.glob(pattern)

    if not parquet_files:
        raise ValueError(f'No parquet files found matching pattern: {pattern}')

    print(f'Found {len(parquet_files)} parquet files')

    # Create manager for sharing mismatches list between processes
    with Manager() as manager:
        mismatches_list = manager.list()

        # Create partial function with train_df and mismatches_list
        process_file = partial(
            process_single_file, train_df=train_df, mismatches_list=mismatches_list
        )

        # Process files with progress bar
        dfs = []
        total_dwarfs_found = 0
        total_non_dwarfs_found = 0

        with Pool(processes=num_processes) as pool:
            for df, stats in tqdm(
                pool.imap(process_file, parquet_files),
                total=len(parquet_files),
                desc='Processing files',
                unit='file',
            ):
                dfs.append(df)
                total_dwarfs_found += stats['dwarfs_found']
                total_non_dwarfs_found += stats['non_dwarfs_found']

        # Convert mismatches list to DataFrame
        mismatches_df = pd.DataFrame(list(mismatches_list))

    print('\nCombining dataframes...')
    combined_df = pd.concat(dfs, ignore_index=True)

    # Get total counts from training set
    dwarfs = train_df[~train_df['known_id'].str.startswith('non_dwarf', na=False)]
    non_dwarfs = train_df[train_df['known_id'].str.startswith('non_dwarf', na=False)]

    print('\nTraining object statistics:')
    print('Dwarfs (objects with real IDs):')
    print(f'  - Total in training set: {len(dwarfs)}')
    print(f'  - Found in results: {total_dwarfs_found}')
    print(f'  - Missing: {len(dwarfs) - total_dwarfs_found}')

    print('\nNon-dwarfs:')
    print(f'  - Total in training set: {len(non_dwarfs)}')
    print(f'  - Found in results: {total_non_dwarfs_found}')
    print(f'  - Missing: {len(non_dwarfs) - total_non_dwarfs_found}')

    return combined_df, mismatches_df


def count_intervals(values):
    # Create bins from 0 to 1 with 0.05 intervals
    bins = np.arange(0, 1.05, 0.05)

    # Use numpy's histogram function
    counts, _ = np.histogram(values, bins=bins)

    # Print the counts for each interval
    for i in range(len(counts)):
        start = round(bins[i], 2)
        end = round(bins[i + 1], 2)
        print(f'{start:.2f} - {end:.2f}: {counts[i]}')

    return counts

# Gather all data

In [None]:
# Process all files
result_df, mismatch_df = gather_data(data_dir, class_df, num_processes=16)

# Save the combined dataset
# result_df.to_parquet("unions_lsb_catalog_all.parquet")

print(f'Combined shape: {result_df.shape}')
print(f"Number of training objects: {result_df['in_training_data'].sum()}")

In [None]:
mismatch_df

In [None]:
non_na = result_df[result_df['ID_known'].notna()].reset_index(drop=True)

In [None]:
non_na[non_na['ID_known'].str.startswith('1.24')]

In [None]:
missing = class_df[~class_df['known_id'].isin(result_df['ID_known'])].reset_index(drop=True)

In [None]:
len(result_df[result_df['ID_known'].notna()]['ID_known'])

In [None]:
missing[missing['known_id'].str.startswith('non_dwarf')]

In [None]:
result_df[result_df['ID_known'] == '1234502+293314']

In [None]:
result_df[result_df['ID_known'].notna()]

In [None]:
# remove objects that were used to train the model
# result_no_train = result_df[(result_df['in_training_data'] == 0) & (result_df['lsb'].isna())].reset_index(drop=True)
# save to file
# result_no_train.to_parquet(os.path.join(table_dir, 'combined_no_train.parquet'), index=False)
# remove rows with nan predictions
result_df = result_df[~result_df['zoobot_pred_v2'].isna()].reset_index(drop=True)

In [None]:
result_df[result_df['zoobot_pred'].isna()]

In [None]:
plt.figure(figsize=(6, 6))
plt.hist(result_df['zoobot_pred_v2'], bins=np.arange(0, 1.05, 0.05), ec='black')
plt.xticks(np.arange(0, 1.1, 0.1))
plt.xlabel('Prediction')
plt.ylabel('Count')
plt.yscale('log')
plt.tight_layout()
# plt.savefig(os.path.join(figures, 'prediction_dist_example.png'), dpi=300, bbox_inches='tight')
plt.show()

In [None]:
plt.figure(figsize=(6, 6))
plt.hist(result_df['zoobot_pred'], bins=np.arange(0, 1.05, 0.05), ec='black')
plt.xticks(np.arange(0, 1.1, 0.1))
plt.xlabel('Prediction')
plt.ylabel('Count')
plt.yscale('log')
plt.tight_layout()
# plt.savefig(os.path.join(figures, 'prediction_dist_example.png'), dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# print counts in 0.05 bins from 0 to 1
counts = count_intervals(result_df['zoobot_pred_v2'].values)

In [None]:
def analyze_string_matches(arr1, arr2):
    valid_strings = set(arr2)

    # Initialize counters
    matches = 0
    nones = 0
    nans = 0
    non_matches = 0
    non_matches_list = []
    matches_list = []

    for x in arr1:
        if x is None:
            nones += 1
        elif isinstance(x, float) and np.isnan(x):
            nans += 1
        elif x in valid_strings:
            matches += 1
            matches_list.append(x)
        else:
            non_matches += 1
            non_matches_list.append(x)

    print(f'Matches: {matches}')
    print(f'None values: {nones}')
    print(f'NaN values: {nans}')
    print(f'Non-matching strings: {non_matches}')

    return matches_list, non_matches_list

In [None]:
matches, non_matches = analyze_string_matches(result_df['ID_known'], class_df['known_id'])

In [None]:
result_df[result_df['ID_known'] == non_matches[16]]

# Remove duplicates

In [None]:
def remove_duplicates(df, max_separation_arcsec=10.0, priority_column='zoobot_pred'):
    """
    Remove duplicate entries from the DataFrame using Friends-of-Friends clustering.
    Returns both deduplicated DataFrame and groups DataFrame for inspection.
    """
    # Convert RA/Dec to SkyCoord
    coords = SkyCoord(ra=df['ra'].values * u.degree, dec=df['dec'].values * u.degree)

    # Convert to 3D Cartesian coordinates on the unit sphere
    x = np.cos(coords.dec.radian) * np.cos(coords.ra.radian)
    y = np.cos(coords.dec.radian) * np.sin(coords.ra.radian)
    z = np.sin(coords.dec.radian)
    points = np.column_stack((x, y, z))

    # Compute maximum chord length for clustering
    max_angle_rad = (max_separation_arcsec * u.arcsec).to(u.radian).value
    max_chord_length = 2 * np.sin(max_angle_rad / 2)

    # Build KD-tree for efficient neighbor search
    tree = cKDTree(points)

    # Find all pairs within the threshold distance
    pairs = tree.query_pairs(max_chord_length, output_type='set')

    # Union-Find to create clusters
    parent = np.arange(len(df))

    def find(node):
        while parent[node] != node:
            parent[node] = parent[parent[node]]  # Path compression
            node = parent[node]
        return node

    for node1, node2 in pairs:
        root1 = find(node1)
        root2 = find(node2)
        if root1 != root2:
            parent[root2] = root1

    # Assign cluster labels
    cluster_labels = np.array([find(i) for i in range(len(parent))])
    df = df.copy()
    df['_cluster'] = cluster_labels

    # Create groups information
    groups_info = []
    same_tile_duplicates = 0
    overlap_duplicates = 0

    # Process each cluster with more than one object
    for cluster_id in tqdm(np.unique(cluster_labels), desc='Processing clusters'):
        cluster_mask = cluster_labels == cluster_id
        if np.sum(cluster_mask) > 1:  # Only process groups with multiple objects
            cluster_objects = df[cluster_mask]
            best_idx = cluster_objects[priority_column].idxmax()

            # Count duplicate types
            unique_tiles = cluster_objects['tile'].unique()
            if len(unique_tiles) == 1:
                same_tile_duplicates += len(cluster_objects) - 1
            else:
                overlap_duplicates += len(cluster_objects) - 1

            # Calculate separations from best object
            best_coord = SkyCoord(
                ra=df.loc[best_idx, 'ra'] * u.degree,
                dec=df.loc[best_idx, 'dec'] * u.degree,
            )

            # Add information for each object in the cluster
            for idx, row in cluster_objects.iterrows():
                obj_coord = SkyCoord(ra=row['ra'] * u.degree, dec=row['dec'] * u.degree)
                sep_to_best = obj_coord.separation(best_coord).arcsec

                groups_info.append(
                    {
                        'cluster_id': cluster_id,
                        'object_id': idx,
                        'ra': row['ra'],
                        'dec': row['dec'],
                        'tile': row['tile'],
                        'separation_to_best': sep_to_best,
                        f'{priority_column}': row[priority_column],
                        'is_best': idx == best_idx,
                        'duplicate_type': ('same_tile' if len(unique_tiles) == 1 else 'overlap'),
                    }
                )

    # Create groups DataFrame
    groups_df = pd.DataFrame(groups_info)

    # Print summary
    total_duplicates = same_tile_duplicates + overlap_duplicates
    print('\nDuplicate Detection Summary:')
    print(f'Total objects in input: {len(df)}')
    print(f'Total duplicate objects found: {total_duplicates}')
    print(f'  - Same tile duplicates: {same_tile_duplicates}')
    print(f'  - Tile overlap duplicates: {overlap_duplicates}')
    print(f'Objects after deduplication: {len(df) - total_duplicates}')

    # Select best entry in each cluster
    dedup_df = df.loc[df.groupby('_cluster')[priority_column].idxmax()]
    dedup_df = dedup_df.drop(columns=['_cluster']).reset_index(drop=True)

    return dedup_df, groups_df

In [None]:
# deduped_catalog, group_cat = remove_duplicates(result_no_train, max_separation_arcsec=10.0)

# save to file
deduped_catalog.to_parquet(os.path.join(table_dir, 'combined_no_dups.parquet'), index=False)
# group_cat.to_parquet(os.path.join(table_dir, 'duplicate_groups.parquet'), index=False)

In [None]:
# print counts in 0.05 bins from 0 to 1
counts = count_intervals(deduped_catalog['zoobot_pred'].values)

In [None]:
counts[-4:]

In [None]:
# First get unique cluster IDs where at least one object has high prediction score
high_score_clusters = group_cat[group_cat['zoobot_pred'] > 0.8]['cluster_id'].unique()

# Then look at all objects in these clusters
for cluster_id in high_score_clusters[:20]:
    group = group_cat[group_cat['cluster_id'] == cluster_id]
    print(f'\nCluster {cluster_id}:')
    print('Number of objects in group:', len(group))
    print(group[['ra', 'dec', 'tile', 'zoobot_pred', 'duplicate_type']])
    print('-' * 80)

In [None]:
not_known = deduped_catalog[
    deduped_catalog['class_label'].isna() & deduped_catalog['zspec'].isna()
].reset_index(drop=True)

In [None]:
not_known

In [None]:
# save to file
not_known.to_parquet(os.path.join(table_dir, 'combined_unknown.parquet'), index=False)

In [None]:
deduped_catalog[deduped_catalog['zoobot_pred'] > 0.7][:15]

In [None]:
plt.figure(figsize=(6, 6))
plt.hist(not_known['zoobot_pred'], bins=np.arange(0, 1.05, 0.05), ec='black')
plt.xticks(np.arange(0, 1.1, 0.1))
plt.xlabel('Prediction')
plt.ylabel('Count')
plt.yscale('log')
plt.tight_layout()
# plt.savefig(os.path.join(figures, 'prediction_dist_example.png'), dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# print counts in 0.05 bins from 0 to 1
counts = count_intervals(not_known['zoobot_pred'].values)

In [None]:
sum(counts[-4:])

In [None]:
not_known[(not_known['zoobot_pred'] > 0.8) & (not_known['zoobot_pred'] < 1)]

# Remove duplicates based on effective radius

In [None]:
def remove_duplicates(df, radius_factor=1.5, min_separation=1.0, priority_column='zoobot_pred_v2'):
    """
    Remove duplicates using KD-tree with effective radius-based filtering.

    Parameters:
    - df: DataFrame containing astronomical objects
    - radius_factor: Factor to multiply effective radius by for threshold
    - min_separation: Minimum separation to use when r_eff is very small/nan
    - priority_column: Column used to select which duplicate to keep

    Returns:
    - dedup_df: Deduplicated DataFrame
    - groups_df: DataFrame containing information about duplicate groups
    """
    # Calculate maximum effective radius for each object
    df = df.copy()
    df['max_r_eff'] = df.apply(
        lambda row: max(
            [
                row['re_arcsec_cfis_lsb-r'],
                row['re_arcsec_whigs-g'],
                row['re_arcsec_ps-i'],
            ],
            default=min_separation,
        ),
        axis=1,
    )
    df['max_r_eff'] = df['max_r_eff'].fillna(min_separation)

    # Use maximum possible threshold for KD-tree initial search
    max_possible_threshold = df['max_r_eff'].max() * radius_factor
    print(f'Maximum search threshold: {max_possible_threshold:.2f} arcsec')

    # Convert coordinates to 3D Cartesian
    coords = SkyCoord(ra=df['ra'].values * u.degree, dec=df['dec'].values * u.degree)
    x = np.cos(coords.dec.radian) * np.cos(coords.ra.radian)
    y = np.cos(coords.dec.radian) * np.sin(coords.ra.radian)
    z = np.sin(coords.dec.radian)
    points = np.column_stack((x, y, z))

    # Set up KD-tree with maximum possible threshold
    max_angle_rad = (max_possible_threshold * u.arcsec).to(u.radian).value
    max_chord_length = 2 * np.sin(max_angle_rad / 2)
    tree = cKDTree(points)

    # Get initial pairs from KD-tree
    pairs = tree.query_pairs(max_chord_length, output_type='set')

    # Vectorized filtering of pairs
    filtered_pairs = set()
    if pairs:  # Only process if pairs were found
        pairs_array = np.array(list(pairs))
        i, j = pairs_array.T
        r_eff1 = df['max_r_eff'].values[i]
        r_eff2 = df['max_r_eff'].values[j]
        max_r_eff = np.maximum(r_eff1, r_eff2)
        separations = coords[i].separation(coords[j]).arcsec

        # Apply condition
        mask = separations <= max_r_eff * radius_factor
        filtered_pairs = set(map(tuple, pairs_array[mask]))

        # Store separations for groups information
        df['pair_separation'] = np.nan
        df.loc[pairs_array[mask][:, 0], 'pair_separation'] = separations[mask]

    # Efficient Union-Find implementation
    parent = np.arange(len(df))

    def find(node):
        if parent[node] != node:
            parent[node] = find(parent[node])  # Path compression
        return parent[node]

    # Group filtered pairs
    for node1, node2 in filtered_pairs:
        root1 = find(node1)
        root2 = find(node2)
        if root1 != root2:
            parent[root2] = root1

    # Assign cluster labels more efficiently
    cluster_labels = np.array([find(i) for i in range(len(parent))])
    df['_cluster'] = cluster_labels

    # Pre-calculate unique clusters and their sizes
    unique_clusters, cluster_counts = np.unique(cluster_labels, return_counts=True)
    multi_object_clusters = unique_clusters[cluster_counts > 1]

    # Initialize arrays for groups information
    groups_info = []
    same_tile_duplicates = 0
    overlap_duplicates = 0

    # Process clusters with multiple objects
    for cluster_id in tqdm(multi_object_clusters, desc='Processing clusters'):
        cluster_mask = cluster_labels == cluster_id
        cluster_objects = df[cluster_mask]
        best_idx = cluster_objects[priority_column].idxmax()

        # Count duplicate types
        unique_tiles = cluster_objects['tile'].unique()
        if len(unique_tiles) == 1:
            same_tile_duplicates += len(cluster_objects) - 1
        else:
            overlap_duplicates += len(cluster_objects) - 1

        # Vectorized separation calculation
        best_coord = SkyCoord(
            ra=df.loc[best_idx, 'ra'] * u.degree, dec=df.loc[best_idx, 'dec'] * u.degree
        )

        cluster_coords = SkyCoord(
            ra=cluster_objects['ra'].values * u.degree,
            dec=cluster_objects['dec'].values * u.degree,
        )
        separations = best_coord.separation(cluster_coords).arcsec

        # Add information for each object in cluster
        for (idx, row), sep in zip(cluster_objects.iterrows(), separations):
            groups_info.append(
                {
                    'cluster_id': cluster_id,
                    'object_id': idx,
                    'ra': row['ra'],
                    'dec': row['dec'],
                    'tile': row['tile'],
                    'max_r_eff': row['max_r_eff'],
                    'separation_to_best': sep,
                    f'{priority_column}': row[priority_column],
                    'is_best': idx == best_idx,
                    'duplicate_type': ('same_tile' if len(unique_tiles) == 1 else 'overlap'),
                    'threshold_used': row['max_r_eff'] * radius_factor,
                }
            )

    # Create groups DataFrame
    groups_df = pd.DataFrame(groups_info)

    # Calculate total duplicates
    total_duplicates = same_tile_duplicates + overlap_duplicates

    # Print summary
    print('\nDuplicate Detection Summary:')
    print(f'Total objects in input: {len(df)}')
    print(f'Total duplicate objects found: {total_duplicates}')
    print(f'  - Same tile duplicates: {same_tile_duplicates}')
    print(f'  - Tile overlap duplicates: {overlap_duplicates}')
    print(f'Objects after deduplication: {len(df) - total_duplicates}')

    # Select best entry in each cluster
    dedup_df = df.loc[df.groupby('_cluster')[priority_column].idxmax()]
    dedup_df = dedup_df.drop(columns=['_cluster', 'max_r_eff', 'pair_separation']).reset_index(
        drop=True
    )

    return dedup_df, groups_df

In [None]:
dedup_df, groups_df = remove_duplicates(result_df)

In [None]:
result_df[result_df['zoobot_pred_v2'].isna()]

In [None]:
plt.figure(figsize=(8, 8))
plt.hist(dedup_df['mu_whigs-g'], bins=np.arange(21, 30, 0.5))
plt.show()

In [None]:
# First get unique cluster IDs where at least one object has high prediction score
groups_same_tile = groups_df[groups_df['duplicate_type'] == 'same_tile']
high_score_clusters = groups_same_tile[groups_same_tile['zoobot_pred'] > 0.6]['cluster_id'].unique()

# Then look at all objects in these clusters
for cluster_id in high_score_clusters[:50]:
    group = groups_same_tile[groups_same_tile['cluster_id'] == cluster_id]
    print(f'\nCluster {cluster_id}:')
    print('Number of objects in group:', len(group))
    print(group[['ra', 'dec', 'tile', 'zoobot_pred', 'duplicate_type']])
    print('-' * 80)

In [None]:
# print counts in 0.05 bins from 0 to 1
counts = count_intervals(dedup_df['zoobot_pred_v2'].values)

In [None]:
sum(counts[-4:])

In [None]:
no_train_dedup = dedup_df[
    (dedup_df['in_training_data'] == 0) & (dedup_df['lsb'].isna())
].reset_index(drop=True)
not_known = no_train_dedup[
    no_train_dedup['class_label'].isna() & no_train_dedup['zspec'].isna()
].reset_index(drop=True)

In [None]:
# print counts in 0.05 bins from 0 to 1
counts = count_intervals(dedup_df['zoobot_pred'].values)

In [None]:
sum(counts[-4:])

In [None]:
dedup_filled.to_parquet(os.path.join(table_dir, 'unions_master_v3.parquet'), index=False)
# groups_df.to_parquet(os.path.join(table_dir, 'duplicate_groups.parquet'), index=False)
# not_known.to_parquet(os.path.join(table_dir, 'combined_unknown.parquet'), index=False)
# no_train_dedup.to_parquet(os.path.join(table_dir, 'no_train_no_duplicates.parquet'), index=False)

In [None]:
len(dedup_filled)

In [None]:
np.count_nonzero(dedup_filled['ID_known'].notna())

In [None]:
dedup_df.head()

### Add missing IDs for dwarfs that were in multiple tiles

In [None]:
def fill_missing_ids(
    df_target: pd.DataFrame,
    df_catalog: pd.DataFrame,
    ra_col: str = 'ra',
    dec_col: str = 'dec',
    id_col_target: str = 'ID_known',
    id_col_catalog: str = 'ID',
    tolerance_arcsec: float = 10.0,
    coord_frame: str = 'icrs',
    verbose: bool = True,
) -> pd.DataFrame:
    """
    Fills missing IDs in a target DataFrame by cross-matching with a catalog.

    Matches based on celestial coordinates (RA, Dec) within a specified tolerance.
    Modifies the target DataFrame in place.

    Args:
        df_target: DataFrame to fill missing IDs in. Assumed to have NaN/None
                   for missing values in the target ID column.
        df_catalog: Catalog DataFrame with coordinates and IDs to match against.
        ra_col: Name of the Right Ascension column (in decimal degrees).
        dec_col: Name of the Declination column (in decimal degrees).
        id_col_target: Name of the ID column in df_target to check and fill.
        id_col_catalog: Name of the ID column in df_catalog to use for filling.
        tolerance_arcsec: Matching tolerance in arcseconds.
        coord_frame: Astropy coordinate frame (e.g., 'icrs', 'galactic').
        verbose: If True, print status messages.

    Returns:
        The modified df_target DataFrame (modified in place).
    """
    # 1. Identify rows in the target DataFrame with missing IDs
    df_target = df_target.copy()
    missing_mask = df_target[id_col_target].isna()
    df_target_missing = df_target.loc[missing_mask]

    if df_target_missing.empty:
        if verbose:
            print('No missing IDs found in the target DataFrame.')
        return df_target

    if verbose:
        print(f'Found {len(df_target_missing)} rows with missing IDs.')

    # 2. Create Astropy SkyCoord objects
    try:
        coords_target = SkyCoord(
            ra=df_target_missing[ra_col].values * u.degree,
            dec=df_target_missing[dec_col].values * u.degree,
            frame=coord_frame,
        )
        coords_catalog = SkyCoord(
            ra=df_catalog[ra_col].values * u.degree,
            dec=df_catalog[dec_col].values * u.degree,
            frame=coord_frame,
        )
    except KeyError as e:
        print(f'Error: Column not found - {e}. Check column names.')
        return df_target  # Return unmodified df on error
    except Exception as e:
        print(f'Error creating SkyCoord objects: {e}')
        return df_target

    # 3. Perform cross-matching
    # idx: index *in coords_catalog* of the nearest neighbor
    # sep2d: on-sky separation
    idx, sep2d, _ = match_coordinates_sky(coords_target, coords_catalog)

    # 4. Filter matches by tolerance
    tolerance = tolerance_arcsec * u.arcsec
    valid_match_mask = sep2d <= tolerance

    # 5. Get indices and IDs for update
    # Indices in the *original* df_target corresponding to successful matches
    df_target_indices_to_update = df_target_missing.index[valid_match_mask]
    # Indices in df_catalog for the matched objects
    df_catalog_indices_for_ids = idx[valid_match_mask]

    if len(df_target_indices_to_update) > 0:
        # Get the actual IDs from the catalog to fill in
        ids_to_fill = df_catalog.loc[df_catalog_indices_for_ids, id_col_catalog].values

        # 6. Update df_target (in place) using .loc
        df_target.loc[df_target_indices_to_update, id_col_target] = ids_to_fill
        if verbose:
            print(f'Successfully updated {len(df_target_indices_to_update)} IDs.')
    elif verbose:
        print(f'No matches found within the {tolerance_arcsec} arcsec tolerance.')

    return df_target

In [None]:
dedup_filled = fill_missing_ids(df_target=dedup_df, df_catalog=all_dwarfs)

In [None]:
np.count_nonzero(dedup_df['ID_known'].notna()), np.count_nonzero(dedup_filled['ID_known'].notna())

In [None]:
dedup_filled[dedup_filled['ID_known'].notna()][100:150]

In [None]:
dedup_df[['ID_cfis_lsb-r', 'ID_whigs-g', 'ID_ps-i']]

In [None]:
single_detections = 0

for row in dedup_df[['ID_cfis_lsb-r', 'ID_whigs-g', 'ID_ps-i']].iterrows():
    row.count()
    break

In [None]:
dedup_df.iloc[:].count()

In [None]:
band_id_cols = ['ID_cfis_lsb-r', 'ID_whigs-g', 'ID_ps-i']

# 1. Check for non-NaN values in the specified columns
detections_bool = dedup_df[band_id_cols].notna()
# print("\n--- Detection Boolean Mask ---")
# print(detections_bool)

# 2. Sum the boolean values row-wise (True=1, False=0) to count detections per row
detections_per_row = detections_bool.sum(axis=1)
# print("\n--- Detections Per Row ---")
# print(detections_per_row)

# 3. Count how many rows have exactly 1 detection
count_single_band_detections = (detections_per_row == 1).sum()

print(f'\nNumber of rows with detection in exactly one band: {count_single_band_detections}')

In [None]:
dedup_filled[dedup_filled['ID_known'] == '121853+654443']