# Master Flat & Bad-Pixel Mask Creation

## Overview

This notebook builds the calibration files (flat-field gain maps + master bad-pixel mask) required by **remirpipe**.
It processes twilight-flat sequences acquired by REMIR across multiple nights.

**Note**: If you need to analyse a specific target over many months or years, split the data into 1–2 month time blocks. Create a separate calibration folder for each block (e.g. calibration_jan_feb/, calibration_mar_apr/, …) and use the corresponding calibration files when reducing each temporal subset — this accounts for the possible differences in the detector response over time.

## Input Data

Download the REMIR flat files from the [REM archive](http://www.rem.inaf.it) for the period you need (typically **1–2 months**).
Unzip the main folder — you will end up with this structure:

```
folder/
├── 20260101/
│   ├── Flat_J_dith0.fits.gz
│   ├── Flat_J_dith72.fits.gz
│   ├── ...
│   └── Flat_K_dithN.fits.gz
├── 20260102/
│   ├── Flat_J_dith0.fits.gz
│   ├── Flat_J_dith72.fits.gz
│   ├── ...
│   └── Flat_K_dithN.fits.gz
└── ...
```

Each subfolder is one night.
Files are grouped automatically by **filter × dither-wedge position** using FITS header keywords.

## Algorithm

For each (filter, dither) group the notebook:

1. **Illumination** — computes the 3σ-clipped median of the central 20 % of each frame.
2. **Sorts** frames by illumination (ascending).
3. **Trims** the lowest 10 % and highest 10 % — only the central 80 % enters the fit.
4. **Source masking** — normalises frames by illumination, builds a pixel-wise median reference, and flags pixels deviating > 4σ (MAD) as NaN (removes star trails).
5. **Per-pixel linear fit** — `pixel = gain × illumination + intercept` (NaN-aware OLS).
6. **Normalises** the gain map so that its median = 1.
7. **Flags bad pixels** — gain outside `[0.5, 2.0]`, non-finite, or fewer than 3 valid samples.

The master bad-pixel mask is the **AND** of all individual masks: a pixel is good only if it is good in every group.

## Output

All products are saved to `OUTPUT_FOLDER`:

| File | Description |
|------|-------------|
| `{FILTER}_dither{ANGLE}_flat.fits` | Normalised gain map (one per group) |
| `pixel_mask.fits` | Master bad-pixel mask (`0` = bad, `1` = good) |

## Connecting to remirpipe

Copy or move the output files into the pipeline's data folder, then point the configuration to it:

```yaml
# config.yaml
```
paths:
  data_folder: path/to/OUTPUT_FOLDER   # folder containing your new pixel_mask.fits and *_flat.fits
  
  
## Calibration of data before may 2025

### Set `old = True` for data taken BEFORE May 2025 (old headers use DITHID instead of DITHANGL)


In [None]:
# ============================================================================
# Flat-Field Calibration Pipeline
# ============================================================================
#
# For each (filter, dither-wedge) group:
#   1. Compute illumination = median of 3σ-clipped central 20 % of image
#   2. Sort frames by illumination
#   3. Keep only the central 80 % (trim lowest 10 % + highest 10 %)
#   4. Mask stray sources (stars): normalise by illumination, median-stack
#      to get a source-free reference, σ-clip deviations → NaN
#   5. NaN-aware per-pixel linear fit:
#          pixel_value = gain × illumination + intercept
#   6. Normalise gain map so that median = 1
#   7. Flag bad pixels (gain outside thresholds, non-finite, few samples)
#
# Output: one flat (gain map) per (filter, dither), one master bad-pixel mask.

import numpy as np
import glob
import os
from astropy.io import fits
from astropy.stats import sigma_clip
import matplotlib.pyplot as plt
from collections import defaultdict

# ============================================================================
# CONFIGURATION
# ============================================================================

DATA_FOLDER   = "path/to/your/flat"
OUTPUT_FOLDER = "./calibration_folder/"

FILTER_KEY = "FILTER"

# # # # # # # 
old = False
# # # # # # #

if old:
    DITHER_KEY = "DITHID"
else:
    DITHER_KEY = "DITHANGL"

# Algorithm parameters
MIN_FLATS    = 5
GAIN_MIN     = 0.5    # min for flag as bad pixel
GAIN_MAX     = 2.0    # max for flag as bad pixel
SOURCE_SIGMA = 4.0    # σ threshold for masking stars

In [None]:
def compute_illumination(data):
    """
    Illumination level for one frame:
    3σ-clipped median of the central 20 % of the image.
    """
    ny, nx = data.shape
    y0, y1 = int(ny * 0.4), int(ny * 0.6)
    x0, x1 = int(nx * 0.4), int(nx * 0.6)
    central = data[y0:y1, x0:x1].ravel()
    central = central[np.isfinite(central)]
    clipped = sigma_clip(central, sigma=3, cenfunc="median", stdfunc="std")
    return float(np.median(clipped.data[~clipped.mask]))


def load_flat_files(data_folder):
    """
    Discover all .fits.gz / .fits files under *data_folder* (recursive).
    Group by (filter, dither).  Each entry stores the image array, its
    illumination level, and the FITS header.
    Groups are returned sorted by illumination (ascending).
    """
    pattern_gz   = os.path.join(data_folder, "**", "*.fits.gz")
    pattern_fits = os.path.join(data_folder, "**", "*.fits")
    files = sorted(
        glob.glob(pattern_gz, recursive=True)
        + glob.glob(pattern_fits, recursive=True)
    )
    print(f"Found {len(files)} FITS files")

    groups = defaultdict(list)

    for filepath in files:
        try:
            with fits.open(filepath) as hdul:
                header = hdul[0].header
                data   = hdul[0].data.astype(np.float64)

                filt   = header.get(FILTER_KEY, "UNKNOWN")
                dither = header.get(DITHER_KEY)
                if dither is None:
                    print(f"  Skipping (no {DITHER_KEY}): {os.path.basename(filepath)}")
                    continue
                dither = round(float(dither), 1)

                illum = compute_illumination(data)
                if illum <= 0:
                    print(f"  Skipping (illum <= 0): {os.path.basename(filepath)}")
                    continue

                groups[(filt, dither)].append({
                    "filepath":     filepath,
                    "data":         data,
                    "illumination": illum,
                    "header":       header,
                })
                print(f"  {os.path.basename(filepath)} | {filt} | dither={dither} | illum={illum:.1f}")

        except Exception as e:
            print(f"  Error loading {filepath}: {e}")

    # Sort each group by illumination
    for key in groups:
        groups[key] = sorted(groups[key], key=lambda f: f["illumination"])

    return groups

In [None]:
print("Loading flat files...")
flat_groups = load_flat_files(DATA_FOLDER)

print(f"\nFound {len(flat_groups)} (filter, dither) groups:")
for key, group in flat_groups.items():
    illums = [f["illumination"] for f in group]
    print(f"  {key}: {len(group)} frames, illum range [{min(illums):.0f} – {max(illums):.0f}]")

In [None]:
def mask_sources(data_stack, x, sigma_thresh=4.0):
    """
    Detect and mask stray sources (stars) in the flat-field stack.

    1. Normalise each frame by its illumination → all frames look like
       the same flat-field pattern.
    2. Pixel-wise median of the normalised stack → source-free reference.
    3. For each frame, compute deviation from that reference.
    4. Estimate per-pixel noise via the MAD across frames.
    5. Flag samples deviating > sigma_thresh × sigma_MAD → set to NaN.
    """
    norm_stack = data_stack / x[:, None, None]           # (N, ny, nx)
    ref = np.nanmedian(norm_stack, axis=0)               # (ny, nx)

    diff     = norm_stack - ref[None, :, :]              # (N, ny, nx)
    med_diff = np.nanmedian(diff, axis=0)                # (ny, nx)
    abs_dev  = np.abs(diff - med_diff[None, :, :])       # (N, ny, nx)
    mad      = np.nanmedian(abs_dev, axis=0)             # (ny, nx)
    sigma    = 1.4826 * mad                              # MAD → Gaussian σ

    outlier  = abs_dev > (sigma_thresh * sigma[None, :, :])
    n_masked = int(np.sum(outlier & np.isfinite(data_stack)))

    return np.where(outlier, np.nan, data_stack), n_masked


def create_calibration_maps(flat_group, source_sigma=SOURCE_SIGMA):
    """
    Build gain map, intercept map, and bad-pixel map.

    1. Frames are sorted by illumination.
    2. Trim lowest / highest 10 % → central 80 %.
    3. Mask stray sources with MAD-based sigma-clipping.
    4. NaN-aware OLS per pixel:  pixel = gain × illumination + intercept.
    5. Normalise gain to median = 1.
    6. Flag bad pixels.
    """
    n_total = len(flat_group)

    # --- Trim 10 % from each end ---
    lo = max(1, int(round(n_total * 0.10)))
    hi = n_total - lo
    trimmed = flat_group[lo:hi]
    n_used  = len(trimmed)

    if n_used < MIN_FLATS:
        print(f"  Warning: only {n_used} frames after trimming (need {MIN_FLATS}), skipping")
        return None, None, None, 0

    shape = trimmed[0]["data"].shape
    data_stack = np.array([f["data"] for f in trimmed], dtype=np.float64)
    x = np.array([f["illumination"] for f in trimmed], dtype=np.float64)

    print(f"  Using {n_used}/{n_total} frames (trimmed 10 % each end)")
    print(f"  Illumination range: {x.min():.1f} – {x.max():.1f} ADU")
    print(f"  Image size: {shape[0]} x {shape[1]}")

    # --- Mask stray sources (stars) ---
    data_stack, n_masked = mask_sources(data_stack, x, sigma_thresh=source_sigma)
    frac = 100 * n_masked / data_stack.size
    print(f"  Source masking: {n_masked} samples masked ({frac:.3f} %)")

    # --- NaN-aware OLS: y = slope * x + intercept ---
    x3      = x[:, None, None]                            # (N, 1, 1)
    valid   = np.isfinite(data_stack).astype(np.float64)   # (N, ny, nx)
    n_valid = valid.sum(axis=0)                            # (ny, nx)

    safe = np.where(np.isfinite(data_stack), data_stack, 0.0)

    sum_x  = np.sum(x3 * valid, axis=0)
    sum_y  = np.sum(safe, axis=0)
    sum_xx = np.sum(x3**2 * valid, axis=0)
    sum_xy = np.sum(x3 * safe, axis=0)

    denom     = n_valid * sum_xx - sum_x**2
    slope     = np.where(denom > 0, (n_valid * sum_xy - sum_x * sum_y) / denom, np.nan)
    intercept = np.where(denom > 0, (sum_y * sum_xx - sum_x * sum_xy) / denom, np.nan)

    # --- Normalise gain so median = 1 ---
    median_gain = np.nanmedian(slope)
    if median_gain > 0:
        slope /= median_gain
        print(f"  Normalised gain (raw median slope = {median_gain:.4f})")

    # --- Bad-pixel mask: 1 = good, 0 = bad ---
    badpix = np.ones(shape, dtype=np.uint8)
    badpix[slope < GAIN_MIN]    = 0
    badpix[slope > GAIN_MAX]    = 0
    badpix[~np.isfinite(slope)] = 0
    badpix[n_valid < 3]         = 0

    n_bad = int(np.sum(badpix == 0))
    print(f"  Bad pixels: {n_bad} ({100 * n_bad / badpix.size:.2f} %)")

    return slope, intercept, badpix, n_used

In [None]:
def plot_diagnostics(gain, intercept, badpix, filt, dither):
    """Diagnostic 2x2 panel for one (filter, dither) calibration."""
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))

    ax = axes[0, 0]
    im = ax.imshow(gain, origin="lower", cmap="viridis", vmin=0.8, vmax=1.2)
    ax.set_title("Gain Map (Flat Field)")
    plt.colorbar(im, ax=ax, label="Relative Gain")

    ax = axes[0, 1]
    valid = gain[np.isfinite(gain)].ravel()
    ax.hist(valid, bins=100, range=(0.5, 1.5), alpha=0.7)
    ax.axvline(1.0, color="r", linestyle="--", label="Unity")
    ax.axvline(GAIN_MIN, color="orange", linestyle=":", label=f"Bad < {GAIN_MIN}")
    ax.axvline(GAIN_MAX, color="orange", linestyle=":", label=f"Bad > {GAIN_MAX}")
    ax.set_xlabel("Gain"); ax.set_ylabel("N pixels")
    ax.set_title("Gain Distribution"); ax.legend()

    ax = axes[1, 0]
    vmin, vmax = np.nanpercentile(intercept, [5, 95])
    im = ax.imshow(intercept, origin="lower", cmap="RdBu_r", vmin=vmin, vmax=vmax)
    ax.set_title("Intercept Map (Thermal / Offset)")
    plt.colorbar(im, ax=ax, label="ADU")

    ax = axes[1, 1]
    ax.imshow(badpix, origin="lower", cmap="gray", vmin=0, vmax=1)
    n_bad = np.sum(badpix == 0)
    ax.set_title(f"Bad Pixel Map ({n_bad} bad, {100*n_bad/badpix.size:.2f} %)")

    if old:
        plt.suptitle(f"Diagnostics: {filt}  dither={dither*72-72:.0f}", fontsize=14)
    else:
        plt.suptitle(f"Diagnostics: {filt}  dither={dither:.0f}", fontsize=14)
    plt.tight_layout()
    plt.show()

