In [None]:
import sys, os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import ipywidgets as widgets
from IPython.display import display, clear_output

# PATH SETUP & IMPORTS
pipeline_dir = os.path.abspath(os.path.join(os.getcwd(), '../main-pipeline'))
if pipeline_dir not in sys.path:
    sys.path.append(pipeline_dir)

import config
from lib import plotting

# LOAD DATA
csv_path = '/Users/leodrake/Documents/ss433/HRC_2024/2Dfits/comp tracker tables/comp-tracker-table-4comp-1sigma-mcmc-200k-bin0p25.csv'
tracker_df = pd.read_csv(csv_path)
tracker_df['obs_id'] = tracker_df['obs_id'].astype(str)

# CORRECT SCALE FOR HRC + BINNING FACTOR
# Native HRC pixel scale (arcsec / native pixel). If you have a value in config, use it instead.
HRC_ARCSEC_PER_PIXEL = 0.1318

# Parse bin factor from filename if possible; fallback to 0.25
# "bin0p25" => 0.25
bin_factor = 0.25
fn = os.path.basename(csv_path).lower()
if "bin" in fn:
    import re
    m = re.search(r'bin(\d+)p(\d+)', fn)
    if m:
        bin_factor = float(f"{m.group(1)}.{m.group(2)}")

# Key point:
# bin_factor is "fraction of native pixel". bin=1 => native pixel scale.
ARCSEC_PER_TABLE_PIXEL = bin_factor * HRC_ARCSEC_PER_PIXEL

print("HRC_ARCSEC_PER_PIXEL =", HRC_ARCSEC_PER_PIXEL)
print("bin_factor (fraction of native pixel) =", bin_factor)
print("ARCSEC_PER_TABLE_PIXEL =", ARCSEC_PER_TABLE_PIXEL)

# PREPARE BLOBS PER OBS (TABLE PIXELS -> ARCSEC ONCE)
def prepare_blobs_from_df(df, arcsec_per_table_pixel):
    obs_blobs = {}
    obs_mjd = {}

    for obs_id, group in df.groupby('obs_id'):
        group = group.copy()
        obs_id = str(obs_id)

        obs_mjd[obs_id] = float(group.iloc[0]['mjd'])

        core_row = group[group['component'].astype(str).str.lower().isin(['core', 'c1'])]
        if not core_row.empty:
            x0 = float(core_row.iloc[0]['xpos'])
            y0 = float(core_row.iloc[0]['ypos'])
        else:
            x0, y0 = 80.5, 80.5

        blobs = []
        for _, row in group.iterrows():
            name = str(row['component'])

            dx_pix = float(row['xpos']) - x0
            dy_pix = float(row['ypos']) - y0

            # FITS X increases right; sky East is left => flip x
            dx_arc = -dx_pix * arcsec_per_table_pixel
            dy_arc =  dy_pix * arcsec_per_table_pixel

            rad = float(np.hypot(dx_arc, dy_arc))
            pa  = float(np.degrees(np.arctan2(dx_arc, dy_arc)))  # 0=N, +90=E

            blobs.append({
                'comp': name,
                'rad_obs': rad,
                'pa_obs': pa,
                'rad_err_L': 0.1, 'rad_err_U': 0.1,
                'pa_err_L': 5.0,  'pa_err_U': 5.0
            })

        obs_blobs[obs_id] = blobs

    return obs_blobs, obs_mjd

obs_data_map, obs_mjd_map = prepare_blobs_from_df(tracker_df, ARCSEC_PER_TABLE_PIXEL)
obs_ids = sorted(obs_data_map.keys())

# quick sanity print for first obs
o0 = obs_ids[0]
print("Sanity:", o0, "MJD", obs_mjd_map[o0])
for b in obs_data_map[o0]:
    nm = b['comp'].lower()
    if 'east' in nm or 'west' in nm:
        print(" ", b['comp'], "rad(arcsec)=", round(b['rad_obs'], 3), "pa(deg)=", round(b['pa_obs'], 2))

