# 6) Spatio-Temporal Transformer (Multi-Head)

In [None]:


from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt


mpl.rcParams.update({
    "font.family": "serif",
    "font.serif": ["Times New Roman", "Times", "DejaVu Serif"],
    "font.size": 16,
    "axes.titlesize": 16,
    "axes.labelsize": 14,
    "xtick.labelsize": 16,
    "ytick.labelsize": 16,
    "legend.fontsize": 16,
    "figure.dpi": 350,
    "savefig.dpi": 350,
    "savefig.bbox": "tight",
    "axes.grid": True,
    "grid.alpha": 0.25,
    "lines.linewidth": 1.6,
    "axes.spines.top": False,
    "axes.spines.right": False,
})

np.random.seed(42)
pd.set_option("display.max_rows", 200)
pd.set_option("display.max_columns", 200)


def add_value_labels(ax, fmt="{:.3g}", vpad=0.01):
    """Подписи над столбцами bar-chart."""
    ymin, ymax = ax.get_ylim()
    vrange = ymax - ymin if ymax > ymin else 1.0
    for rect in ax.patches:
        h = rect.get_height()
        if np.isfinite(h):
            x = rect.get_x() + rect.get_width() / 2.0
            y = h + vpad * vrange
            ax.text(x, y, fmt.format(h), ha="center", va="bottom")

def show_hist(series: pd.Series, title: str, xlabel: str, bins: int = 60,
              figsize=(9, 6), dpi=350):
    """Гистограмма без сохранения — только show()."""
    s = pd.to_numeric(series, errors="coerce").dropna()
    if s.empty:
        return
    fig = plt.figure(figsize=figsize, dpi=dpi)
    plt.hist(s, bins=bins)
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel("Count")
    plt.tight_layout()
    plt.show()

def heatmap_with_ann(mat: np.ndarray, xlabels, ylabels, title: str,
                     figsize=(10, 8), dpi=350, fmt="{:.2f}"):
    """Матрица с подписями значений внутри ячеек."""
    fig = plt.figure(figsize=figsize, dpi=dpi)
    ax = plt.gca()
    im = ax.imshow(mat, aspect="auto")
    plt.colorbar(im)
    ax.set_xticks(range(len(xlabels)))
    ax.set_xticklabels(list(xlabels), rotation=45, ha="right")
    ax.set_yticks(range(len(ylabels)))
    ax.set_yticklabels(list(ylabels))
    ax.set_title(title)
    ax.grid(False)

    
    nrows, ncols = mat.shape
    norm = plt.Normalize(vmin=np.nanmin(mat), vmax=np.nanmax(mat))
    cmap = im.get_cmap()
    for i in range(nrows):
        for j in range(ncols):
            v = mat[i, j]
            if np.isfinite(v):
                rgba = cmap(norm(v))
                luminance = 0.299*rgba[0] + 0.587*rgba[1] + 0.114*rgba[2]
                ax.text(j, i, fmt.format(v),
                        ha="center", va="center",
                        color=("black" if luminance > 0.5 else "white"))
    plt.tight_layout()
    plt.show()


In [None]:
CANDIDATES = [
    "original_dataset_china/renewables_combined_FULL.csv",
    "original_dataset_china/renewables_combined_CLEAN.csv",
]
DATA_PATH = None
for p in CANDIDATES:
    try:
        _df = pd.read_csv(p, low_memory=False)
        DATA_PATH = p
        break
    except Exception:
        continue


df = _df.copy()
print("Loaded:", DATA_PATH)
print("Shape:", df.shape)
print("Columns:", list(df.columns)[:20], "...")


df["timestamp"] = pd.to_datetime(df.get("timestamp"), errors="coerce")
df["year"] = df["timestamp"].dt.year
df["month"] = df["timestamp"].dt.month
df["day"] = df["timestamp"].dt.day
df["hour"] = df["timestamp"].dt.hour
df["doy"] = df["timestamp"].dt.dayofyear
df["week"] = df["timestamp"].dt.isocalendar().week.astype("Int64")


if "season_derived" not in df.columns:

    def _season(m: int) -> str:
        if m in (12, 1, 2):
            return "winter"
        if m in (3, 4, 5):
            return "spring"
        if m in (6, 7, 8):
            return "summer"
        return "autumn"

    df["season_derived"] = df["month"].map(_season)

print("Time range:", df["timestamp"].min(), "→", df["timestamp"].max())

Loaded: original_dataset_china/renewables_combined_FULL.csv
Shape: (8753, 18)
Columns: ['datetime_solar', 'hour_index_solar', 'temperature_solar', 'humidity_solar', 'surface_irradiance_solar', 'toa_irradiance_solar', 'kWh_solar_power_solar', 'sheet_solar', 'timestamp', 'specific_humidity_solar', 'relative_humidity_solar', 'datetime_wind', 'hour_index_wind', 'air_density_wind', 'wind_speed_wind', 'kWh_wind_power_wind', 'sheet_wind', 'season_derived'] ...
Time range: 2019-01-01 08:00:00 → 2019-12-31 23:00:00


