## Import libraries


In [None]:
import os
from glob import glob

import ee
import geemap
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pylandstats as pls
import rasterio as rio
import seaborn as sns
from rasterio.warp import Resampling, calculate_default_transform, reproject
from rasterstats import zonal_stats
from shapely.geometry import Point
from tqdm.auto import tqdm

plt.rcParams["font.family"] = "DeJavu Serif"
plt.rcParams["font.serif"] = "Times New Roman"

import warnings
from glob import glob

warnings.filterwarnings("ignore")

WORK_DIR = "/beegfs/halder/GITHUB/RESEARCH/crop-yield-forecasting-germany/"
os.chdir(WORK_DIR)
MAIN_DATA_DIR = "/beegfs/halder/DATA/"
OUT_DIR = os.path.join(WORK_DIR, "output")

## Read the dataset


In [None]:
# Read the NUTS1 and NUTS3 shapefile for DE
de_nuts1_gdf = gpd.read_file(os.path.join(MAIN_DATA_DIR, "DE_NUTS", "DE_NUTS_3.shp"))
de_nuts1_gdf = de_nuts1_gdf[
    de_nuts1_gdf["LEVL_CODE"] == 1
]  # filter only NUT1 level code
de_nuts1_gdf.rename(
    columns={"NUTS_ID": "STATE_ID", "NUTS_NAME": "STATE_NAME"}, inplace=True
)

de_nuts3_gdf = gpd.read_file(os.path.join(MAIN_DATA_DIR, "DE_NUTS", "DE_NUTS_3.shp"))
de_nuts3_gdf = de_nuts3_gdf[
    de_nuts3_gdf["LEVL_CODE"] == 3
]  # filter only NUT3 level code

de_nuts1_gdf.to_crs(crs="EPSG:25832", inplace=True)
de_nuts3_gdf.to_crs(crs="EPSG:25832", inplace=True)

fig, ax = plt.subplots(figsize=(8, 8))
de_nuts3_gdf.plot(
    ax=ax,
    column="NUTS_NAME",
    cmap="Set3",
    edgecolor="grey",
    linewidth=0.5,
    label="NUTS3",
)
de_nuts1_gdf.plot(ax=ax, facecolor="none", edgecolor="k", linewidth=1, label="NUTS1")
plt.show()

print(de_nuts1_gdf.shape, de_nuts3_gdf.shape)
de_nuts3_gdf.head()

## Extract soil quality rating using rasterstats


In [None]:
# Path to soil quality raster (250 m resolution)
SOIL_DATA_PATH = os.path.join(
    MAIN_DATA_DIR, "DE_Soil_Quality_Rating_250m", "sqr1000_250_v10.tif"
)
reprojected_raster_path = os.path.join(
    WORK_DIR, "data", "interim", f"sqr1000_250_v10.tif"
)

# Target CRS (Coordinate Reference System)
dst_crs = "EPSG:25832"

output_path = os.path.join(WORK_DIR, "data", "interim", "static_soil_stats.csv")

if os.path.exists(output_path):
    print("Data has been already extacted.")

else:
    # Open source raster
    with rio.open(SOIL_DATA_PATH) as src:
        transform, width, height = calculate_default_transform(
            src.crs, dst_crs, src.width, src.height, *src.bounds
        )

        kwargs = src.meta.copy()
        kwargs.update(
            {"crs": dst_crs, "transform": transform, "width": width, "height": height}
        )

        # Write reprojected raster
        with rio.open(reprojected_raster_path, "w", **kwargs) as dst:
            for i in range(1, src.count + 1):
                reproject(
                    source=rio.band(src, i),
                    destination=rio.band(dst, i),
                    src_transform=src.transform,
                    src_crs=src.crs,
                    dst_transform=transform,
                    dst_crs=dst_crs,
                    resampling=Resampling.nearest,
                )

    print("Raster Saved Successfully!")

    # Collect valid zone indices
    valid_nuts = []

    with rio.open(reprojected_raster_path) as src:
        for idx, row in tqdm(de_nuts3_gdf.iterrows(), total=len(de_nuts3_gdf)):
            try:
                mask, transform = rio.mask.mask(src, [row["geometry"]], crop=True)
                unique_vals = np.unique(mask)
                if len(unique_vals) <= 1:  # Only background or no-data
                    print(f"NUTS ID {row['NUTS_ID']} skipped: {unique_vals}")
                    continue
                valid_nuts.append(row["NUTS_ID"])
            except Exception as e:
                print(f"NUTS ID  {row['NUTS_ID']} caused error: {e}")
                continue

    # Filter the GeoDataFrame
    de_nuts3_filtered = de_nuts3_gdf[de_nuts3_gdf["NUTS_ID"].isin(valid_nuts)].copy()

    # Calculate zonal stats
    stats = zonal_stats(
        de_nuts3_filtered, reprojected_raster_path, stats=["mean", "std"]
    )

    # Add results to GeoDataFrame
    soil_stats = de_nuts3_filtered.copy()
    for key in stats[0].keys():
        soil_stats[key] = [s[key] for s in stats]

    soil_stats.rename(
        columns={"mean": "soil_quality_mean", "std": "soil_quality_stdDev"},
        inplace=True,
    )
    soil_stats = soil_stats[["NUTS_ID", "soil_quality_mean", "soil_quality_stdDev"]]

    # Save the data
    soil_stats.to_csv(output_path, index=False)
    print(f"Data saved at {output_path}")
    print(soil_stats.shape)