# INTERACTIVE PLOTTER
obs_dropdown = widgets.Dropdown(options=obs_ids, description='Obs ID:', layout=widgets.Layout(width='260px'))
lock_betas = widgets.Checkbox(value=True, description='Lock βe=βw', indent=False)

beta_e = widgets.FloatSlider(
    value=float(config.EPHEMERIS['beta']),
    min=0.15, max=0.40, step=0.0005,
    description='β east:', readout_format='.4f',
    continuous_update=False, layout=widgets.Layout(width='520px')
)
beta_w = widgets.FloatSlider(
    value=float(config.EPHEMERIS['beta']),
    min=0.15, max=0.40, step=0.0005,
    description='β west:', readout_format='.4f',
    continuous_update=False, layout=widgets.Layout(width='520px')
)

def _sync_betas(change):
    if lock_betas.value:
        if change['owner'] is beta_e:
            beta_w.value = beta_e.value
        elif change['owner'] is beta_w:
            beta_e.value = beta_w.value

beta_e.observe(_sync_betas, names='value')
beta_w.observe(_sync_betas, names='value')

out = widgets.Output()

def _plot(obs_id, be, bw):
    blobs = obs_data_map.get(obs_id, [])
    mjd_obs = obs_mjd_map.get(obs_id, np.nan)
    if not blobs or np.isnan(mjd_obs):
        with out:
            clear_output(wait=True)
            print("Missing blobs or MJD.")
        return

    max_rad = max(b['rad_obs'] for b in blobs if b['comp'].lower() not in ['core','c1','bkg'])
    r_max = max(1.0, max_rad * 1.6)

    mappable = plt.cm.ScalarMappable(
        cmap=plt.cm.rainbow,
        norm=mcolors.Normalize(vmin=1.0, vmax=325)
    )

    with out:
        clear_output(wait=True)
        plt.close('all')

        fig, ax = plt.subplots(figsize=(8.5, 8.5), subplot_kw={'projection': 'polar'})

        base_params = config.EPHEMERIS.copy()

        if lock_betas.value:
            p = base_params.copy()
            p['beta'] = float(be)
            plotting._plot_jet_trajectories_on_ax(ax, mjd_obs, p, mappable, r_max)
        else:
            p1 = base_params.copy(); p1['beta'] = float(be)
            p2 = base_params.copy(); p2['beta'] = float(bw)
            plotting._plot_jet_trajectories_on_ax(ax, mjd_obs, p1, mappable, r_max)
            plotting._plot_jet_trajectories_on_ax(ax, mjd_obs, p2, mappable, r_max)

        for b in blobs:
            name_l = b['comp'].lower()
            if name_l == 'bkg':
                continue
            if 'west' in name_l:
                c = 'red'
            elif 'east' in name_l:
                c = 'blue'
            else:
                c = 'green'

            ax.plot(np.deg2rad(b['pa_obs']), b['rad_obs'], 'o',
                    color=c, markersize=8, markeredgecolor='black', zorder=10)

            if name_l not in ['core','c1']:
                ax.text(np.deg2rad(b['pa_obs']), b['rad_obs'] + (r_max*0.03),
                        b['comp'], color=c, fontsize=9, fontweight='bold')

        ax.set_title(
            f"Obs {obs_id} | MJD {mjd_obs:.2f} | βe={be:.4f} βw={bw:.4f}\n"
            f"HRC {HRC_ARCSEC_PER_PIXEL}\"/pix, bin={bin_factor} → {ARCSEC_PER_TABLE_PIXEL:.5f}\"/table-pix",
            pad=18
        )
        ax.set_rmax(r_max)
        plt.show()

def _update(*_):
    _plot(obs_dropdown.value, beta_e.value, beta_w.value)

obs_dropdown.observe(_update, names='value')
beta_e.observe(_update, names='value')
beta_w.observe(_update, names='value')
lock_betas.observe(_update, names='value')

