In [None]:
import logging
import os
import re
import time
import warnings

import h5py
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from astropy.coordinates import SkyCoord
from astropy.wcs import WCS
from scipy.ndimage import binary_dilation, label

In [None]:
import joblib
import pywt

## Paths

In [None]:
data_dir = '/arc/projects/unions/ssl/data/raw/tiles/dwarforge'
main_dir = '/arc/home/heestersnick/dwarforge'
table_dir = os.path.join(main_dir, 'tables')
figure_dir = os.path.join(main_dir, 'figures')
tile_info_dir = os.path.join(main_dir, 'tile_info')
train_df = pd.read_csv(os.path.join(table_dir, 'train_df.csv'))
class_df = pd.read_csv(os.path.join(table_dir, 'class_df.csv'))
master = pd.read_parquet('/arc/home/heestersnick/dwarforge/tables/unions_master.parquet')

In [None]:
master.head()['unique_id']

## Functions

In [None]:
import os

import pandas as pd

band_dictionary = {
    'cfis-u': {
        'name': 'CFIS',
        'band': 'u',
        'vos': 'vos:cfis/tiles_DR5/',
        'suffix': '.u.fits',
        'delimiter': '.',
        'fits_ext': 0,
        'zfill': 3,
        'zp': 30.0,
    },
    'whigs-g': {
        'name': 'calexp-CFIS',
        'band': 'g',
        'vos': 'vos:cfis/whigs/stack_images_CFIS_scheme/',
        'suffix': '.fits',
        'delimiter': '_',
        'fits_ext': 1,
        'zfill': 0,
        'zp': 27.0,
    },
    'cfis_lsb-r': {
        'name': 'CFIS_LSB',
        'band': 'r',
        'vos': 'vos:cfis/tiles_LSB_DR5/',
        'suffix': '.r.fits',
        'delimiter': '.',
        'fits_ext': 0,
        'zfill': 3,
        'zp': 30.0,
    },
    'ps-i': {
        'name': 'PS-DR3',
        'band': 'i',
        'vos': 'vos:cfis/panstarrs/DR3/tiles/',
        'suffix': '.i.fits',
        'delimiter': '.',
        'fits_ext': 0,
        'zfill': 3,
        'zp': 30.0,
    },
    'wishes-z': {
        'name': 'WISHES',
        'band': 'z',
        'vos': 'vos:cfis/wishes_1/coadd/',
        'suffix': '.z.fits',
        'delimiter': '.',
        'fits_ext': 1,
        'zfill': 0,
        'zp': 27.0,
    },
    'ps-z': {
        'name': 'PSS.DR4',
        'band': 'ps-z',
        'vos': 'vos:cfis/panstarrs/DR4/resamp/',
        'suffix': '.z.fits',
        'delimiter': '.',
        'fits_ext': 0,
        'zfill': 3,
        'zp': 30.0,
    },
}


def read_h5(file_path, needed_datasets=['ra', 'dec', 'images', 'zoobot_pred']):
    """
    Reads cutout data from HDF5 file with optimized dataset selection

    Args:
        file_path: path to HDF5 file
        needed_datasets: list of datasets to read (None = read all)

    Returns:
        cutout_data: dictionary with requested datasets
    """
    cutout_data = {}

    with h5py.File(file_path, 'r') as f:
        # Determine which datasets to load
        if needed_datasets is None:
            datasets_to_read = f.keys()
        else:
            datasets_to_read = [d for d in needed_datasets if d in f]

        # Loop through and load only needed datasets
        for dataset_name in datasets_to_read:
            if dataset_name == 'images':
                data = np.nan_to_num(np.array(f[dataset_name]), nan=0.0)
            else:
                data = np.array(f[dataset_name])
            cutout_data[dataset_name] = data

    return cutout_data


def get_tile_numbers(name):
    """
    Extract tile numbers from tile name
    :param name: .fits file name of a given tile
    :return two three digit tile numbers
    """

    if name.startswith('calexp'):
        pattern = re.compile(r'(?<=[_-])(\d+)(?=[_.])')
    else:
        pattern = re.compile(r'(?<=\.)(\d+)(?=\.)')

    matches = pattern.findall(name)

    return tuple(map(int, matches))


def extract_tile_numbers(tile_dict, in_dict):
    """
    Extract tile numbers from .fits file names.

    Args:
        tile_dict: lists of file names from the different bands
        in_dict: band dictionary

    Returns:
        num_lists (list): list of lists containing available tile numbers in the different bands
    """

    num_lists = []
    for band in np.array(list(in_dict.keys())):
        num_lists.append(np.array([get_tile_numbers(name) for name in tile_dict[band]]))

    return num_lists