In [None]:

cols = df.columns.tolist()

def find_power_target(suffix: str):
    
    cands = [c for c in cols if c.endswith(suffix) and c.lower().startswith("kwt_")]
    if cands:
        return cands[0]
    
    cands = [c for c in cols if c.endswith(suffix) and ("power" in c.lower())]
    return cands[0] if cands else None

solar_target = find_power_target("_solar")
wind_target  = find_power_target("_wind")


solar_feature_keys = ["irradiance_solar", "surface_irradiance_solar", "toa_irradiance_solar",
                      "temperature_solar", "module_temperature_solar", "humidity_solar",
                      "relative_humidity_solar", "specific_humidity_solar"]
wind_feature_keys  = ["wind_speed_wind", "wind_direction_wind", "air_density_wind"]

solar_feats = [c for c in cols if any(k == c for k in solar_feature_keys)]
wind_feats  = [c for c in cols if any(k == c for k in wind_feature_keys)]

print("Solar target:", solar_target)
print("Wind  target:", wind_target)
print("Solar features:", solar_feats)
print("Wind  features:", wind_feats)


Solar target: kWh_solar_power_solar
Wind  target: kWh_wind_power_wind
Solar features: ['temperature_solar', 'humidity_solar', 'surface_irradiance_solar', 'toa_irradiance_solar', 'specific_humidity_solar', 'relative_humidity_solar']
Wind  features: ['air_density_wind', 'wind_speed_wind']


In [None]:

print("Rows, Cols:", df.shape)
print("Time coverage:", df["timestamp"].min(), "→", df["timestamp"].max())
print("Season distribution:", df["season_derived"].value_counts(dropna=False).to_dict())


if "timestamp" in df.columns:
    dup_count = int(df["timestamp"].duplicated().sum())
    print("Duplicate timestamp rows:", dup_count)


missing = (
    df.isna().sum().sort_values(ascending=False)
      .to_frame("missing_count")
      .assign(missing_pct=lambda s: (s["missing_count"] / df.shape[0] * 100).round(3))
)
print(missing.head(20))


desc = df.select_dtypes(include=[np.number]).describe().T
print(desc.head(20))


print(df.head(5))


Rows, Cols: (8753, 24)
Time coverage: 2019-01-01 08:00:00 → 2019-12-31 23:00:00
Season distribution: {'spring': 2209, 'summer': 2208, 'autumn': 2184, 'winter': 2152}
Duplicate timestamp rows: 1
                         missing_count  missing_pct
datetime_solar                       0          0.0
hour_index_solar                     0          0.0
doy                                  0          0.0
hour                                 0          0.0
day                                  0          0.0
month                                0          0.0
year                                 0          0.0
season_derived                       0          0.0
sheet_wind                           0          0.0
kWh_wind_power_wind                  0          0.0
wind_speed_wind                      0          0.0
air_density_wind                     0          0.0
hour_index_wind                      0          0.0
datetime_wind                        0          0.0
relative_humidity_solar   

In [None]:


if "timestamp" in df.columns:
    if "hour" not in df.columns:
        df["hour"] = pd.to_datetime(df["timestamp"], errors="coerce").dt.hour
    if "month" not in df.columns:
        df["month"] = pd.to_datetime(df["timestamp"], errors="coerce").dt.month

def _hourly_means(series: pd.Series):
    s = pd.to_numeric(series, errors="coerce")
    tmp = pd.DataFrame({"val": s, "hour": df["hour"]}).dropna()
    if tmp.empty:
        return None
    return tmp.groupby("hour")["val"].mean()

def _monthly_means(series: pd.Series):
    s = pd.to_numeric(series, errors="coerce")
    tmp = pd.DataFrame({"val": s, "month": df["month"]}).dropna()
    if tmp.empty:
        return None
    
    return tmp.groupby("month")["val"].mean().reindex(range(1, 13))


In [None]:

def report_count(name, cond):
    n = int(cond.sum())
    print(f"{name}: {n}")


if "hour" in df.columns:
    night = df["hour"].isin([0,1,2,3,4,5,20,21,22,23])

    if "surface_irradiance_solar" in df.columns:
        irr = pd.to_numeric(df["surface_irradiance_solar"], errors="coerce")
        report_count("Night irradiance > 0", (irr > 0) & night)

    if solar_target:
        sp = pd.to_numeric(df[solar_target], errors="coerce")
        report_count("Night solar power > 0", (sp > 0) & night)


if "wind_speed_wind" in df.columns:
    ws = pd.to_numeric(df["wind_speed_wind"], errors="coerce")
    report_count("Negative wind speed", ws < 0)

if solar_target:
    sp = pd.to_numeric(df[solar_target], errors="coerce")
    report_count("Negative solar power", sp < 0)

if wind_target:
    wp = pd.to_numeric(df[wind_target], errors="coerce")
    report_count("Negative wind power", wp < 0)


Night irradiance > 0: 579
Night solar power > 0: 354
Negative wind speed: 0
Negative solar power: 0
Negative wind power: 0


In [None]:




feature_candidates = []
feature_candidates += solar_feats
feature_candidates += wind_feats

for c in ["month", "hour", "doy"]:
    if c in df.columns:
        feature_candidates.append(c)


m = df.dropna(subset=["timestamp"]).copy()
m = m.drop_duplicates(subset=["timestamp"])


def make_Xy(data: pd.DataFrame, target_col: str, features: list):
    X = data[features].copy()
    y = pd.to_numeric(data[target_col], errors="coerce")
    
    mask = (~y.isna())
    for c in X.columns:
        mask &= ~X[c].isna()
    X = X[mask]
    y = y[mask]
    return X, y

X_solar, y_solar = (None, None)
X_wind,  y_wind  = (None, None)

if solar_target:
    X_solar, y_solar = make_Xy(m, solar_target, [c for c in feature_candidates if c in m.columns])
    print("Solar X/y:", X_solar.shape, y_solar.shape)

if wind_target:
    X_wind, y_wind = make_Xy(m, wind_target, [c for c in feature_candidates if c in m.columns])
    print("Wind  X/y:", X_wind.shape, y_wind.shape)


def time_split(df_like: pd.DataFrame, frac_test=0.2):
    n = len(df_like)
    cut = int(np.floor(n*(1-frac_test)))
    idx = np.arange(n)
    return idx[:cut], idx[cut:]

if X_solar is not None:
    idx_tr, idx_te = time_split(X_solar, 0.2)
    print("Solar split:", len(idx_tr), "train /", len(idx_te), "test")

if X_wind is not None:
    idx_tr, idx_te = time_split(X_wind, 0.2)
    print("Wind split:", len(idx_tr), "train /", len(idx_te), "test")


Solar X/y: (8752, 11) (8752,)
Wind  X/y: (8752, 11) (8752,)
Solar split: 7001 train / 1751 test
Wind split: 7001 train / 1751 test


In [None]:
# ===== Fragment 13: Tables for the paper =====
def seasonal_stats(series: pd.Series, name: str):
    s = pd.to_numeric(series, errors="coerce")
    t = pd.DataFrame({"val": s, "season": df["season_derived"]}).dropna()
    if t.empty:
        return
    g = t.groupby("season")["val"].agg(["count","mean","std","min","max"])
    print(f"\nSeasonal stats — {name}")
    print(g)

def monthly_stats(series: pd.Series, name: str):
    s = pd.to_numeric(series, errors="coerce")
    t = pd.DataFrame({"val": s, "month": df["month"]}).dropna()
    if t.empty:
        return
    g = t.groupby("month")["val"].agg(["count","mean","std","min","max"])
    print(f"\nMonthly stats — {name}")
    print(g)

if solar_target:
    seasonal_stats(df[solar_target], "Solar Power (kW)")
    monthly_stats(df[solar_target], "Solar Power (kW)")

if wind_target:
    seasonal_stats(df[wind_target], "Wind Power (kW)")
    monthly_stats(df[wind_target], "Wind Power (kW)")


def table_top_corr(target, k=10, method="pearson"):
    if (target is None) or (target not in df.columns):
        return
    num_df = df.select_dtypes(include=[np.number]).copy()
    corr = num_df.corr(method=method)[target].drop(target).sort_values(ascending=False)
    top = corr.head(k).to_frame(f"corr_with_{target}_{method}")
    print(top)

table_top_corr(solar_target, 10, "pearson")
table_top_corr(solar_target, 10, "spearman")
table_top_corr(wind_target,  10, "pearson")
table_top_corr(wind_target,  10, "spearman")



Seasonal stats — Solar Power (kW)
        count         mean          std  min       max
season                                                
autumn   2184   942.717727  1377.896308  0.0  4238.193
spring   2209  1161.535002  1462.703420  0.0  4500.000
summer   2208   966.957184  1200.696901  0.0  4036.684
winter   2152   721.235507  1237.513954  0.0  4500.000

Monthly stats — Solar Power (kW)
       count         mean          std  min       max
month                                                
1        736   685.842629  1177.053979  0.0  4005.425
2        672  1061.194115  1543.353973  0.0  4500.000
3        744  1050.507254  1467.837257  0.0  4500.000
4        720  1318.606440  1544.096468  0.0  4500.000
5        745  1120.613136  1362.291924  0.0  4300.442
6        720  1144.471549  1320.815891  0.0  4036.684
7        744  1000.190301  1185.015244  0.0  3742.922
8        744   761.935972  1056.473145  0.0  3886.848
9        720  1196.589754  1485.146683  0.0  4197.958
10     

In [None]:

def _find_target(cols, suffix):
    cands = [c for c in cols if c.lower().endswith(suffix) and ("kw" in c.lower()) and ("power" in c.lower())]
    if cands: return cands[0]
    # запасной вариант: любое "power_*suffix"
    cands = [c for c in cols if c.lower().endswith(suffix) and ("power" in c.lower())]
    return cands[0] if cands else None

cols = df.columns.tolist()
target_solar = _find_target(cols, "_solar") 
target_wind  = _find_target(cols, "_wind")  

