# Step 11: V08 — Centroid shift (per sector)

Goal: measure whether the in-transit centroid position shifts relative to out-of-transit.

Why this matters:
- A significant centroid shift can indicate the transit signal originates from a nearby source rather than the target.

Notes:
- This check requires a TPF for each sector.
- The tutorial dataset includes a TPF for sector 83; other sectors are downloaded and cached to `persistent_cache/tutorial_toi-5807-incremental/tpfs/`.


In [None]:
from __future__ import annotations

from pathlib import Path
import json
import sys

import numpy as np

tutorial_dir = Path('docs/tutorials/tutorial_toi-5807-incremental').resolve()
sys.path.insert(0, str(tutorial_dir))

import toi5807_shared as sh
import bittr_tess_vetter.api as btv

ds = sh.load_dataset()

stitched = sh.stitch_pdcsap(ds)
depth_ppm, _ = sh.estimate_depth_ppm(stitched)
candidate = sh.make_candidate(depth_ppm)

# -----------------------------------------------------------------------------
# Load per-sector TPF stamps (download missing sectors once, then cache on disk)
# -----------------------------------------------------------------------------
tpf_by_sector: dict[int, btv.TPFStamp] = {int(k): v for k, v in ds.tpf_by_sector.items()}
cache_dir = Path('persistent_cache/tutorial_toi-5807-incremental/tpfs')
cache_dir.mkdir(parents=True, exist_ok=True)

sectors = sorted(int(s) for s in ds.lc_by_sector.keys())
missing = [s for s in sectors if s not in tpf_by_sector]

if missing:
    import lightkurve as lk

    for sector in missing:
        npz = cache_dir / f'sector{sector}_tpf.npz'
        if npz.exists():
            d = np.load(npz, allow_pickle=True)
            try:
                from astropy.wcs import WCS

                wcs = WCS(d['wcs_header'].item()) if 'wcs_header' in d else None
            except Exception:
                wcs = d['wcs_header'].item() if 'wcs_header' in d else None

            tpf_by_sector[sector] = btv.TPFStamp(
                time=np.asarray(d['time'], dtype=np.float64),
                flux=np.asarray(d['flux'], dtype=np.float64),
                flux_err=np.asarray(d['flux_err'], dtype=np.float64) if 'flux_err' in d else None,
                wcs=wcs,
                aperture_mask=np.asarray(d['aperture_mask'], dtype=bool) if 'aperture_mask' in d else None,
                quality=np.asarray(d['quality'], dtype=np.int32) if 'quality' in d else None,
            )
            continue

        search = lk.search_targetpixelfile(f'TIC {sh.TIC_ID}', sector=int(sector), exptime=120)
        tpf = search.download() if len(search) else None
        if tpf is None:
            print(f'No TPF available for sector {sector}')
            continue

        stamp = btv.TPFStamp(
            time=np.asarray(tpf.time.value, dtype=np.float64),
            flux=np.asarray(tpf.flux.value, dtype=np.float64),
            flux_err=np.asarray(tpf.flux_err.value, dtype=np.float64),
            wcs=tpf.wcs,
            aperture_mask=np.asarray(tpf.pipeline_mask, dtype=bool),
            quality=np.asarray(tpf.quality, dtype=np.int32),
        )
        tpf_by_sector[sector] = stamp

        # Cache to disk (best-effort)
        try:
            wcs_header = dict(tpf.wcs.to_header()) if getattr(tpf, 'wcs', None) is not None else None
            np.savez(
                npz,
                time=stamp.time,
                flux=stamp.flux,
                flux_err=stamp.flux_err,
                wcs_header=wcs_header,
                aperture_mask=stamp.aperture_mask,
                quality=stamp.quality,
            )
        except Exception:
            pass

print(json.dumps({'tpf_sectors_loaded': sorted(tpf_by_sector.keys()), 'tpf_cache_dir': str(cache_dir)}, indent=2, sort_keys=True))

# -----------------------------------------------------------------------------
# Run V08 per sector
# -----------------------------------------------------------------------------
by_sector = {}
for sector in sectors:
    lc0 = ds.lc_by_sector[sector]
    q = np.asarray(
        lc0.quality if getattr(lc0, 'quality', None) is not None else np.zeros(len(lc0.time)),
        dtype=np.int32,
    )
    lc_sec = btv.LightCurve(time=lc0.time, flux=lc0.flux, flux_err=lc0.flux_err, quality=q)
    tpf = tpf_by_sector.get(sector)
    if tpf is None:
        by_sector[str(sector)] = {'status': 'skipped', 'flags': ['MISSING_TPF'], 'metrics': {}}
        continue

    session = btv.VettingSession.from_api(
        lc=lc_sec,
        candidate=candidate,
        stellar=sh.STELLAR,
        tpf=tpf,
        network=False,
        tic_id=sh.TIC_ID,
    )
    r = session.run('V08')
    m = dict(r.metrics)
    by_sector[str(sector)] = {
        'status': r.status,
        'flags': r.flags,
        'metrics': {
            'centroid_shift_pixels': m.get('centroid_shift_pixels'),
            'shift_uncertainty_pixels': m.get('shift_uncertainty_pixels'),
            'significance_sigma': m.get('significance_sigma'),
            'centroid_shift_arcsec': m.get('centroid_shift_arcsec'),
            'n_in_transit_cadences': m.get('n_in_transit_cadences'),
            'n_out_of_transit_cadences': m.get('n_out_of_transit_cadences'),
        },
    }

