<a href="https://colab.research.google.com/github/larasauser/master/blob/main/Whitt_8.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Mon Nov  3 10:18:06 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   65C    P8             10W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
!pip install rasterio tqdm scikit-image scikit-learn --quiet

import os
from glob import glob
from datetime import datetime
from tqdm import tqdm
import numpy as np
import rasterio
from rasterio.warp import reproject, Resampling
from scipy.linalg import lstsq
import warnings
warnings.filterwarnings('ignore')

from google.colab import drive
drive.mount('/content/drive')

In [None]:
# === 1) Dossiers Google Drive ===
DRIVE_L8_FOLDER = '/content/drive/MyDrive/Whitt/NDVI_herens_Landsat8'
DRIVE_S2_FOLDER = '/content/drive/MyDrive/Whitt/NDVI_herens_Sentinel2'
DRIVE_MODIS_FOLDER = '/content/drive/MyDrive/Whitt/NDVI_herens_MODIS'
DRIVE_MASKED_FOLDER = '/content/drive/MyDrive/Whitt/NDVI_herens_Landsat8_holes'
OUTPUT_FOLDER = '/content/drive/MyDrive/Whitt/egfwh_herens_outputs_multisource'
os.makedirs(OUTPUT_FOLDER, exist_ok=True)

# === 2) Paramètres globaux ===
WINDOW_METERS = 400.0
KAPPA = 5.0  # Paramètre du filtre Whittaker
MIN_OVERLAP = 3

In [None]:
def parse_date_from_filename(fn):
    import re
    m = re.search(r'(\d{4})-(\d{2})-(\d{2})', os.path.basename(fn))
    if m:
        return datetime.strptime('-'.join(m.groups()), '%Y-%m-%d').date()
    return datetime.fromtimestamp(os.path.getmtime(fn)).date()

def read_singleband_tif(path):
    with rasterio.open(path) as src:
        arr = src.read(1).astype(np.float32)
        meta = src.meta.copy()
    nd = meta.get('nodata', None)
    if nd is not None:
        arr[arr == nd] = np.nan
    return arr, meta

def resample_to_target(src_arr, src_meta, target_meta):
    dst_arr = np.full((target_meta['height'], target_meta['width']), np.nan, dtype=np.float32)
    reproject(
        source=src_arr,
        destination=dst_arr,
        src_transform=src_meta['transform'],
        src_crs=src_meta['crs'],
        dst_transform=target_meta['transform'],
        dst_crs=target_meta['crs'],
        resampling=Resampling.bilinear
    )
    return dst_arr

In [None]:
# === Chargement des fichiers ===
landsat_files = sorted(glob(os.path.join(DRIVE_L8_FOLDER, '*.tif')))
sentinel_files = sorted(glob(os.path.join(DRIVE_S2_FOLDER, '*.tif')))
modis_files = sorted(glob(os.path.join(DRIVE_MODIS_FOLDER, '*.tif')))
masked_files = sorted(glob(os.path.join(DRIVE_MASKED_FOLDER, '*.tif')))

# --- Choisir Landsat comme référence spatiale ---
ref_arr, ref_meta = read_singleband_tif(landsat_files[0])
H_ref, W_ref = ref_arr.shape
PIXEL_SIZE_M = abs(ref_meta['transform'][0])
print("Référence Landsat:", landsat_files[0], "shape:", (H_ref, W_ref))

# --- Charger Landsat ---
landsat_stack, landsat_dates = [], []
for f in landsat_files:
    arr, meta = read_singleband_tif(f)
    if arr.shape != (H_ref, W_ref):
        arr = resample_to_target(arr, meta, ref_meta)
    landsat_stack.append(arr)
    landsat_dates.append(parse_date_from_filename(f))
landsat_stack = np.stack(landsat_stack)

# --- Charger Sentinel-2 ---
sentinel_stack, sentinel_dates = [], []
for f in sentinel_files:
    arr, meta = read_singleband_tif(f)
    if arr.shape != (H_ref, W_ref):
        arr = resample_to_target(arr, meta, ref_meta)
    sentinel_stack.append(arr)
    sentinel_dates.append(parse_date_from_filename(f))
sentinel_stack = np.stack(sentinel_stack)

# --- Fusion Landsat/Sentinel par Maximum Value Composite (MVC) ---
T = min(len(landsat_dates), len(sentinel_dates))
SL_stack = np.nanmax(np.stack([landsat_stack[:T], sentinel_stack[:T]]), axis=0)
SL_dates = landsat_dates[:T]
print("SL-NDVI fusionné:", SL_stack.shape)

# --- Alignement MODIS ---
modis_stack = np.full_like(SL_stack, np.nan)
modis_map = {parse_date_from_filename(f): f for f in modis_files}
modis_dates_list = sorted(modis_map.keys())