def load_available_tiles(path, in_dict):
    """
    Load tile lists from disk.
    Args:
        path (str): path to files
        in_dict (dict): band dictionary

    Returns:
        dictionary of available tiles for the selected bands
    """

    band_tiles = {}
    for band in np.array(list(in_dict.keys())):
        tiles = np.loadtxt(os.path.join(path, f'{band}_tiles.txt'), dtype=str)
        band_tiles[band] = tiles

    return band_tiles


def relate_coord_tile(coords=None, nums=None):
    """
    Conversion between tile numbers and coordinates.

    Args:
        right ascention, declination (tuple): ra and dec coordinates
        nums (tuple): first and second tile numbers

    Returns:
        tuple: depending on the input, return the tile numbers or the ra and dec coordinates
    """
    if coords:
        ra, dec = coords
        xxx = ra * 2 * np.cos(np.radians(dec))
        yyy = (dec + 90) * 2
        return int(xxx), int(yyy)
    else:
        xxx, yyy = nums  # type: ignore
        dec = yyy / 2 - 90
        ra = xxx / 2 / np.cos(np.radians(dec))
        return np.round(ra, 12), np.round(dec, 12)


class TileAvailability:
    def __init__(self, tile_nums, in_dict, at_least=False, band=None):
        self.all_tiles = tile_nums
        self.tile_num_sets = [set(map(tuple, tile_array)) for tile_array in self.all_tiles]
        self.unique_tiles = sorted(set.union(*self.tile_num_sets))
        self.availability_matrix = self._create_availability_matrix()
        self.counts = self._calculate_counts(at_least)
        self.band_dict = in_dict

    def _create_availability_matrix(self):
        array_shape = (len(self.unique_tiles), len(self.all_tiles))
        availability_matrix = np.zeros(array_shape, dtype=int)

        for i, tile in enumerate(self.unique_tiles):
            for j, tile_num_set in enumerate(self.tile_num_sets):
                availability_matrix[i, j] = int(tile in tile_num_set)

        return availability_matrix

    def _calculate_counts(self, at_least):
        counts = np.sum(self.availability_matrix, axis=1)
        bands_available, tile_counts = np.unique(counts, return_counts=True)

        counts_dict = dict(zip(bands_available, tile_counts))

        if at_least:
            at_least_counts = np.zeros_like(bands_available)
            for i, count in enumerate(bands_available):
                at_least_counts[i] = np.sum(tile_counts[i:])
            counts_dict = dict(zip(bands_available, at_least_counts))

        return counts_dict

    def get_availability(self, tile_nums):
        try:
            index = self.unique_tiles.index(tuple(tile_nums))
        except ValueError:
            logger.warning(f'Tile number {tile_nums} not available in any band.')
            return [], []
        except TypeError:
            return [], []
        bands_available = np.where(self.availability_matrix[index] == 1)[0]
        return [list(self.band_dict.keys())[i] for i in bands_available], bands_available

    def band_tiles(self, band=None):
        tile_array = np.array(self.unique_tiles)[
            self.availability_matrix[:, list(self.band_dict.keys()).index(band)] == 1
        ]
        return [tuple(tile) for tile in tile_array]

    def get_tiles_for_bands(self, bands=None):
        """
        Get all tiles that are available in specified bands.
        If no bands are specified, return all unique tiles.

        Args:
            bands (str or list): Band name(s) to check for availability.
                                 Can be a single band name or a list of band names.

        Returns:
            list: List of tuples representing the tiles available in all specified bands.
        """
        if bands is None:
            return self.unique_tiles

        if isinstance(bands, str):
            bands = [bands]

        try:
            band_indices = [list(self.band_dict.keys()).index(band) for band in bands]
        except ValueError as e:
            logger.error(f'Invalid band name: {e}')
            return []

        # Get tiles available in all specified bands
        available_tiles = np.where(self.availability_matrix[:, band_indices].all(axis=1))[0]

        return [self.unique_tiles[i] for i in available_tiles]

    def stats(self, band=None):
        logger.info('Number of currently available tiles per band:')
        max_band_name_length = max(map(len, self.band_dict.keys()))  # for output format
        for band_name, count in zip(
            self.band_dict.keys(), np.sum(self.availability_matrix, axis=0)
        ):
            logger.info(f'{band_name.ljust(max_band_name_length)}: {count}')

        logger.info('Number of tiles available in different bands:')
        for bands_available, count in sorted(self.counts.items(), reverse=True):
            logger.info(f'In {bands_available} bands: {count}')

        logger.info(f'Number of unique tiles available: {len(self.unique_tiles)}')

        if band:
            logger.info(f'Number of tiles available in combinations containing the {band}-band:\n')

            all_bands = list(self.band_dict.keys())
            all_combinations = []
            for r in range(1, len(all_bands) + 1):
                all_combinations.extend(combinations(all_bands, r))
            combinations_w_r = [x for x in all_combinations if band in x]

            for band_combination in combinations_w_r:
                band_combination_str = ''.join([str(x).split('-')[-1] for x in band_combination])
                band_indices = [
                    list(self.band_dict.keys()).index(band_c) for band_c in band_combination
                ]
                common_tiles = np.sum(self.availability_matrix[:, band_indices].all(axis=1))
                logger.info(f'{band_combination_str}: {common_tiles}')


