<a href="https://colab.research.google.com/github/larasauser/master/blob/main/Whitt_6.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# === 0) Installer / importer libs et monter Drive ===
!pip install rasterio tqdm scikit-image scikit-learn piq joblib --quiet

import os
import re
from glob import glob
from datetime import datetime
from tqdm import tqdm
import numpy as np
import rasterio
from rasterio.warp import reproject, Resampling
from scipy.linalg import lstsq
import warnings
warnings.filterwarnings('ignore')

import torch
import piq
from joblib import Parallel, delayed
import matplotlib.pyplot as plt
import pandas as pd

from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
# === 1) Paramètres principaux ===
DRIVE_CLEAR_FOLDER = '/content/drive/MyDrive/Whitt/NDVI_grancy_Landsat8'
DRIVE_MASKED_FOLDER = '/content/drive/MyDrive/Whitt/NDVI_grancy_Landsat8_hole'
DRIVE_MODIS_FOLDER = '/content/drive/MyDrive/Whitt/NDVI_grancy_MODIS'
DRIVE_SENTINEL_FOLDER = '/content/drive/MyDrive/Whitt/NDVI_grancy_S2'
OUTPUT_FOLDER = '/content/drive/MyDrive/Whitt/egfwh_grancy_outputs_ok'
os.makedirs(OUTPUT_FOLDER, exist_ok=True)

WINDOW_METERS = 200.0   # Taille fenêtre pour pixels voisins
MODIS_WINDOW_DAYS = 32  # ± jours autour de chaque Landsat
KAPPA = 5.0             # Whittaker smoothing
MIN_OVERLAP = 3         # nb min de points valides pour corrélation


In [3]:
# === 2) Fonctions helpers ===
def parse_date_from_filename(fn):
    m = re.search(r'(\d{4}-\d{2}-\d{2})', os.path.basename(fn))
    if m:
        return datetime.strptime(m.group(1), '%Y-%m-%d').date()
    return datetime.fromtimestamp(os.path.getmtime(fn)).date()

def read_singleband_tif(path):
    with rasterio.open(path) as src:
        arr = src.read(1).astype(np.float32)
        meta = src.meta.copy()
    nd = meta.get('nodata', None)
    if nd is not None:
        arr[arr==nd] = np.nan
    return arr, meta

def resample_to_reference(src_arr, src_meta, target_meta):
    dst_arr = np.full((target_meta['height'], target_meta['width']), np.nan, dtype=np.float32)
    reproject(
        source=src_arr,
        destination=dst_arr,
        src_transform=src_meta['transform'],
        src_crs=src_meta['crs'],
        dst_transform=target_meta['transform'],
        dst_crs=target_meta['crs'],
        resampling=Resampling.bilinear
    )
    return dst_arr

def pearson_r(a, b):
    valid = ~np.isnan(a) & ~np.isnan(b)
    if valid.sum() < 2: return np.nan
    x, y = a[valid], b[valid]
    xm, ym = x.mean(), y.mean()
    num = np.sum((x - xm)*(y - ym))
    den = np.sqrt(np.sum((x - xm)**2)*np.sum((y - ym)**2))
    if den == 0: return np.nan
    return num / den

def estimate_linear_transfer(M_ref, SL_ts):
    valid = ~np.isnan(M_ref) & ~np.isnan(SL_ts)
    if valid.sum() < 2: return None
    A = np.vstack([M_ref[valid], np.ones(valid.sum())]).T
    y = SL_ts[valid]
    sol, *_ = lstsq(A, y)
    return sol[0], sol[1]

def whittaker_smoother(y, kappa=5.0):
    n = y.shape[0]
    mask = ~np.isnan(y)
    if mask.sum()==0: return np.full_like(y,np.nan)
    out = np.full_like(y,np.nan,dtype=float)
    idx = np.where(mask)[0]
    runs = np.split(idx, np.where(np.diff(idx)!=1)[0]+1)
    for run in runs:
        s,e = run[0], run[-1]+1
        seg = y[s:e].astype(float)
        m = len(seg)
        if m<=2: out[s:e]=seg; continue
        D = np.zeros((m-2,m))
        for r in range(m-2): D[r,r]=1; D[r,r+1]=-2; D[r,r+2]=1
        A_mat = np.eye(m)+kappa*(D.T@D)
        out[s:e] = np.linalg.solve(A_mat, seg)
    return out

In [4]:
# === 3) Charger Sentinel-2 comme référence ===
sentinel_files = sorted(glob(os.path.join(DRIVE_SENTINEL_FOLDER, '*.tif')))
sentinel_dates = [parse_date_from_filename(f) for f in sentinel_files]

