# Fake Analysis

This notebook provides a short example of analyzing the fakes from the DEEP data using `analyze_fakes.py`.

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from kbmod.analysis.analyze_fakes import FakeInfo, load_fake_info_from_ecsv
from kbmod.analysis.plotting import plot_image, plot_multiple_images, plot_time_series
from kbmod.configuration import SearchConfiguration
from kbmod.results import Results
from kbmod.trajectory_explorer import TrajectoryExplorer
from kbmod.trajectory_generator import create_trajectory_generator
from kbmod.trajectory_utils import match_trajectory_sets, find_closest_velocity
from kbmod.work_unit import WorkUnit

## Load the WorkUnit

Load the test `WorkUnit` for this run and extract some metadata (t0 and configuration).  Also extract the configuration and set up a `TrajectoryExplorer` object that we can later use to run simulated searches on the `WorkUnit`.

In [None]:
wu_file = "/epyc/projects/kbmod/data/20210908_B1h_047_test_data/20210908_B1h_047.wu"
wu = WorkUnit.from_fits(wu_file)

times = wu.get_all_obstimes()
t0 = times[0]
zeroed_times = np.array(times) - t0
max_dt = np.max(zeroed_times)

print(f"Loaded {len(wu)} images starting at time {t0}")
wu.print_stats()

Use a configuration that matches what was used in the search (with changes noted as comments).

In [None]:
config_data = {
    "chunk_size": 1000000,
    "clip_negative": False,
    "cluster_eps": 120.0,
    "cluster_type": all,
    "cluster_v_scale": 1.0,
    "coadds": ["sum", "mean", "median", "weighted"],
    "debug": True,  # Use debugging output
    "do_clustering": True,
    "do_mask": True,
    "encode_num_bytes": -1,
    "generator_config": {
        "angle_units": "degree",
        "angles": [-270, -90, 64],
        "given_ecliptic": None,
        "name": "EclipticCenteredSearch",
        "velocities": [80.0, 500.0, 64],
        "velocity_units": "pix / d",
    },
    "gpu_filter": True,
    "lh_level": 7.0,
    "max_lh": 1000.0,
    "num_obs": 50,  # Set lower
    "psf_val": 1.4,
    "result_filename": None,
    "results_per_pixel": 8,
    "save_all_stamps": False,
    "sigmaG_lims": [25, 75],
    "stamp_radius": 10,
    "stamp_type": sum,
    "track_filtered": False,
    "x_pixel_bounds": None,
    "x_pixel_buffer": None,
    "y_pixel_bounds": None,
}
config = SearchConfiguration(config_data)
wu.config = config

explorer = TrajectoryExplorer(wu.im_stack, config)
explorer.initialize_data()

We can use the configuration in the `WorkUnit` to access information about the trajectories that are searched.

In [None]:
trj_generator = create_trajectory_generator(config, work_unit=wu)
print(trj_generator)

candidates = [trj for trj in trj_generator]
print(f"Searches {len(candidates)} Trajectories.")

# Print a scatter plot of the candidate velocities.
cand_vx = [trj.vx for trj in candidates]
cand_vy = [trj.vy for trj in candidates]

fig, ax = plt.subplots(figsize=(10.0, 10.0))
ax.plot(cand_vx, cand_vy, color="black", marker=".", linewidth=0)
ax.set_title("Candidate Velocity Distribution")
ax.set_xlabel("vx (pixels / day)")
ax.set_ylabel("vy (pixels / day)")

## Load the Inserted Fakes

Load the fakes data from the ecsv file. For each object (unique orbitid), build a trajectory with from those observations.  We use the `load_fake_info_from_ecsv()` helper function.

In [None]:
fakes_file = "/epyc/projects/kbmod/data/20210908_B1h_047_test_data/20210908_B1h_047_inserted_fakes.ecsv"
fakes_list = load_fake_info_from_ecsv(fakes_file)