class TileWCS:
    """
    Class to create a WCS object for a tile.
    """

    def __init__(self, wcs_keywords={}):
        wcs_keywords.update(
            {
                'NAXIS': 2,
                'CTYPE1': 'RA---TAN',
                'CTYPE2': 'DEC--TAN',
                'CRVAL1': 0,
                'CRVAL2': 0,
                'CRPIX1': 5000.0,
                'CRPIX2': 5000.0,
                'CD1_1': -5.160234650248e-05,
                'CD1_2': 0.0,
                'CD2_1': 0.0,
                'CD2_2': 5.160234650248e-05,
                'NAXIS1': 10000,
                'NAXIS2': 10000,
            }
        )

        self.wcs_tile = WCS(wcs_keywords)

    def set_coords(self, coords):
        self.wcs_tile.wcs.crval = [coords[0], coords[1]]


def find_all_tiles(tiles, coords, tile_info_dir, loaded_tree=None):
    """
    Find all tiles containing a given coordinate (up to 4).

    Args:
        tiles (list): list of tile numbers as tuples
        coords (list): ra, dec of object to query
        tile_info_dir (str): path to save and load the tree
        loaded_tree: pre-loaded KD-tree (optional)

    Returns:
        list: list of tile number tuples that contain the coordinate
    """
    # Only load the tree if it wasn't provided
    if loaded_tree is None:
        loaded_tree = joblib.load(os.path.join(tile_info_dir, 'kdtree_xyz.joblib'))
    coord_c = SkyCoord(coords[0], coords[1], unit='deg', frame='icrs')
    coord_xyz = coord_c.cartesian.xyz.value

    # Query 4 nearest neighbors since that's the maximum by your survey design
    dists, indices = loaded_tree.query(coord_xyz, k=4)

    matching_tiles = []
    wcs = TileWCS()

    for idx in indices:
        tile_nums = tiles[idx]
        wcs.set_coords(relate_coord_tile(nums=tile_nums))
        if wcs.wcs_tile.footprint_contains(coord_c):
            matching_tiles.append(tile_nums)

    return matching_tiles


def match_coordinates(query_ra, query_dec, catalog_ra, catalog_dec, max_separation=5.0):
    """Match query coordinates to catalog entries within a maximum separation.

    Args:
        query_ra (float or array-like): Right ascension of query position(s) in degrees.
        query_dec (float or array-like): Declination of query position(s) in degrees.
        catalog_ra (array-like): Right ascension values of catalog in degrees.
        catalog_dec (array-like): Declination values of catalog in degrees.
        max_separation (float, optional): Maximum separation in arcseconds. Defaults to 5.0.

    Returns:
        tuple:
            - match_mask: Boolean array indicating which catalog entries match the query
            - catalog_indices: Indices of matching catalog entries
    """
    # Convert inputs to arrays
    query_ra_array = np.atleast_1d(query_ra)
    query_dec_array = np.atleast_1d(query_dec)
    catalog_ra_array = np.atleast_1d(catalog_ra)
    catalog_dec_array = np.atleast_1d(catalog_dec)

    # Handle empty inputs
    if len(query_ra_array) == 0 or len(catalog_ra_array) == 0:
        return np.zeros(len(catalog_ra_array), dtype=bool), np.array([], dtype=int)

    # Create SkyCoord objects
    query_coords = SkyCoord(query_ra_array, query_dec_array, unit='deg', frame='icrs')
    catalog_coords = SkyCoord(catalog_ra_array, catalog_dec_array, unit='deg', frame='icrs')

    # For each catalog coordinate, find the closest query coordinate
    idx, d2d, _ = catalog_coords.match_to_catalog_sky(query_coords)

    # Create mask for matches within separation limit
    mask = d2d.arcsec <= max_separation

    return mask, idx


