# Environment

This demo was presented on [baldur.astro.washington.edu/jupyter](baldur.astro.washington.edu/jupyter) with the shared jupyter kernel `kbmod/w_2023_38`. 

It assumes that the user has read access to the test data at `/epyc/projects/kbmod/data` on epyc and is assumed to be executed on baldur.

This notebook is currently stored for shared access in `/epyc/projects/kbmod/jupyter/notebooks/e2e`

# Setup

In [None]:
import kbmod

from kbmod.region_search import RegionSearch

# Inspect the butler repo's contents
While you can inspect the butler repo in a fairly straightforwared manner, the `RegionSearch` module provides some static methods that can help you pick which collections and datatypes to query from the butler. 

In [None]:
REPO_PATH = "/epyc/projects/kbmod/data/imdiff_w09_gaiadr3"

In [None]:
RegionSearch.get_collection_names(repo_path=REPO_PATH)

For this example, we want to pick one of the collections with fakes and we'll use 'DECam/withFakes/20210318'.

We now want to inspect how many datarefs are associated with each datatype we can query from this collection.

In [None]:
collections = ["DECam/withFakes/20210318"]

RegionSearch.get_dataset_type_freq(repo_path=REPO_PATH, collections=collections)

# Fetch Data from the Butler for Region Search

From the above, 'fakes_calexp' seems a reasonable choice for a datatype we can limit our queries to.

In the following, we construct a `RegionSearch` object which will instantiate a butler for our repo and fetch the image data keyed by (Visit, Detector, Region) (aka VDR) along with some associated metadata and calculations in an astropy table.

In [None]:
dataset_types = ["fakes_calexp"]
rs = RegionSearch(
    REPO_PATH, collections, dataset_types, visit_info_str="calexp.visitInfo", fetch_data_on_start=True
)

rs.vdr_data

# Find Discrete Piles

In the 10 images above we want to find

In [None]:
overlapping_sets = rs.find_overlapping_coords(uncertainty_radius=30)
print(f"Found {len(overlapping_sets)} discrete piles")
for i in range(len(overlapping_sets)):
    print(
        f"In overlapping set {i + 1}, we have the following indices for images in the VDR data table: {overlapping_sets[i]}"
    )

## Create an ImageCollection
The first pile has the most images, so we'll use it to create a KBMOD ImageCollection from which we can run a search.

In [None]:
uris = [rs.vdr_data["uri"][index] for index in overlapping_sets[0]]
ic = kbmod.ImageCollection.fromTargets(uris)
ic

# Create a KBMOD Workunit from the ImageCollection

Use KBMOD to search for trajectories in one of the identified discrete piles

In [None]:
from pathlib import Path
import os
import numpy as np

results_suffix = "DEMO"

res_filepath = "./demo_results"
if not Path(res_filepath).is_dir():
    os.mkdir(res_filepath)

# The demo data has an object moving at x_v=10 px/day
# and y_v = 0 px/day. So we search velocities [0, 20].
v_min = 3000
v_max = 4000
v_steps = 50
v_arr = [v_min, v_max, v_steps]

# angle with respect to ecliptic, in radians
ang_below = 3 * np.pi / 2  # 0
ang_above = 2 * np.pi  # 1
ang_steps = 50  # 21
ang_arr = [ang_below, ang_above, ang_steps]

# There are 3 images in the demo data. Make sure we see
# the object in at least 2.
num_obs = 2

input_parameters = {
    # Required
    "res_filepath": res_filepath,
    "output_suffix": results_suffix,
    "v_arr": v_arr,
    "ang_arr": ang_arr,
    # Important
    "num_obs": 2,
    "do_mask": False,
    "lh_level": 10.0,
    "gpu_filter": True,
    # Fine tuning
    "sigmaG_lims": [15, 60],
    "mom_lims": [37.5, 37.5, 1.5, 1.0, 1.0],
    "peak_offset": [3.0, 3.0],
    "chunk_size": 1000000,
    "stamp_type": "cpp_median",
    "cluster_eps": 20.0,
    "clip_negative": True,
    "mask_num_images": 0,
    "cluster_type": "position",
    "average_angle": 0.0,
}

