## Import libraries


In [None]:
import os
from glob import glob

import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from tqdm.auto import tqdm

plt.rcParams["font.family"] = "DeJavu Serif"
plt.rcParams["font.serif"] = "Times New Roman"

import warnings

warnings.filterwarnings("ignore")

WORK_DIR = "/beegfs/halder/GITHUB/RESEARCH/crop-yield-forecasting-germany/"
os.chdir(WORK_DIR)
MAIN_DATA_DIR = "/beegfs/halder/DATA/"
OUT_DIR = os.path.join(WORK_DIR, "output")

In [None]:
CROP = "silage_maize"

## Read the dataset


In [None]:
# Read the NUTS1 and NUTS3 shapefile for DE
de_nuts1_gdf = gpd.read_file(os.path.join(MAIN_DATA_DIR, "DE_NUTS", "DE_NUTS_3.shp"))
de_nuts1_gdf = de_nuts1_gdf[
    de_nuts1_gdf["LEVL_CODE"] == 1
]  # filter only NUT1 level code
de_nuts1_gdf.rename(
    columns={"NUTS_ID": "STATE_ID", "NUTS_NAME": "STATE_NAME"}, inplace=True
)

de_nuts3_gdf = gpd.read_file(os.path.join(MAIN_DATA_DIR, "DE_NUTS", "DE_NUTS_3.shp"))
de_nuts3_gdf = de_nuts3_gdf[
    de_nuts3_gdf["LEVL_CODE"] == 3
]  # filter only NUT3 level code

fig, ax = plt.subplots(figsize=(8, 8))
de_nuts3_gdf.plot(
    ax=ax,
    column="NUTS_NAME",
    cmap="Set3",
    edgecolor="grey",
    linewidth=0.5,
    label="NUTS3",
)
de_nuts1_gdf.plot(ax=ax, facecolor="none", edgecolor="k", linewidth=1, label="NUTS1")
plt.show()

print(de_nuts1_gdf.shape, de_nuts3_gdf.shape)
de_nuts3_gdf.head()

## Extract the remote sensing variables (NDVI, EVI, fPAR, and LAI)


### Initialize and setup Earth Engine


In [None]:
import ee
import geemap
from geeagri.extract import TimeseriesExtractor

# ee.Authenticate()
ee.Initialize(project="ee-geonextgis")

# Instantiate a Map object
Map = geemap.Map(basemap="SATELLITE")
Map

### Extract NDVI and EVI


In [None]:
# Load the crop mask
crop_dict = {
    "winter_wheat": 1110,
    "winter_barley": 1120,
    "silage_maize": 1130,
    "winter_rapeseed": 1430,
    "winter_rye": 1150,
}

de_nuts3_ee = ee.FeatureCollection("projects/ee-geonextgis/assets/DE/DE_NUTS_3")

crop_type_2017 = (
    ee.Image(f"projects/ee-geonextgis/assets/DE/DE_Crop_Type_2017")
    .select(0)
    .eq(crop_dict[CROP])
    .selfMask()
)
crop_type_2018 = (
    ee.Image(f"projects/ee-geonextgis/assets/DE/DE_Crop_Type_2018")
    .select(0)
    .eq(crop_dict[CROP])
    .selfMask()
)
crop_type_2019 = (
    ee.Image(f"projects/ee-geonextgis/assets/DE/DE_Crop_Type_2019")
    .select(0)
    .eq(crop_dict[CROP])
    .selfMask()
)
crop_type_2020 = (
    ee.Image(f"projects/ee-geonextgis/assets/DE/DE_Crop_Type_2020")
    .select(0)
    .eq(crop_dict[CROP])
    .selfMask()
)
crop_type_2021 = (
    ee.Image(f"projects/ee-geonextgis/assets/DE/DE_Crop_Type_2021")
    .select(0)
    .eq(crop_dict[CROP])
    .selfMask()
)

