In [None]:
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
import tensorflow as tf

# 1. Load and preprocess
def load_and_prepare(path, suffix):
    df = pd.read_csv(path, parse_dates=['valid_time'])
    df = df.sort_values('valid_time')
    df['tp'] *= 1000  # Convert m to mm
    df = df[['valid_time', 'u10', 'v10', 't2m', 'sp', 'tp']]
    df.columns = ['valid_time'] + [f'{col}_{suffix}' for col in df.columns if col != 'valid_time']
    return df

# note, DO NOT df['tp'] = np.log1p(df['tp'])  # log transform is a BIG NO
#(model learns log transformed AND scaled space, leading to worse performance when reverteed, just use scalar later)


# Load all 9 regions
df_center = load_and_prepare('data/brazil.csv', 'center')
df_north  = load_and_prepare('data/brazil_north.csv', 'north')
df_south  = load_and_prepare('data/brazil_south.csv', 'south')
df_east   = load_and_prepare('data/brazil_east.csv', 'east')
df_west   = load_and_prepare('data/brazil_west.csv', 'west')
df_ne     = load_and_prepare('data/brazil_ne.csv', 'north-east')
df_nw     = load_and_prepare('data/brazil_nw.csv', 'north-west')
df_se     = load_and_prepare('data/brazil_se.csv', 'south-east')
df_sw     = load_and_prepare('data/brazil_sw.csv', 'south-west')

# Merge on valid_time
df = df_center
for regional_df in [df_north, df_south, df_east, df_west, df_ne, df_nw, df_se, df_sw]:
    df = df.merge(regional_df, on='valid_time')

df = df.sort_values('valid_time').reset_index(drop=True)

# 2. Add cyclical time features
df['dayofyear'] = df['valid_time'].dt.dayofyear
df['dayofyear_sin'] = np.sin(2 * np.pi * df['dayofyear'] / 365)
df['dayofyear_cos'] = np.cos(2 * np.pi * df['dayofyear'] / 365)

# 3. Define feature groups
regions = ['center', 'north', 'south', 'east', 'west',
           'north-east', 'north-west', 'south-east', 'south-west']

all_meteorological_features = []
for region in regions:
    all_meteorological_features += [f'{var}_{region}' for var in ['u10', 'v10', 't2m', 'sp', 'tp']]

# 4. Train-test split index (time based)
train_size = int(0.6 * len(df))

# 5. Scale meteorological features
scaler = StandardScaler()
df[all_meteorological_features] = scaler.fit_transform(df[all_meteorological_features].copy())

# Also scale tp_center separately for model output
tp_scaler = StandardScaler()
df['tp_center_scaled'] = tp_scaler.fit_transform(df[['tp_center']].copy())

# 6. Lag, rolling, diff on tp_center
lags = [1, 3, 6, 12, 24]
rolling_windows = [3, 6, 12]

for lag in lags:
    df[f'tp_lag_{lag}h'] = df['tp_center'].shift(lag)

for window in rolling_windows:
    df[f'tp_roll_mean_{window}h'] = df['tp_center'].rolling(window).mean()

df['tp_diff_1h'] = df['tp_center'].diff()


# 7. Wind features
wind_feature_dict = {}
for region in regions:
    u = f'u10_{region}'
    v = f'v10_{region}'
    wind_speed = np.sqrt(df[u]**2 + df[v]**2)
    wind_dir = (np.arctan2(df[v], df[u]) * 180 / np.pi + 360) % 360

    wind_feature_dict[f'wind_speed_{region}'] = wind_speed
    wind_feature_dict[f'wind_direction_{region}'] = wind_dir
    wind_feature_dict[f'wind_speed_change_{region}'] = wind_speed.diff()
    wind_feature_dict[f'wind_direction_change_{region}'] = wind_dir.diff()

# 8. Wind-rain interactions
interaction_feature_dict = {}
for region in regions:
    interaction_feature_dict[f'wind_rain_local_{region}'] = wind_feature_dict[f'wind_speed_{region}'] * df[f'tp_{region}']
    if region != 'center':
        interaction_feature_dict[f'wind_to_center_rain_{region}'] = wind_feature_dict[f'wind_speed_{region}'] * df['tp_center'].shift(1)