# Charger Sentinel-2 et définir la grille référence
sentinel_stack = Parallel(n_jobs=-1)(
    delayed(lambda f: read_singleband_tif(f)[0])(f) for f in tqdm(sentinel_files)
)
with rasterio.open(sentinel_files[0]) as src:
    ref_meta = src.meta.copy()
    H_ref, W_ref = src.height, src.width
PIXEL_SIZE_M = abs(ref_meta['transform'][0])
sentinel_stack = np.stack(sentinel_stack, axis=0)

100%|██████████| 216/216 [00:39<00:00,  5.42it/s]


In [5]:
# === 4) Charger et resampler Landsat et MODIS vers grille Sentinel-2 ===
def load_and_resample_stack(files, dates, ref_meta):
    stack = []
    for f in files:
        arr, meta = read_singleband_tif(f)
        if arr.shape != (ref_meta['height'], ref_meta['width']):
            arr = resample_to_reference(arr, meta, ref_meta)
        stack.append(arr)
    stack = np.stack(stack, axis=0)
    return stack, dates

clear_files = sorted(glob(os.path.join(DRIVE_CLEAR_FOLDER, '*.tif')))
masked_files = sorted(glob(os.path.join(DRIVE_MASKED_FOLDER, '*.tif')))
clear_dates = [parse_date_from_filename(f) for f in clear_files]
masked_dates = [parse_date_from_filename(f) for f in masked_files]
all_dates = sorted(list(set(clear_dates + masked_dates)))

landsat_clear_stack, _ = load_and_resample_stack(clear_files, clear_dates, ref_meta)
landsat_masked_stack, _ = load_and_resample_stack(masked_files, masked_dates, ref_meta)

modis_files = sorted(glob(os.path.join(DRIVE_MODIS_FOLDER, '*.tif')))
modis_dates = [parse_date_from_filename(f) for f in modis_files]
modis_stack_full = Parallel(n_jobs=-1)(
    delayed(lambda f: resample_to_reference(*read_singleband_tif(f), ref_meta))(f)
    for f in tqdm(modis_files)
)
modis_stack_full = np.stack(modis_stack_full, axis=0)

100%|██████████| 493/493 [00:31<00:00, 15.41it/s]


In [6]:
# === 5) Construire stack Landsat complet sur timeline ===
ndvi_stack = np.full((len(all_dates), H_ref, W_ref), np.nan, dtype=np.float32)
for i, date in enumerate(all_dates):
    if date in clear_dates:
        idx = clear_dates.index(date)
        ndvi_stack[i] = landsat_clear_stack[idx]
    elif date in masked_dates:
        idx = masked_dates.index(date)
        ndvi_stack[i] = landsat_masked_stack[idx]

In [10]:
# === Paramètres ===
radius_pix = int(round(WINDOW_METERS / 2 / PIXEL_SIZE_M))
Mref_stack = np.full((len(all_dates), H_ref, W_ref), np.nan, dtype=np.float32)
A = np.full((H_ref, W_ref), np.nan, dtype=np.float32)
A0 = np.full((H_ref, W_ref), np.nan, dtype=np.float32)

# === Fonction sécurisée M_ref ===
def compute_M_reference_pixel(i, j, landsat_ts, all_dates,
                              modis_stack, modis_dates,
                              sentinel_stack, sentinel_dates,
                              radius_pix=3, corr_thresh=0.3, window_days=32):
    T = len(all_dates)
    Mref = np.full(T, np.nan)

    for t, landsat_date in enumerate(all_dates):
        # Dates MODIS/Sentinel proches
        modis_idx = [k for k, d in enumerate(modis_dates) if abs((d - landsat_date).days) <= window_days]
        sentinel_idx = [k for k, d in enumerate(sentinel_dates) if abs((d - landsat_date).days) <= window_days]

        if len(modis_idx) == 0 and len(sentinel_idx) == 0:
            continue

        # Pixels voisins
        i0, i1 = max(0, i - radius_pix), min(modis_stack.shape[1], i + radius_pix + 1)
        j0, j1 = max(0, j - radius_pix), min(modis_stack.shape[2], j + radius_pix + 1)

        values, corrs = [], []

        # MODIS
        for k in modis_idx:
            block = modis_stack[k, i0:i1, j0:j1].reshape(-1)
            valid = ~np.isnan(block) & ~np.isnan(landsat_ts[t])
            if valid.sum() < MIN_OVERLAP:
                continue
            r = pearson_r(block, np.full_like(block, landsat_ts[t]))
            if r >= corr_thresh:
                values.append(np.nanmean(block))
                corrs.append(r)

        # Sentinel
        for k in sentinel_idx:
            block = sentinel_stack[k, i0:i1, j0:j1].reshape(-1)
            valid = ~np.isnan(block) & ~np.isnan(landsat_ts[t])
            if valid.sum() < MIN_OVERLAP:
                continue
            r = pearson_r(block, np.full_like(block, landsat_ts[t]))
            if r >= corr_thresh:
                values.append(np.nanmean(block))
                corrs.append(r)

        if len(values) > 0:
            corrs = np.array(corrs)
            weights = (corrs - corrs.min()) / (corrs.max() - corrs.min()) if corrs.max() != corrs.min() else np.ones_like(corrs)
            weights /= weights.sum()
            Mref[t] = np.sum(np.array(values) * weights)

    return Mref

