# Fake Analysis

This notebook provides a short example of analyzing the fakes from the DEEP data using `analyze_fakes.py`.

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from kbmod.analysis.analyze_fakes import FakeInfo, load_fake_info_from_ecsv
from kbmod.analysis.plotting import plot_image, plot_multiple_images, plot_time_series
from kbmod.results import Results
from kbmod.trajectory_explorer import TrajectoryExplorer
from kbmod.trajectory_utils import match_trajectory_sets
from kbmod.work_unit import WorkUnit

import logging

logging.basicConfig(level=logging.INFO)

Load the test `WorkUnit` and extract some metadata (t0 and configuration).  Also extract the configuration and set up a `TrajectoryExplorer` object that we can later use to run simulated searches on the `WorkUnit`.

In [None]:
wu_file = "/epyc/projects/kbmod/data/20210908_B1h_047_test_data/20210908_B1h_047.wu"
wu = WorkUnit.from_fits(wu_file)

times = wu.get_all_obstimes()
t0 = times[0]
zeroed_times = np.array(times) - t0

config = wu.config
explorer = TrajectoryExplorer(wu.im_stack, wu.config)
explorer.initialize_data()

print(f"Loaded {len(wu)} images starting at time {t0}")

We can compute basic statistics for the WorkUnit.

In [None]:
wu.print_stats()

Load the fakes data from the ecsv file. For each object (unique orbitid), build a trajectory with from those observations.  We use the `load_fake_info_from_ecsv()` helper function.

In [None]:
fakes_file = "/epyc/projects/kbmod/data/20210908_B1h_047_test_data/20210908_B1h_047_inserted_fakes.ecsv"
fakes_list = load_fake_info_from_ecsv(fakes_file)

We now have a list of `FakeInfo` objects with the minimal information (the inserted fake's time, RA, and dec).  To get meaningful information, we need to join against what we know about the images. We use the images' WCS to compute the fake object's pixel positions at each time. Then we use those pixel positions to extract stamps and fit a linear trajectory.

Currently we join with each `FakeInfo` object separately (we could add a helper function here if needed).

In [None]:
for fake in fakes_list:
    fake.join_with_workunit(wu, 10)
    print(f"{fake.name}:\n  Fit:{fake.trj}\n  MSE={fake.compute_fit_mse()}")

# Examining Stamps

We can plot the stamps at the raw (x, y) that we computed from the (RA, dec) position and the image WCS. We can also look at the positions predicted by the fitted, linear trajectory.  Below we look at the stamps at the first 4 time steps.

In [None]:
fakes_list[0].compare_stamps([0, 1, 2, 3])

## Loading the Results

We load in the results of the actual KBMOD run for comparison.

In [None]:
results_file = "/epyc/projects/kbmod/data/20210908_B1h_047_test_data/20210908_B1h_047.results.ecsv"
results = Results.read_table(results_file)
print(results)

Match the known fakes against the result set.

In [None]:
known_list = [fake.trj for fake in fakes_list]
found_list = results.make_trajectory_list()

matches = match_trajectory_sets(known_list, found_list, 200.0, times=[0.0, zeroed_times[-1]])

for idx, trj in enumerate(known_list):
    m_idx = matches[idx]
    print(f"Fake {idx}: Match={m_idx}")
    print(f"  Fake TRJ: x={trj.x:5d}, y={trj.y:5d}, vx={trj.vx:8.3f}, vy={trj.vy:8.3f}")
    if m_idx != -1:
        m_trj = found_list[m_idx]
        print(f"  Res  TRJ: x={m_trj.x:5d}, y={m_trj.y:5d}, vx={m_trj.vx:8.3f}, vy={m_trj.vy:8.3f}")
        print(f"  Score: lh={m_trj.lh}, obs_count={m_trj.obs_count}")
    print("\n")

# Per Fake Investigation

Now let's go deep on understanding what happened with a single fake.  The first unmatched fake was 2.

In [None]:
fake_num = 2

trj = fakes_list[fake_num].trj
single_res = explorer.evaluate_linear_trajectory(trj.x, trj.y, trj.vx, trj.vy)[0]

In [None]:
print(f"TRJ: x={trj.x:5d}, y={trj.y:5d}, vx={trj.vx:8.3f}, vy={trj.vy:8.3f}")
print(f"GPU LH {single_res['likelihood']}")

# Plot some summary data.
fig, ax = plt.subplots(2, 3, figsize=(12.0, 8.0))
plot_image(single_res["coadd_sum"], ax=ax[0][0], figure=fig, norm=True, title="Sum Stamp")
plot_image(single_res["coadd_mean"], ax=ax[0][1], figure=fig, norm=True, title="Mean Stamp")
plot_image(single_res["coadd_median"], ax=ax[0][2], figure=fig, norm=True, title="Median Stamp")

psi = single_res["psi_curve"]
phi = single_res["phi_curve"]
valid = (phi != 0) & np.isfinite(psi) & np.isfinite(phi)

psi[~valid] = 0.0
phi[~valid] = 1e-28
lh = psi / np.sqrt(phi)

plot_time_series(psi, zeroed_times, indices=valid, ax=ax[1][0], figure=fig, title="PSI")
plot_time_series(phi, zeroed_times, indices=valid, ax=ax[1][1], figure=fig, title="PHI")
plot_time_series(lh, zeroed_times, indices=valid, ax=ax[1][2], figure=fig, title="LH")

The plots above do not account for the sigma-G filtering performed on the GPU or afterwards. So let's add that into i

In [None]:
# Plot some summary data.
fig, ax = plt.subplots(1, 3, figsize=(12.0, 4.0))

sigma_g = single_res["sigma_g_res"]
valid2 = (phi != 0) & np.isfinite(psi) & np.isfinite(phi) & sigma_g

plot_time_series(psi, zeroed_times, indices=valid2, ax=ax[0], figure=fig, title="PSI")
plot_time_series(phi, zeroed_times, indices=valid2, ax=ax[1], figure=fig, title="PHI")
plot_time_series(lh, zeroed_times, indices=valid2, ax=ax[2], figure=fig, title="LH")

In [None]:
num_stamps = len(single_res["all_stamps"])

num_cols = 4
num_rows = np.ceil(num_stamps / num_cols)
img_width = 2.0

labels = []
for idx, t in enumerate(zeroed_times):
    label = f"{idx}={t:.4f}\n"
    if not valid[idx]:
        label += "MASKED"
    elif not sigma_g[idx]:
        label += "SIGMA G"
    else:
        label += "VALID"
    labels.append(label)

fig = plt.figure(layout="tight", figsize=(img_width * num_cols, img_width * num_rows))
plot_multiple_images(single_res["all_stamps"], fig, columns=num_cols, labels=labels, norm=True)