assert target_solar is not None, "Не найден столбец таргета для солнца"
assert target_wind  is not None, "Не найден столбец таргета для ветра"

base_feature_keys = [

    "surface_irradiance_solar","toa_irradiance_solar","irradiance_solar",
    "temperature_solar","module_temperature_solar",
    "humidity_solar","relative_humidity_solar","specific_humidity_solar",

    "wind_speed_wind","wind_direction_wind","air_density_wind",

    "hour","month","doy"
]
features_all = [c for c in base_feature_keys if c in df.columns]

m = df.dropna(subset=features_all + [target_solar, target_wind]).copy()


X = m[features_all].copy()
y_solar = m[target_solar].astype(float)
y_wind  = m[target_wind].astype(float)

print("Targets:", target_solar, "|", target_wind)
print("Shared features:", features_all)
print("Shapes:", X.shape, y_solar.shape, y_wind.shape)


def time_split(n, frac_test=0.2):
    cut = int(n * (1 - frac_test))
    idx_tr = np.arange(cut)
    idx_te = np.arange(cut, n)
    return idx_tr, idx_te

idx_tr, idx_te = time_split(len(X), 0.2)
Xtr, Xte = X.iloc[idx_tr], X.iloc[idx_te]
ys_tr, ys_te = y_solar.iloc[idx_tr], y_solar.iloc[idx_te]
yw_tr, yw_te = y_wind.iloc[idx_tr],  y_wind.iloc[idx_te]


Targets: kWh_solar_power_solar | kWh_wind_power_wind
Shared features: ['surface_irradiance_solar', 'toa_irradiance_solar', 'temperature_solar', 'humidity_solar', 'relative_humidity_solar', 'specific_humidity_solar', 'wind_speed_wind', 'air_density_wind', 'hour', 'month', 'doy']
Shapes: (8753, 11) (8753,) (8753,)


In [None]:
# ===== 0. Common setup (metrics, splits, CV & report) =====
import numpy as np, pandas as pd
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score, explained_variance_score
from sklearn.model_selection import TimeSeriesSplit
from sklearn.preprocessing import StandardScaler

def _mape(y_true, y_pred, eps=1e-8):
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    m = np.abs((y_true - y_pred) / np.clip(np.abs(y_true), eps, None)).mean()
    return m * 100.0

def compute_metrics(y_true, y_pred):
    mse  = mean_squared_error(y_true, y_pred)
    rmse = np.sqrt(mse)
    mae  = mean_absolute_error(y_true, y_pred)
    mape = _mape(y_true, y_pred)
    r2   = r2_score(y_true, y_pred)
    evs  = explained_variance_score(y_true, y_pred)
    return {"MSE":mse, "RMSE":rmse, "MAE":mae, "MAPE%":mape, "R2":r2, "EVS":evs}


def summarize_cv(cv_metrics_list):
    keys = list(cv_metrics_list[0].keys())
    agg = {}
    n = len(cv_metrics_list)
    for k in keys:
        vals = np.array([d[k] for d in cv_metrics_list], float)
        mean = vals.mean()
        std  = vals.std(ddof=1) if n>1 else 0.0
        from math import sqrt
        from scipy.stats import t
        alpha = 0.10
        tcrit = t.ppf(1 - alpha/2, df=max(n-1,1))
        margin = tcrit * std / sqrt(n) if n>1 else 0.0
        agg[k] = {"mean":mean, "std":std, "CI90_low":mean-margin, "CI90_high":mean+margin}
    return pd.DataFrame(agg).T

# ---- 80/20  ----
def train_test_time_split(y, frac_test=0.2, X_exog=None):
    n = len(y)
    cut = int(np.floor(n*(1-frac_test)))
    sl_tr = slice(0, cut)
    sl_te = slice(cut, n)
    if X_exog is None:
        return (y.iloc[sl_tr], y.iloc[sl_te], None, None)
    else:
        return (y.iloc[sl_tr], y.iloc[sl_te],
                X_exog.iloc[sl_tr] if isinstance(X_exog, pd.DataFrame) else X_exog[sl_tr],
                X_exog.iloc[sl_te]  if isinstance(X_exog, pd.DataFrame) else X_exog[sl_te])

# ---- 5-fold TimeSeriesSplit ----
def folds_time_series(n_splits=5):
    return TimeSeriesSplit(n_splits=n_splits)


report_rows = [] 


In [None]:
# ===== ONE-FRAGMENT — Spatio-Temporal Transformer (STT) [VERBOSE] =====
# Anti-leakage: train-only winsorize/scale; train-only extreme thresholds
# Causality: flat_in = last step of the window (i-1)
# Stability: fixed seeds; safe AUC; gradient clipping
# Visibility: verbose Keras logs (Epoch/steps, per-output losses) + printed summaries

import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"   

import numpy as np, pandas as pd
from scipy.stats import t
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import TimeSeriesSplit
from sklearn.metrics import (mean_squared_error, mean_absolute_error, r2_score,
                             explained_variance_score, roc_auc_score)