crop_mask = (
    ee.ImageCollection(
        [crop_type_2017, crop_type_2018, crop_type_2019, crop_type_2020, crop_type_2021]
    )
    .reduce(ee.Reducer.max())
    .selfMask()
)

Map.addLayer(crop_mask, {"min": 0, "max": 1, "palette": "yellow"}, "Cropmask")
Map.centerObject(de_nuts3_ee.geometry().bounds(), 6)


def applyBitmask(img):
    # Extract the QA band
    qa = img.select("SummaryQA")

    # Create a mask for bits 0-1 == 00 (good quality)
    goodQualityMask = qa.bitwiseAnd(3).eq(0)

    return img.updateMask(goodQualityMask).copyProperties(img, ["system:time_start"])


# Load the MOD13Q1.061 Terra Vegetation Indices 16-Day Global 250m data
modis = (
    ee.ImageCollection("MODIS/061/MOD13Q1")
    .select(["NDVI", "EVI", "SummaryQA"])
    .map(
        lambda img: img.updateMask(crop_mask).copyProperties(img, ["system:time_start"])
    )
    .map(lambda img: applyBitmask(img))
    .select(["NDVI", "EVI"])
)

output_dir_ndvi = os.path.join(
    WORK_DIR, "data", "interim", "remote_sensing", CROP, "ndvi_evi"
)

if os.path.exists(output_dir_ndvi):
    print("Directory already exists!")
else:
    os.makedirs(output_dir_ndvi, exist_ok=True)
    print("Directory successfully created!")

de_nuts3_gdf_reprojected = de_nuts3_gdf.to_crs("EPSG:4326")

# Extract timeseries in parallel for all samples
ts_extractor = TimeseriesExtractor(
    image_collection=modis,
    sample_gdf=de_nuts3_gdf_reprojected,
    identifier="NUTS_ID",
    out_dir=output_dir_ndvi,
    selectors=["NDVI", "EVI"],
    scale=250,
    num_processes=20,  # parallel processes
    start_date="2000-01-01",
    end_date="2025-08-31",
    reducer="MEAN",
)

# Run extraction
# ts_extractor.extract_timeseries()

### Extract fPAR and LAI


In [None]:
# Define a function to mask low-quality pixels
def applyBitmask(img):
    qc = img.select("FparLai_QC")

    # Bit 0: MODLAND_QC bits ‚Äî keep both 0 and 1, so no filtering
    modland = ee.Image(1)  # always true mask

    # Bits 3‚Äì4: CloudState (00 = clear)
    cloud_state = qc.rightShift(3).bitwiseAnd(3).eq(0)

    # Bits 5‚Äì7: SCF_QC (000 or 001 = good or best quality)
    scf = qc.rightShift(5).bitwiseAnd(7)
    scf_mask = scf.lte(1)

    # Combine masks
    mask = modland.And(cloud_state).And(scf_mask)

    # Apply mask to keep only good pixels
    return (
        img.updateMask(mask)
        .select(["Fpar_500m", "Lai_500m"])
        .copyProperties(img, img.propertyNames())
    )


# Load the MOD15A2H.061: Terra Leaf Area Index/FPAR 8-Day Global 500m data
modis = (
    ee.ImageCollection("MODIS/061/MOD15A2H")
    .select(["Fpar_500m", "Lai_500m", "FparLai_QC"])
    .map(
        lambda img: img.updateMask(crop_mask).copyProperties(img, ["system:time_start"])
    )
    .map(applyBitmask)
)

output_dir_lai = os.path.join(
    WORK_DIR, "data", "interim", "remote_sensing", CROP, "lai_fpar"
)
if os.path.exists(output_dir_lai):
    print("Directory already exists!")
else:
    os.makedirs(output_dir_lai, exist_ok=True)
    print("Directory successfully created!")


# Extract timeseries in parallel for all samples
ts_extractor = TimeseriesExtractor(
    image_collection=modis,
    sample_gdf=de_nuts3_gdf_reprojected,
    identifier="NUTS_ID",
    out_dir=output_dir_lai,
    selectors=["Fpar_500m", "Lai_500m"],
    scale=500,
    num_processes=20,  # parallel processes
    start_date="2000-01-01",
    end_date="2025-08-31",
    reducer="MEAN",
)