We now have a list of `FakeInfo` objects with the minimal information (the inserted fake's time, RA, and dec).  To get meaningful information, we need to join against what we know about the images. We use the images' WCS to compute the fake object's pixel positions at each time. Then we use those pixel positions to extract stamps and fit a linear trajectory.

Currently we join with each `FakeInfo` object separately (we could add a helper function here if needed).

In [None]:
for idx, fake in enumerate(fakes_list):
    fake.join_with_workunit(wu, 10)
    print(f"{idx}, {fake.name}:\n  {fake.trj}")
    print(f"  Times Seen={fake.num_times_seen} of {len(fake)}")
    print(f"  MSE={fake.compute_fit_mse()}\n")

We can also plot some summary information.  Here we plot three aspects:
  * RA vs time to show how linear the trajectory is in that dimension.
  * Dec vs time to show how linear the trajectory is in that dimension.
  * Magnitude vs time to show how the signal is changing over time.

In [None]:
for idx, fake in enumerate(fakes_list):
    fake.plot_summary(title=f"\n---------------\n({idx}) {fake.name}")

We can also see how the x and y velocities of the fakes align with our search space.

In [None]:
fig, ax = plt.subplots(figsize=(6.0, 6.0))
ax.plot(cand_vx, cand_vy, color="black", marker=".", linewidth=0)

fake_vx = [fake.trj.vx for fake in fakes_list]
fake_vy = [fake.trj.vy for fake in fakes_list]
ax.plot(fake_vx, fake_vy, color="red", marker=".", linewidth=0)

ax.set_title("Candidate Velocity Distribution")
ax.set_xlabel("vx (pixels / day)")
ax.set_ylabel("vy (pixels / day)")

## Examining Stamps

We can plot the stamps at the raw (x, y) that we computed from the (RA, dec) position and the image WCS. We can also look at the positions predicted by the fitted, linear trajectory.  Below we look at the stamps at the first 4 time steps.

In [None]:
fakes_list[0].compare_stamps([0, 1, 2, 3])

## Loading the Results

We load in the results of the actual KBMOD run for comparison.

In [None]:
results_file = "/epyc/projects/kbmod/data/20210908_B1h_047_test_data/20210908_B1h_047.results.ecsv"
results = Results.read_table(results_file)
print(results)

We match the known fakes against the result set. If there is a match, we display the result information (trajectory, likelihood, and observation count). 

If there is no match, we look up the closest velocity in the candidate set. This is the best we could do if we matched the starting pixel exactly. From this candidate trajectory, we can compute how far we are by the last time step. Again this is the best case distance at the last time step.

In [None]:
known_list = [fake.trj for fake in fakes_list]
found_list = results.make_trajectory_list()

# Match the known fakes and the found trajectories. To match the mean error at t_0 and t_last
# must be <= 10.0 pixels.
match_times = [0.0, zeroed_times[-1]]
matches = match_trajectory_sets(known_list, found_list, 10.0, times=match_times)

for idx, trj in enumerate(known_list):
    m_idx = matches[idx]
    print(f"Fake {idx} ({fakes_list[idx].name}): Match={m_idx}")
    print(f"  Ave Mag: {np.mean(fakes_list[idx].mag)}")
    print(f"  Times Seen: {fakes_list[idx].num_times_seen}")
    print(f"  Fake TRJ: x={trj.x:5d}, y={trj.y:5d}, vx={trj.vx:8.3f}, vy={trj.vy:8.3f}")
    if m_idx != -1:
        m_trj = found_list[m_idx]
        print(f"  Res  TRJ: x={m_trj.x:5d}, y={m_trj.y:5d}, vx={m_trj.vx:8.3f}, vy={m_trj.vy:8.3f}")
        print(f"  Result Score: lh={m_trj.lh}, obs_count={m_trj.obs_count}")
    else:
        # What is the closest candidate this COULD have matched with. Since we
        # could start at any pixel, just account for the velocity.
        closest = find_closest_velocity(trj, candidates)
        m_trj = candidates[closest]
        print(f"  Closest Candidate Vel: vx={m_trj.vx:8.3f}, vy={m_trj.vy:8.3f}")

        dx_max = max_dt * m_trj.vx - max_dt * trj.vx
        dy_max = max_dt * m_trj.vy - max_dt * trj.vy
        dist = np.sqrt(dx_max * dx_max + dy_max * dy_max)
        print(f"  Distance after dt={max_dt:.4f} is {dist:.2f} pixels.")

    print("\n")

We can also visualize the intersection of the search space (black dots), the fakes (red dots), and the found results (blue dots).

In [None]:
fig, ax = plt.subplots(figsize=(6.0, 6.0))
ax.plot(cand_vx, cand_vy, color="black", marker=".", linewidth=0)
ax.plot(fake_vx, fake_vy, color="red", marker=".", linewidth=0)

res_vx = [trj.vx for trj in found_list]
res_vy = [trj.vy for trj in found_list]
ax.plot(res_vx, res_vy, color="blue", marker=".", linewidth=0)

ax.set_title("Candidate Velocity Distribution")
ax.set_xlabel("vx (pixels / day)")
ax.set_ylabel("vy (pixels / day)")

# Per Fake Investigation

Now let's go deep on understanding what happened with each fake by trying a search on exactly the fit parameters. For each fake we are going to display:
  * The coadded stamps from the fitted trajectory. The red dots indicate masked points.
  * The psi, phi, and lh curves as computed after sigma-G filtering. The red dots indicate pointed that are either masked or filtered by sigma-G filtering. Masked points are also assigned values of 0.0.
  * The individual stamp at each time. We also indicate whether the stamp was a valid time step, a masked time step (e.g. bad pixel), or filtered by sigma-G filtering.

In [None]:
subplot_size = 3.0

for idx, fake in enumerate(fakes_list):
    trj = fake.trj
    single_res = explorer.evaluate_linear_trajectory(trj.x, trj.y, trj.vx, trj.vy)[0]

    # Extract the basic information to use in the title string.
    data_str = f"\n---------------------------\n{idx}: Orbit ID={fake.name}\n"
    if matches[idx] == -1:
        data_str += "  Status: NOT FOUND\n"
        data_str += f"  FAKE TRJ: x={trj.x:5d}, y={trj.y:5d}, vx={trj.vx:8.3f}, vy={trj.vy:8.3f}\n"
    else:
        data_str += f"  Status: RECOVERED (result={matches[idx]})\n"
        data_str += f"   FAKE TRJ: x={trj.x:5d}, y={trj.y:5d}, vx={trj.vx:8.3f}, vy={trj.vy:8.3f}\n"
        data_str += f"  FOUND TRJ: x={m_trj.x:5d}, y={m_trj.y:5d}, vx={m_trj.vx:8.3f}, vy={m_trj.vy:8.3f}\n"
        data_str += f"  Score: lh={m_trj.lh}, obs_count={m_trj.obs_count}\n"
    data_str += f"  FAKE GPU LH: {single_res['likelihood']}\n"

    # Plot some summary data.
    fig = plt.figure(layout="tight", figsize=(3 * subplot_size, 2 * subplot_size))
    ax = fig.subplots(2, 3)
    fig.suptitle(data_str)

    plot_image(
        single_res["coadd_sum"], ax=ax[0][0], figure=fig, norm=True, title="Sum Stamp", show_counts=False
    )
    plot_image(
        single_res["coadd_mean"], ax=ax[0][1], figure=fig, norm=True, title="Mean Stamp", show_counts=False
    )
    plot_image(
        single_res["coadd_median"],
        ax=ax[0][2],
        figure=fig,
        norm=True,
        title="Median Stamp",
        show_counts=False,
    )

    # Compute the psi, phi, and LH curves without sigma-G filtering. Only
    # account for the masked points (valid array).
    psi = single_res["psi_curve"]
    phi = single_res["phi_curve"]
    valid = (phi != 0) & np.isfinite(psi) & np.isfinite(phi)

    psi[~valid] = 0.0
    phi[~valid] = 1e-28
    lh = psi / np.sqrt(phi)

    # Run sigma-G filtering on the curves and mark any points that
    # are either masked or subject to sigma-G filtering.
    sigma_g = single_res["sigma_g_res"]
    valid2 = (phi != 0) & np.isfinite(psi) & np.isfinite(phi) & sigma_g

    plot_time_series(psi, zeroed_times, indices=valid2, ax=ax[1][0], figure=fig, title="Sigma G PSI")
    plot_time_series(phi, zeroed_times, indices=valid2, ax=ax[1][1], figure=fig, title="Sigma G PHI")
    plot_time_series(lh, zeroed_times, indices=valid2, ax=ax[1][2], figure=fig, title="Sigma G LH")

    # Plot the stamps as their own figure.
    num_stamps = len(single_res["all_stamps"])
    num_cols = 5
    num_rows = np.ceil(num_stamps / num_cols)
    img_width = 2.0

    labels = []
    for idx, t in enumerate(zeroed_times):
        label = f"{idx}={t:.4f}\n"
        if not valid[idx]:
            label += "MASKED"
        elif not sigma_g[idx]:
            label += "SIGMA G"
        else:
            label += "VALID"
        labels.append(label)

    fig = plt.figure(layout="tight", figsize=(img_width * num_cols, img_width * num_rows))
    plot_multiple_images(single_res["all_stamps"], fig, columns=num_cols, labels=labels, norm=True)

plt.show()