# 9. Synoptic gradients and trends
grad_features = {
    # Gradients
    'sp_gradient_ns': df['sp_north'] - df['sp_south'],
    'sp_gradient_ew': df['sp_east'] - df['sp_west'],
    't2m_gradient_ns': df['t2m_north'] - df['t2m_south'],
    't2m_gradient_ew': df['t2m_east'] - df['t2m_west'],
    
    # Magnitudes
    'sp_gradient_mag': lambda d: np.sqrt(d['sp_gradient_ns']**2 + d['sp_gradient_ew']**2),
    't2m_gradient_mag': lambda d: np.sqrt(d['t2m_gradient_ns']**2 + d['t2m_gradient_ew']**2),

    # Divergence
    'u10_divergence': df['u10_east'] - df['u10_west'],
    'v10_divergence': df['v10_north'] - df['v10_south'],
}

# Add divergence sum
grad_features['wind_divergence'] = grad_features['u10_divergence'] + grad_features['v10_divergence']

# Trends
grad_features['sp_center_trend'] = df['sp_center'].diff()
grad_features['t2m_center_trend'] = df['t2m_center'].diff()
grad_features['sp_center_trend_6h'] = df['sp_center'].rolling(6).mean().diff()

# Slopes
grad_features['sp_slope_north'] = df['sp_center'] - df['sp_north']
grad_features['sp_slope_south'] = df['sp_center'] - df['sp_south']
grad_features['sp_slope_east'] = df['sp_center'] - df['sp_east']
grad_features['sp_slope_west'] = df['sp_center'] - df['sp_west']

# Wind center features
grad_features['wind_speed_center'] = np.sqrt(df['u10_center']**2 + df['v10_center']**2)
grad_features['wind_dir_center'] = np.arctan2(df['v10_center'], df['u10_center'])

# Evaluate any lambda functions (used for magnitude)
for k, v in grad_features.items():
    if callable(v):
        grad_features[k] = v(grad_features)

# 10. Concatenate all new features at once
df = pd.concat([
    df,
    pd.DataFrame(wind_feature_dict, index=df.index),
    pd.DataFrame(interaction_feature_dict, index=df.index),
    pd.DataFrame(grad_features, index=df.index),
], axis=1)

# 11. Drop NaNs from lag/rolling/diff
df.dropna(inplace=True)

# 12. De-fragment the DataFrame (optional but recommended)
df = df.copy()


lag_features   = [f'tp_lag_{lag}h' for lag in lags]
roll_features  = [f'tp_roll_mean_{w}h' for w in rolling_windows]
diff_features  = ['tp_diff_1h']
cyclical_features = ['dayofyear_sin', 'dayofyear_cos']
front_features = [
    'sp_gradient_ns', 'sp_gradient_ew', 't2m_gradient_ns', 't2m_gradient_ew',
    'sp_gradient_mag', 't2m_gradient_mag',
    'u10_divergence', 'v10_divergence', 'wind_divergence',
    'sp_center_trend', 't2m_center_trend', 'sp_center_trend_6h',
    'sp_slope_north', 'sp_slope_south', 'sp_slope_east', 'sp_slope_west',
    'wind_speed_center', 'wind_dir_center'
]


# Wind features
wind_features = []
for region in regions:
    wind_features += [
        f'wind_speed_{region}',
        f'wind_direction_{region}',
        f'wind_speed_change_{region}',
        f'wind_direction_change_{region}'
    ]

# Wind-rain interaction features
interaction_features = []
for region in regions:
    interaction_features.append(f'wind_rain_local_{region}')
    if region != 'center':
        interaction_features.append(f'wind_to_center_rain_{region}')

# Front/synoptic features
front_features = [
    'sp_gradient_ns', 'sp_gradient_ew',
    't2m_gradient_ns', 't2m_gradient_ew',
    'sp_gradient_mag', 't2m_gradient_mag',
    'u10_divergence', 'v10_divergence', 'wind_divergence',
    'sp_center_trend', 't2m_center_trend', 'sp_center_trend_6h',
    'sp_slope_north', 'sp_slope_south', 'sp_slope_east', 'sp_slope_west',
    'wind_speed_center', 'wind_dir_center'
]