for ti, d in enumerate(SL_dates):
    diffs = [(abs((md - d).days), md) for md in modis_dates_list]
    nearest_md = min(diffs, key=lambda x: x[0])[1]
    mod_file = modis_map[nearest_md]
    arr, meta = read_singleband_tif(mod_file)
    arr_resampled = resample_to_target(arr, meta, ref_meta)
    modis_stack[ti] = arr_resampled

print("MODIS aligné:", modis_stack.shape)

In [None]:
def pearson_r(a, b):
    valid = ~np.isnan(a) & ~np.isnan(b)
    if valid.sum() < 2: return np.nan
    x, y = a[valid], b[valid]
    xm, ym = x.mean(), y.mean()
    num = np.sum((x - xm) * (y - ym))
    den = np.sqrt(np.sum((x - xm)**2) * np.sum((y - ym)**2))
    return num / den if den != 0 else np.nan

def compute_M_reference_pixel(i, j, modis_stack, sl_target_ts, radius_pix, corr_threshold=0.3):
    T, H, W = modis_stack.shape
    i0, i1 = max(0, i - radius_pix), min(H, i + radius_pix + 1)
    j0, j1 = max(0, j - radius_pix), min(W, j + radius_pix + 1)

    corrs, mseries = [], []
    for ii in range(i0, i1):
        for jj in range(j0, j1):
            ms = modis_stack[:, ii, jj]
            r = pearson_r(ms, sl_target_ts)
            if not np.isnan(r) and r > corr_threshold:
                corrs.append(r)
                mseries.append(ms)

    if len(corrs) == 0: return None

    R = np.array(corrs)
    ms_valid = np.array(mseries)
    R = (R - R.min()) / (R.max() - R.min()) if R.max() != R.min() else np.ones_like(R)
    weights = R / R.sum()

    Mref = np.zeros(T)
    support = np.zeros(T)
    for w, ms in zip(weights, ms_valid):
        valid_t = ~np.isnan(ms)
        Mref[valid_t] += w * ms[valid_t]
        support[valid_t] += w
    Mref[support == 0] = np.nan
    return Mref

def estimate_linear_transfer(M_ref, SL_ts):
    valid = ~np.isnan(M_ref) & ~np.isnan(SL_ts)
    if valid.sum() < 2: return None
    A = np.vstack([M_ref[valid], np.ones(valid.sum())]).T
    y = SL_ts[valid]
    sol, *_ = lstsq(A, y)
    return sol[0], sol[1]

def whittaker_smoother(y, kappa=5.0):
    n = len(y)
    mask = ~np.isnan(y)
    if mask.sum() == 0: return np.full_like(y, np.nan)
    out = np.full_like(y, np.nan)
    idx = np.where(mask)[0]
    runs = np.split(idx, np.where(np.diff(idx) != 1)[0] + 1)
    for run in runs:
        s, e = run[0], run[-1] + 1
        seg = y[s:e]
        m = len(seg)
        if m <= 2:
            out[s:e] = seg
            continue
        D = np.zeros((m - 2, m))
        for r in range(m - 2): D[r, r:r+3] = [1, -2, 1]
        A_mat = np.eye(m) + kappa * (D.T @ D)
        out[s:e] = np.linalg.solve(A_mat, seg)
    return out

In [None]:
# === 6) EGF full execution + Whittaker smoothing ONLY on gaps + export ===
from tqdm import tqdm
import math
import rasterio

# --- Safety checks (adapt if variable names differ) ---
# SL_stack : (T, H, W) fused Landsat/Sentinel NDVI (SL-NDVI)
# SL_dates : list of datetime.date for each SL_stack slice
# modis_stack : (T, H, W) MODIS NDVI aligned to SL_dates
# masked_files : list of masked Landsat paths (with NaNs for gaps)
# ref_meta : rasterio meta of reference grid (Landsat)

assert 'SL_stack' in globals(), "SL_stack not found. Exécute le bloc de chargement/fusion (Bloc 4)."
assert 'modis_stack' in globals(), "modis_stack not found. Exécute l'alignement MODIS (Bloc 4)."
assert 'ref_meta' in globals(), "ref_meta not found. Charge une image de référence Landsat."

T, H, W = SL_stack.shape
print("T,H,W =", T, H, W)

# --- Make sure we have a landsat_masked_stack aligned to ref_meta (load if needed) ---
# If landsat_masked_stack already exists from previous blocks, keep it; else build it now.
if 'landsat_masked_stack' not in globals():
    print("Building landsat_masked_stack from masked_files (resampling to reference)...")
    landsat_masked_stack = []
    masked_dates = []
    for f in masked_files:
        arr, meta = read_singleband_tif(f)
        if arr.shape != (H, W):
            arr = resample_to_target(arr, meta, ref_meta)
        landsat_masked_stack.append(arr)
        masked_dates.append(parse_date_from_filename(f))
    landsat_masked_stack = np.stack(landsat_masked_stack, axis=0)
    print("landsat_masked_stack shape:", landsat_masked_stack.shape)
else:
    masked_dates = [parse_date_from_filename(os.path.basename(p)) if isinstance(p, str) else None for p in masked_files]

