# KBMOD Search for Multi-Night Results
  
A basic notebook to demonstrate searching results that appear across multiple nights.

Note that this notebook is intended to be run using shared data on baldur.

# Setup demo

Before importing, make sure you have installed kbmod using `pip install .` in the root `KBMOD` directory.  Also be sure you are running with python3 and using the correct notebook kernel.

In [1]:
import math
import matplotlib.pyplot as plt
import numpy as np
import os

import kbmod
from kbmod.analysis.plotting import *
from kbmod.data_interface import load_deccam_layered_image
from kbmod.search import ImageStack, PSF, StampCreator, Trajectory
from kbmod.results import Results
from kbmod.work_unit import WorkUnit
from kbmod.trajectory_generator import VelocityGridSearch
from kbmod.trajectory_utils import trajectory_predict_skypos

from kbmod.trajectory_explorer import TrajectoryExplorer

from astropy.coordinates import SkyCoord, search_around_sky
import astropy.units as u
from astropy.table import Table
import astropy.time 

In [2]:
# Data paths
wu_path = "/epyc/projects/kbmod/runs/wbeebe/20240609_42au/slice0/reprojected_wu.fits"  # A reflex-corrected WorkUnit
res_path = "/epyc/projects/kbmod/runs/wbeebe/20240609_42au/slice0"

# Path to known fakes (with reflex-corrected) coordinates on the dates used in this KBMOD search.
fakes_path = "/epyc/projects/kbmod/runs/wbeebe/fakes_42_au_2019_04_02_and_2019_05_07.csv"

# Reload a Saved KBMOD WorkUnit
Note that this WorkUnit was reflex-corrected with a guess distance of 42 AU.

By reloading the WorkUnit, we're able to recreate the stamps for individual observations in any given trajectory

In [None]:
wu = WorkUnit.from_fits(wu_path)

stack = wu.im_stack

print(f"Loaded stack with {stack.img_count()} images")

Setting unknown parameter: cluster_function
Setting unknown parameter: num_cores


# Reload KBMOD Results

A KBMOD run will store several files in its results directory. But we can generate a `Results` object (which wraps an astropy Table) from the `results.ecsv`

In [None]:
results = Results.read_table(os.path.join(res_path, "results.ecsv"))
results

Let's reconstruct the trajectories from our saved results and recreate the stamps for each observation

In [None]:
trajectories = results.make_trajectory_list()
results.table["all_stamps"] = [StampCreator.get_stamps(stack, trj, 10) for trj in trajectories]

Add a column to our results for the number of unique days observed for each result based on the MJDs 

In [None]:
def mjd_to_day(mjd):
    # Takes a float mjd and converts to a string calendar date
    return str(astropy.time.Time(mjd, format='mjd').to_value('datetime')).split()[0]

# For each result find the number of unique days observed.
num_days = []
for idx in range(len(results)):
    # Whether an observation was "valid" and included in the result
    is_valid = results[idx]["obs_valid"]
    
    # Get all of the observation times that were valid and included in the result
    valid_obstimes = [] 
    for i in range(len(is_valid)):
        if is_valid[i]:
            valid_obstimes.append(stack.get_obstime(i))
    
    # Convert the obstimes to days and generate the number of days.
    num_days.append(len(set([mjd_to_day(t) for t in valid_obstimes])))

# Add as a column in the results table
results.table["num_days"] = num_days

Now for each result, we're interested in plotting the cumulative coadd as well as the coadds consisting only of the observations within a given day for that result.

In [None]:
# Generate coadds per each day to sanity check against the fakes.
def plot_daily_coadds(result_table, result_idx):
    # Map each day for a result to its coadded stamp
    daily_coadds = {}
    result_row = result_table[result_idx]
    for i in range(stack.img_count()):
        if result_row["obs_valid"][i]:
            day = mjd_to_day(stack.get_obstime(i))
            curr_stamp = result_row["all_stamps"][i]
            # Depending on where "all_stamps" is generated may be a RawImage
            if not isinstance(curr_stamp, np.ndarray):
                curr_stamp = curr_stamp.image

            if day not in daily_coadds:
                # Create the initial coadd
                daily_coadds[day] = curr_stamp.copy()
            else:
                # Add the stamps together
                daily_coadds[day] += curr_stamp
    
    # First we'll plot the full coadd
    imgs = [result_table["stamp"][result_idx]]
    labels = [f'Coadd for result {result_idx}']
    
    # Add images and labels for each individual day
    for day in daily_coadds:
        imgs.append(daily_coadds[day])
        labels.append(str(day))
    
    plot_multiple_images(imgs, labels=labels)