# Final derived features
derived_features = (
    wind_features +
    lag_features +
    roll_features +
    diff_features +
    interaction_features +
    cyclical_features +
    front_features
)


# Scale all derived features
scaler_derived = StandardScaler()
df[derived_features] = scaler_derived.fit_transform(df[derived_features])

# Final feature columns for model input
feature_cols = all_meteorological_features + derived_features

feature_cols = list(dict.fromkeys(feature_cols))  # Keeps order, removes duplicates

# Optional: check and report
dupes = [col for col in feature_cols if feature_cols.count(col) > 1]
if dupes:
    print(f"Duplicate feature columns removed: {dupes}")
else:
    print("No duplicate feature columns.")



print(f"Final feature columns count: {len(feature_cols)}")
print(f"Length of df after processing: {len(df)}")


# Original sequence creation (short sequences)
def create_sequences(df, seq_len=5, horizon=7):
    X, y = [], []
    for i in range(seq_len, len(df) - horizon):
        X.append(df.iloc[i - seq_len:i][feature_cols].values.astype(np.float32))
        y.append(df.iloc[i:i + horizon]['tp_center_scaled'].values.astype(np.float32))
    return np.array(X, dtype=np.float32), np.array(y, dtype=np.float32)


sequence_length = 5
forecast_horizon = 7
X, y = create_sequences(df, sequence_length, forecast_horizon)
print(f"X shape: {X.shape}, y shape: {y.shape}")

# Rain classification (multi-class)

y_mm = y
y_rain_class = np.digitize(y_mm, bins=[0.1, 0.5, 2.0, 10.0])  # shape: (samples, horizon)

# Prepare long-input sequences
def create_long_sequences(df, long_seq_len=168, horizon=7):
    X_long = []
    for i in range(long_seq_len, len(df) - horizon):
        X_long.append(df.iloc[i - long_seq_len:i][feature_cols].values.astype(np.float32))
    return np.array(X_long, dtype=np.float32)

X_long = create_long_sequences(df, long_seq_len=168, horizon=forecast_horizon)
offset = len(X) - len(X_long)
X = X[offset:]
y = y[offset:]
y_rain_class = y_rain_class[offset:]

# Interaction inputs
def create_interaction_inputs(df, interaction_features, seq_len=5, horizon=7):
    X_wind = []
    for i in range(seq_len, len(df) - horizon):
        X_wind.append(df.iloc[i][interaction_features].values.astype(np.float32))
    return np.array(X_wind, dtype=np.float32)

X_wind = create_interaction_inputs(df, interaction_features, sequence_length, forecast_horizon)


feature_dim = len(feature_cols) + len(interaction_features)

# feature_cols = feature_cols + interaction_features
feature_cols = list(dict.fromkeys(feature_cols + interaction_features))

X, y = create_sequences(df, sequence_length, forecast_horizon)
X_long = create_long_sequences(df, long_seq_len=96, horizon=forecast_horizon)

# Splitting into train/val/test
train_size = int(0.6 * len(X))
val_size   = int(0.2 * len(X))

X_train = X[:train_size]
X_val   = X[train_size:train_size + val_size]
X_test  = X[train_size + val_size:]

y_train = y[:train_size]
y_val   = y[train_size:train_size + val_size]
y_test  = y[train_size + val_size:]

y_rain_train = y_rain_class[:train_size]
y_rain_val   = y_rain_class[train_size:train_size + val_size]
y_rain_test  = y_rain_class[train_size + val_size:]

X_long_train = X_long[:train_size]
X_long_val   = X_long[train_size:train_size + val_size]
X_long_test  = X_long[train_size + val_size:]

# Align X_wind with the final X before splitting
X_wind = X_wind[:len(X)]  # Force same number of samples

# Then do the splits again
X_wind_train = X_wind[:train_size]
X_wind_val   = X_wind[train_size:train_size + val_size]
X_wind_test  = X_wind[train_size + val_size:]

# Sequence generation & splitting

