<a href="https://colab.research.google.com/github/larasauser/master/blob/main/Whitt_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Colab notebook script : EGF (faithful) + Whittaker (kappa=1) pour Landsat + MODIS
# - Charge Landsat clear (NDVI) et Landsat masked (NDVI with NaNs) depuis Google Drive
# - Télécharge MODIS NDVI via Earth Engine, rééchantillonne sur grille Landsat et aligne temporellement
# - Calcule M_reference (Eq.3-4), estime a,a0 (Eq.5), génère SLM, combine et lisse avec Whittaker (kappa=1)
# - Exporte / sauvegarde les 6 reconstructions + calcule RMSE, R2, MAE, MS-SSIM, %reconstruction
# Dépendances : earthengine-api, rasterio, numpy, scipy, scikit-image, tqdm, sklearn

In [2]:
# === 0) Installer / importer libs et monter Drive / authentifier Earth Engine ===
!pip install rasterio tqdm scikit-image sklearn

import os
from google.colab import drive
drive.mount('/content/drive')

import numpy as np
import rasterio
from rasterio.warp import reproject, Resampling
from rasterio.enums import Resampling as RIO_RES
from glob import glob
from datetime import datetime
from tqdm import tqdm
from scipy.linalg import lstsq
from skimage.metrics import multiscale_structural_similarity as mssim
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
import warnings
warnings.filterwarnings('ignore')

Collecting rasterio
  Using cached rasterio-1.4.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.1 kB)