import tensorflow as tf
from tensorflow.keras.layers import (Input, Dense, BatchNormalization, Dropout, Add,
                                     LayerNormalization, GlobalAveragePooling1D,
                                     MultiHeadAttention)
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import backend as K


tf.keras.utils.set_random_seed(42)


def _sign_log1p_df(df_):
    A = df_.astype(float).copy()
    return np.sign(A) * np.log1p(np.abs(A))

def _train_only_preprocess(Xtr_raw: pd.DataFrame, Xte_raw: pd.DataFrame):
    """Winsorize по train-квантилям, затем sign-log1p, затем MinMaxScaler.fit(train)."""
    q_low, q_high = Xtr_raw.quantile(0.01), Xtr_raw.quantile(0.99)
    Xtr_raw = Xtr_raw.clip(lower=q_low, upper=q_high, axis=1)
    Xte_raw = Xte_raw.clip(lower=q_low, upper=q_high, axis=1)

    Xtr_t = _sign_log1p_df(Xtr_raw).values
    Xte_t = _sign_log1p_df(Xte_raw).values

    scaler = MinMaxScaler().fit(Xtr_t)
    Xtr = scaler.transform(Xtr_t)
    Xte = scaler.transform(Xte_t)
    return Xtr, Xte

def preprocess(df, features_all, y_solar, y_wind):
    """Совместимость API: возвращаем X_scaled-заглушку и y в sign-log1p."""
    X = df[features_all].astype(float)
    X_scaled = np.zeros_like(X.values, dtype=float)  
    def _log_tr(y):
        y = np.asarray(y).astype(float)
        return np.where(y >= 0, np.log1p(y), -np.log1p(np.abs(y)))
    return X_scaled, _log_tr(y_solar).ravel(), _log_tr(y_wind).ravel()

def make_seq(X, L):
    out=[]
    for i in range(len(X)):
        seq = X[max(0,i-L):i]
        if len(seq)<L: seq = np.pad(seq, ((L-len(seq),0),(0,0)), 'constant')
        out.append(seq)
    return np.stack(out)

def safe_mape(y_true,y_pred,eps=1e-8):
    y_true, y_pred = np.asarray(y_true), np.asarray(y_pred)
    m = np.abs(y_true) > eps
    if m.sum()==0: return np.nan
    return np.mean(np.abs((y_true[m]-y_pred[m])/y_true[m]))*100.0

def safe_auc(y_true_bin, y_score):
    y_true_bin = np.asarray(y_true_bin).astype(int)
    p = int(y_true_bin.sum()); n = int(len(y_true_bin)-p)
    if p==0 or n==0: return np.nan
    return roc_auc_score(y_true_bin, y_score)


