# Benchmark Filtering

This notebook is used to compare the effectiveness of different filtering algorithms, including combinations of filtering algorithms. It creates a fake data set, applies a subset of filtering algorithms, and then scores the results. It is meant to be used as an evaluation tool when testing new algorithms.

In [None]:
import numpy as np

from kbmod.analysis.plotting import plot_image
from kbmod.configuration import SearchConfiguration
from kbmod.fake_data.fake_data_creator import create_fake_times, FakeDataSet
from kbmod.run_search import SearchRunner
from kbmod.trajectory_generator import VelocityGridSearch
from kbmod.trajectory_utils import match_trajectory_sets

import timeit

rng = np.random.default_rng()

## Define the image parameters

We predefine a the parameters that indicate how we will generate the data. Users may want to vary these to determine their impact.

In [None]:
# Data set sizes.  Larger image sizes will mean more potential results found (including noise).
num_times = 20
width = 100
height = 150

# Create fake times with 3 observations per night.
times = create_fake_times(num_times, t0=60000.0, obs_per_day=3, intra_night_gap=0.04, inter_night_gap=1)

# Data characteristics
psf_val = 3.0
noise_level = 2.0
mask_fraction = 0.1  # 10% of pixels are masked
artifacts_fraction = 0.001  # 0.1% of pixels contain bright artifacts
artifacts_brightness = 250.0  # Mean brightness of artifacts (> trajectory brightness)

# Information about the trajectors to insert.
num_trjs = 20
trj_brightness = 200.0

Use the parameters to create the base data set that we will use for all of the studies.

In [None]:
# Create fake times with 3 observations per night.
times = create_fake_times(num_times, t0=60000.0, obs_per_day=3, intra_night_gap=0.04, inter_night_gap=1)

# Create a fake data set.
fake_ds = FakeDataSet(
    width,
    height,
    times,
    mask_fraction=mask_fraction,
    noise_level=noise_level,
    psf_val=psf_val,
    artifacts_fraction=artifacts_fraction,
    artifacts_mean=artifacts_brightness,
    artifacts_std=noise_level,
)

plot_image(fake_ds.stack_py.sci[0], title="Fake Image", show_counts=False)

## Define the search parameters

We predefine the parameters that will be used in the search.

In [None]:
min_obs = int(num_times / 2)

input_parameters = {
    "cpu_only": True,
    "do_clustering": False,
    "generate_psi_phi": True,
    "gpu_filter": False,
    "lh_level": 0.00000001,
    "max_results": 100_000_000,
    "near_dup_thresh": 1,
    "num_obs": min_obs,
    "psf_val": psf_val,
    "results_per_pixel": 10,
    "sigmaG_filter": False,
}
config = SearchConfiguration.from_dict(input_parameters)

trj_generator = VelocityGridSearch(41, 0.0, 20.0, 41, -10.0, 10.0)
search = SearchRunner()

Define a helper function for determining which results match which inserted fakes.

In [None]:
def _compute_match_stats(all_trjs, results, threshold, times):
    found_trjs = results.make_trajectory_list()
    all_matches = match_trajectory_sets(all_trjs, found_trjs, threshold, times=times)
    return np.count_nonzero(all_matches > -1)

We define a helper function that selects the correct filtering function from a string.

In [None]:
from kbmod.filters.sigma_g_filter import apply_clipped_sigma_g, SigmaGClipping
from kbmod.filters.clustering_filters import apply_clustering

def _sigma_g_filter(results, lower=0.25, upper=0.75, sigma=2, clip_negative=True):
    """Filter results based on the sigmaG value."""
    sigma_g = SigmaGClipping(lower, upper, sigma, clip_negative)
    apply_clipped_sigma_g(sigma_g, results)

def _cluster_filter(results, cluster_type="all", cluster_eps=10):
    """Filter results based on clustering."""
    global times
    cluster_params = {
        "cluster_type": cluster_type,
        "cluster_eps": cluster_eps,
        "cluster_v_scale": 1.0,
        "times": times,
    }
    apply_clustering(results, cluster_params)

# Create a dictionary mapping from filter names to a tuple of the function and its parameters.
filters = {
    "sigma_g_25_75_2": (_sigma_g_filter, {"lower": 0.25, "upper": 0.75, "sigma": 2}),
    "sigma_g_25_75_3": (_sigma_g_filter, {"lower": 0.25, "upper": 0.75, "sigma": 3}),
    "sigma_g_10_90_2": (_sigma_g_filter, {"lower": 0.10, "upper": 0.90, "sigma": 2}),
    "sigma_g_10_90_3": (_sigma_g_filter, {"lower": 0.10, "upper": 0.90, "sigma": 3}),   
}

# Add a combination of clustering filters.
for cluster_type in ["position", "start_end_position", "nn_start_end", "nn_start"]:
    for cluster_eps in [1.0, 2.0, 5.0, 10.0, 20.0]:
        filter_name = f"cluster_{cluster_type}_{cluster_eps}"
        filters[filter_name] = (_cluster_filter, {"cluster_type": cluster_type, "cluster_eps": cluster_eps})

all_filters = list(filters.keys())
num_filters = len(all_filters)

## Run and evaluate filtering algorithms

Iterate through the different filtering algorithms along with their threshold parameter. Each of these clustering calls takes a while, so we print out progress markers for each run.

**Note:** Timing only uses a single clustering run (instead of the average of a bunch), so it will be noisy.

In [None]:
num_iterations = 10

num_res = np.zeros((num_filters, num_iterations))
num_matched = np.zeros((num_filters, num_iterations))
run_time = np.zeros((num_filters, num_iterations))

for itr in range(num_iterations):
    print(f"Iteration {itr + 1} of {num_iterations}")

    # Regenerate the data set.
    fake_ds.reset()
    all_trjs = fake_ds.insert_random_objects_from_generator(num_trjs, trj_generator, trj_brightness)
    results = search.do_core_search(config, fake_ds.stack_py, trj_generator)

    for f_idx, filter_name in enumerate(all_filters):
        print(f"  Testing filter: {filter_name}")
        tmp_res = results.copy()

        # Apply the filter.
        filter_func, filter_params = filters[filter_name]
        run_time = timeit.timeit(
            f"{filter_func.__name__}(tmp_res, **filter_params)", 
            globals=globals()
        )

        num_res[f_idx, itr] = len(tmp_res)
        num_matched[f_idx, itr] = _compute_match_stats(all_trjs, tmp_res, 5.0, times)
        run_time[f_idx, itr] = run_time


In [None]:
# Put the results into a Table for easier analysis.
from astropy.table import Table
results_table = Table(
    data={
        "filter": all_filters,
        "num_results": np.mean(num_res, axis=1),
        "num_matched": np.mean(num_matched, axis=1),
        "run_time": np.mean(run_time, axis=1),
    }
)

# Compute precision and recall for each filter.
results_table["Precision"] = results_table["num_matched"] / results_table["num_results"]
results_table["Recall"] = results_table["num_matched"] / num_trjs

print(results_table)