In [None]:
for i in range(min(len(results),10)):
    plot_daily_coadds(results, i)

In [None]:
for i in range(len(results)):
    if results[i]["num_days"] >= 2:
        plot_daily_coadds(results, i)

# Search for results that are near known fakes

We have a table of fakes that are present in the data, uniquely identified by the 'ORBITID' column

In [None]:
fakes = Table.read(fakes_path, format="csv")
fakes.sort(["ORBITID", "mjd_mid"]) # Sort by the unique ORBITID for each fake and then observation time
fakes

astropy allows us to take two catalogus of coordinates (represented by `SkyCoord` objects) and easily search for nearest neighbors between them. 

First, as a simple approximation let's translate the initial (x, y) of each of our results into an (ra, dec). Note that we are using a reflex-corrected WCS from our `WorkUnit` so the (ra, dec) will be in reflex-corrected space.

In [None]:
def get_ra_decs_from_trj(idx, result_table, trajectories):
    # To generate a trajectory object we need times of each valid observation for this result
    is_valid = result_table[idx]["obs_valid"]
    valid_obstimes = []
    for i in range(len(is_valid)):
        if is_valid[i]:
            valid_obstimes.append(stack.get_obstime(i))
    
    # Now we can generate (ra, DEC) coordinates for the trajectory using our work unit's WCS
    return trajectory_predict_skypos(trajectories[idx], wu.wcs, valid_obstimes)

ra_decs = [get_ra_decs_from_trj(i, results, trajectories) for i in range(len(results))]

results.table["ra_dec_start"] = [ra_decs[i][0] for i in range(len(results))]
results.table["ra_dec_end"] = [ra_decs[i][-1] for i in range(len(results))]

In [None]:
fakes_guess_ra = "RA_42.0"
fakes_guess_dec = "Dec_42.0"

Now we can translate our (ra, dec) pairs into single `SkyCoord` objects.

Then we can use astopy's `search_around_sky` to find which KBMOD results are near our known fakes, with a max separation limit of 1 arcsecond

In [None]:
# Get the reflex-corrected (ra, DEC) positions for our fakes  
fake_coords = SkyCoord(ra=fakes[fakes_guess_ra] * u.degree, dec=fakes[fakes_guess_dec] * u.degree)

# Performe the search for if our trajectory start positions were near any fakes
idx1, idx2, sep2dAngle, dist3d = search_around_sky(results["ra_dec_start"], fake_coords, 1 * u.arcsecond)

`idx1` Maps fakes to their nearest neighbors in the results. Each value is an index to a potential finding within our results table.

In [None]:
print(len(idx1))
idx1

In [None]:
idx2

In [None]:
potential_fake_starts = np.unique(idx1)
potential_fake_starts

Let's first examine the first result which we believe might be a fake.

In [None]:
results[potential_fake_starts[0]]

In [None]:
# helper function to plot a row of the results table
plot_result_row(results[potential_fake_starts[0]])

In [None]:
plot_daily_coadds(results, potential_fake_starts[0])

`idx2` shows the inverse mapping of which fakes might be potential results. Here its values are indices within our fakes table.

So taking the first potential match we examined above, we can use the corresponding index (in this case 0) to inspect within our fakes table.


In [None]:
fakes[idx2[0]]["ORBITID"]

In [None]:
def find_fakes_matches(fakes_table, idx1, idx2):
    # Matches a result's index index in the results table to its corresponding fake orbit ID
    matches = {}
    for i in range(len(idx1)):
        if idx1[i] not in matches:
            matches[idx1[i]] = set()
        matches[idx1[i]].add(fakes[idx2[i]]["ORBITID"])
        
    return matches
    
result_start_to_fakes = find_fakes_matches(fakes, idx1, idx2)
result_start_to_fakes
    

In [None]:
fakes[fakes["ORBITID"] == 4661373]

# Now see if the end points of any of the trajectories are also near the endpoints of our fakes

In [None]:
idx1_end, idx2_end, sep2dAngle_end, dist3d_end = search_around_sky(results["ra_dec_end"], fake_coords, 1 * u.arcsecond)

In [None]:
idx1_end

In [None]:
idx2_end

In [None]:
potential_fake_ends = np.unique(idx1_end)
potential_fake_ends