## Extract topographical data from GEE


In [None]:
Map = geemap.Map()
Map

In [None]:
# Import the region
de_nuts3_ee = ee.FeatureCollection("projects/ee-geonextgis/assets/DE/DE_NUTS_3")
roi_style = {"fillColor": "00000000", "color": "brown", "width": 1}
Map.addLayer(
    de_nuts3_ee.style(**roi_style),
    {},
    "Germany",
)
Map.centerObject(de_nuts3_ee.geometry(), 6)

In [None]:
# Import elevation data
elev = (
    ee.ImageCollection("projects/sat-io/open-datasets/GLO-30")
    .filterBounds(de_nuts3_ee)
    .mosaic()
)

# Calculate slope from elevation
slope = ee.Terrain.slope(elev)

# Combine elevation and slope into a single image with two bands
topo_image = elev.rename("elevation").addBands(slope.rename("slope"))

elevationPalette = ["006600", "002200", "fff700", "ab7634", "c4d0ff", "ffffff"]
visParams = {"min": 1, "max": 3000, "palette": elevationPalette}
Map.addLayer(topo_image.select("elevation"), visParams, "Elevation")

output_path = os.path.join(WORK_DIR, "data", "interim", "static_topo_stats.csv")

if os.path.exists(output_path):
    print("Data has been already extacted.")

else:
    # Extract mean and std for both bands
    topo_stats = topo_image.reduceRegions(
        collection=de_nuts3_ee,
        reducer=ee.Reducer.mean().combine(
            reducer2=ee.Reducer.stdDev(), sharedInputs=True
        ),
        scale=100,
    )
    topo_stats = geemap.ee_to_df(topo_stats)
    topo_stats = topo_stats[
        ["NUTS_ID", "elevation_mean", "elevation_stdDev", "slope_mean", "slope_stdDev"]
    ]

    # Save the data
    topo_stats.to_csv(output_path, index=False)
    print(f"Data saved at {output_path}")
    print(topo_stats.shape)

## Extract irrigated area data


In [None]:
crops = [
    "winter_wheat",
    "winter_barley",
    "winter_rapeseed",
    "winter_rye",
    "silage_maize",
]

crop_irri_dict = {
    "winter_wheat": "CERE",
    "winter_barley": "CERE",
    "winter_rapeseed": "LRAPE",
    "winter_rye": "CERE",
    "silage_maize": "LMAIZ",
}