def detect_anomaly(
    image,
    zero_threshold=0.005,
    min_size=50,
    replace_anomaly=True,
    dilate_mask=True,
    dilation_iters=1,
    band='cfis_lsb-r',
):
    # replace nan values with zeros
    image[np.isnan(image)] = 0.0

    # Perform a 2D Discrete Wavelet Transform using Haar wavelets
    coeffs = pywt.dwt2(image, 'haar')
    cA, (cH, cV, cD) = coeffs  # Decomposition into approximation and details

    # Create binary masks where wavelet coefficients are below the threshold
    mask_horizontal = np.abs(cH) <= zero_threshold
    mask_vertical = np.abs(cV) <= zero_threshold
    mask_diagonal = np.abs(cD) <= zero_threshold

    masks = [mask_diagonal, mask_horizontal, mask_vertical]

    global_mask = np.zeros_like(image, dtype=bool)
    component_masks = np.zeros((3, cA.shape[0], cA.shape[1]), dtype=bool)
    anomalies = np.zeros(3, dtype=bool)
    for i, mask in enumerate(masks):
        # Apply connected-component labeling to find connected regions in the mask
        labeled_array, num_features = label(mask)  # type: ignore

        # Calculate the sizes of all components
        component_sizes = np.bincount(labeled_array.ravel())

        anomaly_detected = np.any(component_sizes[1:] >= min_size)
        anomalies[i] = anomaly_detected

        if not anomaly_detected:
            continue

        # Prepare to accumulate a total mask
        total_feature_mask = np.zeros_like(image, dtype=bool)

        # Loop through all labels to find significant components
        for component_label in range(1, num_features + 1):  # Start from 1 to skip background
            if component_sizes[component_label] >= min_size:
                # Create a binary mask for this component
                component_mask = labeled_array == component_label
                # add component mask to component masks
                component_masks[i] |= component_mask
                # Upscale the mask to match the original image dimensions
                upscaled_mask = np.kron(component_mask, np.ones((2, 2), dtype=bool))
                # Accumulate the upscaled feature mask
                total_feature_mask |= upscaled_mask

        # Accumulate global mask
        global_mask |= total_feature_mask
        # Dilate the masks to catch some odd pixels on the outskirts of the anomaly
        if dilate_mask:
            global_mask = binary_dilation(global_mask, iterations=dilation_iters)
            for j, comp_mask in enumerate(component_masks):
                component_masks[j] = binary_dilation(comp_mask, iterations=dilation_iters)
    # Replace the anomaly with zeros
    if replace_anomaly:
        image[global_mask] = 0.0

    return image


def process_channels(
    img,
    scaling_type='asinh',
    stretch=0.008,
    Q=7.0,
    gamma=0.25,
):
    """
    Create an RGB image from three bands of data preserving relative channel intensities.
    Handles channels that are all zeros.
    """
    frac = 0.1
    with np.errstate(divide='ignore', invalid='ignore'):
        red = img[:, :, 0]
        green = img[:, :, 1]
        blue = img[:, :, 2]

        # Check for zero channels
        red_is_zero = np.all(red == 0)
        green_is_zero = np.all(green == 0)
        blue_is_zero = np.all(blue == 0)

        # Compute average intensity before scaling choice (avoiding zero channels)
        nonzero_channels = []
        if not red_is_zero:
            nonzero_channels.append(red)
        if not green_is_zero:
            nonzero_channels.append(green)
        if not blue_is_zero:
            nonzero_channels.append(blue)

        if nonzero_channels:
            i_mean = sum(nonzero_channels) / len(nonzero_channels)
        else:
            i_mean = np.zeros_like(red)  # All channels are zero

        if scaling_type == 'asinh':
            # Apply asinh scaling only to non-zero channels
            if not red_is_zero:
                red = (
                    red * np.arcsinh(Q * i_mean / stretch) * frac / (np.arcsinh(frac * Q) * i_mean)
                )
            if not green_is_zero:
                green = (
                    green
                    * np.arcsinh(Q * i_mean / stretch)
                    * frac
                    / (np.arcsinh(frac * Q) * i_mean)
                )
            if not blue_is_zero:
                blue = (
                    blue * np.arcsinh(Q * i_mean / stretch) * frac / (np.arcsinh(frac * Q) * i_mean)
                )
        elif scaling_type == 'linear':
            # Apply linear scaling without normalization
            if not red_is_zero:
                red = red * stretch
            if not green_is_zero:
                green = green * stretch
            if not blue_is_zero:
                blue = blue * stretch
        else:
            raise ValueError(f'Unknown scaling type: {scaling_type}')

        # Apply gamma correction only to non-zero channels
        if gamma is not None:
            if not red_is_zero:
                red_mask = abs(red) <= 1e-9
                red = np.sign(red) * (abs(red) ** gamma)  # Preserve sign
                red[red_mask] = 0

            if not green_is_zero:
                green_mask = abs(green) <= 1e-9
                green = np.sign(green) * (abs(green) ** gamma)
                green[green_mask] = 0

            if not blue_is_zero:
                blue_mask = abs(blue) <= 1e-9
                blue = np.sign(blue) * (abs(blue) ** gamma)
                blue[blue_mask] = 0

        result = np.stack([red, green, blue], axis=-1).astype(np.float32)
    return result