class LossScaleLayer(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.log_var_solar_reg = self.add_weight("log_var_solar_reg", initializer="zeros", trainable=True)
        self.log_var_wind_reg  = self.add_weight("log_var_wind_reg",  initializer="zeros", trainable=True)
    def call(self, z): return z

def _u_mse(log_var, l2=1e-4, clip_min=-3.0, clip_max=3.0):
    def loss(y_true, y_pred):
        lv  = tf.clip_by_value(log_var, clip_min, clip_max)
        mse = tf.reduce_mean(tf.square(y_true - y_pred))
        return tf.exp(-lv)*mse + lv + l2*tf.square(lv)
    return loss

try:
    FocalBCE = tf.keras.losses.BinaryFocalCrossentropy
    focal_bce = FocalBCE(gamma=2.0, reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE)
except Exception:
    focal_bce = tf.keras.losses.BinaryCrossentropy()


class TimePositionalEncoding(tf.keras.layers.Layer):
    def __init__(self, d_model: int, **kwargs):
        super().__init__(**kwargs)
        self.d_model = int(d_model)
    def build(self, input_shape):
        d = self.d_model
        i = tf.range(d)[tf.newaxis, :]
        self.i_float = tf.cast(i, tf.float32)
        self.even_mask = tf.cast((i % 2) == 0, tf.bool)
        self.d_float = tf.cast(d, tf.float32)
        self.const_10000 = tf.constant(10000.0, tf.float32)
        super().build(input_shape)
    def _den(self, i_float):
        return tf.pow(self.const_10000, (2.0 * tf.math.floor(i_float / 2.0)) / self.d_float)
    def call(self, x):

        B = tf.shape(x)[0]; L = tf.shape(x)[1]
        pos = tf.cast(tf.range(L)[:, tf.newaxis], tf.float32)      # (L,1)
        angles = pos / self._den(self.i_float)                      # (L,d)
        sin_t, cos_t = tf.sin(angles), tf.cos(angles)
        pe = tf.where(self.even_mask, sin_t, cos_t)                 # (L,d)
        pe = tf.tile(pe[tf.newaxis, ...], [B, 1, 1])                # (B,L,d)
        return x + pe


def transformer_encoder_block(x, num_heads=4, d_model=32,
                              ff_dim1=64, ff_dim2=32, dropout_rate=0.1, eps=1e-5, causal=True):
    key_dim = max(1, d_model // num_heads)  
    x_norm = LayerNormalization(epsilon=eps)(x)
    attn = MultiHeadAttention(num_heads=num_heads, key_dim=key_dim, dropout=dropout_rate)(
        x_norm, x_norm, use_causal_mask=causal
    )
    x = Add()([x, attn])
    x_norm = LayerNormalization(epsilon=eps)(x)
    ff  = Dense(ff_dim1, activation="relu")(x_norm)
    ff  = Dense(ff_dim2, activation=None)(ff)
    ff  = Dropout(dropout_rate)(ff)
    x = Add()([x, ff])
    return LayerNormalization(epsilon=eps)(x)


def build_st_transformer(input_shapes, d_model=32, num_heads=4, n_blocks=4,
                         dropout_rate=0.1, lr=1e-3):
    """
    input_shapes: {'seq': (L, D), 'flat': (D,)} — flat оставлен для совместимости API
    """
    L, D = input_shapes['seq']
    seq_in  = Input(shape=(L, D), name="seq_in")
    flat_in = Input(shape=(D,),   name="flat_in") 
    z = Dense(d_model, name="stt_proj")(seq_in)            
    z = TimePositionalEncoding(d_model=d_model, name="TimePE")(z)
    for _ in range(n_blocks):
        z = transformer_encoder_block(z, num_heads=num_heads, d_model=d_model,
                                      ff_dim1=64, ff_dim2=d_model,
                                      dropout_rate=dropout_rate, eps=1e-5, causal=True)
    z = GlobalAveragePooling1D(name="stt_gap")(z)
    z = BatchNormalization(name="bn_after_gap")(z)
    z = Dense(64, activation="relu")(z)
    z = Dropout(0.2)(z)

    z = LossScaleLayer(name="loss_scales")(z)

    solar_reg = Dense(1, name="solar_reg")(z)
    wind_reg  = Dense(1, name="wind_reg")(z)
    solar_ext = Dense(1, activation="sigmoid", name="solar_ext")(z)
    wind_ext  = Dense(1, activation="sigmoid", name="wind_ext")(z)

    model = Model(inputs=[seq_in, flat_in],
                  outputs=[solar_reg, wind_reg, solar_ext, wind_ext],
                  name="SpatioTemporalTransformer")

    lw = model.get_layer("loss_scales")
    losses = {
        "solar_reg": _u_mse(lw.log_var_solar_reg),
        "wind_reg":  _u_mse(lw.log_var_wind_reg),
        "solar_ext": focal_bce,
        "wind_ext":  focal_bce,
    }

    opt = Adam(learning_rate=lr, clipnorm=1.0)   
    model.compile(optimizer=opt, loss=losses)
    return model


def run_stt_holdout(df, features_all, y_solar, y_wind,
                    n_steps=24, epochs=20, batch=64, lr=1e-3, verbose=1):
    _, ys_all, yw_all = preprocess(df, features_all, y_solar, y_wind)

    cut = int(len(df)*0.8)
    Xtr_raw, Xte_raw = df.iloc[:cut][features_all].copy(), df.iloc[cut:][features_all].copy()
    Xtr, Xte = _train_only_preprocess(Xtr_raw, Xte_raw)

    ys_tr, ys_te = ys_all[:cut], ys_all[cut:]
    yw_tr, yw_te = yw_all[:cut], yw_all[cut:]

    L = min(n_steps, len(Xtr))
    Xtr_seq, Xte_seq = make_seq(Xtr, L), make_seq(Xte, L)
    Xtr_flat, Xte_flat = Xtr_seq[:, -1, :], Xte_seq[:, -1, :]

    thr_s = np.percentile(ys_tr, 95); thr_w = np.percentile(yw_tr, 95)
    ys_tr_ext = (ys_tr > thr_s).astype(int); ys_te_ext = (ys_te > thr_s).astype(int)
    yw_tr_ext = (yw_tr > thr_w).astype(int); yw_te_ext = (yw_te > thr_w).astype(int)

    model = build_st_transformer({'seq':(L, Xtr.shape[1]), 'flat':(Xtr.shape[1],)}, lr=lr)

    # печатаем эпохи/шаги
    model.fit([Xtr_seq, Xtr_flat],
              {"solar_reg": ys_tr, "wind_reg": yw_tr,
               "solar_ext": ys_tr_ext, "wind_ext": yw_tr_ext},
              validation_data=([Xte_seq, Xte_flat],
                               {"solar_reg": ys_te, "wind_reg": yw_te,
                                "solar_ext": ys_te_ext, "wind_ext": yw_te_ext}),
              epochs=epochs, batch_size=batch, verbose=verbose)

    pr = model.predict([Xte_seq, Xte_flat], verbose=1)
    s_pred, w_pred = pr[0].ravel(), pr[1].ravel()
    s_ext,  w_ext  = pr[2].ravel(), pr[3].ravel()

    metrics = {
        "Solar": {"RMSE": np.sqrt(mean_squared_error(ys_te, s_pred)),
                  "MAE":  mean_absolute_error(ys_te, s_pred),
                  "MAPE": safe_mape(ys_te, s_pred),
                  "R2":   r2_score(ys_te, s_pred),
                  "EVS":  explained_variance_score(ys_te, s_pred),
                  "AUC":  safe_auc(ys_te_ext, s_ext)},
        "Wind":  {"RMSE": np.sqrt(mean_squared_error(yw_te, w_pred)),
                  "MAE":  mean_absolute_error(yw_te, w_pred),
                  "MAPE": safe_mape(yw_te, w_pred),
                  "R2":   r2_score(yw_te, w_pred),
                  "EVS":  explained_variance_score(yw_te, w_pred),
                  "AUC":  safe_auc(yw_te_ext, w_ext)}
    }
    print("\n=== Holdout 80/20 ===")
    print(pd.DataFrame(metrics).T)
    return metrics


def run_stt_cv(df, features_all, y_solar, y_wind,
               n_steps=24, epochs=20, batch=64, lr=1e-3, n_splits=5, verbose=1):
    _, ys_all, yw_all = preprocess(df, features_all, y_solar, y_wind)
    tscv = TimeSeriesSplit(n_splits=n_splits)

    scores_s, scores_w, aucs = [], [], []
    for f,(tr,te) in enumerate(tscv.split(df), 1):
        print(f"\nFOLD {f}")
        K.clear_session()
        tf.keras.utils.set_random_seed(42 + f)

        Xtr_raw, Xte_raw = df.iloc[tr][features_all].copy(), df.iloc[te][features_all].copy()
        Xtr, Xte = _train_only_preprocess(Xtr_raw, Xte_raw)

        ys_tr, ys_te = ys_all[tr], ys_all[te]
        yw_tr, yw_te = yw_all[tr], yw_all[te]

        L = min(n_steps, len(Xtr))
        Xtr_seq, Xte_seq = make_seq(Xtr, L), make_seq(Xte, L)
        Xtr_flat, Xte_flat = Xtr_seq[:, -1, :], Xte_seq[:, -1, :]

        thr_s = np.percentile(ys_tr, 95); thr_w = np.percentile(yw_tr, 95)
        ys_tr_ext = (ys_tr > thr_s).astype(int); ys_te_ext = (ys_te > thr_s).astype(int)
        yw_tr_ext = (yw_tr > thr_w).astype(int); yw_te_ext = (yw_te > thr_w).astype(int)

        m = build_st_transformer({'seq':(L, Xtr.shape[1]), 'flat':(Xtr.shape[1],)}, lr=lr)

        m.fit([Xtr_seq, Xtr_flat],
              {"solar_reg": ys_tr, "wind_reg": yw_tr,
               "solar_ext": ys_tr_ext, "wind_ext": yw_tr_ext},
              epochs=epochs, batch_size=batch, verbose=verbose)

        _ = m.predict([Xte_seq, Xte_flat], verbose=1)  
        pr = m.predict([Xte_seq, Xte_flat], verbose=0) 
        s_pred, w_pred = pr[0].ravel(), pr[1].ravel()
        s_ext,  w_ext  = pr[2].ravel(), pr[3].ravel()

        ms = {"MSE": mean_squared_error(ys_te, s_pred),
              "RMSE": np.sqrt(mean_squared_error(ys_te, s_pred)),
              "MAE": mean_absolute_error(ys_te, s_pred),
              "MAPE": safe_mape(ys_te, s_pred),
              "R2": r2_score(ys_te, s_pred),
              "EVS": explained_variance_score(ys_te, s_pred)}
        mw = {"MSE": mean_squared_error(yw_te, w_pred),
              "RMSE": np.sqrt(mean_squared_error(yw_te, w_pred)),
              "MAE": mean_absolute_error(yw_te, w_pred),
              "MAPE": safe_mape(yw_te, w_pred),
              "R2": r2_score(yw_te, w_pred),
              "EVS": explained_variance_score(yw_te, w_pred)}
        scores_s.append(ms); scores_w.append(mw)
        aucs.append({"AUC_s": safe_auc(ys_te_ext, s_ext),
                     "AUC_w": safe_auc(yw_te_ext, w_ext)})

        print(f"Solar fold {f}: {ms}")
        print(f"Wind  fold {f}: {mw}")

    df_s = pd.DataFrame(scores_s); df_w = pd.DataFrame(scores_w); df_auc = pd.DataFrame(aucs)

    def summarize(dfm):
        n=len(dfm); m=dfm.mean(); sd=dfm.std(ddof=1)
        tcrit=t.ppf(1-0.05, df=max(n-1,1))
        low=m - tcrit*sd/np.sqrt(max(n,1)); high=m + tcrit*sd/np.sqrt(max(n,1))
        return pd.concat([m.rename('mean'), sd.rename('std'),
                          low.rename('CI90_low'), high.rename('CI90_high')], axis=1)

    print("\n=== CV Summary Solar ==="); print(summarize(df_s))
    print("\n=== CV Summary Wind ===");  print(summarize(df_w))
    print("\n=== CV Summary AUCs ===");  print(df_auc.mean(numeric_only=True))
    return df_s, df_w, df_auc


hold = run_stt_holdout(df, features_all, y_solar, y_wind,
                       n_steps=24, epochs=20, batch=64, lr=1e-3, verbose=1)
cv_s, cv_w, cv_auc = run_stt_cv(df, features_all, y_solar, y_wind,
                                n_steps=24, epochs=20, batch=64, lr=1e-3, n_splits=5, verbose=1)


Epoch 1/20


2025-09-02 12:23:28.893721: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:961] model_pruner failed: INVALID_ARGUMENT: Graph does not contain terminal node Adam/AssignAddVariableOp.


Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20

=== Holdout 80/20 ===
           RMSE       MAE        MAPE        R2       EVS       AUC
Solar  2.149901  2.044666   32.315917  0.594692  0.610020  0.858377
Wind   4.715794  4.523226  121.947291 -8.269877 -0.005441  0.621057

FOLD 1
Epoch 1/20


2025-09-02 12:28:04.996500: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:961] model_pruner failed: INVALID_ARGUMENT: Graph does not contain terminal node Adam/AssignAddVariableOp.


Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
Solar fold 1: {'MSE': 17.18682822609355, 'RMSE': 4.145699968171063, 'MAE': 3.608615426328767, 'MAPE': 71.5001998949776, 'R2': -0.23251485315375242, 'EVS': 0.08392783055970143}
Wind  fold 1: {'MSE': 11.814726423333386, 'RMSE': 3.4372556528913276, 'MAE': 3.234811006436605, 'MAPE': 180.17723886231542, 'R2': -2.91428686304319, 'EVS': 0.01680953969677712}