In [None]:
result_end_to_fakes = find_fakes_matches(fakes, idx1_end, idx2_end)
result_end_to_fakes

# Look for Fakes Matching Both the Start and Endpoint

In [None]:
# Maps of result indices that could be matched to a fake for both their start and end coordinate
result_full_fake_match = {}
for r in result_start_to_fakes:
    # Check for results that also were matched to a fake for their end coordinate
    if r in result_end_to_fakes:
        shared_fakes = result_start_to_fakes[r].intersection(result_end_to_fakes[r])
        if len(shared_fakes) > 0:
            result_full_fake_match[r] = shared_fakes
result_full_fake_match

In [None]:
for idx in result_start_to_fakes:
    plot_daily_coadds(results, idx)

Result 237 is interesting, since it's a multi-day result matched to a fake, but doesn't have a clean coadd on its second day. Let's get it's fake ORBITID and examine it some more

In [None]:
result_full_fake_match[237]

In [None]:
4661373

# Examine a given Fake

Let's evaluate the linearity and velocities of our fake

In [None]:
CURR_ORBIT_ID = 4661373

In [None]:
def evaluate_fake(fake_table, fake_orbit_id, verbose=True):
    our_fake = fake_table[fake_table["ORBITID"] == fake_orbit_id]

    our_fake.sort("mjd_mid")
    fake_x, fake_y = wu.wcs.world_to_pixel(SkyCoord(ra=our_fake[fakes_guess_ra]*u.deg, dec=our_fake[fakes_guess_dec]*u.degree))

    A = np.vstack([fake_x, np.ones(len(fake_x))]).T
    m, c = np.linalg.lstsq(A, fake_y, rcond=None)[0]

    # Make predictions
    y_pred = m * fake_x + c

    # Compute R-squared
    ss_res = np.sum((fake_y - y_pred) ** 2)
    ss_tot = np.sum((fake_y - np.mean(fake_y)) ** 2)
    r_squared = 1 - (ss_res / ss_tot)
    
    elapsed_time = our_fake["mjd_mid"][-1] - our_fake["mjd_mid"][0]
    fake_vx = float(fake_x[-1] - fake_x[0]) / elapsed_time
    fake_vy = float(fake_y[-1] - fake_y[0]) / elapsed_time
    fake_v = math.sqrt(fake_vx*fake_vx + fake_vy*fake_vy)

    print(f'For fake ORBITID={fake_orbit_id}, R^2={r_squared} and its pixel/day velocities are vx={fake_vx} vy={fake_vy}, v={fake_v}')
    return r_squared, fake_vx, fake_vy, fake_v
    

evaluate_fake(fakes, CURR_ORBIT_ID)

# Plot Fake by Day

In [None]:
def manual_coadd(stamps, stamp_indices, plot_me=False):
    # Manually coadds stsmps
    result_stamp = stamps[stamp_indices[0]].image.copy()
    for i in stamp_indices[1:]:
        result_stamp += stamps[i].image
    if plot_me:
        plt.imshow(result_stamp, cmap='grey')
    return result_stamp

def plot_daily_fake_stamps(stack, fakes_table, fake_orbit_id):
    # Pull the sorted fakes table
    our_fake = fakes_table[fakes_table["ORBITID"] == fake_orbit_id]
    our_fake.sort("mjd_mid")
    
    fake_x, fake_y = wu.wcs.world_to_pixel(SkyCoord(ra=our_fake[fakes_guess_ra]*u.deg, dec=our_fake[fakes_guess_dec]*u.degree))
    
    imgs = stack.get_images()
    
    stack_obstimes = [imgs[i].get_obstime() for i in range(len(imgs))]
    fake_times = our_fake["mjd_mid"]

    # Match to the closest times
    epsilon = 0.01
    closest_times = [-1 for i in fake_times] # List indices in the image stack with the closest obstime to our fake time
    for i in range(len(fake_times)):
        curr_fake_time = fake_times[i]
        for j in range(len(stack_obstimes)):
            old_closest_time = stack_obstimes[closest_times[i]]
            curr_time = stack_obstimes[j]
            curr_diff = abs(curr_time - curr_fake_time)
            if curr_diff <= epsilon:
                if closest_times[i] == -1 or curr_diff < abs(old_closest_time - curr_fake_time):
                    closest_times[i] = j
    
    # Generate fake_stamps
    fake_stamps = []
    for i in range(len(closest_times)):
        if closest_times[i] != -1:
            curr_img = imgs[closest_times[i]].get_science()
            fake_stamps.append(curr_img.create_stamp(fake_x[i], fake_y[i], 10, False))
    
    # Plot the coads for the whole and all days
    img_to_plot = [manual_coadd(fake_stamps, range(len(fake_stamps)), plot_me=False)]
    labels = [f'ORBITID={fake_orbit_id}']
    for day in set([int(x) for x in fake_times]):
        current_idx = []
        for i in range(len(closest_times)):
            if int(imgs[closest_times[i]].get_obstime()) == day: #fake_times[i] == day:
                current_idx.append(i)
        img_to_plot.append(manual_coadd(fake_stamps, current_idx, plot_me=False))
        labels.append(str(day))
    plot_multiple_images(img_to_plot, labels=labels) 