ui = widgets.VBox([
    widgets.HBox([obs_dropdown, lock_betas]),
    beta_e,
    beta_w,
    out
])

display(ui)
_update()

In [None]:
# CARTESIAN INTERACTIVE PLOTTER (x,y) + χ²=1 ERROR ELLIPSES
# MIRRORED (East-left) + EXTENDABLE JET LENGTH
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.collections import LineCollection
from matplotlib.patches import Ellipse
import ipywidgets as widgets
from IPython.display import display, clear_output

# Robust import for ss433_phases (adjust if your pipeline uses a different module name)
ss433_phases = None
_import_errors = []
for mod in ("lib.physics", "lib.model_kinematics", "lib.plotting"):
    try:
        m = __import__(mod, fromlist=["ss433_phases"])
        ss433_phases = getattr(m, "ss433_phases")
        break
    except Exception as e:
        _import_errors.append((mod, repr(e)))

if ss433_phases is None:
    raise ImportError(
        "Couldn't find ss433_phases. Tried:\n"
        + "\n".join([f" - {mod}: {err}" for mod, err in _import_errors])
        + "\n\nUpdate the module list above to where ss433_phases actually lives."
    )


def _age_curve_extended(params, n, min_age=0.0, extend_factor=1.3):
    """
    Age grid from min_age to extend_factor * precession_period.
    endpoint=False avoids duplicating the phase boundary exactly at integer cycles.
    """
    P = float(params["precession_period"])
    max_age = float(extend_factor) * P
    if max_age <= min_age:
        raise ValueError("Bad extend_factor/min_age")
    return np.linspace(min_age, max_age, n, endpoint=False)


def _calc_xy_and_sigmas(blob):
    """
    Match fitter's definition exactly:
      x = r sin(PA), y = r cos(PA)   (arcsec)
      σx², σy² from linear propagation (no covariance)
    """
    pa_rad = np.deg2rad(float(blob["pa_obs"]))
    r = float(blob["rad_obs"])

    x_obs = r * np.sin(pa_rad)
    y_obs = r * np.cos(pa_rad)

    sig_r = (float(blob["rad_err_U"]) + abs(float(blob["rad_err_L"]))) / 2.0
    sig_pa_deg = (float(blob["pa_err_U"]) + abs(float(blob["pa_err_L"]))) / 2.0
    sig_pa = np.deg2rad(sig_pa_deg)

    sig_x_sq = (np.sin(pa_rad) * sig_r) ** 2 + (r * np.cos(pa_rad) * sig_pa) ** 2
    sig_y_sq = (np.cos(pa_rad) * sig_r) ** 2 + (r * -np.sin(pa_rad) * sig_pa) ** 2

    sig_x_sq = max(sig_x_sq, 1e-9)
    sig_y_sq = max(sig_y_sq, 1e-9)

    return x_obs, y_obs, sig_x_sq, sig_y_sq


def _jet_curve_xy(mjd_obs, params, side, beta_override=None, n_age=8000, extend_factor=1.3):
    """
    Model jet curve in x/y (arcsec) for a single side.
    Colored by age since ejection (days).
    """
    age = _age_curve_extended(params, n=n_age, min_age=0.0, extend_factor=extend_factor)
    jd_ej = (mjd_obs + 2400000.5) - age

    beta_e = float(params["beta"])
    beta_w = float(params["beta"])
    if beta_override is not None:
        if side == "east":
            beta_e = float(beta_override)
        else:
            beta_w = float(beta_override)

    mu_e_ra, mu_e_dec, mu_w_ra, mu_w_dec, _, _ = ss433_phases(
        jd_ej, params, beta_east=beta_e, beta_west=beta_w
    )

    if side == "east":
        mu_ra, mu_dec = mu_e_ra, mu_e_dec
    else:
        mu_ra, mu_dec = mu_w_ra, mu_w_dec

    rad = (np.sqrt(mu_ra**2 + mu_dec**2) * config.C_PC_PER_DAY * age / config.D_SS433_PC) * config.ARCSEC_PER_RADIAN
    pa = np.arctan2(mu_ra, mu_dec)  # radians

    x = rad * np.sin(pa)
    y = rad * np.cos(pa)
    return x, y, age