# 1. Set sequence lengths
sequence_length = 5
long_sequence_length = 96  # or 96, if memory-constrained
forecast_horizon = 7

# 2. Get valid sequence indices
max_seq_start = max(sequence_length, long_sequence_length)
end = len(df) - forecast_horizon
valid_indices = np.arange(max_seq_start, end)

# 3. Create aligned sequences
X = np.array([
    df.iloc[i - sequence_length:i][feature_cols].values.astype(np.float32)
    for i in valid_indices
])

X_long = np.array([
    df.iloc[i - long_sequence_length:i][feature_cols].values.astype(np.float32)
    for i in valid_indices
])

y = np.array([
    df.iloc[i:i + forecast_horizon]['tp_center_scaled'].values.astype(np.float32)
    for i in valid_indices
])

# 4. Rain classification targets
y_mm = np.expm1(y)
y_rain_class = np.digitize(y_mm, bins=[0.1, 0.5, 2.0, 10.0])

# 5. Interaction features
X_wind = np.array([
    df.iloc[i][interaction_features].values.astype(np.float32)
    for i in valid_indices
])

# 6. Time features
X_time = df.loc[valid_indices, ['dayofyear_sin', 'dayofyear_cos']].values.astype(np.float32)

# 7. Split into train/val/test
total = len(valid_indices)
train_size = int(0.6 * total)
val_size   = int(0.2 * total)

def split(arr):
    return arr[:train_size], arr[train_size:train_size + val_size], arr[train_size + val_size:]

X_train, X_val, X_test = split(X)
X_long_train, X_long_val, X_long_test = split(X_long)
y_train, y_val, y_test = split(y)
y_rain_train, y_rain_val, y_rain_test = split(y_rain_class)
X_wind_train, X_wind_val, X_wind_test = split(X_wind)
X_time_train, X_time_val, X_time_test = split(X_time)

# 8. Final shape checks
print("Shapes:")
print("X_train:", X_train.shape)
print("X_long_train:", X_long_train.shape)
print("y_train:", y_train.shape)
print("y_rain_train:", y_rain_train.shape)
print("X_wind_train:", X_wind_train.shape)
print("X_time_train:", X_time_train.shape)

# 9. Final assertions
assert X_train.shape[0] == X_long_train.shape[0] == y_train.shape[0] == y_rain_train.shape[0] == X_wind_train.shape[0] == X_time_train.shape[0]
assert X_val.shape[0]   == X_long_val.shape[0]   == y_val.shape[0]   == y_rain_val.shape[0]   == X_wind_val.shape[0]   == X_time_val.shape[0]
assert X_test.shape[0]  == X_long_test.shape[0]  == y_test.shape[0]  == y_rain_test.shape[0]  == X_wind_test.shape[0]  == X_time_test.shape[0]

#note, be VERY careful if modifying, check no data leaks between splits, can use libraries but less robust than manual splitting (easier and more reproducable though)


In [None]:
# feature splits before regional splitting
regions = [f"r{i}" for i in range(9)]

# Helper function to check if a feature name ends with a region suffix
def is_per_region(feature_name):
    return any(feature_name.endswith(f"_{region}") for region in regions)

# Split features into per-region and non-per-region
per_region_features = [f for f in feature_cols if is_per_region(f)]
non_region_features = [f for f in feature_cols if not is_per_region(f)]

# Check the numbers
len_per_region = len(per_region_features)
extra_features = len(feature_cols) - len_per_region

print(f"Per-region features: {len_per_region}")
print(f"Extra (non-region) features: {extra_features}")
print(f"Non-region features: {non_region_features}")

#remapping for classifier
def remap_rain_class(y_rain):
    # Flatten if needed
    y_rain_flat = y_rain.flatten()
    
    new_labels = []
    for val in y_rain_flat:
        if val <= 0.1:
            new_labels.append(0)  # No Rain
        elif val <= 2.0:
            new_labels.append(1)  # Light Rain
        elif val <= 10.0:
            new_labels.append(2)  # Moderate Rain
        else:
            new_labels.append(3)  # Heavy Rain
    return np.array(new_labels).reshape(y_rain.shape)

