# Step 12: V09 â€” Difference image localization (per sector)

Goal: compute a difference image (out-of-transit minus in-transit) and identify where the strongest transit-like signal appears on the detector.

Why this matters:
- If the difference-image signal peaks away from the target location, the transit may originate from a nearby contaminant.

Notes:
- This check requires a per-sector TPF.
- Difference images can be fragile for shallow signals or challenging apertures; interpret alongside V08 (centroid shift) and V10 (aperture dependence).


In [None]:
from __future__ import annotations

from pathlib import Path
import json
import sys

import numpy as np

tutorial_dir = Path('docs/tutorials/tutorial_toi-5807-incremental').resolve()
sys.path.insert(0, str(tutorial_dir))

import toi5807_shared as sh
import tess_vetter.api as btv

ds = sh.load_dataset()

stitched = sh.stitch_pdcsap(ds)
depth_ppm, _ = sh.estimate_depth_ppm(stitched)
candidate = sh.make_candidate(depth_ppm)

# -----------------------------------------------------------------------------
# Load per-sector TPF stamps (download missing sectors once, then cache on disk)
# -----------------------------------------------------------------------------
tpf_by_sector: dict[int, btv.TPFStamp] = {int(k): v for k, v in ds.tpf_by_sector.items()}
cache_dir = Path('persistent_cache/tutorial_toi-5807-incremental/tpfs')
cache_dir.mkdir(parents=True, exist_ok=True)

sectors = sorted(int(s) for s in ds.lc_by_sector.keys())
missing = [s for s in sectors if s not in tpf_by_sector]

if missing:
    import lightkurve as lk

    for sector in missing:
        npz = cache_dir / f'sector{sector}_tpf.npz'
        if npz.exists():
            d = np.load(npz, allow_pickle=True)
            try:
                from astropy.wcs import WCS

                wcs = WCS(d['wcs_header'].item()) if 'wcs_header' in d else None
            except Exception:
                wcs = d['wcs_header'].item() if 'wcs_header' in d else None

            tpf_by_sector[sector] = btv.TPFStamp(
                time=np.asarray(d['time'], dtype=np.float64),
                flux=np.asarray(d['flux'], dtype=np.float64),
                flux_err=np.asarray(d['flux_err'], dtype=np.float64) if 'flux_err' in d else None,
                wcs=wcs,
                aperture_mask=np.asarray(d['aperture_mask'], dtype=bool) if 'aperture_mask' in d else None,
                quality=np.asarray(d['quality'], dtype=np.int32) if 'quality' in d else None,
            )
            continue

        search = lk.search_targetpixelfile(f'TIC {sh.TIC_ID}', sector=int(sector), exptime=120)
        tpf = search.download() if len(search) else None
        if tpf is None:
            print(f'No TPF available for sector {sector}')
            continue

        stamp = btv.TPFStamp(
            time=np.asarray(tpf.time.value, dtype=np.float64),
            flux=np.asarray(tpf.flux.value, dtype=np.float64),
            flux_err=np.asarray(tpf.flux_err.value, dtype=np.float64),
            wcs=tpf.wcs,
            aperture_mask=np.asarray(tpf.pipeline_mask, dtype=bool),
            quality=np.asarray(tpf.quality, dtype=np.int32),
        )
        tpf_by_sector[sector] = stamp

        # Cache to disk (best-effort)
        try:
            wcs_header = dict(tpf.wcs.to_header()) if getattr(tpf, 'wcs', None) is not None else None
            np.savez(
                npz,
                time=stamp.time,
                flux=stamp.flux,
                flux_err=stamp.flux_err,
                wcs_header=wcs_header,
                aperture_mask=stamp.aperture_mask,
                quality=stamp.quality,
            )
        except Exception:
            pass

print(json.dumps({'tpf_sectors_loaded': sorted(tpf_by_sector.keys()), 'tpf_cache_dir': str(cache_dir)}, indent=2, sort_keys=True))

# -----------------------------------------------------------------------------
# Run V09 per sector
# -----------------------------------------------------------------------------
by_sector = {}
for sector in sectors:
    lc0 = ds.lc_by_sector[sector]
    q = np.asarray(
        lc0.quality if getattr(lc0, 'quality', None) is not None else np.zeros(len(lc0.time)),
        dtype=np.int32,
    )
    lc_sec = btv.LightCurve(time=lc0.time, flux=lc0.flux, flux_err=lc0.flux_err, quality=q)
    tpf = tpf_by_sector.get(sector)
    if tpf is None:
        by_sector[str(sector)] = {'status': 'skipped', 'flags': ['MISSING_TPF'], 'metrics': {}}
        continue

    session = btv.VettingSession.from_api(
        lc=lc_sec,
        candidate=candidate,
        stellar=sh.STELLAR,
        tpf=tpf,
        network=False,
        tic_id=sh.TIC_ID,
    )
    r = session.run('V09')
    m = dict(r.metrics)
    by_sector[str(sector)] = {
        'status': r.status,
        'flags': r.flags,
        'metrics': {
            'max_depth_ppm': m.get('max_depth_ppm'),
            'target_offset_pixels': m.get('target_offset_pixels'),
            'target_offset_arcsec': m.get('target_offset_arcsec'),
            'significance_sigma': m.get('significance_sigma'),
        },
    }

print(json.dumps(by_sector, indent=2, sort_keys=True))


<details>
<summary><b>Expected Output (per-sector summary)</b></summary>

