In [0]:
from entsoe import EntsoePandasClient
import pandas as pd
import time
import re
from pyspark.sql import functions as F

# ============================================================================
# GLOBAL CONFIG VARIABLES (REPLACING THE OLD Config CLASS)
# ============================================================================

API_KEY = '7b785108-53d7-42f8-931e-3d28c4323c68'

COUNTRIES = {
    'ES': 'Spain', 'PT': 'Portugal', 'FR': 'France', 'DE': 'Germany',
    'IT': 'Italy', 'GB': 'Great Britain', 'NL': 'Netherlands',
    'BE': 'Belgium', 'AT': 'Austria', 'CH': 'Switzerland', 'PL': 'Poland',
    'CZ': 'Czechia', 'DK': 'Denmark', 'SE': 'Sweden', 'NO': 'Norway',
    'FI': 'Finland', 'GR': 'Greece', 'IE': 'Ireland', 'RO': 'Romania',
    'BG': 'Bulgaria', 'HU': 'Hungary', 'SK': 'Slovakia', 'SI': 'Slovenia',
    'HR': 'Croatia', 'EE': 'Estonia', 'LT': 'Lithuania', 'LV': 'Latvia'
}

# YOU REQUESTED THIS EXACT BLOCK KEPT UNCHANGED
VALID_BORDERS = {
    ('ES', 'PT'), ('ES', 'FR'),
    ('FR', 'BE'), ('FR', 'CH'), ('FR', 'DE'), ('FR', 'IT'),
    ('BE', 'NL'), ('BE', 'DE'),
    ('NL', 'DE'), ('NL', 'GB'),
    ('GB', 'NL'), ('GB', 'FR'), ('GB', 'IE'),
    ('DE', 'CZ'), ('DE', 'PL'), ('DE', 'CH'), ('DE', 'DK'), ('DE', 'AT'),
    ('DK', 'DE'), ('DK', 'NO'), ('DK', 'SE'),
    ('SE', 'NO'), ('SE', 'FI'), ('SE', 'DK'),
    ('NO', 'NL'), ('NO', 'GB'), ('NO', 'SE'), ('NO', 'DK'),
    ('FI', 'EE'), ('FI', 'SE'),
    ('EE', 'LV'),
    ('LV', 'LT'),
    ('LT', 'PL'),
    ('PL', 'SK'), ('PL', 'CZ'),
    ('CZ', 'AT'), ('CZ', 'SK'),
    ('AT', 'SI'), ('AT', 'IT'), ('AT', 'CH'), ('AT', 'CZ'), ('AT', 'DE'),
    ('SI', 'HR'), ('SI', 'IT'), ('SI', 'AT'),
    ('HR', 'HU'), ('HR', 'SI'),
    ('HU', 'SK'), ('HU', 'RO'), ('HU', 'HR'), ('HU', 'AT'),
    ('SK', 'HU'), ('SK', 'CZ'), ('SK', 'PL'),
    ('RO', 'BG'), ('RO', 'HU'),
    ('BG', 'GR'), ('BG', 'RO'),
    ('GR', 'BG')
}

START_DATE = '2023-01-01'
END_DATE   = '2025-10-31'

# Databricks Delta database name
DATABASE = "european_grid_raw"

DATASETS = [
    "load_actual",
    "load_forecast",
    "generation",
    "wind_forecast",
    "solar_forecast",
    "installed_capacity",
    "crossborder_flows",
]

# ============================================================================
# COLUMN SANITIZER
# ============================================================================

INVALID_CHARS_PATTERN = re.compile(r"[^0-9A-Za-z_]+")

def sanitize_columns(df: pd.DataFrame) -> pd.DataFrame:
    """
    Flatten MultiIndex columns and remove characters invalid for Delta.
    - MultiIndex ('Biomass', 'Actual Aggregated') -> 'Biomass__Actual_Aggregated'
    - Replace any non [0-9A-Za-z_] chars with '_'
    - If name starts with a digit, prefix with '_'
    """
    new_cols = []
    for col in df.columns:
        # Flatten MultiIndex tuples
        if isinstance(col, tuple):
            parts = [str(p) for p in col if p is not None and str(p) != ""]
            name = "__".join(parts) if parts else "col"
        else:
            name = str(col)

        name = name.strip().replace(" ", "_")

        # Replace any remaining bad chars (., /, -, quotes, parens, etc.)
        name = INVALID_CHARS_PATTERN.sub("_", name)

        # Avoid starting with a digit
        if name and name[0].isdigit():
            name = "_" + name

        new_cols.append(name)

    df.columns = new_cols
    return df

# ============================================================================
# TABLE MANAGEMENT
# ============================================================================

def truncate_table(dataset_name: str):
    """
    TRUNCATE the Delta table if it exists.
    Keeps schema, just removes all rows.
    """
    full_name = f"{DATABASE}.{dataset_name}"
    if spark.catalog.tableExists(full_name):
        print(f"  → Truncating table {full_name}")
        spark.sql(f"TRUNCATE TABLE {full_name}")
    else:
        print(f"  → Table {full_name} does not exist yet (will be created on write)")