# Re-binning rain classes from regression targets (or original class targets)
y_rain_train_new = remap_rain_class(y_rain_train)
y_rain_val_new = remap_rain_class(y_rain_val)
y_rain_test_new = remap_rain_class(y_rain_test)

# One-hot encode
from tensorflow.keras.utils import to_categorical
num_classes = 4
y_rain_train_cat = to_categorical(y_rain_train_new, num_classes=num_classes)
y_rain_val_cat = to_categorical(y_rain_val_new, num_classes=num_classes)
y_rain_test_cat = to_categorical(y_rain_test_new, num_classes=num_classes)

In [None]:
# Fixing duplicate feature before fit

#note, if modifying, hybrid model takes in varying inputs, mismatches occur a lot, check shapes, debug, remove duplicates etc. example below for future issues if features are expanded, if need be, modify but CAREFULLY

X_train, X_val, X_test = split(X)
X_long_train, X_long_val, X_long_test = split(X_long)
y_train, y_val, y_test = split(y)
y_rain_train, y_rain_val, y_rain_test = split(y_rain_class)
X_wind_train, X_wind_val, X_wind_test = split(X_wind)
X_time_train, X_time_val, X_time_test = split(X_time)


X_train = X_train[:, :, :-1]  # Drop last feature
X_val = X_val[:, :, :-1]
X_test = X_test[:, :, :-1]   # If using test data later

X_long_train = X_long_train[:, :, :-1]  # Drop last feature
X_long_val = X_long_val[:, :, :-1]
X_long_test = X_long_test[:, :, :-1]  

# weights

import tensorflow.keras.backend as K

def weighted_categorical_crossentropy(alpha):
    alpha = K.constant(alpha)

    def loss(y_true, y_pred):
        y_pred = K.clip(y_pred, K.epsilon(), 1 - K.epsilon())
        loss = -y_true * K.log(y_pred) * alpha
        return K.sum(loss, axis=-1)

    return loss

raw_weights = np.array([0.247, 3.436, 3.188, 6.593])
alpha_vec = raw_weights / np.mean(raw_weights)  # Normalized weights

In [None]:

import tensorflow as tf
from tensorflow.keras.models import Model 
from tensorflow.keras.layers import Layer, Embedding, Add, LayerNormalization, MultiHeadAttention, Dropout, Dense, Reshape, Input, GlobalAveragePooling2D, Concatenate, Lambda, TimeDistributed, Softmax, Conv2D

base_feature_cols = feature_cols
# Assumptions:
num_regions = 9
# Number of features per region (e.g. u10, v10, t2m, sp, tp = 5) + extra derived features that are per-region can be separated accordingly.
# features_per_region = int(len(base_feature_cols) / num_regions) 
features_per_region = int(len(all_meteorological_features) / num_regions)



# Model definitions


