In [None]:
import pandas as pd
import numpy as np
import pycountry_convert as pc
import kgcpy
from sklearn.metrics.pairwise import haversine_distances
from sklearn.cluster import AgglomerativeClustering
import geopandas as gpd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from datetime import datetime
from pyproj import Transformer
from scipy.spatial import cKDTree
import rioxarray
import seaborn as sns

In [None]:
pd.set_option('display.max_columns', None)
df = pd.read_csv("data/1_max_precipitation_grid_cells.csv", sep = ",")
df.head(2)

In [None]:
# 1. Preparation
# Mapping os Labels to season (s. https://zenodo.org/records/13165034)
SEASON_LABELS = {
    1: "Winter",
    2: "Spring",
    3: "Summer",
    4: "Autumn",
    5: "Hotter",
    6: "Cooler",
    7: "Dry",
    8: "Wet",
}

# Define file paths 
season_tif = "data/WorldSeasons/season.tif"
meta_txt = "data/WorldSeasons/seasons_data.txt"

# 2. Load and prepare WorldSeasons data
df_pheno = pd.read_csv(meta_txt, quotechar='"', skipinitialspace=True)
print(df_pheno.columns)

# Initialize transformer
season_raster = rioxarray.open_rasterio(season_tif, mask_and_scale=True)
raster_wkt = season_raster.rio.crs # get CRS info
transformer = Transformer.from_crs("EPSG:4326", raster_wkt, always_xy=True)

# Build search tree for x,y coordinates
coords = df_pheno[['x', 'y']].values
tree = cKDTree(coords)

# 3. Define function to retrieve phenological season from cooridnates and date
def get_pheno_info(row, transformer, tree, df_pheno):
    """
    Extrahiert sowohl die ID als auch den Namen der Saison f√ºr eine Zeile.
    """
    lon = row['longitude']
    lat = row['latitude']
    date_str = row['end_date']
    
    # A. Transform lat, lon to meter
    x_m, y_m = transformer.transform(lon, lat)
    
    # B. Search for closest match
    dist, idx = tree.query([x_m, y_m])
    pheno_row = df_pheno.iloc[idx]
    
    # C. Get month from date 
    date_obj = datetime.strptime(date_str, '%Y-%m-%d')
    month_name = date_obj.strftime("%B")
    
    # D. Get Label ID
    label_id = int(pheno_row[month_name])
    
    # E. Mapping anwenden
    season_name = SEASON_LABELS.get(label_id, "Unknown")
    
    return pd.Series([label_id, season_name])   

In [None]:
def get_continent_from_iso(iso_code):
    try:
        iso_alpha2 = pc.country_alpha3_to_country_alpha2(iso_code)
        continent_code = pc.country_alpha2_to_continent_code(iso_alpha2)
        continent_name = pc.convert_continent_code_to_continent_name(continent_code)
        return continent_name
    except:
        return "Unknown"

def get_country_from_iso(iso_code):
    try:
        iso_alpha2 = pc.country_alpha3_to_country_alpha2(iso_code)
        country_name = pc.country_alpha2_to_country_name(iso_alpha2)
        return country_name
    except:
        return "Unknown"