In [None]:
print("=" * 80)
print("Creating Calibration Maps")
print("=" * 80)
print(f"Gain thresholds: [{GAIN_MIN}, {GAIN_MAX}]")
print(f"Source masking sigma: {SOURCE_SIGMA}")
print()

os.makedirs(OUTPUT_FOLDER, exist_ok=True)

results     = {}
badpix_maps = []

for (filt, dither), group in flat_groups.items():
    print(f"\nProcessing: {filt}  dither={dither}  ({len(group)} frames)")

    gain, intercept, badpix, n_used = create_calibration_maps(group)

    if gain is not None:
        results[(filt, dither)] = {
            "gain":     gain,
            "n_frames": n_used,
            "header":   group[0]["header"],
        }
        badpix_maps.append(badpix)
        plot_diagnostics(gain, intercept, badpix, filt, dither)

print("\n" + "=" * 80)
print(f"Processed {len(results)} (filter, dither) groups.")

In [None]:
print("=" * 80)
print("Master Bad Pixel Mask")
print("=" * 80)

if not badpix_maps:
    print("ERROR: No bad-pixel maps to combine!")
else:
    master_badpix = badpix_maps[0].copy()
    for bpm in badpix_maps[1:]:
        master_badpix &= bpm
    
    # --- Apply geometric mask (e.g. triangular corner) ---
    # # you can also remove it, it allows to mask entirely the bottom left region which is full
    # # of bad pixels
    ny, nx = master_badpix.shape
    y, x = np.indices((ny, nx))
    geo_mask = (x + y <= 110)
    master_badpix[geo_mask] = 0
    n_geo = int(np.sum(geo_mask))
    print(f"  Geometric mask: {n_geo} pixels forced bad (x + y <= 110)")

    n_bad   = int(np.sum(master_badpix == 0))
    n_total = master_badpix.size
    print(f"  Total pixels:  {n_total}")
    print(f"  Bad pixels:    {n_bad} ({100*n_bad/n_total:.2f} %)")
    print(f"  Good pixels:   {n_total - n_bad} ({100*(n_total-n_bad)/n_total:.2f} %)")

    fig, axes = plt.subplots(1, 3, figsize=(18, 6))

    axes[0].imshow(master_badpix, origin="lower", cmap="gray", vmin=0, vmax=1)
    axes[0].set_title(f"Master Bad Pixel Map\n({n_bad} bad, {100*n_bad/n_total:.2f} %)")
    axes[0].set_xlabel("X pixel"); axes[0].set_ylabel("Y pixel")

    axes[1].imshow((master_badpix == 0).astype(float), origin="lower", cmap="Reds")
    axes[1].set_title("Bad Pixel Locations (red = bad)")
    axes[1].set_xlabel("X pixel"); axes[1].set_ylabel("Y pixel")

    axes[2].plot(np.sum(master_badpix == 0, axis=1), label="per row", alpha=0.7)
    axes[2].plot(np.sum(master_badpix == 0, axis=0), label="per col", alpha=0.7)
    axes[2].set_xlabel("Row / Column"); axes[2].set_ylabel("N bad pixels")
    axes[2].set_title("Bad Pixel Distribution")
    axes[2].legend(); axes[2].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