def _colored_line(ax, x, y, age, cmap, norm, lw=2.2, alpha=0.95):
    pts = np.column_stack([x, y]).reshape(-1, 1, 2)
    segs = np.concatenate([pts[:-1], pts[1:]], axis=1)
    lc = LineCollection(segs, cmap=cmap, norm=norm, linewidths=lw, alpha=alpha)
    lc.set_array(age[:-1])
    ax.add_collection(lc)
    return lc


# UI controls (mirrors your polar UI)
obs_dropdown_xy = widgets.Dropdown(options=obs_ids, description="Obs ID:", layout=widgets.Layout(width="260px"))
lock_betas_xy = widgets.Checkbox(value=True, description="Lock βe=βw", indent=False)

beta_e_xy = widgets.FloatSlider(
    value=float(config.EPHEMERIS["beta"]),
    min=0.15, max=0.40, step=0.0005,
    description="β east:", readout_format=".4f",
    continuous_update=False, layout=widgets.Layout(width="520px")
)
beta_w_xy = widgets.FloatSlider(
    value=float(config.EPHEMERIS["beta"]),
    min=0.15, max=0.40, step=0.0005,
    description="β west:", readout_format=".4f",
    continuous_update=False, layout=widgets.Layout(width="520px")
)

extend_factor = widgets.FloatSlider(
    value=1.30, min=1.0, max=2.5, step=0.05,
    description="Extend ×P:", readout_format=".2f",
    continuous_update=False, layout=widgets.Layout(width="520px")
)

show_ellipses = widgets.Checkbox(value=True, description="Show χ²=1 ellipses", indent=False)
show_centers  = widgets.Checkbox(value=True, description="Show blob centers", indent=False)
mirror_x      = widgets.Checkbox(value=True, description="Mirror x (East-left)", indent=False)

out_xy = widgets.Output()

def _sync_betas_xy(change):
    if lock_betas_xy.value:
        if change["owner"] is beta_e_xy:
            beta_w_xy.value = beta_e_xy.value
        elif change["owner"] is beta_w_xy:
            beta_e_xy.value = beta_w_xy.value

beta_e_xy.observe(_sync_betas_xy, names="value")
beta_w_xy.observe(_sync_betas_xy, names="value")