FOLD 2
Epoch 1/20


2025-09-02 12:29:14.576390: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:961] model_pruner failed: INVALID_ARGUMENT: Graph does not contain terminal node Adam/AssignAddVariableOp.


Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
Solar fold 2: {'MSE': 1.687293460037563, 'RMSE': 1.2989586059754032, 'MAE': 0.9223538958256406, 'MAPE': 14.064814284579679, 'R2': 0.8627009559876531, 'EVS': 0.8892721936535035}
Wind  fold 2: {'MSE': 4.332456517806232, 'RMSE': 2.081455384534156, 'MAE': 1.873383102724139, 'MAPE': 81.34855176424843, 'R2': -0.6506155700224772, 'EVS': 0.10553352993628651}

FOLD 3
Epoch 1/20


2025-09-02 12:31:19.932126: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:961] model_pruner failed: INVALID_ARGUMENT: Graph does not contain terminal node Adam/AssignAddVariableOp.


Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
Solar fold 3: {'MSE': 167.88649723579402, 'RMSE': 12.957102192843662, 'MAE': 12.906900374809277, 'MAPE': 254.7736811045811, 'R2': -12.952949026229163, 'EVS': 0.8920890739705555}
Wind  fold 3: {'MSE': 4.119963612932588, 'RMSE': 2.0297693496879363, 'MAE': 1.8447287321811043, 'MAPE': 480.3036469197862, 'R2': -0.15177439418332006, 'EVS': 0.24858613217854497}