config = kbmod.configuration.SearchConfiguration()
config.set_multiple(input_parameters)

wunit = ic.toWorkUnit(config)

# Visualize Our ImageCollection

The following defines some helper functions for visualizing the images in our `WorkUnit`. We can quickly inspect these to sanity check.

In [None]:
import matplotlib.pyplot as plt
import numpy as np

from astropy.visualization import astropy_mpl_style
from astropy.visualization import ZScaleInterval, simple_norm, imshow_norm, ZScaleInterval, SinhStretch


def get_image(workunit, n):
    return workunit.im_stack.get_images()[n]


def get_science_image(workunit, n):
    return get_image(workunit, n).get_science().image


def get_variance_image(workunit, n):
    return get_image(workunit, n).get_variance().image


def get_mask_image(workunit, n):
    return get_image(workunit, n).get_mask().image


def plot_img(img):
    fig, ax = plt.subplots(figsize=(25, 25))
    _ = imshow_norm(
        img.T, ax, cmap="gray", origin="lower", interval=ZScaleInterval(contrast=0.5), stretch=SinhStretch()
    )
    plt.show()

## The Science Images

In [None]:
for i in range(len(ic)):
    plot_img(get_science_image(wunit, i))

## The Variance Images

In [None]:
for i in range(len(ic)):
    plot_img(get_variance_image(wunit, i))

# Create a Reprojected Workunit

First we'll need to create a new initial work unit so results can be saved in a different directory

In [None]:
from pathlib import Path
import os

results_suffix = "REPROJECT_DEMO"

res_filepath = "./reproject_demo_results"
if not Path(res_filepath).is_dir():
    os.mkdir(res_filepath)

# The demo data has an object moving at x_v=10 px/day
# and y_v = 0 px/day. So we search velocities [0, 20].
v_min = 3000
v_max = 4000
v_steps = 50
v_arr = [v_min, v_max, v_steps]

# angle with respect to ecliptic, in radians
ang_below = 3 * np.pi / 2  # 0
ang_above = 2 * np.pi  # 1
ang_steps = 50  # 21
ang_arr = [ang_below, ang_above, ang_steps]

# There are 3 images in the demo data. Make sure we see
# the object in at least 2.
num_obs = 2

input_parameters = {
    # Required
    "res_filepath": res_filepath,
    "output_suffix": results_suffix,
    "v_arr": v_arr,
    "ang_arr": ang_arr,
    # Important
    "num_obs": 2,
    "do_mask": False,
    "lh_level": 10.0,
    "gpu_filter": True,
    # Fine tuning
    "sigmaG_lims": [15, 60],
    "mom_lims": [37.5, 37.5, 1.5, 1.0, 1.0],
    "peak_offset": [3.0, 3.0],
    "chunk_size": 1000000,
    "stamp_type": "cpp_median",
    "eps": 0.03,
    "clip_negative": True,
    "mask_num_images": 0,
    "cluster_type": "position",
    "average_angle": 0.0,
}

config = kbmod.configuration.SearchConfiguration()
config.set_multiple(input_parameters)

new_wunit = ic.toWorkUnit(config)

In [None]:
%%time
from kbmod import reprojection

common_wcs = new_wunit._per_image_wcs[0]

uwunit = reprojection.reproject_work_unit(new_wunit, common_wcs)

# Let's visualize our reprojected images.

## The reprojected science images

In [None]:
for i in range(len(ic)):
    plot_img(get_science_image(uwunit, i))

## The reprojected variance images

In [None]:
for i in range(len(ic)):
    plot_img(get_variance_image(uwunit, i))

# Run KBMOD Search without Reprojection

In [None]:
res = kbmod.run_search.SearchRunner().run_search_from_work_unit(wunit)