# --- 6.1) Estimate M_ref, A, A0 for each pixel ---
radius_pix = int(round((WINDOW_METERS / 2.0) / PIXEL_SIZE_M))
print("Window meters:", WINDOW_METERS, "-> radius_pix:", radius_pix)

A = np.full((H, W), np.nan, dtype=float)
A0 = np.full((H, W), np.nan, dtype=float)
Mref_stack = np.full((T, H, W), np.nan, dtype=float)

print("Estimating M_reference and linear transfer per pixel (this may take time)...")
for i in tqdm(range(H), desc='rows'):
    for j in range(W):
        sl_target_ts = SL_stack[:, i, j]              # SL-NDVI time series for pixel
        Mref = compute_M_reference_pixel(i, j, modis_stack, sl_target_ts, radius_pix, corr_threshold=0.3)
        if Mref is None:
            # fallback: local mean in window (use only valid MODIS values)
            i0, i1 = max(0, i - radius_pix), min(H, i + radius_pix + 1)
            j0, j1 = max(0, j - radius_pix), min(W, j + radius_pix + 1)
            block = modis_stack[:, i0:i1, j0:j1]         # shape (T, nrows, ncols)
            if np.all(np.isnan(block)):
                continue
            Mref_alt = np.nanmean(block.reshape(T, -1), axis=1)
            Mref = Mref_alt
        Mref_stack[:, i, j] = Mref
        est = estimate_linear_transfer(Mref, sl_target_ts)
        if est is not None:
            A[i, j], A0[i, j] = est

# --- 6.2) Compute SLM (full but will be used only on gaps) ---
print("Computing SLM = a * Mref + a0 ...")
SLM = np.full((T, H, W), np.nan, dtype=float)
valid_a = ~np.isnan(A) & ~np.isnan(A0)
inds = np.where(valid_a)
for ii, jj in zip(inds[0], inds[1]):
    # vectorized along time
    SLM[:, ii, jj] = Mref_stack[:, ii, jj] * A[ii, jj] + A0[ii, jj]

# --- 6.3) For each masked image: apply Whittaker per-gap pixel only and inject ---
final_stack = landsat_masked_stack.copy()  # copy original masked images (NaN in gaps)
print("Merging reconstructed values into masked pixels and applying Whittaker smoothing...")

# To accelerate: precompute per-pixel Whittaker once (only for pixels that have at least one gap in any masked image)
# Build a boolean map of pixels that ever need reconstruction (appears as NaN in any masked image)
need_recon_mask = np.any(np.isnan(final_stack), axis=0)  # shape (H,W)
print("Pixels requiring reconstruction (ever):", int(np.count_nonzero(need_recon_mask)), "/", H*W)

# Precompute smoothed series for required pixels only (store as array T x H x W but fill only needed pixels)
# We'll compute whittaker_smoother on SLM[:,i,j] where need_recon_mask[i,j] is True and valid SLM exists
smoothed_series = np.full((T, H, W), np.nan, dtype=float)

# Loop only over pixels that need recon
coords = np.argwhere(need_recon_mask)
for (i, j) in tqdm(coords, desc='pixels to smooth'):
    series = SLM[:, i, j]
    if np.all(np.isnan(series)):
        continue
    z = whittaker_smoother(series, kappa=KAPPA)
    smoothed_series[:, i, j] = z

# Now inject per masked date only into gaps
for t in tqdm(range(final_stack.shape[0]), desc='masked images'):
    gap_mask = np.isnan(final_stack[t])  # True where pixel missing
    if not np.any(gap_mask):
        continue
    # Use smoothed_series at time t for those pixels (if available)
    inject_vals = smoothed_series[t]
    valid_inject = ~np.isnan(inject_vals) & gap_mask
    final_stack[t, valid_inject] = inject_vals[valid_inject]

# --- 6.4) Exporting reconstructed GeoTIFFs (preserve geo/meta) ---
def write_tif_nan(path, arr, meta_template):
    meta = meta_template.copy()
    # Ensure typical fields are present
    meta.update({'count': 1, 'dtype': 'float32', 'nodata': None})
    with rasterio.open(path, 'w', **meta) as dst:
        dst.write(arr.astype(np.float32), 1)

EXPORT_FOLDER = OUTPUT_FOLDER
os.makedirs(EXPORT_FOLDER, exist_ok=True)
print("Saving reconstructed GeoTIFFs to", EXPORT_FOLDER)

# masked_dates: if masked_dates are datetime.date objects, convert to isoformat for filenames
try:
    date_names = [d.isoformat() if hasattr(d, 'isoformat') else str(d) for d in masked_dates]
except Exception:
    date_names = [str(d) for d in masked_dates]

for t, name in enumerate(date_names):
    outpath = os.path.join(EXPORT_FOLDER, f'recon_{name}.tif')
    write_tif_nan(outpath, final_stack[t], ref_meta)
    print("Saved", outpath)

print("✅ Finished. Reconstructions and exports in:", EXPORT_FOLDER)