# Run extraction
# ts_extractor.extract_timeseries()

## Post process and combine all the remote sensing variables


In [None]:
from scipy.signal import savgol_filter
from concurrent.futures import ProcessPoolExecutor, as_completed

# Store the file paths
ndvi_file_paths, fpar_file_paths = sorted(os.listdir(output_dir_ndvi)), sorted(
    os.listdir(output_dir_lai)
)

print("Number of NDVI files:", len(ndvi_file_paths))
print("Number of FPAR files:", len(fpar_file_paths))

In [None]:
# Function to post-process ndvi and evi data
def process_ndvi(nuts_id):

    index = nuts_id
    file_path = os.path.join(output_dir_ndvi, f"{index}.csv")
    df = pd.read_csv(file_path)

    # Convert to datetime and scale the values
    df["date"] = pd.to_datetime(df["time"])
    df[["ndvi", "evi"]] = df[["NDVI", "EVI"]] * 0.0001

    # Keep only the columns we need
    df = df[["date", "ndvi", "evi"]]

    # Apply the Savitzky-Golay filter
    df["smoothed_ndvi"] = savgol_filter(df["ndvi"], window_length=5, polyorder=2)
    df["smoothed_evi"] = savgol_filter(df["evi"], window_length=5, polyorder=2)

    # Set the 'date' column as the index
    df = df.set_index("date")

    # Create a new daily date range
    start_date = df.index.min()
    end_date = df.index.max()
    daily_index = pd.date_range(start=start_date, end=end_date, freq="D")

    # We only want to interpolate the smoothed values
    df_smooth = df[["smoothed_ndvi", "smoothed_evi"]]

    # Reindex the smooth DataFrame to the new daily index.
    # This creates rows for every day, with NaNs for the new days.
    df_daily = df_smooth.reindex(daily_index)

    # Interpolate to fill in the daily values
    # 'cubic' creates a smooth, continuous curve.
    df_daily_interpolated = df_daily.interpolate(method="cubic")

    # Fill any NaNs at the very start or end
    df_daily_interpolated = df_daily_interpolated.bfill().ffill()
    df_daily_interpolated.reset_index(inplace=True)

    df_daily_interpolated.columns = ["date", "ndvi", "evi"]
    df_daily_interpolated[["ndvi", "evi"]] = df_daily_interpolated[
        ["ndvi", "evi"]
    ].round(3)

    return df_daily_interpolated


# Function to post-process fpar and lai data
def process_fpar(nuts_id):

    index = nuts_id
    file_path = os.path.join(output_dir_lai, f"{index}.csv")
    df = pd.read_csv(file_path)

    # Convert to datetime and scale the values
    df["date"] = pd.to_datetime(df["time"])
    df[["fpar", "lai"]] = df[["Fpar_500m", "Lai_500m"]]
    df["fpar"] = df["fpar"] * 0.01
    df["lai"] = df["lai"] * 0.1

    # Keep only the columns we need
    df = df[["date", "fpar", "lai"]]

    # Apply the Savitzky-Golay filter
    df["smoothed_fpar"] = savgol_filter(df["fpar"], window_length=5, polyorder=2)
    df["smoothed_lai"] = savgol_filter(df["lai"], window_length=5, polyorder=2)

    # Set the 'date' column as the index
    df = df.set_index("date")

    # Create a new daily date range
    start_date = df.index.min()
    end_date = df.index.max()
    daily_index = pd.date_range(start=start_date, end=end_date, freq="D")

    # We only want to interpolate the smoothed values
    df_smooth = df[["smoothed_fpar", "smoothed_lai"]]

    # Reindex the smooth DataFrame to the new daily index.
    # This creates rows for every day, with NaNs for the new days.
    df_daily = df_smooth.reindex(daily_index)

    # Interpolate to fill in the daily values
    # 'cubic' creates a smooth, continuous curve.
    df_daily_interpolated = df_daily.interpolate(method="cubic")

    # Fill any NaNs at the very start or end
    df_daily_interpolated = df_daily_interpolated.bfill().ffill()
    df_daily_interpolated.reset_index(inplace=True)

    df_daily_interpolated.columns = ["date", "fpar", "lai"]
    df_daily_interpolated[["fpar", "lai"]] = df_daily_interpolated[
        ["fpar", "lai"]
    ].round(3)

    return df_daily_interpolated