In [11]:
# === Fonction pixel pour calcul parallèle ===
def process_pixel(i,j):
    landsat_ts = ndvi_stack[:, i, j]
    Mref = compute_M_reference_pixel(i, j, landsat_ts, all_dates,
                                     modis_stack_full, modis_dates,
                                     sentinel_stack, sentinel_dates,
                                     radius_pix=radius_pix,
                                     corr_thresh=0.3,
                                     window_days=MODIS_WINDOW_DAYS)
    est = estimate_linear_transfer(Mref, landsat_ts)
    return i, j, Mref, est

# === Calcul parallèle ===
results = Parallel(n_jobs=-1, verbose=10)(
    delayed(process_pixel)(i,j) for i in range(H_ref) for j in range(W_ref)
)

# === Remplir stacks ===
for i,j,Mref,est in results:
    Mref_stack[:, i, j] = Mref
    if est is not None:
        A[i,j], A0[i,j] = est

[Parallel(n_jobs=-1)]: Using backend LokyBackend with 2 concurrent workers.
[Parallel(n_jobs=-1)]: Done   1 tasks      | elapsed:    3.7s
[Parallel(n_jobs=-1)]: Done   4 tasks      | elapsed:    4.1s
[Parallel(n_jobs=-1)]: Done   9 tasks      | elapsed:    5.1s
[Parallel(n_jobs=-1)]: Done  14 tasks      | elapsed:    6.0s
[Parallel(n_jobs=-1)]: Done  21 tasks      | elapsed:    8.1s
[Parallel(n_jobs=-1)]: Done  28 tasks      | elapsed:   10.0s
[Parallel(n_jobs=-1)]: Done  37 tasks      | elapsed:   12.2s
[Parallel(n_jobs=-1)]: Done  46 tasks      | elapsed:   13.8s
[Parallel(n_jobs=-1)]: Done  57 tasks      | elapsed:   16.0s
[Parallel(n_jobs=-1)]: Done  68 tasks      | elapsed:   18.0s
[Parallel(n_jobs=-1)]: Done  81 tasks      | elapsed:   22.4s
[Parallel(n_jobs=-1)]: Done  94 tasks      | elapsed:   26.7s
[Parallel(n_jobs=-1)]: Done 109 tasks      | elapsed:   31.6s
[Parallel(n_jobs=-1)]: Done 124 tasks      | elapsed:   34.9s
[Parallel(n_jobs=-1)]: Done 141 tasks      | elapsed:   

KeyboardInterrupt: 

In [None]:
# === Reconstruction Landsat sur timeline ===
SLM_stack = np.full_like(Mref_stack, np.nan, dtype=np.float32)
for i in range(H_ref):
    for j in range(W_ref):
        if not np.isnan(A[i,j]) and not np.isnan(A0[i,j]):
            SLM_stack[:, i, j] = Mref_stack[:, i,j]*A[i,j] + A0[i,j]

In [None]:
# === Injection des valeurs reconstruites dans les gaps + Whittaker smoothing ===
final_stack = ndvi_stack.copy()
for t, date in enumerate(all_dates):
    gap_mask = np.isnan(final_stack[t])
    smoothed_gap = np.full((H_ref, W_ref), np.nan)
    for i in range(H_ref):
        for j in range(W_ref):
            if gap_mask[i,j]:
                smoothed_gap[i,j] = whittaker_smoother(SLM_stack[:, i,j], kappa=KAPPA)[t]
    final_stack[t, gap_mask] = smoothed_gap[gap_mask]

In [None]:
def write_tif_nan(path, arr, meta_template):
    """
    Sauvegarde un array 2D en GeoTIFF, avec NaN conservés.
    """
    meta = meta_template.copy()
    meta.update({
        'count': 1,
        'dtype': 'float32',
        'nodata': None  # on laisse les NaN tels quels
    })
    with rasterio.open(path, 'w', **meta) as dst:
        dst.write(arr.astype(np.float32), 1)

In [None]:
# === Sauvegarde GeoTIFFs ===
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
for t, date in enumerate(all_dates):
    outpath = os.path.join(OUTPUT_FOLDER, f'recon_{date}.tif')
    write_tif_nan(outpath, final_stack[t], ref_meta)
    print("Saved", outpath)

print("✅ Reconstruction complète avec Sentinel-2 et parallélisation terminée.")