def adjust_flux_with_zp(flux, current_zp, standard_zp):
    adjusted_flux = flux * 10 ** (-0.4 * (current_zp - standard_zp))
    return adjusted_flux


def preprocess_cutout(
    cutout,
    scaling='asinh',
    Q=7,
    stretch=125,
    gamma=0.25,
    mode='vis',
):
    """
    Create an RGB image from the cutout data and save or plot it.

    Args:
        cutout (numpy.ndarray): cutout data
        scaling (str, optional): scaling type. Defaults to 'asinh'. Valid options are 'linear' or 'asinh'.
        Q (float, optional): softening parameter for asinh scaling. Defaults to 7.
        stretch (float, optional): scaling factor. Defaults to 125.
        gamma (float, optional): gamma correction factor. Defaults to 0.25.
        mode (str, optional): mode of operation. Defaults to 'training'. Valid options are 'training' or 'vis'. Fills missing channels for visualization.

    Returns:
        numpy.ndarray: preprocessed image cutout
    """

    # Define warning filter for specific warnings
    warnings.filterwarnings(
        'ignore', category=RuntimeWarning, message='invalid value encountered in log'
    )
    warnings.filterwarnings(
        'ignore', category=RuntimeWarning, message='invalid value encountered in power'
    )
    warnings.filterwarnings(
        'ignore', category=RuntimeWarning, message='invalid value encountered in cast'
    )
    warnings.filterwarnings('ignore', category=RuntimeWarning, message='divide by zero encountered')
    warnings.filterwarnings(
        'ignore',
        category=RuntimeWarning,
        message='RuntimeWarning: invalid value encountered in divide',
    )

    def local_warn_handler(message, category, filename, lineno, file=None, line=None):
        if category in [RuntimeWarning, UserWarning]:  # Filter specific warning types
            # Custom message with context about the image
            log = f'Warning: {filename}:{lineno}: {category.__name__}: {message}'
            logging.warning(log)  # Log the warning with contextual info
        else:
            # Let other warnings through normally
            warnings.showwarning_default(message, category, filename, lineno, file, line)

    # Store the default warning handler
    warnings.showwarning_default = warnings.showwarning
    # Set our custom handler
    warnings.showwarning = local_warn_handler

    # Use context manager for temporary warning suppression
    with warnings.catch_warnings():
        warnings.simplefilter('ignore', category=RuntimeWarning)

    # map out bands to RGB
    cutout_red = cutout[2]  # i-band
    cutout_green = cutout[1]  # r-band
    cutout_blue = cutout[0]  # g-band

    # adjust zero-point for the g-band
    if np.count_nonzero(cutout_blue) > 0:
        cutout_blue = adjust_flux_with_zp(cutout_blue, 27.0, 30.0)

    # replace anomalies
    cutout_red = detect_anomaly(cutout_red)
    cutout_green = detect_anomaly(cutout_green)
    cutout_blue = detect_anomaly(cutout_blue)

    # synthesize missing channel from the existing ones
    # longest valid wavelength is mapped to red, middle to green, shortest to blue
    if mode == 'vis':
        if np.count_nonzero(cutout_red > 1e-10) == 0:
            cutout_red = cutout_green
            cutout_green = (cutout_green + cutout_blue) / 2
        elif np.count_nonzero(cutout_green > 1e-10) == 0:
            cutout_green = (cutout_red + cutout_blue) / 2
        elif np.count_nonzero(cutout_blue > 1e-10) == 0:
            cutout_blue = cutout_red
            cutout_red = (cutout_red + cutout_green) / 2

    rgb = np.stack([cutout_red, cutout_green, cutout_blue], axis=-1)

    # Create RGB image
    img_rgb = process_channels(
        rgb,
        scaling_type=scaling,
        stretch=stretch,
        Q=Q,
        gamma=gamma,
    )

    # restore original cutout shape (channel, cutout_size, cutout_size)
    img_rgb = np.moveaxis(img_rgb, -1, 0)

    # Restore default warning behavior after function completes
    warnings.showwarning = warnings.showwarning_default

    return img_rgb