# Function to compile remote sensing data
def process_remote_sensing(nuts_id, out_dir=None):
    ndvi_df = process_ndvi(nuts_id)
    fpar_df = process_fpar(nuts_id)

    merged_df = pd.merge(ndvi_df, fpar_df, on="date", how="inner")
    file_name = f"{nuts_id}.csv"

    # Skip grids with any missing values
    if merged_df.isna().values.any():
        return None

    # Save if output directory provided
    if out_dir:
        merged_df.to_csv(os.path.join(out_dir, file_name), index=False)

    return merged_df


# Function to compile remote sensing data parallely
def run_parallel(nuts_ids, out_dir, max_workers=8):
    with ProcessPoolExecutor(max_workers=max_workers) as executor:
        futures = {
            executor.submit(process_remote_sensing, nuts_id, out_dir): nuts_id
            for nuts_id in nuts_ids
        }

        for future in as_completed(futures):
            nuts_id = futures[future]
            try:
                future.result()
                print(f"‚úÖ Finished {nuts_id}")
            except Exception as e:
                print(f"‚ö†Ô∏è Error processing {nuts_id}: {e}")

In [None]:
# Post-process the remote sensing files
out_dir = os.path.join(WORK_DIR, "data", "interim", "remote_sensing", CROP, "combined")
if os.path.exists(out_dir):
    print("Directory already exists!")
else:
    os.makedirs(out_dir, exist_ok=True)
    print("Directory successfully created!")

run_parallel(nuts_ids=de_nuts3_gdf["NUTS_ID"].unique(), out_dir=out_dir, max_workers=70)

In [None]:
import shutil


# Replace the error NUTS with the nearest one
def find_nearest_nuts(nuts_id, nuts_gdf):
    g = nuts_gdf.to_crs("EPSG:3857").copy()
    g["centroid"] = g.geometry.centroid
    target = g.loc[g["NUTS_ID"] == nuts_id, "centroid"]

    existed_nuts = [name.replace(".csv", "") for name in os.listdir(out_dir)]
    g = g[g["NUTS_ID"].isin(existed_nuts)]

    if target.empty:
        return None
    target_geom = target.values[0]
    g = g[g["NUTS_ID"] != nuts_id].copy()
    if g.empty:
        return None
    g["dist"] = g["centroid"].distance(target_geom)
    nearest = g.sort_values("dist").iloc[0]["NUTS_ID"]
    return nearest


existed_nuts = [name.replace(".csv", "") for name in os.listdir(out_dir)]

for nuts_id in de_nuts3_gdf["NUTS_ID"].unique():
    if nuts_id not in existed_nuts:
        nearest = find_nearest_nuts(nuts_id, de_nuts3_gdf)
        if nearest is None:
            print(f"‚ùå No nearest NUTS found for {nuts_id}; skipping.")
            continue
        src = os.path.join(out_dir, f"{nearest}.csv")
        dst = os.path.join(out_dir, f"{nuts_id}.csv")
        if os.path.exists(src):
            try:
                shutil.copy(src, dst)
                print(f"üîÅ Replaced {nuts_id} with nearest NUTS file {nearest}.csv")
            except Exception as copy_err:
                print(f"‚ùå Failed to copy fallback file for {nuts_id}: {copy_err}")
        else:
            print(
                f"‚ùå Nearest file {nearest}.csv not available to copy for {nuts_id}; skipping."
            )

print("Number of final files:", len(os.listdir(out_dir)))