# ============================================================================
# DELTA WRITER
# ============================================================================

def write_dataset(dataset_name: str, country_code: str, df: pd.DataFrame):
    """
    Write a single country dataset into a shared Delta table.

    - Table name: european_grid_raw.<dataset_name>
      e.g. european_grid_raw.load_actual
    - Adds column:
        * country  (2-letter code, e.g. 'DE')
    """
    if df is None or len(df) == 0:
        return

    # bring index into a column (e.g. time) and add country
    df = df.reset_index()
    df["country"] = country_code

    # sanitize columns (handles MultiIndex + bad characters)
    df = sanitize_columns(df)

    table_name = dataset_name
    full_name = f"{DATABASE}.{table_name}"

    spark_df = spark.createDataFrame(df)

    (
        spark_df.write
        .format("delta")
        .mode("append")
        # optional but recommended for performance:
        # .partitionBy("country")
        .saveAsTable(full_name)
    )

    print(f"  → Saved to Delta: {full_name} ({spark_df.count()} rows)")

# ============================================================================
# DATA COLLECTOR
# ============================================================================

class EuropeanGridDataCollector:

    def __init__(self, api_key):
        self.client = EntsoePandasClient(api_key=api_key)
        self.countries = COUNTRIES

        self.start = pd.Timestamp(START_DATE, tz="UTC")
        self.end   = pd.Timestamp(END_DATE,   tz="UTC")

    # -------------------------------
    # SINGLE COUNTRY DATA
    # -------------------------------
    def collect_country_data(self, country_code):
        c = country_code
        print(f"\n==== Collecting for {c} ({self.countries[c]}) ====")

        try:
            print(f"    → load_actual...")
            df = self.client.query_load(c, start=self.start, end=self.end)
            write_dataset("load_actual", c, df)
        except Exception as e:
            print(f"    ✗ load_actual: {e}")
        time.sleep(1)

        try:
            print(f"    → load_forecast...")
            df = self.client.query_load_forecast(c, start=self.start, end=self.end)
            write_dataset("load_forecast", c, df)
        except Exception as e:
            print(f"    ✗ load_forecast: {e}")
        time.sleep(1)

        try:
            print(f"    → generation...")
            df = self.client.query_generation(c, start=self.start, end=self.end)
            write_dataset("generation", c, df)
        except Exception as e:
            print(f"    ✗ generation: {e}")
        time.sleep(1)

        try:
            print(f"    → wind_forecast...")
            df = self.client.query_wind_and_solar_forecast(
                c, start=self.start, end=self.end, psr_type='B19'
            )
            write_dataset("wind_forecast", c, df)
        except Exception as e:
            print(f"    ✗ wind_forecast: {e}")
        time.sleep(1)

        try:
            print(f"    → solar_forecast...")
            df = self.client.query_wind_and_solar_forecast(
                c, start=self.start, end=self.end, psr_type='B16'
            )
            write_dataset("solar_forecast", c, df)
        except Exception as e:
            print(f"    ✗ solar_forecast: {e}")
        time.sleep(1)

        try:
            print(f"    → installed_capacity...")
            df = self.client.query_installed_generation_capacity(
                c, start=self.start, end=self.end
            )
            write_dataset("installed_capacity", c, df)
        except Exception as e:
            print(f"    ✗ installed_capacity: {e}")

    # -------------------------------
    # CROSS-BORDER FLOWS
    # -------------------------------
    def collect_crossborder_flows(self):
        print("\n=== Collecting Cross-Border Flows ===")

        flows_list = []

        for from_c, to_c in VALID_BORDERS:
            print(f"  → {from_c} ↔ {to_c}...", end="")

            try:
                flow = self.client.query_crossborder_flows(
                    from_c, to_c, start=self.start, end=self.end
                )
                if flow is not None and len(flow) > 0:
                    df = pd.DataFrame(flow)
                    df["from_country"] = from_c
                    df["to_country"]   = to_c
                    flows_list.append(df)
                    print(" ✓")
                else:
                    print(" ✗ No data")
            except Exception:
                print(" ✗ Failed")

            time.sleep(0.5)

        if flows_list:
            df_all = pd.concat(flows_list, ignore_index=True)
            df_all = sanitize_columns(df_all)

            table_name = "crossborder_flows"
            full_name = f"{DATABASE}.{table_name}"

            spark_df = spark.createDataFrame(df_all)

            (
                spark_df.write
                .format("delta")
                .mode("append")
                .saveAsTable(full_name)
            )

            print(f"  → Saved cross-border flows table {full_name} ({spark_df.count()} rows)")

    # -------------------------------
    # MAIN COLLECTOR
    # -------------------------------
    def collect_all(self):
        print("=== Truncating tables for fresh run ===")
        for ds in DATASETS:
            truncate_table(ds)

        for c in self.countries.keys():
            self.collect_country_data(c)

        self.collect_crossborder_flows()

# ============================================================================
# RUN PIPELINE
# ============================================================================

collector = EuropeanGridDataCollector(api_key=API_KEY)
collector.collect_all()
print("\nCOMPLETE.")