def get_cutout_data(
    avail, main_dir, data_dir, tile_info_dir, coords=None, df=None, ra_key='ra', dec_key='dec'
):
    """
    Get cutout data with highest prediction score for given coordinates,
    optimized to read each tile's H5 file only once.
    """
    available_tiles = avail.unique_tiles

    # Process input coordinates
    if coords is not None:
        ra_list = [coords[0]]
        dec_list = [coords[1]]
    elif df is not None:
        ra_list = df[ra_key].values
        dec_list = df[dec_key].values
    else:
        raise ValueError('Either coords or df must be provided')

    # Create a list of all coordinates
    all_coords = list(zip(ra_list, dec_list))

    # Map tiles to the coordinates they might contain
    tile_to_coords = {}

    # Load the KD-tree once instead of for each coordinate
    kdtree_path = os.path.join(tile_info_dir, 'kdtree_xyz.joblib')
    loaded_tree = joblib.load(kdtree_path)

    # For each coordinate, find all potential tiles
    for coord_idx, (ra, dec) in enumerate(all_coords):
        matching_tiles = find_all_tiles(available_tiles, (ra, dec), tile_info_dir, loaded_tree)

        for tile_nums in matching_tiles:
            tile_key = (tile_nums[0], tile_nums[1])
            if tile_key not in tile_to_coords:
                tile_to_coords[tile_key] = []
            tile_to_coords[tile_key].append(coord_idx)

    # Initialize arrays for best matches
    num_coords = len(all_coords)
    best_cutouts = [None] * num_coords
    best_preds = [-1] * num_coords
    best_tiles = [None] * num_coords

    # Process each tile
    for tile_nums, coord_indices in tile_to_coords.items():
        tile_dir = f'{str(tile_nums[0]).zfill(3)}_{str(tile_nums[1]).zfill(3)}'
        h5_file = f'{tile_dir}_matched_cutouts_full_res_final.h5'
        h5_path = os.path.join(data_dir, tile_dir, 'gri', h5_file)

        if not os.path.exists(h5_path):
            continue

        try:
            # Read the H5 file ONCE for all coordinates in this tile
            cutout_data = read_h5(h5_path)
            h5_ra = cutout_data['ra']
            h5_dec = cutout_data['dec']
            h5_images = cutout_data['images']
            h5_preds = cutout_data['zoobot_pred']

            # Process each coordinate for this tile
            for coord_idx in coord_indices:
                ra, dec = all_coords[coord_idx]

                # Find matches for this coordinate
                match_bool, _ = match_coordinates(ra, dec, h5_ra, h5_dec, max_separation=5.0)

                if np.any(match_bool):
                    matched_preds = h5_preds[match_bool]
                    if len(matched_preds) > 0:
                        # Find highest prediction
                        max_idx = np.argmax(matched_preds)
                        max_pred = matched_preds[max_idx]

                        # Update if better than current best
                        if max_pred > best_preds[coord_idx]:
                            best_preds[coord_idx] = max_pred
                            # Get indices of matches
                            matched_indices = np.where(match_bool)[0]
                            # Get cutout with max prediction
                            max_cutout_idx = matched_indices[max_idx]
                            best_cutouts[coord_idx] = h5_images[max_cutout_idx : max_cutout_idx + 1]
                            best_tiles[coord_idx] = tile_dir
        except Exception as e:
            print(f'Error processing tile {tile_dir}: {e}')

    # Create index maps and arrays for return
    matched_indices = []
    unmatched_indices = []

    # Collect matched and unmatched indices and print notifications
    for i, (ra, dec) in enumerate(all_coords):
        if best_preds[i] > -1:
            matched_indices.append(i)
        else:
            print(f'No match found for object at RA/Dec: {ra:.4f}, {dec:.4f}')
            unmatched_indices.append(i)

    # Create output arrays
    if matched_indices:
        final_cutouts = np.concatenate([best_cutouts[i] for i in matched_indices], axis=0)
        final_preds = np.array([best_preds[i] for i in matched_indices])
        final_tiles = np.array([best_tiles[i] for i in matched_indices])
    else:
        final_cutouts = np.array([])
        final_preds = np.array([])
        final_tiles = np.array([])

    # Create mapping from original indices to match indices
    index_map = {}
    for output_idx, orig_idx in enumerate(matched_indices):
        index_map[orig_idx] = output_idx

    # Also return the original coordinates and the index mapping
    return final_cutouts, final_preds, final_tiles, all_coords, index_map