for crop in crops:
    print("*" * 25, crop, "*" * 25)
    out_path = os.path.join(
        WORK_DIR, "data", "interim", "irrigation", f"static_irrigated_area_{crop}.csv"
    )

    if os.path.exists(out_path):
        continue

    else:
        # Path to irrigation raster (1km resolution)
        irri_folder_path = os.path.join(MAIN_DATA_DIR, "Crop_IR")

        all_rasters = []
        for year in tqdm(range(2010, 2021)):
            irri_file_path = os.path.join(
                irri_folder_path, str(year), f"{crop_irri_dict[crop]}_IR_A_{year}.tif"
            )
            reprojected_raster_path = os.path.join(
                WORK_DIR,
                "data",
                "interim",
                f"{crop_irri_dict[crop]}_IR_A_{year}_EPSG:25832.tif",
            )

            dst_crs = "EPSG:25832"

            # Reproject raster
            with rio.open(irri_file_path) as src:
                transform, width, height = calculate_default_transform(
                    src.crs, dst_crs, src.width, src.height, *src.bounds
                )
                kwargs = src.meta.copy()
                kwargs.update(
                    {
                        "crs": dst_crs,
                        "transform": transform,
                        "width": width,
                        "height": height,
                    }
                )

                with rio.open(reprojected_raster_path, "w", **kwargs) as dst:
                    for i in range(1, src.count + 1):
                        reproject(
                            source=rio.band(src, i),
                            destination=rio.band(dst, i),
                            src_transform=src.transform,
                            src_crs=src.crs,
                            dst_transform=transform,
                            dst_crs=dst_crs,
                            resampling=Resampling.nearest,
                        )
            all_rasters.append(reprojected_raster_path)

        print("\nAll yearly rasters reprojected successfully!")

        # Stack all rasters and compute mean
        arrays = []
        for fp in all_rasters:
            with rio.open(fp) as src:
                arr = src.read(1, masked=True)
                arrays.append(arr)

        stacked = np.ma.stack(arrays)
        mean_arr = stacked.mean(axis=0)

        # Save mean raster
        with rio.open(all_rasters[0]) as src_ref:
            profile = src_ref.profile
            profile.update(dtype=rio.float32, count=1, nodata=np.nan)

            mean_raster_path = os.path.join(
                WORK_DIR, "data", "interim", f"{crop}_IRRIGATION_2010_2020_mean.tif"
            )
            with rio.open(mean_raster_path, "w", **profile) as dst:
                dst.write(mean_arr.astype(np.float32), 1)

        print(f"Mean raster saved at:\n{mean_raster_path}")

        # Compute zonal statistics (irrigation per grid)
        stats = zonal_stats(de_nuts3_gdf, mean_raster_path, stats=["sum"], nodata=0)

        zones_stats = de_nuts3_gdf.copy()
        for key in stats[0].keys():
            zones_stats[key] = [s[key] for s in stats]

        zones_stats.rename(columns={"sum": "total_irrigated_area"}, inplace=True)
        zones_stats["area_ha"] = zones_stats.geometry.area / 10000  # 1 ha = 10,000 m²

        # Calculate fraction of irrigated area
        zones_stats["irrigated_fraction"] = (
            zones_stats["total_irrigated_area"] / zones_stats["area_ha"]
        )
        zones_stats[["total_irrigated_area", "irrigated_fraction"]] = zones_stats[
            ["total_irrigated_area", "irrigated_fraction"]
        ].fillna(0)

        # Save the data
        zones_stats[
            ["NUTS_ID", "total_irrigated_area", "area_ha", "irrigated_fraction"]
        ].to_csv(os.path.join(out_path), index=False)

        print("\n✅ Irrigation area merged and zonal stats completed successfully!")

## Finalize the static data


In [None]:
for crop in crops:
    print("*" * 25, crop, "*" * 25)
    out_path = os.path.join(WORK_DIR, "data", "processed", crop, f"{crop}_static.csv")
    if os.path.exists(out_path):
        print("File already existed!")
    else:
        region = de_nuts3_gdf[["NUTS_ID", "geometry"]].copy()

        soil_stat = pd.read_csv(
            os.path.join(WORK_DIR, "data", "interim", "static_soil_stats.csv")
        )
        topo_stat = pd.read_csv(
            os.path.join(WORK_DIR, "data", "interim", "static_topo_stats.csv")
        )
        irrigation_stat = pd.read_csv(
            os.path.join(
                WORK_DIR,
                "data",
                "interim",
                "irrigation",
                f"static_irrigated_area_{crop}.csv",
            )
        )

        # Merge all the variables
        static_data = pd.merge(left=region, right=soil_stat, on="NUTS_ID", how="left")
        static_data = pd.merge(
            left=static_data, right=topo_stat, on="NUTS_ID", how="left"
        )
        static_data = pd.merge(
            left=static_data,
            right=irrigation_stat[["NUTS_ID", "irrigated_fraction"]],
            on="NUTS_ID",
            how="left",
        )

        # Fill NaN values
        static_data["centroid"] = static_data.geometry.centroid

        cols_to_fill = [
            col
            for col in static_data.columns
            if col not in ["NUTS_ID", "geometry", "centroid"]
        ]

        for col in cols_to_fill:
            missing_idx = static_data[static_data[col].isna()].index
            for idx in missing_idx:
                this_centroid = static_data.at[idx, "centroid"]
                valid = static_data[static_data[col].notna()]
                if valid.empty:
                    continue
                distances = valid["centroid"].distance(this_centroid)
                nearest_idx = distances.idxmin()
                static_data.at[idx, col] = static_data.at[nearest_idx, col]

            static_data[col] = static_data[col].round(3)

        static_data = static_data.drop(columns="centroid")
        static_data = static_data.sort_values(by="NUTS_ID")

        # Save the data
        static_data.drop(columns=["geometry"]).to_csv(output_path, index=False)
        print(f"Data saved at {out_path}")
        print(static_data.shape)