In [None]:
# 3rd party
import tensorflow as tf
import lightkurve as lk
import pandas as pd
import numpy as np
from pathlib import Path

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.colors import LogNorm

from scipy.interpolate import interp1d

In [None]:
# From transit_detection/visualization/utils_plot_examples.py
def plot_ex_flux_window_and_diff_imgs(
    flux_window,
    diff_imgs,
    midpoint,
    is_transit_ex,
    plot_fp,
    norm=False,
    snr=None,
):
    """
    Plot flux and difference image data corresponding to a transit dataset example.

    Args:
        flux: NumPy array, flux window
        diff_imgs: list/tuple of NumPy arrays, in order: [diff_img,oot_img,snr_img,target_img]
        midpoint: float, time t corresponding to the midpoint of the flux window example
        is_transit_ex: boolean, True if example corresponds to a transit
        plot_fp: Path, file path to saved plot
        norm: bool, changes plot titles and diff_img plot scales
        snr: float, coresponding to tce_model_snr

    Returns:
        None
    """

    (diff_img, oot_img, snr_img, target_img) = diff_imgs

    label = "In Transit" if is_transit_ex else "Out of Transit"

    snr = str(round(snr, 2)) if snr else None
    it_idx = np.argmax(target_img)
    row, col = np.unravel_index(
        it_idx, target_img.shape
    )  # Get row and col as if array were flattened
    target_coords = {"x": row, "y": col}
    fig = plt.figure(figsize=(16, 16))
    gs = gridspec.GridSpec(3, 2, figure=fig)
    ax = fig.add_subplot(gs[0, :])
    time = np.linspace(0, 100, 100)
    ax.plot(time, flux_window, linestyle="None", marker="o", markersize=3, alpha=0.6)
    ax.set_title(
        f"{label} Flux Window w/ midpoint {round(midpoint,2) if midpoint else None} Over Time"
    )
    ax.set_xlabel("Time")
    ax.set_ylabel("Flux value") if not norm else ax.set_ylabel("Normalized Flux value")

    # diff imgs
    ax = fig.add_subplot(gs[1, 0])
    im = ax.imshow(diff_img, norm=None)
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)
    ax.scatter(
        target_coords["y"], target_coords["x"], marker="x", color="r", label="Target"
    )
    ax.set_ylabel("Row")
    ax.set_xlabel("Col")
    ax.legend()
    (
        ax.set_title("Difference Flux (e-/cadence)")
        if not norm
        else ax.set_title("Normalized Difference Flux")
    )
    # oot img
    ax = fig.add_subplot(gs[1, 1])
    im = ax.imshow(oot_img, norm=LogNorm() if not norm else None)
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)
    ax.scatter(
        target_coords["y"], target_coords["x"], marker="x", color="r", label="Target"
    )
    ax.set_ylabel("Row")
    ax.set_xlabel("Col")
    ax.legend()
    (
        ax.set_title("Out-of-Transit Flux (e-/cadence)")
        if not norm
        else ax.set_title("Normalized Out-of-Transit Flux")
    )
    # target img
    ax = fig.add_subplot(gs[2, 0])
    ax.imshow(target_img)
    ax.set_ylabel("Row")
    ax.set_xlabel("Col")
    ax.set_title("Target Position")
    # snr img
    ax = fig.add_subplot(gs[2, 1])
    im = ax.imshow(snr_img, norm=LogNorm() if not norm else None)
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)
    ax.scatter(
        target_coords["y"], target_coords["x"], marker="x", color="r", label="Target"
    )
    ax.set_ylabel("Row")
    ax.set_xlabel("Col")
    ax.legend()
    (
        ax.set_title(f"Difference SNR: {snr}")
        if not norm
        else ax.set_title(f"Normalized Difference SNR: {snr}")
    )
    plt.tight_layout()

    try:
        plt.savefig(plot_fp)
        plt.close()

        print(f"Saved plots figures to {plot_fp}")
    except Exception as e:
        print(f"ERROR: plotting: {e}")


In [None]:
tfrec_fp = "/Users/jochoa4/Downloads/raw_shard_0001-0001.tfrecord"

plot_dir = Path("/Users/jochoa4/Desktop/plots/visualize_flux_examples_06-07-2025/")
plot_dir.mkdir(parents=True, exist_ok=True)

In [None]:
tfrec_dataset = tf.data.TFRecordDataset(tfrec_fp)

for str_record in tfrec_dataset.as_numpy_iterator():
    example = tf.train.Example()
    example.ParseFromString(str_record)

    # get example info
    label = example.features.feature["label"].bytes_list.value[0].decode("utf-8")
    uid = example.features.feature["uid"].bytes_list.value[0].decode("utf-8")
    midpoint = example.features.feature["t"].float_list.value[0]

    # get ephemerides/tce data
    disposition = example.features.feature["disposition"].bytes_list.value[0].decode("utf-8") # fmt: skip
    tce_period = example.features.feature["tce_period"].float_list.value[0]
    tce_model_snr = example.features.feature["tce_model_snr"].float_list.value[0]
    tce_uid = uid.split("_")[0]
    target_id = tce_uid.split("-")[0]
    sector_run = tce_uid.split("S")[-1]
    sector = example.features.feature["sector"].bytes_list.value[0].decode("utf-8")

    # get flux window
    flux_window = example.features.feature["flux"].float_list.value

    # get diff_imgs
    diff_imgs = []
    for img_feature in ["diff_img", "oot_img", "snr_img", "target_img"]:
        example_img_feature = tf.reshape(
            tf.io.parse_tensor(
                example.features.feature[img_feature].bytes_list.value[0],
                tf.float32,
            ),
            (33, 33),
        ).numpy()
        diff_imgs.append(example_img_feature)

    t_sr_s_plot_dir = plot_dir / target_id / sector_run / sector
    t_sr_s_plot_dir.mkdir(parents=True, exist_ok=True)
    
    plot_fp = t_sr_s_plot_dir / f"t_{round(midpoint, 2)}_{disposition}_{tce_uid}_{label}.png"

    assert label in ("1", "0"), "ERROR: example label {label}: {type(label)} does not meeted expected format"
    plot_ex_flux_window_and_diff_imgs(
                    flux_window=flux_window,
                    diff_imgs=diff_imgs,
                    midpoint=midpoint,
                    is_transit_ex=(True if label == "1" else False),
                    plot_fp=plot_fp,
                    norm=False,
                    snr=None,
                )