# Inspect the Results

In [None]:
trajectories = [t.trajectory for t in sorted(res.results, key=lambda x: x.trajectory.lh, reverse=True)]

In [None]:
trajectories

In [None]:
# We can create stamps for each result
imgstack = wunit.im_stack

# Create the stamps around remaining results
nres = len(trajectories)
fig, axes = plt.subplots(nres, 3, figsize=(10, nres * 3), sharey=True, sharex=True)

stamp_size = 20
for row, traj in zip(axes, trajectories):
    stamps = kbmod.search.StampCreator.get_stamps(imgstack, traj, stamp_size)
    for ax, stamp in zip(row, stamps):
        ax.imshow(stamp.image, interpolation=None, cmap="gist_heat")

plt.tight_layout()

In [None]:
# We can further filter these results - let's say we had a lower cutoff on likelihood of 10
# but now that we can see there are many results with a much larger likelihoods than that - we want to increase that limit
# This is not uncommon as usually the number of false positives returned by KBMOD is rather large
from kbmod.filters.stats_filters import LHFilter

# Filter out all results that have a likelihood < 40.0.
lhfilter = LHFilter(40.0, None)
res.apply_filter(lhfilter)
print(f"{res.num_results()} results remaining.")

for result in res.results:
    print(result.trajectory)

In [None]:
# We can filter on stamps too, for example:
from kbmod.filters.stamp_filters import StampPeakFilter

filter2 = StampPeakFilter(10, 2.1, 0.1)
res.apply_filter(filter2)
print(f"{res.num_results()} results remaining.")

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(10, nres * 3), sharey=True, sharex=True)

stamps = kbmod.search.StampCreator.get_stamps(imgstack, res.results[0].trajectory, 20)
for ax, stamp in zip(axes, stamps):
    ax.imshow(stamp.image, interpolation=None, cmap="gist_heat")

plt.tight_layout()

# Run KBMOD Search on the Reprojected Images

In [None]:
reproject_res = kbmod.run_search.SearchRunner().run_search_from_work_unit(wunit)

In [None]:
reproj_traj = [t.trajectory for t in reproject_res.results]
reproj_traj

In [None]:
# We can create stamps for each result
imgstack = uwunit.im_stack

# Create the stamps around remaining results
nres = len(reproj_traj)
fig, axes = plt.subplots(nres, 3, figsize=(10, nres * 3), sharey=True, sharex=True)

stamp_size = 20
for row, traj in zip(axes, reproj_traj):
    stamps = kbmod.search.StampCreator.get_stamps(imgstack, traj, stamp_size)
    for ax, stamp in zip(row, stamps):
        ax.imshow(stamp.image, interpolation=None, cmap="gist_heat")

plt.tight_layout()

In [None]:
# We can further filter these results - let's say we had a lower cutoff on likelihood of 10
# but now that we can see there are many results with a much larger likelihoods than that - we want to increase that limit
# This is not uncommon as usually the number of false positives returned by KBMOD is rather large
from kbmod.filters.stats_filters import LHFilter

# Filter out all results that have a likelihood < 40.0.
lhfilter = LHFilter(40.0, None)
reproject_res.apply_filter(lhfilter)
print(f"{reproject_res.num_results()} results remaining.")

for result in reproject_res.results:
    print(result.trajectory)

In [None]:
# We can filter on stamps too, for example:
from kbmod.filters.stamp_filters import StampPeakFilter

filter2 = StampPeakFilter(10, 2.1, 0.1)
reproject_res.apply_filter(filter2)
print(f"{reproject_res.num_results()} results remaining.")

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(10, nres * 3), sharey=True, sharex=True)

stamps = kbmod.search.StampCreator.get_stamps(imgstack, reproject_res.results[0].trajectory, 20)
for ax, stamp in zip(axes, stamps):
    ax.imshow(stamp.image, interpolation=None, cmap="gist_heat")

plt.tight_layout()