class SpatialTemporalTransformerBlock(Layer):
    def __init__(self, feature_dim, num_heads, ff_dim, dropout=0.1, seq_len=5, num_regions=9, **kwargs):
        super().__init__(**kwargs)
        self.seq_len = seq_len
        self.num_regions = num_regions
        self.feature_dim = feature_dim

        # Positional embeddings
        self.temporal_pos_emb = Embedding(input_dim=seq_len, output_dim=feature_dim)
        self.spatial_pos_emb = Embedding(input_dim=num_regions, output_dim=feature_dim)

        # Temporal attention components
        self.temporal_norm1 = LayerNormalization(epsilon=1e-6)
        self.temporal_mha = MultiHeadAttention(num_heads=num_heads, key_dim=feature_dim, dropout=dropout)
        self.temporal_dropout1 = Dropout(dropout)
        self.temporal_norm2 = LayerNormalization(epsilon=1e-6)
        self.temporal_dense1 = Dense(ff_dim, activation='relu')
        self.temporal_dense2 = Dense(feature_dim)
        self.temporal_dropout2 = Dropout(dropout)

        # Spatial attention components
        self.spatial_norm1 = LayerNormalization(epsilon=1e-6)
        self.spatial_mha = MultiHeadAttention(num_heads=num_heads, key_dim=feature_dim, dropout=dropout)
        self.spatial_dropout1 = Dropout(dropout)
        self.spatial_norm2 = LayerNormalization(epsilon=1e-6)
        self.spatial_dense1 = Dense(ff_dim, activation='relu')
        self.spatial_dense2 = Dense(feature_dim)
        self.spatial_dropout2 = Dropout(dropout)

    def call(self, inputs):
        # inputs shape: (batch, seq_len, num_regions, feature_dim)
        batch_size = tf.shape(inputs)[0]

        # Add temporal positional embeddings
        temporal_positions = tf.range(self.seq_len)
        temporal_pos_emb = self.temporal_pos_emb(temporal_positions)  # (seq_len, feature_dim)
        temporal_pos_emb = tf.reshape(temporal_pos_emb, (1, self.seq_len, 1, self.feature_dim))
        x = inputs + temporal_pos_emb

        # Add spatial positional embeddings
        spatial_positions = tf.range(self.num_regions)
        spatial_pos_emb = self.spatial_pos_emb(spatial_positions)  # (num_regions, feature_dim)
        spatial_pos_emb = tf.reshape(spatial_pos_emb, (1, 1, self.num_regions, self.feature_dim))
        x = x + spatial_pos_emb

        # Temporal attention (attention across time for each region)
        x_temporal = tf.reshape(x, (batch_size * self.num_regions, self.seq_len, self.feature_dim))
        attn_input = self.temporal_norm1(x_temporal)
        attn_output = self.temporal_mha(attn_input, attn_input)
        attn_output = self.temporal_dropout1(attn_output)
        out1 = attn_output + x_temporal

        ffn_output = self.temporal_dense1(self.temporal_norm2(out1))
        ffn_output = self.temporal_dense2(ffn_output)
        ffn_output = self.temporal_dropout2(ffn_output)
        temporal_out = ffn_output + out1

        # Reshape back to (batch, seq_len, num_regions, feature_dim)
        temporal_out = tf.reshape(temporal_out, (batch_size, self.num_regions, self.seq_len, self.feature_dim))
        temporal_out = tf.transpose(temporal_out, perm=[0, 2, 1, 3])  # (batch, seq_len, num_regions, feature_dim)

        # Spatial attention (attention across regions for each time step)
        x_spatial = tf.reshape(temporal_out, (batch_size * self.seq_len, self.num_regions, self.feature_dim))
        attn_input2 = self.spatial_norm1(x_spatial)
        attn_output2 = self.spatial_mha(attn_input2, attn_input2)
        attn_output2 = self.spatial_dropout1(attn_output2)
        out2 = attn_output2 + x_spatial

        ffn_output2 = self.spatial_dense1(self.spatial_norm2(out2))
        ffn_output2 = self.spatial_dense2(ffn_output2)
        ffn_output2 = self.spatial_dropout2(ffn_output2)
        spatial_out = ffn_output2 + out2

        # Reshape back to (batch, seq_len, num_regions, feature_dim)
        spatial_out = tf.reshape(spatial_out, (batch_size, self.seq_len, self.num_regions, self.feature_dim))

        return spatial_out