plot_daily_fake_stamps(stack, fakes, CURR_ORBIT_ID)

# Plot a Given KBMOD Result Trajectory Alongside a Given Fake

In the next cell we provide a variety of helper functions for evaluating a KBMOD trajectory alongside a given fake

In [None]:
def get_x_y_from_trj(idx, result_table, trajectories, times):
    dt = np.array(times)
    dt -= dt[0]

    trj = trajectories[idx]
    # Predict locations in pixel space.
    x_vals = trj.x + trj.vx * dt
    y_vals = trj.y + trj.vy * dt

    return x_vals, y_vals

def plot_res_traj(result_table, res_idx, trajectories):
    valid_obstimes = []
    for i in range(len(result_table[res_idx]["obs_valid"])):
        if result_table[res_idx]["obs_valid"][i]:
            valid_obstimes.append(stack.get_obstime(i))
    
    res_x, res_y = get_x_y_from_trj(res_idx, result_table, trajectories, valid_obstimes)
    
    min_date = min(valid_obstimes)
    max_date = max(valid_obstimes)
    day_sep = int(max_date) - int(min_date)

    kbmod_plt = plt.scatter(res_x, res_y, color='blue', label="KBMOD trajectory", marker='.')
    plt.xlabel('X (pixels)')
    plt.ylabel('Y (pixels)')
    plt.title(f'KBMOD Result Trajectory {res_idx} ({day_sep} day separation)')
    plt.show()

def plot_fake(fake_table, orbit_id):
    our_fake = fake_table[fake_table["ORBITID"] == orbit_id]
    fake_x, fake_y = wu.wcs.world_to_pixel(SkyCoord(ra=our_fake[fakes_guess_ra]*u.deg, dec=our_fake[fakes_guess_dec]*u.degree))
    
    min_date = min(our_fake["mjd_mid"])
    max_date = max(our_fake["mjd_mid"])
    day_sep = int(max_date) - int(min_date)

    fake_plt = plt.scatter(fake_x, fake_y, color='red', label="fakes", marker='.')
    plt.xlabel('X (pixels)')
    plt.ylabel('Y (pixels)')
    plt.title(f'Fake ORBIT_ID={orbit_id} ({day_sep} day separation)')
    plt.show()
    
def plot_fake_ra_dec(fake_table, orbit_id):
    our_fake = fake_table[fake_table["ORBITID"] == orbit_id]
    #fake_x, fake_y = SkyCoord(ra=our_fake[fakes_guess_ra]*u.deg, dec=our_fake[fakes_guess_dec]*u.degree)
    
    min_date = min(our_fake["mjd_mid"])
    max_date = max(our_fake["mjd_mid"])
    day_sep = int(max_date) - int(min_date)

    fake_plt = plt.scatter(our_fake[fakes_guess_ra], our_fake[fakes_guess_dec], color='red', label="fakes", marker='.')
    plt.xlabel('ra (degrees)')
    plt.ylabel('dec (degrees)')
    plt.title(f'Fake ORBIT_ID={orbit_id} ({day_sep} day separation)')
    plt.show()
    
def plot_fake_ra_dec_single_night(fake_table, orbit_id, night):
    our_fake = fake_table[fake_table["ORBITID"] == orbit_id]
    our_fake =  our_fake[our_fake["local_obsnight"] == night]
    #fake_x, fake_y = SkyCoord(ra=our_fake[fakes_guess_ra]*u.deg, dec=our_fake[fakes_guess_dec]*u.degree)
    
    min_date = min(our_fake["mjd_mid"])
    max_date = max(our_fake["mjd_mid"])
    day_sep = int(max_date) - int(min_date)

    fake_plt = plt.scatter(our_fake[fakes_guess_ra], our_fake[fakes_guess_dec], color='red', label="fakes", marker='.')
    plt.xlabel('ra (degrees)')
    plt.ylabel('dec (degrees)')
    plt.title(f'Fake ORBIT_ID={orbit_id} ({night})')
    plt.show()
    