def add_metadata_to_cubes(df):
    print("üöÄ Start Metadata extrakcion...")

    ts = api.load.timescale()
    eph = api.load('de421.bsp')

    # 1. Kontinent hinzuf√ºgen
    df['continent'] = df['iso'].apply(get_continent_from_iso)
    df['country'] = df['iso'].apply(get_country_from_iso)
    df[['pheno_label_id', 'pheno_season_name']] = df.apply(
        lambda row: get_pheno_info(row, transformer, tree, df_pheno), 
        axis=1
    )

    # 2. Define K√∂ppen-Geiger Climate zone
    # kgp uses global grid, to find the respective zone based on lat/lon
    print("üåç Bestimme Klimazonen (K√∂ppen-Geiger)...")
    
    # Extract all zones for rows in df
    zones = []
    for idx, row in df.iterrows():
        try:
            zone = kgcpy.lookupCZ(row['latitude'], row['longitude'])
            zones.append(zone)
        except:
            zones.append("Unknown")
    
    df['koppen_geiger'] = zones

    # 3. Derive more coarse climate class (first letter of Koppen-Geiger Zone)
    # (A=Tropic, B=Arid, C=Warm-termperate, D=Cold, E=Polar)
    df['climate_class'] = df['koppen_geiger'].str[0]

    start_cols = ['DisNo.', 'iso', 'country', 'continent', 'climate_class', 'koppen_geiger', 'pheno_label_id', 'pheno_season_name',  'latitude', 'longitude', 'start_date', 'end_date']

    df = df[start_cols + [c for c in df.columns if c not in start_cols]]

    return df

df = add_metadata_to_cubes(df)
df.sample(10)

In [None]:
# Train-Test-Split
train_idx, test_idx = train_test_split(
    df.index, 
    test_size=0.20, 
    random_state=42,    
    shuffle=True
)

df_final = df.copy()

# Create column for split
df_final["split"] = "train"

# Set test indices to test
df_final.loc[test_idx, 'split'] = 'test'

# Check
print(df_final['split'].value_counts())

In [None]:
ds_train = df_final[df_final["split"] == "train"]
ds_test = df_final[df_final["split"] == "test"]

In [None]:
def plot_train_test_split(ds_train, ds_test, column_to_plot):
    """
    Plots Train and Test sets side-by-side with synchronized colors for the given column.
    """
    # 1. Load world map
    url = "https://naciscdn.org/naturalearth/110m/cultural/ne_110m_admin_0_countries.zip"
    world = gpd.read_file(url)

    # 2. Create a consistent color map for all unique values in the column
    all_values = pd.concat([ds_train[column_to_plot], ds_test[column_to_plot]]).unique()
    unique_values = sorted([v for v in all_values if pd.notna(v)])
    
    # Using 'tab20' for distinct colors; scales automatically to the number of unique values
    colors = plt.cm.tab20(np.linspace(0, 1, len(unique_values)))
    color_dict = dict(zip(unique_values, colors))

    # 3. Convert to GeoDataFrames
    gdf_train = gpd.GeoDataFrame(
        ds_train, geometry=gpd.points_from_xy(ds_train.longitude, ds_train.latitude), crs="EPSG:4326")
    gdf_test = gpd.GeoDataFrame(
        ds_test, geometry=gpd.points_from_xy(ds_test.longitude, ds_test.latitude), crs="EPSG:4326")

    # 4. Setup Plot
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(25, 10), sharex=True, sharey=True)
    bg_settings = {'color': '#eeeeee', 'edgecolor': '#bcbcbc'}
    marker_size = 50

    # Helper to plot each category to ensure legend and color consistency
    def plot_subset(ax, gdf, title):
        world.plot(ax=ax, **bg_settings)
        for val, color in color_dict.items():
            mask = gdf[column_to_plot] == val
            if mask.any():
                gdf[mask].plot(ax=ax, color=color, label=val, markersize=marker_size)
        ax.legend(title=column_to_plot, bbox_to_anchor=(1.05, 1), loc='upper left')
        ax.set_title(f"{title} (n={len(gdf)})", fontsize=16)
        ax.set_aspect("equal")
        ax.grid(alpha=0.2)

    # Execute plotting
    plot_subset(ax1, gdf_train, f"Training Set: {column_to_plot}")
    plot_subset(ax2, gdf_test, f"Test Set: {column_to_plot}")

    fig.supxlabel('Longitude')
    fig.supylabel('Latitude')
    plt.tight_layout()
    plt.show()

# --- Example Usage ---
plot_train_test_split(ds_train, ds_test, 'koppen_geiger')
plot_train_test_split(ds_train, ds_test, 'climate_class')
plot_train_test_split(ds_train, ds_test, 'pheno_season_name')