def plot_cutouts(
    cutouts,
    preds,
    all_coords,
    index_map,
    mode='grid',
    figsize=None,
    save_path=None,
    show_plot=False,
):
    """
    Display galaxy cutouts with prediction probabilities.
    Shows placeholders for missing matches in their original positions.

    Args:
        cutouts: Numpy array of image cutouts for matched objects
        preds: Numpy array of prediction probabilities for matched objects
        all_coords: List of all (ra, dec) coordinates in original order
        index_map: Dictionary mapping original indices to positions in cutouts/preds arrays
        mode: 'grid' or 'channel' to select display format
        figsize: Optional tuple for figure size (width, height)
        save_path: Optional path to save the figure
        show_plot: Whether to show the plot and return the figure object

    Returns:
        matplotlib figure object if show_plot is True
    """
    n_total = len(all_coords)

    if n_total == 0:
        print('No coordinates to display.')
        return None

    # Prepare images in the original coordinate order
    processed_images = []
    display_preds = []

    # Process each original coordinate, either showing the match or a placeholder
    for orig_idx, (ra, dec) in enumerate(all_coords):
        if orig_idx in index_map:
            # This coordinate has a match - use the actual cutout
            match_idx = index_map[orig_idx]
            img_rgb = preprocess_cutout(
                cutouts[match_idx], scaling='asinh', Q=7, stretch=125, gamma=0.25, mode='vis'
            )
            img = np.moveaxis(img_rgb, 0, -1)
            img = np.clip(img, 0, 1)
            processed_images.append(img)
            display_preds.append(preds[match_idx])
        else:
            # No match for this coordinate - use a placeholder
            placeholder = np.ones((64, 64, 3)) * 0.8  # Light gray
            processed_images.append(placeholder)
            display_preds.append(None)  # None indicates no match

    if mode.lower() == 'grid':
        # Calculate grid dimensions
        n_cols = min(5, n_total)  # Max 5 columns
        n_rows = (n_total + n_cols - 1) // n_cols  # Ceiling division

        if figsize is None:
            figsize = (3 * n_cols, 3 * n_rows)

        fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)

        # Handle different axis array shapes efficiently
        axes = np.atleast_2d(axes)

        # Display each cutout in original order
        for idx in range(n_total):
            i, j = divmod(idx, n_cols)
            ax = axes[i, j]
            ax.imshow(processed_images[idx], origin='lower', aspect='equal')

            if display_preds[idx] is None:  # No match
                ra, dec = all_coords[idx]
                ax.text(
                    32,
                    32,
                    'No match',
                    fontsize=12,
                    ha='center',
                    va='center',
                    color='red',
                    weight='bold',
                )
                ax.set_xlabel(f'RA={ra:.4f}, Dec={dec:.4f}', fontsize=10)
            else:
                ax.set_xlabel(f'{display_preds[idx]:.2f}', fontsize=15)

            ax.set_xticks([])
            ax.set_yticks([])

        # Hide empty subplots
        for idx in range(n_total, n_rows * n_cols):
            i, j = divmod(idx, n_cols)
            axes[i, j].axis('off')

        plt.tight_layout(pad=0.5)

    elif mode.lower() == 'channel':
        if figsize is None:
            figsize = (12, 3 * n_total)

        # Create figure without specifying DPI
        fig, axes = plt.subplots(n_total, 4, figsize=figsize, constrained_layout=True)

        # Handle case with single cutout
        axes = np.atleast_2d(axes)

        # Set column headers once
        col_titles = ['Red', 'Green', 'Blue', 'RGB']
        for j, title in enumerate(col_titles):
            if n_total > 0:  # Only add titles if we have items to display
                axes[0, j].set_title(title, fontsize=18, fontweight='bold', pad=10)

        # Process and display all cutouts and placeholders in original order
        for i in range(n_total):
            img = processed_images[i]

            if display_preds[i] is None:  # No match
                ra, dec = all_coords[i]

                # Display placeholder in all channels
                for j in range(4):
                    axes[i, j].imshow(img, cmap='gray', origin='lower', aspect='equal')
                    axes[i, j].set_xticks([])
                    axes[i, j].set_yticks([])

                # Add "No match" text ONLY to RGB panel
                axes[i, 3].text(
                    32,
                    32,
                    'No match',
                    fontsize=12,
                    ha='center',
                    va='center',
                    color='red',
                    weight='bold',
                )

                # Add coordinates as label
                axes[i, 0].text(
                    -0.02,
                    0.5,
                    f'RA={ra:.4f}\nDec={dec:.4f}',
                    rotation=90,
                    verticalalignment='center',
                    horizontalalignment='right',
                    transform=axes[i, 0].transAxes,
                    fontsize=10,
                )
            else:
                # Display individual channels for matched cutouts
                for j in range(3):  # R, G, B channels
                    axes[i, j].imshow(img[:, :, j], cmap='gray', origin='lower', aspect='equal')
                    axes[i, j].set_xticks([])
                    axes[i, j].set_yticks([])

                # Display RGB image
                axes[i, 3].imshow(img, origin='lower', aspect='equal')
                axes[i, 3].set_xticks([])
                axes[i, 3].set_yticks([])

                # Add probability text
                axes[i, 0].text(
                    -0.02,
                    0.5,
                    f'{display_preds[i]:.2f}',
                    rotation=90,
                    verticalalignment='center',
                    horizontalalignment='right',
                    transform=axes[i, 0].transAxes,
                    fontsize=15,
                )
    else:
        raise ValueError(f"Unknown mode: {mode}. Use 'grid' or 'channel'.")

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')

    if show_plot:
        return fig
    else:
        plt.close()