In [None]:
print("=" * 80)
print("Saving final products")
print("=" * 80)

os.makedirs(OUTPUT_FOLDER, exist_ok=True)

for (filt, dither), res in results.items():
    if old:
        base = f"{filt}_dither{dither*72-72:.0f}"
    else:
        base = f"{filt}_dither{dither:.0f}"

    flat_path = os.path.join(OUTPUT_FOLDER, f"{base}_flat.fits")

    hdu = fits.PrimaryHDU(res["gain"].astype(np.float32))
    hdu.header["FILTER"]  = filt
    hdu.header["DITHER"]  = dither * 72 - 72 if old else dither
    hdu.header["NFRAMES"] = res["n_frames"]
    hdu.header["PRODUCT"] = "GAIN_MAP"
    hdu.header["BUNIT"]   = "relative gain"
    hdu.header["GAINMIN"] = GAIN_MIN
    hdu.header["GAINMAX"] = GAIN_MAX
    hdu.header["SRCSIG"]  = SOURCE_SIGMA
    ref = res["header"]
    for key in ["INSTRUME", "TELESCOP", "DETECTOR"]:
        if key in ref:
            hdu.header[key] = ref[key]
    hdu.writeto(flat_path, overwrite=True)
    print(f"  Saved flat: {flat_path}")

mask_path = os.path.join(OUTPUT_FOLDER, "pixel_mask.fits")
hdu = fits.PrimaryHDU(master_badpix)
hdu.header["PRODUCT"]  = "MASTER_BADPIX"
hdu.header["BUNIT"]    = "mask"
hdu.header["MASKCONV"] = "0=bad, 1=good"
hdu.header["NMAPS"]    = len(badpix_maps)
hdu.header["NBAD"]     = int(np.sum(master_badpix == 0))
hdu.header["NGOOD"]    = int(np.sum(master_badpix == 1))
hdu.header["COMMENT"]  = "AND of all filter/dither bad-pixel maps"
hdu.writeto(mask_path, overwrite=True)
print(f"  Saved pixel mask: {mask_path}")

print("\nDone!")