```text
{
  "55": {
    "flags": [
      "DIFFIMG_MAX_AT_EDGE",
      "DIFFIMG_UNRELIABLE"
    ],
    "metrics": {
      "max_depth_ppm": 3721467.367126439,
      "significance_sigma": null,
      "target_offset_arcsec": null,
      "target_offset_pixels": null
    },
    "status": "ok"
  },
  "75": {
    "flags": [
      "DIFFIMG_MAX_AT_EDGE",
      "DIFFIMG_TARGET_DEPTH_NONPOSITIVE",
      "DIFFIMG_UNRELIABLE"
    ],
    "metrics": {
      "max_depth_ppm": 557510.2548859358,
      "significance_sigma": null,
      "target_offset_arcsec": null,
      "target_offset_pixels": null
    },
    "status": "ok"
  },
  "82": {
    "flags": [
      "DIFFIMG_MAX_AT_EDGE",
      "DIFFIMG_TARGET_DEPTH_NONPOSITIVE",
      "DIFFIMG_UNRELIABLE"
    ],
    "metrics": {
      "max_depth_ppm": 546367.7780012799,
      "significance_sigma": null,
      "target_offset_arcsec": null,
      "target_offset_pixels": null
    },
    "status": "ok"
  },
  "83": {
    "flags": [
      "DIFFIMG_MAX_AT_EDGE",
      "DIFFIMG_TARGET_DEPTH_NONPOSITIVE",
      "DIFFIMG_UNRELIABLE"
    ],
    "metrics": {
      "max_depth_ppm": 224066.93494969374,
      "significance_sigma": null,
      "target_offset_arcsec": null,
      "target_offset_pixels": null
    },
    "status": "ok"
  }
}
```

</details>


In [None]:
# Plot V09: Difference image (single sector example)
out = {}

try:
    import matplotlib.pyplot as plt
    from tess_vetter.api import plot_difference_image
    PLOTTING_AVAILABLE = True
except Exception as e:
    PLOTTING_AVAILABLE = False
    out['plotting_error'] = str(e)

if PLOTTING_AVAILABLE:
    sector = max(int(s) for s in ds.lc_by_sector.keys())
    lc0 = ds.lc_by_sector[sector]
    q = np.asarray(
        lc0.quality if getattr(lc0, 'quality', None) is not None else np.zeros(len(lc0.time)),
        dtype=np.int32,
    )
    lc_sec = btv.LightCurve(time=lc0.time, flux=lc0.flux, flux_err=lc0.flux_err, quality=q)
    tpf = tpf_by_sector.get(sector)
    if tpf is None:
        raise RuntimeError(f'Missing TPF for sector {sector}')

    session = btv.VettingSession.from_api(
        lc=lc_sec,
        candidate=candidate,
        stellar=sh.STELLAR,
        tpf=tpf,
        network=False,
        tic_id=sh.TIC_ID,
    )
    r = session.run('V09')

    run_out_dir, docs_out_dir = sh.artifact_dirs(step_id='12_v09_difference_image')
    run_path = run_out_dir / 'V09_difference_image.png'
    docs_path = (docs_out_dir / 'V09_difference_image.png') if docs_out_dir is not None else None

    fig, ax = plt.subplots(figsize=(7, 5))
    plot_difference_image(r, ax=ax)
    ax.set_title(f'V09: Difference image (sector {sector})')
    fig.tight_layout()
    fig.savefig(run_path, dpi=150, bbox_inches='tight')
    if docs_path is not None:
        fig.savefig(docs_path, dpi=150, bbox_inches='tight')
    plt.show()

    out['sector_plotted'] = int(sector)
    out['flags'] = r.flags
    out['run_plot_path'] = str(run_path)
    out['docs_plot_path'] = str(docs_path) if docs_path is not None else None

print(json.dumps(out, indent=2, sort_keys=True))


**Pre-rendered plot (no execution required):** `../artifacts/tutorial_toi-5807-incremental/12_v09_difference_image/V09_difference_image.png`

![V09: Difference image](../artifacts/tutorial_toi-5807-incremental/12_v09_difference_image/V09_difference_image.png)


<details>
<summary><b>Expected Output (plot cell)</b></summary>

```text
{
  "docs_plot_path": "docs/tutorials/artifacts/tutorial_toi-5807-incremental/12_v09_difference_image/V09_difference_image.png",
  "flags": [
    "DIFFIMG_MAX_AT_EDGE",
    "DIFFIMG_TARGET_DEPTH_NONPOSITIVE",
    "DIFFIMG_UNRELIABLE"
  ],
  "run_plot_path": "persistent_cache/tutorial_toi-5807-incremental/12_v09_difference_image/V09_difference_image.png",
  "sector_plotted": 83
}
```

</details>


<details>
<summary><b>Analysis</b></summary>

- **Flags:** DIFFIMG_UNRELIABLE (+ related edge/target-depth flags).
- **Result:** per-sector V09 is flagged unreliable (e.g. `DIFFIMG_MAX_AT_EDGE`, `DIFFIMG_TARGET_DEPTH_NONPOSITIVE`).
- **Why it’s useful:** it tells us not to over-interpret the difference image for this shallow signal/TPF setup.
- **Interpretation:** treat V09 as inconclusive here; rely more heavily on V08 (centroid shift) and V10 (aperture dependence).
- **Next step:** V10 (aperture dependence).

</details>