print(json.dumps(by_sector, indent=2, sort_keys=True))


<details>
<summary><b>Expected Output (per-sector summary)</b></summary>

```text
{
  "55": {
    "flags": [],
    "metrics": {
      "centroid_shift_arcsec": 0.07941180258977476,
      "centroid_shift_pixels": 0.0037815144090368935,
      "n_in_transit_cadences": 243,
      "n_out_of_transit_cadences": 18391,
      "shift_uncertainty_pixels": 0.008188661269528373,
      "significance_sigma": 0.4617988563171682
    },
    "status": "ok"
  },
  "75": {
    "flags": [],
    "metrics": {
      "centroid_shift_arcsec": 0.303498360843917,
      "centroid_shift_pixels": 0.014452302897329383,
      "n_in_transit_cadences": 242,
      "n_out_of_transit_cadences": 18984,
      "shift_uncertainty_pixels": 0.0076705035355055596,
      "significance_sigma": 1.8841400477076813
    },
    "status": "ok"
  },
  "82": {
    "flags": [],
    "metrics": {
      "centroid_shift_arcsec": 0.09029882097528287,
      "centroid_shift_pixels": 0.004299943855965851,
      "n_in_transit_cadences": 243,
      "n_out_of_transit_cadences": 17648,
      "shift_uncertainty_pixels": 0.0027112940160366667,
      "significance_sigma": 1.5859378697156021
    },
    "status": "ok"
  },
  "83": {
    "flags": [],
    "metrics": {
      "centroid_shift_arcsec": 0.35706258426488147,
      "centroid_shift_pixels": 0.017002980203089595,
      "n_in_transit_cadences": 244,
      "n_out_of_transit_cadences": 16841,
      "shift_uncertainty_pixels": 0.0052391155564065735,
      "significance_sigma": 3.245391329896848
    },
    "status": "ok"
  }
}
```

</details>


In [None]:
# Plot V08: Centroid shift (single sector example)
out = {}

try:
    import matplotlib.pyplot as plt
    from bittr_tess_vetter.api import plot_centroid_shift
    PLOTTING_AVAILABLE = True
except Exception as e:
    PLOTTING_AVAILABLE = False
    out['plotting_error'] = str(e)

if PLOTTING_AVAILABLE:
    sector = max(int(s) for s in ds.lc_by_sector.keys())
    lc0 = ds.lc_by_sector[sector]
    q = np.asarray(
        lc0.quality if getattr(lc0, 'quality', None) is not None else np.zeros(len(lc0.time)),
        dtype=np.int32,
    )
    lc_sec = btv.LightCurve(time=lc0.time, flux=lc0.flux, flux_err=lc0.flux_err, quality=q)
    tpf = tpf_by_sector.get(sector)
    if tpf is None:
        raise RuntimeError(f'Missing TPF for sector {sector}')

    session = btv.VettingSession.from_api(
        lc=lc_sec,
        candidate=candidate,
        stellar=sh.STELLAR,
        tpf=tpf,
        network=False,
        tic_id=sh.TIC_ID,
    )
    r = session.run('V08')

    run_out_dir, docs_out_dir = sh.artifact_dirs(step_id='11_v08_centroid_shift')
    run_path = run_out_dir / 'V08_centroid_shift.png'
    docs_path = (docs_out_dir / 'V08_centroid_shift.png') if docs_out_dir is not None else None

    fig, ax = plt.subplots(figsize=(8, 5))
    plot_centroid_shift(r, ax=ax)
    ax.set_title(f'V08: Centroid shift (sector {sector})')
    fig.tight_layout()
    fig.savefig(run_path, dpi=150, bbox_inches='tight')
    if docs_path is not None:
        fig.savefig(docs_path, dpi=150, bbox_inches='tight')
    plt.show()

    out['sector_plotted'] = int(sector)
    out['run_plot_path'] = str(run_path)
    out['docs_plot_path'] = str(docs_path) if docs_path is not None else None

print(json.dumps(out, indent=2, sort_keys=True))


**Pre-rendered plot (no execution required):** `../artifacts/tutorial_toi-5807-incremental/11_v08_centroid_shift/V08_centroid_shift.png`

![V08: Centroid shift](../artifacts/tutorial_toi-5807-incremental/11_v08_centroid_shift/V08_centroid_shift.png)


<details>
<summary><b>Expected Output (plot cell)</b></summary>

```text
{
  "docs_plot_path": "docs/tutorials/artifacts/tutorial_toi-5807-incremental/11_v08_centroid_shift/V08_centroid_shift.png",
  "run_plot_path": "persistent_cache/tutorial_toi-5807-incremental/11_v08_centroid_shift/V08_centroid_shift.png",
  "sector_plotted": 83
}
```

</details>


<details>
<summary><b>Analysis</b></summary>

- **Flags:** none.
- **Per-sector summary:** shifts are small (all <0.02 px); the largest is sector 83 at ∼0.017 px with ∼3.25σ significance.
- **Why it’s useful:** a large, high-significance shift would strongly suggest an off-target source; we don’t see that.
- **Interpretation:** centroid evidence is mildly suggestive at most; interpret alongside V09 and V10.
- **Next step:** V09 (difference image localization).

</details>