In [None]:
data_dir = '/arc/projects/unions/ssl/data/raw/tiles/dwarforge'
main_dir = '/arc/home/heestersnick/dwarforge'
table_dir = os.path.join(main_dir, 'tables')
figure_dir = os.path.join(main_dir, 'figures')
tile_info_dir = os.path.join(main_dir, 'tile_info')
master = pd.read_parquet(os.path.join(table_dir, 'unions_master.parquet'))
master_unknown = master[
    master['lsb'].isna() & master['class_label'].isna() & (master['in_training_data'] == 0)
].reset_index(drop=True)
master_unknown_high_prob = master_unknown[master_unknown['zoobot_pred'] > 0.8].reset_index(
    drop=True
)

In [None]:
test_df = master_unknown_high_prob[:40].reset_index(drop=True)

# define the bands to consider
considered_bands = ['whigs-g', 'cfis_lsb-r', 'ps-i']
# create a dictionary with the bands to consider
band_dict_incl = {key: band_dictionary.get(key) for key in considered_bands}
# get all available tile numbers in the specified bands
all_bands = extract_tile_numbers(
    load_available_tiles(tile_info_dir, band_dict_incl), band_dict_incl
)
# create the tile availability object
availability = TileAvailability(all_bands, band_dict_incl, band=considered_bands)
# get cutout data
cutouts, preds, tiles_query, all_coords, index_map = get_cutout_data(
    avail=availability,
    main_dir=main_dir,
    data_dir=data_dir,
    tile_info_dir=tile_info_dir,
    df=test_df,
)
# plot the cutouts
fig = plot_cutouts(
    cutouts, preds, all_coords, index_map, mode='grid', save_path=None, show_plot=True
)

In [None]:
obj_idx = 1009

obj_coords = master_unknown_high_prob[['ra', 'dec']].values[obj_idx]
obj_coords = np.array([205.6863, 47.6032])
print(obj_coords)
master_pred = master_unknown_high_prob['zoobot_pred'][obj_idx]
master_tile = master_unknown_high_prob['tile'][obj_idx]

test_df = master_unknown_high_prob[:20].reset_index(drop=True)

test_df.loc[5, 'ra'] = 205.6863
test_df.loc[5, 'dec'] = 47.6032

test_df.loc[10, 'ra'] = 207.1622
test_df.loc[10, 'dec'] = 43.4118

# define the bands to consider
considered_bands = ['whigs-g', 'cfis_lsb-r', 'ps-i']
# create a dictionary with the bands to consider
band_dict_incl = {key: band_dictionary.get(key) for key in considered_bands}
# get all available tile numbers in the specified bands
all_bands = extract_tile_numbers(
    load_available_tiles(tile_info_dir, band_dict_incl), band_dict_incl
)
# create the tile availability object
availability = TileAvailability(all_bands, band_dict_incl, band=considered_bands)
# get cutout data
start_get = time.time()
cutouts, preds, tiles_query, all_coords, index_map = get_cutout_data(
    avail=availability,
    main_dir=main_dir,
    data_dir=data_dir,
    tile_info_dir=tile_info_dir,
    coords=obj_coords,
)
print(f'Got cutout data in {(time.time() - start_get):.1f} seconds.')
# plot the cutouts
start = time.time()
# save_path = os.path.join(figure_dir, 'cutout_test_300_dpi.png')
fig = plot_cutouts(
    cutouts, preds, all_coords, index_map, mode='channel', save_path=None, show_plot=True
)
print(f'Plot done in {(time.time() - start):.1f} seconds.')