FOLD 4
Epoch 1/20


2025-09-02 12:34:18.119133: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:961] model_pruner failed: INVALID_ARGUMENT: Graph does not contain terminal node Adam/AssignAddVariableOp.


Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
Solar fold 4: {'MSE': 4.514918724503366, 'RMSE': 2.124833811031669, 'MAE': 1.9893107533557206, 'MAPE': 31.90471214372322, 'R2': 0.6747427775205236, 'EVS': 0.676655471751574}
Wind  fold 4: {'MSE': 3.0425333675014206, 'RMSE': 1.7442859190801892, 'MAE': 1.1904333835212122, 'MAPE': 201.12367372120383, 'R2': -0.0693453068357408, 'EVS': -0.04758735933796787}

FOLD 5
Epoch 1/20


2025-09-02 12:37:59.355439: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:961] model_pruner failed: INVALID_ARGUMENT: Graph does not contain terminal node Adam/AssignAddVariableOp.


Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
Solar fold 5: {'MSE': 63.96678458675802, 'RMSE': 7.997923767250974, 'MAE': 7.915762819864198, 'MAPE': 124.15908611567663, 'R2': -4.699893154753834, 'EVS': 0.8818550529738665}
Wind  fold 5: {'MSE': 123.64979152575809, 'RMSE': 11.119792782500854, 'MAE': 10.99771862629951, 'MAPE': 595.3843027076605, 'R2': -48.95246838510115, 'EVS': -0.09074581700199791}

=== CV Summary Solar ===
           mean        std   CI90_low   CI90_high
MSE   51.048464  69.940180 -15.631865  117.728794
RMSE   5.704904   4.809176   1.119879   10.289928
MAE    5.468589   4.938377   0.760385   10.176792
MAPE  99.280499  96.656645   7.128936  191.432062
R2    -3.269583   5.866656  -8.862799    2.323634
EVS    0.684760   0.348110   0.352875    1.016645

=== CV Summary Wind ===
            mean         s