Collecting sklearn
  Using cached sklearn-0.0.post12.tar.gz (2.6 kB)
  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m×[0m [32mpython setup.py egg_info[0m did not run successfully.
  [31m│[0m exit code: [1;36m1[0m
  [31m╰─>[0m See above for output.
  
  [1;35mnote[0m: This error originates from a subprocess, and is likely not a problem with pip.
  Preparing metadata (setup.py) ... [?25l[?25herror
[1;31merror[0m: [1mmetadata-generation-failed[0m

[31m×[0m Encountered error while generating package metadata.
[31m╰─>[0m See above for output.

[1;35mnote[0m: This is an issue with the package mentioned above, not pip.
[1;36mhint[0m: See above for details.
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


ModuleNotFoundError: No module named 'rasterio'

In [None]:
# === 1) Paramètres principaux (à modifier si besoin) ===
DRIVE_CLEAR_FOLDER = '/content/drive/MyDrive/landsat_clear'     # GeoTIFFs Landsat clear NDVI
DRIVE_MASKED_FOLDER = '/content/drive/MyDrive/landsat_masked'   # GeoTIFFs Landsat masked NDVI (with NaN holes)
OUTPUT_FOLDER = '/content/drive/MyDrive/egf_ws_outputs'         # outputs saved here
os.makedirs(OUTPUT_FOLDER, exist_ok=True)

# MODIS settings (change collection if you prefer another)
MODIS_COLLECTION = 'MODIS/006/MOD13Q1'  # 16-day NDVI 250m (common choice); can be changed
# Window used in paper: 200 m x 200 m
WINDOW_METERS = 200.0
PIXEL_SIZE_M = None  # will be inferred from first Landsat image (expected 30m)

# Other parameters
R_MIN_THRESHOLD = 0.0  # we'll use normalization per paper; keep threshold low but can be tuned
MIN_OVERLAP = 3        # minimal time overlap to compute correlation/regression
KAPPA = 1.0            # Whittaker smoothing parameter (paper uses kappa = 1)
TILE_SIZE = 256        # if you want to tuiler (not used by default)
# date parsing pattern (expects YYYYMMDD somewhere in filename)
import re
date_pattern = re.compile(r'(\d{4})(\d{2})(\d{2})')

def parse_date_from_filename(fn):
    m = date_pattern.search(os.path.basename(fn))
    if m:
        return datetime.strptime(''.join(m.groups()), '%Y%m%d').date()
    # fallback: file mtime
    return datetime.fromtimestamp(os.path.getmtime(fn)).date()

In [None]:
# === 2) Charger listes de fichiers Landsat clear & masked (assume they match in dates) ===
clear_files = sorted(glob(os.path.join(DRIVE_CLEAR_FOLDER, '*.tif')))
masked_files = sorted(glob(os.path.join(DRIVE_MASKED_FOLDER, '*.tif')))
if len(clear_files)==0:
    raise SystemExit("Aucun fichier trouvé dans DRIVE_CLEAR_FOLDER. Mets tes GeoTIFFs NDVI clear.")
if len(masked_files)==0:
    raise SystemExit("Aucun fichier trouvé dans DRIVE_MASKED_FOLDER. Mets tes GeoTIFFs NDVI masked.")

# build date lists
clear_dates = [parse_date_from_filename(f) for f in clear_files]
masked_dates = [parse_date_from_filename(f) for f in masked_files]
print("Found {} clear files and {} masked files".format(len(clear_files), len(masked_files)))

# We'll use the intersection of dates present in both collections to build the unified time axis.
common_dates = sorted(list(set(clear_dates) | set(masked_dates)))
print("Unified dates length (union):", len(common_dates))


In [None]:
# === 3) Helper: read a single-band GeoTIFF into array + meta ===
def read_singleband_tif(path):
    with rasterio.open(path) as src:
        arr = src.read(1).astype(np.float32)
        meta = src.meta.copy()
    # convert nodata to np.nan if nodata set
    nd = meta.get('nodata', None)
    if nd is not None:
        arr[arr==nd] = np.nan
    return arr, meta

# read first clear to infer projection/shape/transform
arr0, meta0 = read_singleband_tif(clear_files[0])
H, W = arr0.shape
PIXEL_SIZE_M = abs(meta0['transform'][0])
print("Inferred grid shape:", H, W, "pixel size (m):", PIXEL_SIZE_M)

In [None]:
# === 4) Build time-indexed stacks (T,H,W) aligned on common_dates ===
# Create dictionaries mapping date->file for clarity
clear_map = {parse_date_from_filename(f): f for f in clear_files}
masked_map = {parse_date_from_filename(f): f for f in masked_files}

T = len(common_dates)
landsat_clear_stack = np.full((T, H, W), np.nan, dtype=np.float32)
landsat_masked_stack = np.full((T, H, W), np.nan, dtype=np.float32)
dates = []

for ti, d in enumerate(common_dates):
    dates.append(d.isoformat())
    if d in clear_map:
        arr, _ = read_singleband_tif(clear_map[d])
        landsat_clear_stack[ti] = arr
    if d in masked_map:
        arr, _ = read_singleband_tif(masked_map[d])
        landsat_masked_stack[ti] = arr

print("Stacks shapes:", landsat_clear_stack.shape, landsat_masked_stack.shape)

In [None]:
# === 5) Download MODIS NDVI time series via Earth Engine, resample to Landsat grid & align temporal axis ===
# We'll request MODIS NDVI (collection) for the date range and for the Landsat footprint (meta0 transform)
start_date = common_dates[0].isoformat()
end_date = common_dates[-1].isoformat()
print("MODIS date range:", start_date, end_date)

# compute bounds from meta0
transform = meta0['transform']
crs = meta0['crs']
# compute bbox
left = transform[2]
top = transform[5]
right = left + W * transform[0]
bottom = top + H * transform[4]
geom = ee.Geometry.Rectangle([left, bottom, right, top], proj=crs)

# Helper to extract NDVI band from MODIS collection
# MOD13Q1 has 'NDVI' band scaled by 0.0001 (check)
modis_col = ee.ImageCollection(MODIS_COLLECTION).filterDate(start_date, end_date).filterBounds(geom)
# Convert collection to list of images and their dates
modis_list = modis_col.toList(modis_col.size())
n_mod = modis_col.size().getInfo()
print("Number of MODIS images in period (collection):", n_mod)

# We'll produce a MODIS stack aligned on the same dates (common_dates):
# For each date in common_dates, we take the MODIS image with the exact same date if exists,
# else we take the nearest MODIS image in time (within a window of +/- 16 days). This approximates the 8/16-day cadence.
def ee_image_to_array(img, out_shape, out_transform, out_crs):
    """Fetch small image region from Earth Engine and return numpy array (single band)."""
    # use getRegion to download as patches - more reliable for small areas
    url = img.getDownloadURL({
        'scale': int(PIXEL_SIZE_M),
        'crs': out_crs,
        'region': geom.toGeoJSONString(),
        'format': 'GEO_TIFF'
    })
    # download via requests
    import requests, tempfile, zipfile, io
    r = requests.get(url)
    z = zipfile.ZipFile(io.BytesIO(r.content))
    # find the tif inside
    tifname = [n for n in z.namelist() if n.endswith('.tif')][0]
    z.extract(tifname, path='/content')
    path = os.path.join('/content', tifname)
    arr, meta = read_singleband_tif(path)
    os.remove(path)
    return arr

# Build MODIS stack: for each common_date, find nearest MODIS image in collection
modis_stack = np.full((T, H, W), np.nan, dtype=np.float32)

# Build a map of MODIS image date -> ee.Image for quick nearest lookup
modis_dates = []
modis_images = []
for i in range(n_mod):
    img = ee.Image(modis_list.get(i))
    ts = ee.Date(img.get('system:time_start')).format('YYYY-MM-dd').getInfo()
    modis_dates.append(datetime.strptime(ts, '%Y-%m-%d').date())
    modis_images.append(img)

# simple nearest search and download per date (may be slow for many dates; acceptable here)
for ti, d in enumerate(common_dates):
    # find nearest MODIS date
    diffs = [abs((md - d).days) for md in modis_dates]
    if len(diffs)==0:
        continue
    min_idx = int(np.argmin(diffs))
    if diffs[min_idx] > 30:
        # too far -> skip
        continue
    eeimg = modis_images[min_idx]
    # select NDVI band (collection-specific): try 'NDVI' first, else 'NDVI'
    try:
        bandname = 'NDVI'
        ee_band = eeimg.select(bandname)
    except Exception:
        raise SystemExit("Vérifie la collection MODIS choisie et le nom de bande NDVI.")
    # scale factor handling: many MODIS NDVI are scaled by 0.0001
    # We'll download and then scale by 0.0001
    print(f"Downloading MODIS for date {d} (nearest {modis_dates[min_idx]}) ...")
    try:
        arr = ee_image_to_array(ee_band, (H,W), transform, crs)
        # scale
        arr = arr.astype(np.float32) * 0.0001
        arr[arr<-1] = np.nan
        arr[arr>1] = np.nan
        modis_stack[ti] = arr
    except Exception as e:
        print("MODIS download failed for date", d, e)
        continue

print("MODIS stack built:", modis_stack.shape)

In [None]:
# === 6) EGF faithful implementation (Eq 3-5) helpers ===
from math import floor

def pearson_r(a, b):
    """Pearson r on 1D arrays with NaNs — returns np.nan if insufficient."""
    valid = ~np.isnan(a) & ~np.isnan(b)
    if valid.sum() < 2:
        return np.nan
    x = a[valid].astype(float); y = b[valid].astype(float)
    xm = x.mean(); ym = y.mean()
    num = np.sum((x-xm)*(y-ym))
    den = np.sqrt(np.sum((x-xm)**2)*np.sum((y-ym)**2))
    if den == 0:
        return np.nan
    return num/den

def compute_M_reference_pixel(i, j, modis_stack, sl_target_ts, radius_pixels):
    """
    compute M_reference for pixel (i,j) following Eq(3)-(4).
    - modis_stack: (T,H,W)
    - sl_target_ts: 1D (T,) of SL target (landsat) at pixel (i,j)
    returns M_ref (T,) or None if fallback
    """
    T, H, W = modis_stack.shape
    i0, i1 = max(0, i-radius_pixels), min(H, i+radius_pixels+1)
    j0, j1 = max(0, j-radius_pixels), min(W, j+radius_pixels+1)
    coords = [(ii,jj) for ii in range(i0,i1) for jj in range(j0,j1)]
    corrs = []
    mseries = []
    for (ii,jj) in coords:
        ms = modis_stack[:,ii,jj]
        r = pearson_r(ms, sl_target_ts)
        corrs.append(r)
        mseries.append(ms)
    corrs = np.array(corrs)
    valid_idx = ~np.isnan(corrs)
    if valid_idx.sum() == 0:
        return None
    cor_vals = corrs[valid_idx]
    mseries_valid = [mseries[k] for k in np.where(valid_idx)[0]]
    cor_min = cor_vals.min()
    cor_max = cor_vals.max()
    if cor_max - cor_min == 0:
        R = np.ones_like(cor_vals)
    else:
        R = (cor_vals - cor_min) / (cor_max - cor_min)
    if np.allclose(R, 0):
        R = np.ones_like(R)
    weights = R / R.sum()
    Mref = np.zeros(T, dtype=float)
    support = np.zeros(T, dtype=float)
    for w, ms in zip(weights, mseries_valid):
        valid_t = ~np.isnan(ms)
        Mref[valid_t] += w * np.nan_to_num(ms[valid_t], nan=0.0)
        support[valid_t] += w
    Mref[support==0] = np.nan
    return Mref

def estimate_linear_transfer(M_ref, SL_ts):
    """Least squares estimate of a, a0 such that SL_ts ~ a*M_ref + a0"""
    valid = ~np.isnan(M_ref) & ~np.isnan(SL_ts)
    if valid.sum() < 2:
        return None
    A = np.vstack([M_ref[valid], np.ones(valid.sum())]).T
    y = SL_ts[valid]
    sol, *_ = lstsq(A, y, cond=None)
    a, a0 = sol[0], sol[1]
    return a, a0

In [None]:
# === 7) Loop to compute a(x,y), a0(x,y), Mref_stack and SLM ===
radius_pix = int(round((WINDOW_METERS/2.0) / PIXEL_SIZE_M))
T, H, W = modis_stack.shape
print("Window radius pixels:", radius_pix)

A = np.full((H,W), np.nan, dtype=float)
A0 = np.full((H,W), np.nan, dtype=float)
Mref_stack = np.full((T, H, W), np.nan, dtype=float)

# We'll iterate only over pixels that are inside image (this loop can be slow; prints progress)
print("Estimating M_reference and linear transfer a,a0 per pixel (this may take time)...")
for i in tqdm(range(H)):
    for j in range(W):
        sl_target_ts = landsat_clear_stack[:, i, j]  # SL target = landsat clear series for this pixel
        Mref = compute_M_reference_pixel(i, j, modis_stack, sl_target_ts, radius_pix)
        if Mref is None:
            # fallback: local mean MODIS over window
            i0, i1 = max(0, i-radius_pix), min(H, i+radius_pix+1)
            j0, j1 = max(0, j-radius_pix), min(W, j+radius_pix+1)
            block = modis_stack[:, i0:i1, j0:j1]
            mref_alt = np.nanmean(block.reshape(T, -1), axis=1)
            if np.all(np.isnan(mref_alt)):
                continue
            Mref = mref_alt
        Mref_stack[:, i, j] = Mref
        est = estimate_linear_transfer(Mref, sl_target_ts)
        if est is None:
            continue
        a, a0 = est
        A[i,j] = a
        A0[i,j] = a0

# Compute SLM = a*Mref + a0
SLM = np.full((T,H,W), np.nan, dtype=float)
for i in range(H):
    for j in range(W):
        if np.isnan(A[i,j]) or np.isnan(A0[i,j]):
            continue
        Mref = Mref_stack[:, i, j]
        SLM[:, i, j] = Mref * A[i,j] + A0[i,j]

In [None]:
# === 8) Combine SLM and SL (use SL where present, else SLM) ===
integrated = np.where(~np.isnan(landsat_clear_stack), landsat_clear_stack, SLM)

In [None]:
# === 9) Whittaker smoothing per-pixel (kappa = 1, second-difference penalty) ===
def whittaker_smoother(y, kappa=1.0):
    """Whittaker smoothing on 1D array with NaNs (operate on contiguous valid segments)."""
    n = y.shape[0]
    mask = ~np.isnan(y)
    if mask.sum() == 0:
        return np.full_like(y, np.nan)
    out = np.full_like(y, np.nan, dtype=float)
    idx = np.where(mask)[0]
    # find contiguous runs
    runs = np.split(idx, np.where(np.diff(idx)!=1)[0]+1)
    for run in runs:
        s = run[0]; e = run[-1]+1
        seg = y[s:e].astype(float)
        m = len(seg)
        if m <= 2:
            out[s:e] = seg
            continue
        # second-difference operator D (m-2 x m)
        D = np.zeros((m-2, m))
        for r in range(m-2):
            D[r, r] = 1
            D[r, r+1] = -2
            D[r, r+2] = 1
        A = np.eye(m) + kappa * (D.T @ D)
        z = np.linalg.solve(A, seg)
        out[s:e] = z
    return out

print("Applying Whittaker smoothing per pixel (this may be slow)...")
smoothed = np.full_like(integrated, np.nan, dtype=float)
for i in tqdm(range(H)):
    for j in range(W):
        smoothed[:, i, j] = whittaker_smoother(integrated[:, i, j], kappa=KAPPA)

In [None]:
# === 10) For the 6 target masked images, compute metrics vs truth (landsat_clear_stack) ===
# Identify which indices correspond to the masked images that contain holes:
# We'll assume user wants to evaluate on the dates where landsat_masked has NaNs (i.e., masked_files)
target_indices = [ti for ti, d in enumerate(common_dates) if parse_date_from_filename(masked_map.get(d, ''))==d] if False else []
# simpler: take dates present in masked_map (masked_files list) and map to common_indices
target_indices = [ti for ti,d in enumerate(common_dates) if d in masked_map]  # indices where we had masked file
print("Found target masked dates at indices:", target_indices)
if len(target_indices) == 0:
    # fallback: choose 6 random dates where masked differs from clear
    nontriv = [ti for ti in range(T) if np.any(np.isnan(landsat_masked_stack[ti]))]
    target_indices = nontriv[:6]
print("Evaluating on indices:", target_indices)

results = {}
for t in target_indices:
    truth = landsat_clear_stack[t]
    recon = smoothed[t]
    mask_was_missing = np.isnan(landsat_masked_stack[t])
    # restrict to pixels that were masked and have truth (non-NaN)
    eval_mask = mask_was_missing & ~np.isnan(truth)
    if eval_mask.sum() == 0:
        print("No eval pixels for date index", t)
        continue
    y_true = truth[eval_mask].ravel()
    y_pred = recon[eval_mask].ravel()
    # compute metrics
    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    r2 = r2_score(y_true, y_pred)
    mae = mean_absolute_error(y_true, y_pred)
    # MS-SSIM expects 2D images; we'll compute on full images (clamped to range -0.2..1 mapped to 0..1)
    def scale01(x):
        xmin, xmax = -0.2, 1.0
        xx = np.clip(x, xmin, xmax)
        return (xx - xmin) / (xmax - xmin)
    try:
        msssim = mssim(scale01(truth), scale01(recon), data_range=1.0)
    except Exception:
        msssim = np.nan
    percent_recon = np.count_nonzero(~np.isnan(recon[mask_was_missing])) / max(1, np.count_nonzero(mask_was_missing)) * 100.0
    results[dates[t]] = {'RMSE': float(rmse), 'R2': float(r2), 'MAE': float(mae), 'MS_SSIM': float(msssim), '%reconstruction': float(percent_recon)}
    print(dates[t], results[dates[t]])

In [None]:
# === 11) Save reconstructed images (the smoothed results) for the target dates to Drive ===
def write_tif(path, arr, meta_template):
    meta = meta_template.copy()
    meta.update({'count': 1, 'dtype': 'float32', 'nodata': -9999})
    out = np.nan_to_num(arr, nan=-9999).astype(np.float32)
    with rasterio.open(path, 'w', **meta) as dst:
        dst.write(out, 1)

for t in target_indices:
    outpath = os.path.join(OUTPUT_FOLDER, f'recon_{dates[t]}.tif')
    write_tif(outpath, smoothed[t], meta0)
    print("Saved", outpath)

# Save metrics CSV
import csv
csv_path = os.path.join(OUTPUT_FOLDER, 'metrics_results.csv')
with open(csv_path, 'w', newline='') as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow(['date','RMSE','R2','MAE','MS_SSIM','%reconstruction'])
    for d, m in results.items():
        writer.writerow([d, m['RMSE'], m['R2'], m['MAE'], m['MS_SSIM'], m['%reconstruction']])
print("Metrics saved to", csv_path)

print("Finished. Résultats et reconstructions disponibles dans:", OUTPUT_FOLDER)