def _plot_xy(obs_id, be, bw, extf):
    blobs = obs_data_map.get(obs_id, [])
    mjd_obs = obs_mjd_map.get(obs_id, np.nan)

    with out_xy:
        clear_output(wait=True)
        plt.close("all")

        if not blobs or np.isnan(mjd_obs):
            print("Missing blobs or MJD.")
            return

        # Gather blob extents (include ellipses)
        xy = []
        for b in blobs:
            name_l = str(b["comp"]).lower()
            if name_l in ["bkg"]:
                continue
            x0, y0, sx2, sy2 = _calc_xy_and_sigmas(b)
            xy.append((x0, y0, sx2, sy2))

        if not xy:
            print("No non-bkg blobs to plot.")
            return

        # Setup colormap for age
        P = float(config.EPHEMERIS["precession_period"])
        age_norm = mcolors.Normalize(vmin=0.0, vmax=float(extf) * P)
        cmap = plt.cm.rainbow

        base_params = config.EPHEMERIS.copy()

        fig, ax = plt.subplots(figsize=(8.7, 8.7))
        ax.set_aspect("equal", adjustable="box")

        # Model curves (extended)
        if lock_betas_xy.value:
            p = base_params.copy()
            p["beta"] = float(be)
            xe, ye, ae = _jet_curve_xy(mjd_obs, p, "east", beta_override=float(be), extend_factor=float(extf))
            xw, yw, aw = _jet_curve_xy(mjd_obs, p, "west", beta_override=float(be), extend_factor=float(extf))
            _colored_line(ax, xe, ye, ae, cmap, age_norm)
            _colored_line(ax, xw, yw, aw, cmap, age_norm)
        else:
            p = base_params.copy()
            xe, ye, ae = _jet_curve_xy(mjd_obs, p, "east", beta_override=float(be), extend_factor=float(extf))
            xw, yw, aw = _jet_curve_xy(mjd_obs, p, "west", beta_override=float(bw), extend_factor=float(extf))
            _colored_line(ax, xe, ye, ae, cmap, age_norm)
            _colored_line(ax, xw, yw, aw, cmap, age_norm)

        cb = fig.colorbar(plt.cm.ScalarMappable(norm=age_norm, cmap=cmap), ax=ax, fraction=0.046, pad=0.04)
        cb.set_label("Age since ejection (days)")

        # Plot blobs + χ²=1 ellipses
        for b in blobs:
            name = str(b["comp"])
            name_l = name.lower()
            if name_l == "bkg":
                continue

            if "west" in name_l:
                c = "red"
            elif "east" in name_l:
                c = "blue"
            else:
                c = "green"

            x0, y0, sx2, sy2 = _calc_xy_and_sigmas(b)

            if show_ellipses.value and name_l not in ["core", "c1"]:
                ell = Ellipse(
                    (x0, y0),
                    width=2.0 * np.sqrt(sx2),
                    height=2.0 * np.sqrt(sy2),
                    angle=0.0,
                    facecolor="none",
                    edgecolor=c,
                    linewidth=2.0,
                    alpha=0.75,
                    zorder=9,
                )
                ax.add_patch(ell)

            if show_centers.value:
                ax.plot(x0, y0, "o", color=c, markersize=8, markeredgecolor="black", zorder=10)

            if name_l not in ["core", "c1"]:
                ax.text(x0, y0, " " + name, color=c, fontsize=9, fontweight="bold", zorder=11)

        # Limits: include blobs (with 3σ) AND model curve
        xs = [t[0] for t in xy]
        ys = [t[1] for t in xy]
        lim_blob = max(
            max(abs(x) + 3*np.sqrt(sx2) for x, _, sx2, _ in xy),
            max(abs(y) + 3*np.sqrt(sy2) for _, y, _, sy2 in xy),
            0.5,
        )
        ax.set_xlim(-lim_blob, lim_blob)
        ax.set_ylim(-lim_blob, lim_blob)

        ax.axhline(0, linewidth=1.0, alpha=0.25)
        ax.axvline(0, linewidth=1.0, alpha=0.25)

        if mirror_x.value:
            ax.invert_xaxis()  # East-left (matches your polar plot convention)

        ax.set_xlabel('x (arcsec)  [x = r sin(PA)]')
        ax.set_ylabel('y (arcsec)  [y = r cos(PA)]')
        ax.set_title(
            f"Cartesian Jets | Obs {obs_id} | MJD {mjd_obs:.2f}\n"
            f"βe={be:.4f}  βw={bw:.4f}  Extend={extf:.2f}×P  MirrorX={mirror_x.value}",
            pad=14
        )

        plt.show()


def _update_xy(*_):
    _plot_xy(obs_dropdown_xy.value, beta_e_xy.value, beta_w_xy.value, extend_factor.value)

obs_dropdown_xy.observe(_update_xy, names="value")
beta_e_xy.observe(_update_xy, names="value")
beta_w_xy.observe(_update_xy, names="value")
lock_betas_xy.observe(_update_xy, names="value")
extend_factor.observe(_update_xy, names="value")
show_ellipses.observe(_update_xy, names="value")
show_centers.observe(_update_xy, names="value")
mirror_x.observe(_update_xy, names="value")

ui_xy = widgets.VBox([
    widgets.HBox([obs_dropdown_xy, lock_betas_xy, mirror_x, show_ellipses, show_centers]),
    beta_e_xy,
    beta_w_xy,
    extend_factor,
    out_xy
])

display(ui_xy)
_update_xy()