def build_spatial_temporal_model(seq_len, num_regions, features_per_region, interaction_dim, forecast_horizon,
                                  extra_front_dim=0, extra_wind_dim=0):
    feature_dim = features_per_region + extra_front_dim + extra_wind_dim

    # Inputs
    short_input = Input(shape=(seq_len, num_regions * features_per_region), name='short_seq')
    long_input = Input(shape=(96, num_regions * features_per_region), name='long_input')
    wind_input = Input(shape=(interaction_dim,), name='wind_input')

    front_input = Input(shape=(seq_len, num_regions * extra_front_dim), name='front_input')
    wind_seq_input = Input(shape=(seq_len, num_regions * extra_wind_dim), name='wind_seq_input')

    # Reshape Inputs to 4D
    x_short = Reshape((seq_len, num_regions, features_per_region))(short_input)
    x_long = Reshape((96, num_regions, features_per_region))(long_input)

    if extra_front_dim > 0:
        x_front = Reshape((seq_len, num_regions, extra_front_dim))(front_input)
        x_short = Concatenate(axis=-1)([x_short, x_front])

    if extra_wind_dim > 0:
        x_wind_seq = Reshape((seq_len, num_regions, extra_wind_dim))(wind_seq_input)
        x_short = Concatenate(axis=-1)([x_short, x_wind_seq])

    # Spatial-Temporal Encoding
    spatial_temporal_block_short = SpatialTemporalTransformerBlock(feature_dim, num_heads=4, ff_dim=64, dropout=0.1,
                                                                    seq_len=seq_len, num_regions=num_regions)
    spatial_temporal_block_long = SpatialTemporalTransformerBlock(features_per_region, num_heads=4, ff_dim=64,
                                                                   dropout=0.1, seq_len=96, num_regions=num_regions)

    x_short_encoded = spatial_temporal_block_short(x_short)
    x_long_encoded = spatial_temporal_block_long(x_long)

    # Global Pooling
    x_short_pooled = GlobalAveragePooling2D()(x_short_encoded)
    x_long_pooled = GlobalAveragePooling2D()(x_long_encoded)

    # split into two decoders: regression and classification

    ## Regression Decoder
    reg_input = Concatenate()([x_short_pooled, wind_input])
    reg_input = Dense(128, activation='relu')(reg_input)
    reg_input = Dropout(0.2)(reg_input)
    reg_x = Dense(forecast_horizon * 64, activation='relu')(reg_input)
    reg_x = Reshape((forecast_horizon, 64))(reg_x)

    tp_outputs = []
    for day in range(forecast_horizon):
        day_i = Lambda(lambda x: x[:, day:day+1, :])(reg_x)
        day_dense = TimeDistributed(Dense(64, activation='relu'))(day_i)
        day_out = TimeDistributed(Dense(1))(day_dense)
        tp_outputs.append(day_out)

    tp_output = Concatenate(axis=1, name='tp_amount')(tp_outputs)

    ## Classification Decoder 
    # Uses x_short_pooled and x_long_pooled
    cls_input = Concatenate()([x_short_pooled, x_long_pooled, wind_input])
    cls_x = Dense(128, activation='relu')(cls_input)
    cls_x = Dropout(0.3)(cls_x)
    cls_x = Dense(forecast_horizon * 64, activation='relu')(cls_x)
    cls_x = Reshape((forecast_horizon, 64))(cls_x)

    cls_outputs = []
    for day in range(forecast_horizon):
        day_i = Lambda(lambda x: x[:, day:day+1, :])(cls_x)
        dense = TimeDistributed(Dense(64, activation='relu'))(day_i)
        logits = TimeDistributed(Dense(4))(dense)
        cls_outputs.append(logits)

    classifier_logits = Concatenate(axis=1)(cls_outputs)
    rain_output = Softmax(name='rain_class')(classifier_logits)

    # Define Model
    inputs = [short_input, long_input, wind_input]
    if extra_front_dim > 0:
        inputs.append(front_input)
    if extra_wind_dim > 0:
        inputs.append(wind_seq_input)

    model = Model(inputs=inputs, outputs=[tp_output, rain_output])
    return model



# Example build:
sequence_length = 5
num_regions = 9
# features_per_region = int(len(base_feature_cols) / num_regions)  # e.g. 127/9 ≈ 14 (adjust if needed)
features_per_region = int(len(base_feature_cols) / num_regions)
# features_per_region = int(len(all_meteorological_features) / num_regions)
interaction_dim = len(interaction_features)
forecast_horizon = 7

In [None]:
# Model build
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.optimizers import Adam

model = build_spatial_temporal_model(
    seq_len=5,
    num_regions=9,
    features_per_region=features_per_region,
    interaction_dim=len(interaction_features),
    forecast_horizon=7,
    extra_front_dim=0,
    extra_wind_dim=0
)

model.compile(
    optimizer=Adam(learning_rate=1e-3),
    loss={
        'tp_amount': 'mse',
        'rain_class': weighted_categorical_crossentropy(alpha_vec)
    },
    loss_weights={
        'tp_amount': 1.0,
        'rain_class': 1.0
    },
    metrics={'rain_class': 'accuracy'}
)


# Callbacks
callbacks = [
    EarlyStopping(patience=5, restore_best_weights=True),
    ReduceLROnPlateau(patience=3, factor=0.5)
]

