<a href="https://colab.research.google.com/github/larasauser/master/blob/main/Whitt_8.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Mon Nov  3 10:18:06 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   65C    P8             10W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [1]:
# === 0) Installer / importer libs et monter Drive ===
!pip install rasterio tqdm scikit-image scikit-learn piq joblib --quiet

import os
from glob import glob
from datetime import datetime
from tqdm import tqdm
import numpy as np
import rasterio
from rasterio.warp import reproject, Resampling
from scipy.linalg import lstsq
import warnings
warnings.filterwarnings('ignore')

import torch
import piq
from joblib import Parallel, delayed

# --- GPU check
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
if device.type=='cuda':
    print("GPU name:", torch.cuda.get_device_name(0))
    print("Memory allocated (MB):", round(torch.cuda.memory_allocated(0)/1e6,1))
    print("Memory cached (MB):", round(torch.cuda.memory_reserved(0)/1e6,1))

# --- Monter Google Drive
from google.colab import drive
drive.mount('/content/drive')

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/22.3 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.4/22.3 MB[0m [31m12.7 MB/s[0m eta [36m0:00:02[0m[2K   [91m━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.7/22.3 MB[0m [31m54.0 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━[0m [32m12.3/22.3 MB[0m [31m230.6 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m22.2/22.3 MB[0m [31m283.7 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m22.3/22.3 MB[0m [31m273.2 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m22.3/22.3 MB[0m [31m125.7 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/106.9 kB[0m [31m?[0m eta [36m-:

In [22]:
# === 1) Paramètres principaux ===
DRIVE_CLEAR_FOLDER = '/content/drive/MyDrive/Whitt/NDVI_grancy_Landsat8'
DRIVE_MASKED_FOLDER = '/content/drive/MyDrive/Whitt/NDVI_grancy_Landsat8_hole'
DRIVE_MODIS_FOLDER = '/content/drive/MyDrive/Whitt/NDVI_grancy_MODIS'
DRIVE_SENTINEL_FOLDER = '/content/drive/MyDrive/Whitt/NDVI_grancy_S2'
OUTPUT_FOLDER = '/content/drive/MyDrive/Whitt/egfwh_grancy_outputs_gpu'
os.makedirs(OUTPUT_FOLDER, exist_ok=True)

WINDOW_METERS = 200.0
MODIS_WINDOW_DAYS = 32
KAPPA = 5.0
MIN_OVERLAP = 3

In [23]:
# ===============================
# 2) Fonctions helpers
# ===============================
import re

def parse_date_from_filename(fn):
    m = re.search(r'(\d{4}-\d{2}-\d{2})', os.path.basename(fn))
    if m:
        return datetime.strptime(m.group(1), '%Y-%m-%d').date()
    return None

def read_singleband_tif(path):
    with rasterio.open(path) as src:
        arr = src.read(1).astype(np.float32)
        meta = src.meta.copy()
    nd = meta.get('nodata', None)
    if nd is not None:
        arr[arr==nd] = np.nan
    return arr, meta

def resample_to_target(src_arr, src_meta, target_meta):
    dst_arr = np.full((target_meta['height'], target_meta['width']), np.nan, dtype=np.float32)
    reproject(
        source=src_arr,
        destination=dst_arr,
        src_transform=src_meta['transform'],
        src_crs=src_meta['crs'],
        dst_transform=target_meta['transform'],
        dst_crs=target_meta['crs'],
        resampling=Resampling.bilinear
    )
    return dst_arr

def write_tif_nan(path, arr, meta_template):
    meta = meta_template.copy()
    meta.update({'count': 1, 'dtype': 'float32', 'nodata': None})
    with rasterio.open(path, 'w', **meta) as dst:
        dst.write(arr.astype(np.float32), 1)

In [24]:
# ===============================
# 3) Charger et aligner Landsat
# ===============================
clear_files = sorted(glob(os.path.join(DRIVE_CLEAR_FOLDER, '*.tif')))
masked_files = sorted(glob(os.path.join(DRIVE_MASKED_FOLDER, '*.tif')))

# Trouver la plus petite image pour référence
sizes = []
for f in clear_files + masked_files:
    arr, meta = read_singleband_tif(f)
    sizes.append((arr.shape[0]*arr.shape[1], f, arr.shape, meta))
sizes.sort()
_, ref_file, ref_shape, ref_meta = sizes[0]
H_ref, W_ref = ref_shape
PIXEL_SIZE_M = abs(ref_meta['transform'][0])
print("Reference image:", ref_file, "shape:", ref_shape)

# Timeline complète
all_files = clear_files + masked_files
all_dates = sorted([parse_date_from_filename(f) for f in all_files])

# Charger les stacks Landsat
def load_stack(files, all_dates):
    stack = []
    for d in all_dates:
        matching = [f for f in files if parse_date_from_filename(f)==d]
        if matching:
            arr, meta = read_singleband_tif(matching[0])
            if arr.shape != ref_shape:
                arr = resample_to_target(arr, meta, ref_meta)
        else:
            arr = np.full(ref_shape, np.nan, dtype=np.float32)
        stack.append(arr)
    return np.stack(stack, axis=0)

ndvi_stack = load_stack(clear_files + masked_files, all_dates)  # shape: T x H x W
T = len(all_dates)

Reference image: /content/drive/MyDrive/Whitt/NDVI_grancy_Landsat8/NDVI_2013-06-05.tif shape: (344, 319)


In [25]:
# ===============================
# 4) Charger MODIS (pas de moyenne)
# ===============================
modis_files = sorted(glob(os.path.join(DRIVE_MODIS_FOLDER, '*.tif')))
modis_dates = [parse_date_from_filename(f) for f in modis_files]

modis_stack = []
for f in modis_files:
    arr, meta = read_singleband_tif(f)
    if arr.shape != ref_shape:
        arr = resample_to_target(arr, meta, ref_meta)
    modis_stack.append(arr)
modis_stack_full = np.stack(modis_stack, axis=0)  # shape: M x H x W

In [27]:
# ===============================
# 5) Charger Sentinel-2
# ===============================
sentinel_files = sorted(glob(os.path.join(DRIVE_SENTINEL_FOLDER, '*.tif')))
sentinel_dates = [parse_date_from_filename(f) for f in sentinel_files]

sentinel_stack = []
for f in sentinel_files:
    arr, meta = read_singleband_tif(f)
    if arr.shape != ref_shape:
        arr = resample_to_target(arr, meta, ref_meta)
    sentinel_stack.append(arr)
sentinel_stack_full = np.stack(sentinel_stack, axis=0)  # S x H x W

In [28]:
# ===============================
# 6) Convertir les stacks en GPU tensors
# ===============================
ndvi_stack_t = torch.from_numpy(ndvi_stack).float().to(device)
modis_stack_t = torch.from_numpy(modis_stack_full).float().to(device)
sentinel_stack_t = torch.from_numpy(sentinel_stack_full).float().to(device)

In [29]:
# ===============================
# 7) Calcul vectorisé M_ref et coefficients linéaires
# ===============================
# Pour simplifier ici, M_ref est pondéré par corrélation avec MODIS+Sentinel
# et Whittaker smoothing sera appliqué GPU-ready
def compute_Mref_vectorized(ndvi_ts, modis_stack, sentinel_stack, corr_thresh=0.3):
    # ndvi_ts: T x 1
    # modis_stack: M x H x W
    # sentinel_stack: S x H x W
    # Ici, on combine MODIS + Sentinel sur la dimension temporelle
    combined = torch.cat([modis_stack, sentinel_stack], dim=0)  # (M+S) x H x W
    # Broadcast sur T x H x W pour corrélation
    # Placeholder: on renvoie la moyenne (à remplacer par pondération réelle)
    Mref = torch.nanmean(combined, dim=0)  # H x W
    return Mref

In [30]:
# ===============================
# 8) Whittaker smoothing GPU-ready
# ===============================
def whittaker_smoother_torch(y, kappa=5.0):
    # y: T x H x W tensor
    y_out = y.clone()
    mask = ~torch.isnan(y)
    for t in range(y.shape[0]):
        valid = mask[t]
        y_out[t][~valid] = y[t][~valid]  # garder NaN pour invalid
    # Pour l'instant placeholder: pas de solve exact, juste copie
    return y_out

In [31]:
# ===============================
# 9) Reconstruction finale
# ===============================
Mref_gpu = compute_Mref_vectorized(ndvi_stack_t, modis_stack_t, sentinel_stack_t)
# Reconstruction linéaire simple: SL = M_ref (placeholder, à adapter)
SLM_gpu = Mref_gpu.unsqueeze(0).repeat(T,1,1)

# Injection dans les gaps
final_stack_t = ndvi_stack_t.clone()
gap_mask = torch.isnan(ndvi_stack_t)
final_stack_t[gap_mask] = SLM_gpu[gap_mask]

# Appliquer Whittaker smoothing
final_stack_t = whittaker_smoother_torch(final_stack_t, kappa=KAPPA)

In [32]:
# ===============================
# 10) Export GeoTIFF
# ===============================
final_stack = final_stack_t.cpu().numpy()
for t, date in enumerate(all_dates):
    outpath = os.path.join(OUTPUT_FOLDER, f'recon_{date}.tif')
    write_tif_nan(outpath, final_stack[t], ref_meta)
    print("Saved", outpath)

Saved /content/drive/MyDrive/Whitt/egfwh_grancy_outputs_gpu/recon_2013-04-18.tif
Saved /content/drive/MyDrive/Whitt/egfwh_grancy_outputs_gpu/recon_2013-04-25.tif
Saved /content/drive/MyDrive/Whitt/egfwh_grancy_outputs_gpu/recon_2013-05-27.tif
Saved /content/drive/MyDrive/Whitt/egfwh_grancy_outputs_gpu/recon_2013-06-05.tif
Saved /content/drive/MyDrive/Whitt/egfwh_grancy_outputs_gpu/recon_2013-06-12.tif
Saved /content/drive/MyDrive/Whitt/egfwh_grancy_outputs_gpu/recon_2013-07-07.tif
Saved /content/drive/MyDrive/Whitt/egfwh_grancy_outputs_gpu/recon_2013-07-14.tif
Saved /content/drive/MyDrive/Whitt/egfwh_grancy_outputs_gpu/recon_2013-08-15.tif
Saved /content/drive/MyDrive/Whitt/egfwh_grancy_outputs_gpu/recon_2013-08-31.tif
Saved /content/drive/MyDrive/Whitt/egfwh_grancy_outputs_gpu/recon_2013-09-25.tif
Saved /content/drive/MyDrive/Whitt/egfwh_grancy_outputs_gpu/recon_2013-10-18.tif
Saved /content/drive/MyDrive/Whitt/egfwh_grancy_outputs_gpu/recon_2013-11-12.tif
Saved /content/drive/MyDrive

In [33]:
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
from rasterio.warp import reproject, Resampling
import rasterio
import os
from glob import glob

# === Dossiers ===
hole_dir = '/content/drive/MyDrive/Whitt/NDVI_grancy_Landsat8_hole'
gt_dir = '/content/drive/MyDrive/Whitt/NDVI_grancy_full'
recon_dir = '/content/drive/MyDrive/Whitt/egfwh_grancy_outputs_GPU'

# === Liste des fichiers ===
hole_files = sorted(glob(os.path.join(hole_dir, '*.tif')))
gt_files = sorted(glob(os.path.join(gt_dir, '*.tif')))
recon_files = sorted(glob(os.path.join(recon_dir, 'recon_*.tif')))

# === Fonctions utilitaires ===
def read_tif_with_meta(path):
    with rasterio.open(path) as src:
        arr = src.read(1).astype(np.float32)
        meta = src.meta.copy()
    arr[arr == meta.get('nodata', -9999)] = np.nan
    return arr, meta

def resample_to_match(source_arr, source_meta, target_meta):
    dst_arr = np.empty((target_meta['height'], target_meta['width']), dtype=np.float32)
    reproject(
        source=source_arr,
        destination=dst_arr,
        src_transform=source_meta['transform'],
        src_crs=source_meta['crs'],
        dst_transform=target_meta['transform'],
        dst_crs=target_meta['crs'],
        resampling=Resampling.bilinear
    )
    return dst_arr

# === Visualisation ===
for i in range(len(hole_files)):
    fname = os.path.basename(hole_files[i])
    date = fname.replace('NDVI_', '').replace('.tif', '')

    # Lecture de l'image trouée
    hole, hole_meta = read_tif_with_meta(hole_files[i])

    # Lecture et alignement GT et reconstruction
    gt, gt_meta = read_tif_with_meta(gt_files[i])
    gt_resampled = resample_to_match(gt, gt_meta, hole_meta)

    recon, recon_meta = read_tif_with_meta(recon_files[i])
    recon_resampled = resample_to_match(recon, recon_meta, hole_meta)

    # Crée un masque des gaps (NaN dans hole)
    gap_mask = np.isnan(hole)

    # Reconstruction uniquement sur les gaps
    recon_only_gaps = np.full_like(hole, np.nan)
    recon_only_gaps[gap_mask] = recon_resampled[gap_mask]

    # Erreur uniquement sur les gaps
    err = np.full_like(hole, np.nan)
    valid_mask = gap_mask & ~np.isnan(gt_resampled) & ~np.isnan(recon_only_gaps)
    err[valid_mask] = gt_resampled[valid_mask] - recon_only_gaps[valid_mask]

    # === Figure ===
    fig, axs = plt.subplots(1, 4, figsize=(16, 4))

    cmap_gray_red = plt.cm.gray
    cmap_gray_red.set_bad(color='red')  # NaN en rouge

    # 1) GAP
    axs[0].imshow(np.ma.masked_invalid(hole), cmap=cmap_gray_red, vmin=-1, vmax=1)
    axs[0].set_title(f'GAP - {date}', fontweight='bold')
    axs[0].axis('off')

    # 2) GT
    axs[1].imshow(np.ma.masked_invalid(gt_resampled), cmap=cmap_gray_red, vmin=-1, vmax=1)
    axs[1].set_title(f'GT - {date}', fontweight='bold')
    axs[1].axis('off')

    # 3) Reconstruction (uniquement dans les trous)
    axs[2].imshow(np.ma.masked_invalid(recon_only_gaps), cmap=cmap_gray_red, vmin=-1, vmax=1)
    axs[2].set_title(f'RE GAPS - {date}', fontweight='bold')
    axs[2].axis('off')

    # 4) Erreur (GT - RE) sur gaps
    im4 = axs[3].imshow(np.ma.masked_invalid(err), cmap='RdYlGn', vmin=-1, vmax=1)
    axs[3].set_title('ERR (GT - RE) sur gaps', fontweight='bold')
    axs[3].axis('off')
    plt.colorbar(im4, ax=axs[3], fraction=0.046, pad=0.04)

    plt.tight_layout()
    plt.show()

In [34]:
import os

# Extraire les dates à partir des noms de fichiers
hole_dates = [os.path.basename(f).replace('NDVI_', '').replace('.tif','') for f in hole_files]
recon_dates = [os.path.basename(f).replace('recon_', '').replace('.tif','') for f in recon_files]

# Images reconstruites
reconstructed = [d for d in hole_dates if d in recon_dates]

# Images manquantes
missing = [d for d in hole_dates if d not in recon_dates]

print("✅ Images reconstruites :", reconstructed)
print("⚠️ Images manquantes :", missing)

✅ Images reconstruites : []
⚠️ Images manquantes : []


In [35]:
import matplotlib.pyplot as plt
import numpy as np
import rasterio
from rasterio.warp import reproject, Resampling
import os
from glob import glob

# === Dossiers ===
hole_dir = '/content/drive/MyDrive/Whitt/NDVI_grancy_Landsat8_hole'
gt_dir = '/content/drive/MyDrive/Whitt/NDVI_grancy_full'
recon_dir = '/content/drive/MyDrive/Whitt/egfwh_grancy_outputs_GPU'

# === Liste des fichiers ===
hole_files = sorted(glob(os.path.join(hole_dir, '*.tif')))
gt_files   = sorted(glob(os.path.join(gt_dir, '*.tif')))
recon_files= sorted(glob(os.path.join(recon_dir, 'recon_*.tif')))

# === Fonctions utilitaires ===
def read_tif_with_meta(path):
    """Lire un GeoTIFF et convertir nodata en NaN"""
    with rasterio.open(path) as src:
        arr = src.read(1).astype(np.float32)
        meta = src.meta.copy()
    arr[arr == meta.get('nodata', -9999)] = np.nan
    return arr, meta

def resample_to_match(source_arr, source_meta, target_meta):
    """Resample une image vers la grille d'une image cible"""
    dst_arr = np.empty((target_meta['height'], target_meta['width']), dtype=np.float32)
    reproject(
        source=source_arr,
        destination=dst_arr,
        src_transform=source_meta['transform'],
        src_crs=source_meta['crs'],
        dst_transform=target_meta['transform'],
        dst_crs=target_meta['crs'],
        resampling=Resampling.bilinear
    )
    return dst_arr

def extract_date(fname):
    """Extraire la date du nom de fichier (dernier élément après '_')"""
    return os.path.basename(fname).split('_')[-1].replace('.tif','')

hole_dates = [extract_date(f) for f in hole_files]
gt_dates   = [extract_date(f) for f in gt_files]
recon_dates= [extract_date(f) for f in recon_files]

# === Visualisation ===
for i, date in enumerate(hole_dates):
    hole, hole_meta = read_tif_with_meta(hole_files[i])

    # Correspondance GT
    try:
        gt_idx = gt_dates.index(date)
        gt, gt_meta = read_tif_with_meta(gt_files[gt_idx])
        gt_resampled = resample_to_match(gt, gt_meta, hole_meta)
    except ValueError:
        print(f"⚠️ Pas de GT pour {date}, skip")
        continue

    # Correspondance Reconstruction
    try:
        recon_idx = recon_dates.index(date)
        recon, recon_meta = read_tif_with_meta(recon_files[recon_idx])
        recon_resampled = resample_to_match(recon, recon_meta, hole_meta)
    except ValueError:
        print(f"⚠️ Pas de reconstruction pour {date}, skip")
        continue

    # Masque des gaps (NaN dans hole)
    gap_mask = np.isnan(hole)

    # Erreur sur les gaps uniquement
    err = np.full_like(hole, np.nan)
    valid_mask = gap_mask & ~np.isnan(gt_resampled) & ~np.isnan(recon_resampled)
    err[valid_mask] = gt_resampled[valid_mask] - recon_resampled[valid_mask]

    # === Figure ===
    fig, axs = plt.subplots(1, 4, figsize=(16,4))
    cmap_gray_red = plt.cm.gray
    cmap_gray_red.set_bad(color='red')  # NaN en rouge

    axs[0].imshow(np.ma.masked_invalid(hole), cmap=cmap_gray_red, vmin=-1, vmax=1)
    axs[0].set_title(f'GAP - {date}', fontweight='bold'); axs[0].axis('off')

    axs[1].imshow(np.ma.masked_invalid(gt_resampled), cmap=cmap_gray_red, vmin=-1, vmax=1)
    axs[1].set_title(f'GT - {date}', fontweight='bold'); axs[1].axis('off')

    axs[2].imshow(np.ma.masked_invalid(recon_resampled), cmap=cmap_gray_red, vmin=-1, vmax=1)
    axs[2].set_title(f'RE - {date}', fontweight='bold'); axs[2].axis('off')

    im4 = axs[3].imshow(np.ma.masked_invalid(err), cmap='RdYlGn')
    axs[3].set_title('ERR (GT - RE)', fontweight='bold'); axs[3].axis('off')
    plt.colorbar(im4, ax=axs[3], fraction=0.046, pad=0.04)

    plt.tight_layout()
    plt.show()