def plot_fake_x_y_single_night(fake_table, orbit_id, night):
    our_fake = fake_table[fake_table["ORBITID"] == orbit_id]
    our_fake =  our_fake[our_fake["local_obsnight"] == night]
    fake_x, fake_y = wu.wcs.world_to_pixel(SkyCoord(ra=our_fake[fakes_guess_ra]*u.deg, dec=our_fake[fakes_guess_dec]*u.degree))
    
    min_date = min(our_fake["mjd_mid"])
    max_date = max(our_fake["mjd_mid"])
    day_sep = int(max_date) - int(min_date)

    fake_plt = plt.scatter(fake_x, fake_y, color='red', label="fakes", marker='.')
    plt.xlabel('ra (degrees)')
    plt.ylabel('dec (degrees)')
    plt.title(f'Fake ORBIT_ID={orbit_id} ({night})')
    plt.show()
    
    
def plot_res_traj_and_fake(result_table, res_idx, trajectories, fake_table, orbit_id):
    
    valid_obstimes = []
    for i in range(len(result_table[res_idx]["obs_valid"])):
        if result_table[res_idx]["obs_valid"][i]:
            valid_obstimes.append(stack.get_obstime(i))
    
    res_x, res_y = get_x_y_from_trj(res_idx, result_table, trajectories, valid_obstimes)
    
    our_fake = fake_table[fake_table["ORBITID"] == orbit_id]
    fake_x, fake_y = wu.wcs.world_to_pixel(SkyCoord(ra=our_fake[fakes_guess_ra]*u.deg, dec=our_fake[fakes_guess_dec]*u.degree))
    
    min_date = min(min(valid_obstimes), min(our_fake["mjd_mid"]))
    max_date = max(max(valid_obstimes), max(our_fake["mjd_mid"]))
    day_sep = int(max_date) - int(min_date)

    kbmod_plt = plt.scatter(res_x, res_y, color='blue', label="KBMOD trajectory", marker='|',  alpha=0.5)
    fake_plt = plt.scatter(fake_x, fake_y, color='red', label="fakes", marker='_',  alpha=0.5)
    plt.xlabel('X (pixels)')
    plt.ylabel('Y (pixels)')
    plt.title(f'KBMOD Result Trajectory {res_idx} vs Fake ORBIT_ID={orbit_id} ({day_sep} day separation)')
    plt.legend((kbmod_plt, fake_plt), ("KBMOD trajectory", "fake"))
    plt.show()
    
    
def plot_fake_ra_dec_sanity(fake_table, orbit_id):
    our_fake = fake_table[fake_table["ORBITID"] == orbit_id]
    our_fake.sort("mjd_mid")
    our_fake = our_fake[:90]
    #our_fake =  fake_table[fake_table["local_obsnight"] == night]
    #fake_x, fake_y = SkyCoord(ra=our_fake[fakes_guess_ra]*u.deg, dec=our_fake[fakes_guess_dec]*u.degree)
    
    min_date = min(our_fake["mjd_mid"])
    max_date = max(our_fake["mjd_mid"])
    day_sep = int(max_date) - int(min_date)

    fake_plt = plt.scatter(our_fake[fakes_guess_ra], our_fake[fakes_guess_dec], color='red', label="fakes", marker='.')
    plt.xlabel('ra (degrees)')
    plt.ylabel('dec (degrees)')
    plt.title(f'Fake ORBIT_ID={orbit_id})')
    plt.show()


In [None]:
plot_fake(fakes, CURR_ORBIT_ID)

## Plot the result and fake trajectories

In [None]:
plot_res_traj_and_fake(results, 237, trajectories, fakes, CURR_ORBIT_ID)

In [None]:
plot_fake_ra_dec(fakes, CURR_ORBIT_ID)

In [None]:
plot_fake_x_y_single_night(fakes, CURR_ORBIT_ID, "2019-04-02")

In [None]:
plot_fake_x_y_single_night(fakes, CURR_ORBIT_ID, "2019-04-02")

In [None]:
plot_fake_x_y_single_night(fakes, CURR_ORBIT_ID, "2019-05-07")