In [None]:
lc_vars = ['BARE', 'BUILT', 'GRASS-MAN', 'GRASS-NAT', 'SHRUBS', 'TREES', 'WATER']

# Sicherstellen, dass end_date ein datetime-Objekt ist
df_final['end_date'] = pd.to_datetime(df_final['end_date'])

# Jahr und Monat extrahieren
df_final['year'] = df_final['end_date'].dt.year
df_final['month'] = df_final['end_date'].dt.month

# Setze ein sauberes Design
sns.set_theme(style="whitegrid")

# 1. Histogramm/Balkendiagramm f√ºr Climate Classes
plt.figure(figsize=(10, 6))
# Wir sortieren die Klassen (A, B, C, D, E) f√ºr eine bessere Lesbarkeit
order = sorted(df_final['climate_class'].unique())
sns.countplot(data=df_final, x='climate_class', hue='split', order=order, palette='viridis')
plt.title('Distribution of climate classes over train and test')
plt.ylabel('Number of Cubes')
plt.legend(title='Split Set')

# 2. Balkendiagramm f√ºr Kontinente
plt.figure(figsize=(12, 6))
# Sortierung nach H√§ufigkeit im Gesamtdatensatz
continent_order = df_final['continent'].value_counts().index
sns.countplot(data=df_final, x='continent', hue='split', order=continent_order, palette='magma')
plt.title('Distribution of continents over train and test')
plt.ylabel('Number of Cubes')
plt.xticks(rotation=45)
plt.legend(title='Split Set')

# 3. Histogramm f√ºr tp_rollingmax (Numerisch)
plt.figure(figsize=(10, 6))
# Wir nutzen ein Histogramm mit KDE (Dichtesch√§tzung), um die Verteilung zu vergleichen
sns.histplot(data=df_final, x='tp_rollingmax', hue='split', kde=True, element="step", common_norm=False, palette='rocket')
plt.title('Distribution of rolling max (max. precipitation) over  train and test')
plt.xlabel('Max Precipitation')
plt.ylabel('Number of Cubes')


# 3. Histogramm f√ºr tp_rollingmax (Numerisch)
plt.figure(figsize=(10, 6))
# Wir nutzen ein Histogramm mit KDE (Dichtesch√§tzung), um die Verteilung zu vergleichen
sns.histplot(data=df_final, x='year', hue='split', kde=True, element="step", common_norm=False, palette='rocket')
plt.title('Temporal Distribution')
plt.xlabel('Year')
plt.ylabel('Number of Cubes')

# 5. Histogramm √ºber Monate
plt.figure(figsize=(10, 6))
# Monate von 1 bis 12 sortieren
sns.countplot(data=df_final, x='pheno_season_name', hue='split', palette='magma')
plt.title('Pheno seasonal distribution')
plt.xlabel('Seasom')
plt.ylabel('Number of Cubes')
plt.legend(title='Split Set')

for var in lc_vars:
    plt.figure(figsize=(10, 6))
    
    # Da dies numerische Werte innerhalb deiner 1000x1000 Cubes sind, 
    # zeigt ein histplot die Verteilung der Fl√§chenanteile am besten.
    sns.histplot(
        data=df_final, 
        x=var, 
        hue='split', 
        kde=True, 
        element="step", 
        common_norm=False, 
        palette='viridis'
    )
    
    plt.title(f'Verteilung der Landbedeckung: {var}')
    plt.xlabel(f'{var} Anteil/Fl√§che')
    plt.ylabel('Anzahl der Cubes')

In [None]:
start_cols = ['DisNo.', 'split', 'iso', 'country', 'continent', 'climate_class', 'koppen_geiger', 'pheno_label_id', 'pheno_season_name','latitude', 'longitude', 'start_date', 'end_date']
df_final = df_final[start_cols + [c for c in df_final.columns if c not in start_cols]]
df_final.to_csv("data/train_test_split.csv", index = False)