# Training
history = model.fit(
    x={
        "short_seq": X_train,
        "long_input": X_long_train,
        "wind_input": X_wind_train
    },
    y={
        "tp_amount": y_train,
        "rain_class": y_rain_train_cat
    },
    validation_data=(
        {
            "short_seq": X_val,
            "long_input": X_long_val,
            "wind_input": X_wind_val
        },
        {
            "tp_amount": y_val,
            "rain_class": y_rain_val_cat
        }
    ),
    epochs=20,
    batch_size=64,
    callbacks=callbacks
)

In [None]:
# Evaluation 
from sklearn.metrics import r2_score
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report

# Predict
y_pred_reg, y_pred_class = model.predict({
    "short_seq": X_test,
    "long_input": X_long_test,
    "wind_input": X_wind_test
})


# unscale
y_pred_reg_unscaled = tp_scaler.inverse_transform(y_pred_reg.reshape(-1, 1)).reshape(y_pred_reg.shape)
y_true_reg_unscaled = tp_scaler.inverse_transform(y_test.reshape(-1, 1)).reshape(y_test.shape)

# R² score (flattened)
r2 = r2_score(y_true_reg_unscaled.flatten(), y_pred_reg_unscaled.flatten())
print(f"R² score (unscaled regression): {r2:.3f}")

from sklearn.metrics import classification_report

# Print true and predicted unique classes to debug in case of data splitting errors and shape mismatch (frequent)
print("Unique true labels:", np.unique(y_rain_test.flatten()))
print("Unique predicted classes:", np.unique(y_pred_class.flatten()))

y_pred_classes = np.argmax(y_pred_class, axis=1)
print("Unique predicted classes:", np.unique(y_pred_classes))

# Choose a forecast day (e.g., day 0 = first day)
forecast_day = 0

# Get true labels for that day
y_true_day = y_rain_test[:, forecast_day]  # shape: (12253,)

# Get predicted class probabilities for that day
y_pred_day_probs = y_pred_class[:, forecast_day, :]  # shape: (12253, 4)

# Convert to predicted class labels
y_pred_day_labels = np.argmax(y_pred_day_probs, axis=1)  # shape: (12253,)

# Debug
print("True labels shape:", y_true_day.shape)
print("Predicted labels shape:", y_pred_day_labels.shape)
print("Unique true labels:", np.unique(y_true_day))
print("Unique predicted labels:", np.unique(y_pred_day_labels))

for day in range(forecast_horizon):  # assuming forecast_horizon = 7
    y_true_day = y_rain_test[:, day]
    y_pred_day_probs = y_pred_class[:, day, :]
    y_pred_day_labels = np.argmax(y_pred_day_probs, axis=1)
    
    print(f"\n📅 Day {day+1} Classification Report:")
    print(classification_report(
        y_true_day,
        y_pred_day_labels,
        labels=[0, 1, 2, 3],
        target_names=[
            'No Rain (≤0.1mm)',
            'Light Rain (0.1–2mm)',
            'Moderate Rain (2–10mm)',
            'Heavy Rain (>10mm)'
        ]
    ))



print("y_true_reg_unscaled shape:", y_true_reg_unscaled.shape)
print("y_pred_reg_unscaled shape:", y_pred_reg_unscaled.shape)

print("R² scores for each forecast day (unscaled regression):\n")

for day in range(forecast_horizon):
    r2 = r2_score(
        y_true_reg_unscaled[:, day],
        y_pred_reg_unscaled[:, day]
    )
    print(f"Day {day+1}: R² = {r2:.3f}")



r2_day_scores = [
    r2_score(y_true_reg_unscaled[:, day], y_pred_reg_unscaled[:, day])
    for day in range(7)
]

plt.figure(figsize=(8, 4))
plt.plot(range(1, 8), r2_day_scores, marker='o')
plt.title("Day-by-Day R² Scores for Rainfall Prediction")
plt.xlabel("Forecast Day")
plt.ylabel("R² Score")
plt.grid(True)
plt.xticks(range(1, 8))
plt.ylim(0, 1)
plt.show()