In [1]:
pip install jdatetime

Collecting jdatetime
  Downloading jdatetime-5.2.0-py3-none-any.whl.metadata (5.6 kB)
Collecting jalali-core>=1.0 (from jdatetime)
  Downloading jalali_core-1.0.0-py3-none-any.whl.metadata (738 bytes)
Downloading jdatetime-5.2.0-py3-none-any.whl (12 kB)
Downloading jalali_core-1.0.0-py3-none-any.whl (3.6 kB)
Installing collected packages: jalali-core, jdatetime
Successfully installed jalali-core-1.0.0 jdatetime-5.2.0


In [4]:
import pandas as pd
import numpy as np
import jdatetime
from sklearn.preprocessing import StandardScaler, LabelEncoder
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from typing import List, Dict, Tuple, Optional, Union

DATA_FILE_PATH = "آمار نهایی روزانه استانی-11-02-1403 (csv).csv"
STATION_ID_COL = 'ایستگاه'
TARGET_COL = 'ماكزيمم دما'

S_YEAR_COL = 'سال شمسی'
S_MONTH_COL = 'ماه شمسی'
S_DAY_COL = 'روز شمسی'
S_JULIAN_DAY_COL = 'ژولیوسی شمسی'

WIND_DIR_COL = 'سمت باد'

COLS_TO_DROP = [
    'تبخير',
    'ميانگين فشار QFF',
    'تاریخ میلادی',
    'سال میلادی',
    'ماه میلادی',
    'روز میلادی',
    'ژولیوسی میلادی',
    'تاریخ شمسی',
    'روزها'
]

LOOKBACK_WINDOW = 30
FORECAST_HORIZON = 7

TRAIN_RATIO = 0.8
VALIDATION_RATIO = 0.15

def is_valid_jalali_date(year, month, day):
    if not (isinstance(year, (int, float)) and isinstance(month, (int, float)) and isinstance(day, (int, float))):
        return False
    if np.isnan(year) or np.isnan(month) or np.isnan(day):
        return False

    year, month, day = int(year), int(month), int(day)
    if not (1 <= month <= 12 and 1 <= day <= 31): return False
    if month <= 6 and day > 31: return False
    if 7 <= month <= 11 and day > 30: return False
    if month == 12:
        try:
            is_leap = jdatetime.date(year, 1, 1).isleap()
        except ValueError:
            return False
        if (is_leap and day > 30) or (not is_leap and day > 29):
            return False
    return True

def preprocess_data(file_path,
                    station_id_col,
                    s_year_col, s_month_col, s_day_col, s_julian_day_col,
                    wind_dir_col,
                    target_col,
                    cols_to_drop,
                    train_ratio, val_ratio,
                    lookback_window, forecast_horizon):
    try:
        na_values_list = ['***', '---', '#N/A', 'N/A', 'NULL', 'nan', 'NaN', 'None', '', ' ']
        df = pd.read_csv(file_path, low_memory=False, na_values=na_values_list)
    except FileNotFoundError:
        print(f"Error: File {file_path} not found."); return None
    print(f"Initial data loaded. Shape: {df.shape}\nInitial columns: {df.columns.tolist()}")

    actual_cols_to_drop = [col for col in cols_to_drop if col in df.columns]
    if actual_cols_to_drop:
        df = df.drop(columns=actual_cols_to_drop)
    print(f"Columns after initial drop ({len(actual_cols_to_drop)} columns): {df.columns.tolist()}")

    critical_cols_map = {
        station_id_col: str, s_year_col: float, s_month_col: float, s_day_col: float,
        s_julian_day_col: float, target_col: float
    }
    if wind_dir_col in df.columns:
        critical_cols_map[wind_dir_col] = float

    other_numeric_covariates = [
        'مينيمم دما', 'ميانگين دما', 'ماكزيمم رطوبت', 'مينيمم رطوبت',
        'ميانگين رطوبت', 'بارندگي', 'ساعات آفتابي', 'حداكثر سرعت باد'
    ]
    for col in other_numeric_covariates:
        if col in df.columns:
            critical_cols_map[col] = float

    for col, col_type in critical_cols_map.items():
        if col in df.columns:
            if col_type == str:
                df[col] = df[col].astype(str)
            else:
                df[col] = pd.to_numeric(df[col], errors='coerce')
        elif col in [target_col, s_year_col, s_month_col, s_day_col, station_id_col]:
            print(f"Error: Critical and essential column '{col}' not found. Program will stop.")
            return None
        else:
            print(f"Warning: Expected column '{col}' not found, but continuing.")

    essential_dropna_subset = [station_id_col, s_year_col, s_month_col, s_day_col, target_col]
    if s_julian_day_col in df.columns: essential_dropna_subset.append(s_julian_day_col)
    df.dropna(subset=essential_dropna_subset, inplace=True)

    if wind_dir_col in df.columns and df[wind_dir_col].isnull().any():
        print(f"NaN values in column '{wind_dir_col}' were replaced with 0 (before sin/cos).")
        df[wind_dir_col] = df[wind_dir_col].fillna(0)

    print(f"Data shape after removing initial NaNs in critical columns: {df.shape}")
    if df.empty: print("Error: No data remaining after removing initial NaNs."); return None

    def get_jalali_weekday(row):
        if pd.isna(row[s_year_col]) or pd.isna(row[s_month_col]) or pd.isna(row[s_day_col]):
            return np.nan
        if is_valid_jalali_date(row[s_year_col], row[s_month_col], row[s_day_col]):
            return jdatetime.date(int(row[s_year_col]), int(row[s_month_col]), int(row[s_day_col])).weekday()
        return np.nan
    df['derived_day_of_week'] = df.apply(get_jalali_weekday, axis=1)
    df.dropna(subset=['derived_day_of_week'], inplace=True)
    df['derived_day_of_week'] = df['derived_day_of_week'].astype(int)
    print(f"Data shape after creating and cleaning derived_day_of_week: {df.shape}")
    if df.empty: print("Error: No data remaining after processing derived_day_of_week."); return None

    if s_julian_day_col in df.columns:
        df['day_of_year_sin'] = np.sin(2 * np.pi * df[s_julian_day_col] / 365.25)
        df['day_of_year_cos'] = np.cos(2 * np.pi * df[s_julian_day_col] / 365.25)
    else:
        print(f"Warning: Column '{s_julian_day_col}' (day of year) not found.")

    if df[station_id_col].isnull().any():
        df[station_id_col] = df[station_id_col].fillna('UNKNOWN_STATION')
    station_encoder = LabelEncoder()
    encoded_station_id_col_name = f'{station_id_col}_encoded'
    df[encoded_station_id_col_name] = station_encoder.fit_transform(df[station_id_col])
    num_unique_stations = df[encoded_station_id_col_name].nunique()
    print(f"Number of unique stations encoded: {num_unique_stations}")

    if wind_dir_col in df.columns:
        df[f'{wind_dir_col}_sin'] = np.sin(2 * np.pi * df[wind_dir_col] / 360.0)
        df[f'{wind_dir_col}_cos'] = np.cos(2 * np.pi * df[wind_dir_col] / 360.0)
        if wind_dir_col in df.columns: df = df.drop(columns=[wind_dir_col])
        print(f"Column '{wind_dir_col}' was cyclically transformed.")
    else:
        print(f"Warning: Column '{wind_dir_col}' not found.")

    df['temp_jalali_int_date'] = df.apply(
        lambda row: int(f"{int(row[s_year_col]):04d}{int(row[s_month_col]):02d}{int(row[s_day_col]):02d}")
        if is_valid_jalali_date(row[s_year_col], row[s_month_col], row[s_day_col]) else np.nan, axis=1)
    df.dropna(subset=['temp_jalali_int_date'], inplace=True)
    if not df.empty:
        df['temp_jalali_int_date'] = df['temp_jalali_int_date'].astype(int)
        df = df.sort_values(by=[encoded_station_id_col_name, 'temp_jalali_int_date']).drop(columns=['temp_jalali_int_date'])
        df = df.reset_index(drop=True)
    print(f"Data shape after final sorting and date cleaning: {df.shape}")
    if df.empty: print("Error: No data remaining after date sorting."); return None

    cols_not_to_impute = [encoded_station_id_col_name, s_year_col, s_month_col, s_day_col, 'derived_day_of_week']
    if s_julian_day_col in df.columns: cols_not_to_impute.append(s_julian_day_col)

    cols_for_imputation = df.select_dtypes(include=np.number).columns.difference(cols_not_to_impute).tolist()

    for col in cols_for_imputation:
        if col in df.columns and df[col].isnull().any():
            df[col] = pd.to_numeric(df[col], errors='coerce')
            df[col] = df.groupby(encoded_station_id_col_name, group_keys=False)[col].apply(lambda x: x.interpolate(method='linear', limit_direction='both'))
            df[col] = df.groupby(encoded_station_id_col_name, group_keys=False)[col].ffill().bfill()

            current_col_median = df[col].median()
            if pd.isna(current_col_median):
                print(f"Warning: Median for column '{col}' (from overall df) could not be calculated. Filling with 0.")
                df[col] = df[col].fillna(0)
            else:
                df[col] = df[col].fillna(current_col_median)

    for col in [s_month_col, s_day_col, 'derived_day_of_week']:
        if col in df.columns:
            df[col] = pd.to_numeric(df[col], errors='coerce').fillna(0).astype(int)

    nan_counts = df.isnull().sum()
    print(f"NaN counts after all filling stages:\n{nan_counts[nan_counts > 0]}")

    train_dfs, val_dfs, test_dfs = [], [], []
    grouped = df.groupby(encoded_station_id_col_name)
    for station_code_encoded, group_df in grouped:
        n = len(group_df)
        min_len_for_a_sequence = lookback_window + forecast_horizon
        if n < min_len_for_a_sequence + 1:
            print(f"Station {station_encoder.inverse_transform([station_code_encoded])[0]} has very little data ({n} rows) and will be ignored.")
            continue
        train_end_idx = int(n * train_ratio)
        val_end_idx = train_end_idx + int(n * val_ratio)

        if not (train_end_idx >= min_len_for_a_sequence and \
            (val_end_idx - train_end_idx) >= min_len_for_a_sequence and \
            (n - val_end_idx) >= min_len_for_a_sequence):
            if n >= min_len_for_a_sequence :
                print(f"Station {station_encoder.inverse_transform([station_code_encoded])[0]} ({n} rows) was added only to training.")
                train_dfs.append(group_df)
            else:
                print(f"Station {station_encoder.inverse_transform([station_code_encoded])[0]} does not have enough data to create a sequence.")
            continue
        train_dfs.append(group_df.iloc[:train_end_idx])
        val_dfs.append(group_df.iloc[train_end_idx:val_end_idx])
        test_dfs.append(group_df.iloc[val_end_idx:])

    if not train_dfs:
        print("Error: No data left for training after splitting.")
        return None

    df_train = pd.concat(train_dfs).reset_index(drop=True)
    df_val = pd.concat(val_dfs).reset_index(drop=True) if val_dfs else pd.DataFrame(columns=df_train.columns)
    df_test = pd.concat(test_dfs).reset_index(drop=True) if test_dfs else pd.DataFrame(columns=df_train.columns)

    print(f"Training set size: {df_train.shape}, Validation: {df_val.shape}, Test: {df_test.shape}")
    if df_train.empty:
        print("Error: Training set is empty."); return None

    past_cont_cov_cols_actual = ['مينيمم دما', 'ميانگين دما', 'ماكزيمم رطوبت', 'مينيمم رطوبت',
                                'بارندگي', 'ساعات آفتابي', 'حداكثر سرعت باد']
    if f'{wind_dir_col}_sin' in df.columns:
        past_cont_cov_cols_actual.extend([f'{wind_dir_col}_sin', f'{wind_dir_col}_cos'])

    final_past_cont_cov_cols = [col for col in past_cont_cov_cols_actual if col in df_train.columns]
    print(f"Final list of past continuous covariates for scaling and sequence: {final_past_cont_cov_cols}, Count: {len(final_past_cont_cov_cols)}")

    future_cont_cov_cols_actual = []
    if 'day_of_year_sin' in df.columns: future_cont_cov_cols_actual.append('day_of_year_sin')
    if 'day_of_year_cos' in df.columns: future_cont_cov_cols_actual.append('day_of_year_cos')
    if s_year_col in df.columns: future_cont_cov_cols_actual.append(s_year_col)
    final_future_cont_cov_cols = [col for col in future_cont_cov_cols_actual if col in df_train.columns]

    final_future_cat_cov_cols = [col for col in [s_month_col, s_day_col, 'derived_day_of_week'] if col in df_train.columns]

    cols_to_scale = [target_col] + final_past_cont_cov_cols + final_future_cont_cov_cols
    cols_to_scale = sorted(list(set(col for col in cols_to_scale if col in df_train.columns)))

    scalers = {}
    print(f"Columns selected for scaling: {cols_to_scale}")
    for col in cols_to_scale:
        scaler = StandardScaler()
        col_median_train = df_train[col].median()
        if pd.isna(col_median_train): col_median_train = 0

        current_train_col_data = pd.to_numeric(df_train[col], errors='coerce').fillna(col_median_train)
        df_train[col] = scaler.fit_transform(current_train_col_data.values.reshape(-1,1))

        if col in df_val.columns and not df_val.empty:
            current_val_col_data = pd.to_numeric(df_val[col], errors='coerce').fillna(col_median_train)
            df_val[col] = scaler.transform(current_val_col_data.values.reshape(-1,1))
        if col in df_test.columns and not df_test.empty:
            current_test_col_data = pd.to_numeric(df_test[col], errors='coerce').fillna(col_median_train)
            df_test[col] = scaler.transform(current_test_col_data.values.reshape(-1,1))
        scalers[col] = scaler

    for df_set_name, df_set in zip(["Training", "Validation", "Test"], [df_train, df_val, df_test]):
        if not df_set.empty and target_col not in df_set.columns:
            print(f"Critical Error: Target column '{target_col}' not found in {df_set_name} set. Available columns: {df_set.columns.tolist()}")
            return None

    def create_sequences_fn(data, station_id_col_encoded_arg, target_col_arg,
                            future_cat_cols_arg, future_cont_cols_arg,
                            past_cont_cov_cols_arg,
                            lookback, horizon):
        sequences = {
            'x_static_cat': [], 'x_past_target': [], 'x_past_cov_cont': [],
            'x_future_known_cat': [], 'x_future_known_cont': [], 'y_target': []
        }
        num_samples_created = 0
        if data.empty: pass
        else:
            grouped_data = data.groupby(station_id_col_encoded_arg)
            for _, group in grouped_data:
                group = group.reset_index(drop=True)
                len_ts = len(group)
                if len_ts < lookback + horizon: continue

                for i in range(len_ts - lookback - horizon + 1):
                    num_samples_created +=1
                    past_end_idx = i + lookback
                    future_end_idx = past_end_idx + horizon

                    sequences['x_static_cat'].append(group.loc[i, station_id_col_encoded_arg])
                    sequences['x_past_target'].append(group.loc[i:past_end_idx-1, target_col_arg].values.astype(float).reshape(lookback, 1))

                    if past_cont_cov_cols_arg:
                        sequences['x_past_cov_cont'].append(group.loc[i:past_end_idx-1, past_cont_cov_cols_arg].values.astype(float))
                    else:
                        sequences['x_past_cov_cont'].append(np.empty((lookback, 0), dtype=float))

                    current_future_cat_vals = []
                    if future_cat_cols_arg:
                        for col in future_cat_cols_arg:
                            if col in group.columns:
                                current_future_cat_vals.append(group.loc[past_end_idx:future_end_idx-1, col].values.astype(int).reshape(horizon,1))
                            else:
                                print(f"Warning in create_sequences: Future categorical column '{col}' not found in group.")
                                current_future_cat_vals.append(np.full((horizon,1), 0, dtype=int))
                    if current_future_cat_vals:
                        sequences['x_future_known_cat'].append(np.concatenate(current_future_cat_vals, axis=1))
                    else:
                        sequences['x_future_known_cat'].append(np.empty((horizon, 0), dtype=int))

                    current_future_cont_vals = []
                    if future_cont_cols_arg:
                        for col in future_cont_cols_arg:
                            if col in group.columns:
                                current_future_cont_vals.append(group.loc[past_end_idx:future_end_idx-1, col].values.astype(float).reshape(horizon,1))
                            else:
                                print(f"Warning in create_sequences: Future continuous column '{col}' not found in group.")
                                current_future_cont_vals.append(np.zeros((horizon,1), dtype=float))
                    if current_future_cont_vals:
                        sequences['x_future_known_cont'].append(np.concatenate(current_future_cont_vals, axis=1))
                    else:
                        sequences['x_future_known_cont'].append(np.empty((horizon, 0), dtype=float))

                    sequences['y_target'].append(group.loc[past_end_idx:future_end_idx-1, target_col_arg].values.astype(float).reshape(horizon, 1))

        for key in sequences:
            dtype_expected = int if key == 'x_static_cat' or key == 'x_future_known_cat' else float
            if sequences[key]:
                try:
                    sequences[key] = np.array(sequences[key], dtype=dtype_expected)
                except ValueError as e:
                    print(f"Error converting '{key}' to NumPy with dtype {dtype_expected}: {e}. Attempting to fix...")
                    n_s = num_samples_created
                    lkbk, hrzn = lookback, horizon
                    n_pcc = len(past_cont_cov_cols_arg if past_cont_cov_cols_arg else [])
                    n_fkc = len(future_cat_cols_arg if future_cat_cols_arg else [])
                    n_fkcn = len(future_cont_cols_arg if future_cont_cols_arg else [])

                    fallback_shapes = {
                        'x_static_cat': (n_s,1), 'x_past_target': (n_s, lkbk, 1),
                        'x_past_cov_cont': (n_s, lkbk, n_pcc),
                        'x_future_known_cat': (n_s, hrzn, n_fkc),
                        'x_future_known_cont': (n_s, hrzn, n_fkcn),
                        'y_target': (n_s, hrzn, 1)
                    }
                    sequences[key] = np.empty(fallback_shapes.get(key,(0,)), dtype=dtype_expected)

            else:
                n_s = num_samples_created
                lkbk, hrzn = lookback, horizon
                n_pcc = len(past_cont_cov_cols_arg if past_cont_cov_cols_arg else [])
                n_fkc = len(future_cat_cols_arg if future_cat_cols_arg else [])
                n_fkcn = len(future_cont_cols_arg if future_cont_cols_arg else [])

                if key == 'x_static_cat': shape = (n_s,1)
                elif key == 'x_past_target': shape = (n_s, lkbk, 1)
                elif key == 'x_past_cov_cont': shape = (n_s, lkbk, n_pcc)
                elif key == 'x_future_known_cat': shape = (n_s, hrzn, n_fkc)
                elif key == 'x_future_known_cont': shape = (n_s, hrzn, n_fkcn)
                elif key == 'y_target': shape = (n_s, hrzn, 1)
                else: shape = (0,)
                sequences[key] = np.empty(shape, dtype=dtype_expected)

            if key == 'x_static_cat' and sequences[key].ndim == 1 and sequences[key].size > 0:
                sequences[key] = sequences[key].reshape(-1, 1)
        return sequences

    print("\nCreating sequences for training set...")
    train_sequences = create_sequences_fn(df_train, encoded_station_id_col_name, target_col,
                                        final_future_cat_cov_cols, final_future_cont_cov_cols,
                                        final_past_cont_cov_cols,
                                        lookback_window, forecast_horizon)
    print("Creating sequences for validation set...")
    val_sequences = create_sequences_fn(df_val, encoded_station_id_col_name, target_col,
                                        final_future_cat_cov_cols, final_future_cont_cov_cols,
                                        final_past_cont_cov_cols,
                                        lookback_window, forecast_horizon)
    print("Creating sequences for test set...")
    test_sequences = create_sequences_fn(df_test, encoded_station_id_col_name, target_col,
                                        final_future_cat_cov_cols, final_future_cont_cov_cols,
                                        final_past_cont_cov_cols,
                                        lookback_window, forecast_horizon)

    final_params = {
        'target_dim': 1,
        'static_categorical_input_dims': [num_unique_stations] if num_unique_stations > 0 and not df_train.empty and (encoded_station_id_col_name in df_train.columns and df_train[encoded_station_id_col_name].nunique() > 0) else [],
        'static_continuous_input_dim': 0,
        'obs_continuous_input_dim': len(final_past_cont_cov_cols),
        'obs_categorical_input_dims': [],
        'known_categorical_input_dims': [
            max(1, df_train[col].nunique(dropna=False)) if col in df_train.columns and not df_train.empty and df_train[col].nunique(dropna=False) > 0 else (12 if col == s_month_col else (31 if col == s_day_col else 7))
            for col in final_future_cat_cov_cols
        ] if final_future_cat_cov_cols else [],
        'known_continuous_input_dim': len(final_future_cont_cov_cols)
    }
    if not final_params['static_categorical_input_dims'] and num_unique_stations > 0 :
            final_params['static_categorical_input_dims'] = []
    if final_params['known_categorical_input_dims'] and any(c == 0 for c in final_params['known_categorical_input_dims']):
        final_params['known_categorical_input_dims'] = [max(1,c) if c is not None else 1 for c in final_params['known_categorical_input_dims']]

    print("\nCalculated parameters for TFT model:")
    for key, value in final_params.items(): print(f"  {key}: {value}")

    return train_sequences, val_sequences, test_sequences, scalers, station_encoder, final_params

if __name__ == "__main__":
    processed_output = preprocess_data(
        file_path=DATA_FILE_PATH,
        station_id_col=STATION_ID_COL,
        s_year_col=S_YEAR_COL, s_month_col=S_MONTH_COL, s_day_col=S_DAY_COL,
        s_julian_day_col=S_JULIAN_DAY_COL,
        wind_dir_col=WIND_DIR_COL, target_col=TARGET_COL,
        cols_to_drop=COLS_TO_DROP,
        train_ratio=TRAIN_RATIO, val_ratio=VALIDATION_RATIO,
        lookback_window=LOOKBACK_WINDOW, forecast_horizon=FORECAST_HORIZON
    )

    if processed_output:
        train_data, val_data, test_data, data_scalers, st_encoder, model_params_from_preprocessing = processed_output
        print("\nSample shape of training data (if created):")
        for key, arr in train_data.items():
            if isinstance(arr, np.ndarray):
                print(f"  {key}: {arr.shape}")
            else:
                print(f"  {key}: is not a numpy array (type: {type(arr)})")

EPSILON = 1e-8

class GatedLinearUnit(nn.Module):
    def __init__(self, input_size: int, output_size: int):
        super().__init__()
        self.linear1 = nn.Linear(input_size, output_size)
        self.linear2 = nn.Linear(input_size, output_size)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.sigmoid(self.linear1(x)) * self.linear2(x)

class GatedResidualNetwork(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, output_size: int,
                dropout_rate: float, context_size: Optional[int] = None,
                is_output_layer: bool = False):
        super().__init__()
        self.input_size = input_size; self.hidden_size = hidden_size; self.output_size = output_size
        self.context_size = context_size; self.dropout_rate = dropout_rate; self.is_output_layer = is_output_layer
        self.W2_a = nn.Linear(self.input_size, self.hidden_size)
        self.W3_c = nn.Linear(self.context_size, self.hidden_size, bias=False) if self.context_size is not None else None
        self.elu = nn.ELU(); self.dropout = nn.Dropout(self.dropout_rate)
        self.W1_eta2 = nn.Linear(self.hidden_size, self.hidden_size)
        self.glu_or_linear = nn.Linear(self.hidden_size, self.output_size) if self.is_output_layer else GatedLinearUnit(self.hidden_size, self.output_size)
        self.skip_connection_projector = nn.Linear(self.input_size, self.output_size) if self.input_size != self.output_size else None
        self.layer_norm = nn.LayerNorm(self.output_size)
    def forward(self, a: torch.Tensor, c: Optional[torch.Tensor] = None) -> torch.Tensor:
        eta_2_intermediate = self.W2_a(a)
        if c is not None and self.W3_c is not None: eta_2_intermediate += self.W3_c(c)
        eta_2 = self.elu(eta_2_intermediate)
        eta_2_dropout = self.dropout(eta_2); eta_1 = self.W1_eta2(eta_2_dropout)
        gated_output = self.glu_or_linear(eta_1); gated_output_dropout = self.dropout(gated_output)
        skip_a = self.skip_connection_projector(a) if self.skip_connection_projector is not None else a
        return self.layer_norm(skip_a + gated_output_dropout)

class VariableSelectionNetwork(nn.Module):
    def __init__(self, num_total_inputs: int, hidden_size: int, dropout_rate: float,
                context_size: Optional[int] = None, is_static: bool = False):
        super().__init__()
        if num_total_inputs <= 0: raise ValueError("num_total_inputs for VSN must be positive.")
        self.num_total_inputs = num_total_inputs; self.hidden_size = hidden_size; self.dropout_rate = dropout_rate
        self.context_size = context_size; self.is_static = is_static
        self.per_variable_grns = nn.ModuleList([
            GatedResidualNetwork(hidden_size, hidden_size, hidden_size, dropout_rate) for _ in range(num_total_inputs)])
        flat_input_dim_for_weights_grn = num_total_inputs * hidden_size
        grn_context_size_for_weights = self.context_size if not self.is_static and self.context_size is not None else None
        self.weight_calculator_grn = GatedResidualNetwork(
            flat_input_dim_for_weights_grn, hidden_size, num_total_inputs, dropout_rate,
            context_size=grn_context_size_for_weights, is_output_layer=False)
        self.softmax = nn.Softmax(dim=-1); self.attention_weights: Optional[torch.Tensor] = None
    def forward(self, all_variables_projected: torch.Tensor,
                context_for_weights: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size, seq_len, num_vars_check, var_hidden_size = all_variables_projected.shape
        if num_vars_check != self.num_total_inputs:
            raise ValueError(f"VSN expected {self.num_total_inputs} vars, got {num_vars_check} (shape: {all_variables_projected.shape})")
        processed_variables_list = [
            self.per_variable_grns[i](all_variables_projected[:, :, i, :].reshape(batch_size * seq_len, var_hidden_size)).reshape(batch_size, seq_len, 1, var_hidden_size)
            for i in range(self.num_total_inputs)]
        processed_variables_stacked = torch.cat(processed_variables_list, dim=2)
        flat_original_vars_for_weights = all_variables_projected.reshape(batch_size * seq_len, self.num_total_inputs * var_hidden_size)
        expanded_context = None
        if context_for_weights is not None and not self.is_static:
            expanded_context = context_for_weights.unsqueeze(1).expand(-1, seq_len, -1).reshape(batch_size * seq_len, -1)
        variable_weights_grn_out = self.weight_calculator_grn(flat_original_vars_for_weights, expanded_context)
        variable_weights = self.softmax(variable_weights_grn_out)
        variable_weights_reshaped = variable_weights.reshape(batch_size, seq_len, self.num_total_inputs)
        self.attention_weights = variable_weights_reshaped
        weighted_sum_of_processed_vars = torch.sum(
            processed_variables_stacked * variable_weights_reshaped.unsqueeze(-1), dim=2)
        return weighted_sum_of_processed_vars, self.attention_weights

class InterpretableMultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, dropout_rate: float):
        super().__init__()
        if d_model % num_heads != 0: raise ValueError("d_model must be divisible by num_heads")
        self.d_model = d_model; self.num_heads = num_heads; self.d_k = d_model // num_heads
        self.d_v_shared = d_model // num_heads
        self.W_q_list = nn.ModuleList([nn.Linear(d_model, self.d_k, bias=False) for _ in range(num_heads)])
        self.W_k_list = nn.ModuleList([nn.Linear(d_model, self.d_k, bias=False) for _ in range(num_heads)])
        self.W_v_shared = nn.Linear(d_model, self.d_v_shared, bias=False)
        self.W_out = nn.Linear(self.d_v_shared, d_model, bias=False)
        self.dropout = nn.Dropout(dropout_rate); self.scale_factor = self.d_k**-0.5
        self.attention_scores: Optional[torch.Tensor] = None
    def forward(self, q_input: torch.Tensor, k_input: torch.Tensor, v_input: torch.Tensor,
                mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        shared_v_projected = self.W_v_shared(v_input)
        head_outputs_list = []; head_attentions_list = []
        for i in range(self.num_heads):
            Q_i = self.W_q_list[i](q_input); K_i = self.W_k_list[i](k_input)
            attn_scores_i = torch.matmul(Q_i, K_i.transpose(-2, -1)) * self.scale_factor
            if mask is not None:
                final_mask_for_head = mask
                if mask.ndim == 4:
                    final_mask_for_head = mask[:, i, :, :]
                attn_scores_i = attn_scores_i.masked_fill(final_mask_for_head == 0, -float('inf'))

            attn_probs_i = F.softmax(attn_scores_i, dim=-1); attn_probs_i = self.dropout(attn_probs_i)
            head_attentions_list.append(attn_probs_i.unsqueeze(1))
            output_i = torch.matmul(attn_probs_i, shared_v_projected)
            head_outputs_list.append(output_i)
        self.attention_scores = torch.cat(head_attentions_list, dim=1)
        summed_heads_output = torch.stack(head_outputs_list, dim=0).sum(dim=0)
        return self.W_out(summed_heads_output)

class TemporalFusionTransformer(nn.Module):
    def __init__(self,
                target_dim: int, static_categorical_input_dims: List[int], static_continuous_input_dim: int,
                obs_categorical_input_dims: List[int], obs_continuous_input_dim: int,
                known_categorical_input_dims: List[int], known_continuous_input_dim: int,
                d_model: int, num_attention_heads: int, lstm_hidden_layers: int,
                dropout_rate: float, lookback_window_size: int, forecast_horizon_size: int,
                output_quantiles: List[float]):
        super().__init__()
        self.d_model = d_model; self.num_attention_heads = num_attention_heads
        self.lstm_hidden_layers = lstm_hidden_layers; self.dropout_rate = dropout_rate
        self.lookback_window_size = lookback_window_size; self.forecast_horizon_size = forecast_horizon_size
        self.output_quantiles = output_quantiles; self.num_quantiles = len(output_quantiles)

        self.static_cat_embed_layers = nn.ModuleList([nn.Embedding(card, d_model) for card in static_categorical_input_dims if card > 0])
        self.static_cont_linear_layer = nn.Linear(static_continuous_input_dim, d_model) if static_continuous_input_dim > 0 else None
        self.past_target_projection = nn.Linear(target_dim, d_model)
        self.past_cat_cov_embed_layers = nn.ModuleList([nn.Embedding(card, d_model) for card in obs_categorical_input_dims if card > 0])
        self.past_cont_cov_projection_layers = nn.ModuleList([nn.Linear(1, d_model) for _ in range(obs_continuous_input_dim)])
        self.future_cat_cov_embed_layers = nn.ModuleList([nn.Embedding(card, d_model) for card in known_categorical_input_dims if card > 0])
        self.future_cont_cov_projection_layers = nn.ModuleList([nn.Linear(1, d_model) for _ in range(known_continuous_input_dim)])

        self.num_static_vars_for_vsn = len(self.static_cat_embed_layers) + (1 if self.static_cont_linear_layer is not None else 0)
        self.num_past_vars_for_vsn = 1 + len(self.past_cat_cov_embed_layers) + len(self.past_cont_cov_projection_layers)
        self.num_future_vars_for_vsn = len(self.future_cat_cov_embed_layers) + len(self.future_cont_cov_projection_layers)

        self.static_vsn = VariableSelectionNetwork(self.num_static_vars_for_vsn, d_model, dropout_rate, is_static=True) if self.num_static_vars_for_vsn > 0 else None
        if self.num_past_vars_for_vsn <=0: raise ValueError("Past VSN must have inputs.")
        self.past_inputs_vsn = VariableSelectionNetwork(self.num_past_vars_for_vsn, d_model, dropout_rate, context_size=d_model)
        self.future_inputs_vsn = VariableSelectionNetwork(self.num_future_vars_for_vsn, d_model, dropout_rate, context_size=d_model) if self.num_future_vars_for_vsn > 0 else None

        self.grn_c_s = GatedResidualNetwork(d_model,d_model,d_model,dropout_rate); self.grn_c_c = GatedResidualNetwork(d_model,d_model,d_model,dropout_rate)
        self.grn_c_h = GatedResidualNetwork(d_model,d_model,d_model,dropout_rate); self.grn_c_e = GatedResidualNetwork(d_model,d_model,d_model,dropout_rate)
        self.past_lstm = nn.LSTM(d_model, d_model, lstm_hidden_layers, batch_first=True, dropout=dropout_rate if lstm_hidden_layers > 1 else 0)
        self.future_lstm = nn.LSTM(d_model, d_model, lstm_hidden_layers, batch_first=True, dropout=dropout_rate if lstm_hidden_layers > 1 else 0) if self.num_future_vars_for_vsn > 0 else None
        self.locality_enhancement_glu = GatedLinearUnit(d_model,d_model); self.locality_enhancement_norm = nn.LayerNorm(d_model)
        self.static_enrichment_grn = GatedResidualNetwork(d_model,d_model,d_model,dropout_rate,context_size=d_model)
        self.multihead_attention = InterpretableMultiHeadAttention(d_model,num_attention_heads,dropout_rate)
        self.attention_gated_skip = GatedLinearUnit(d_model,d_model); self.attention_norm = nn.LayerNorm(d_model)
        self.position_wise_ff_grn = GatedResidualNetwork(d_model,d_model,d_model,dropout_rate)
        self.decoder_block_glu = GatedLinearUnit(d_model,d_model); self.decoder_block_norm = nn.LayerNorm(d_model)
        self.output_projection = nn.Linear(d_model, self.num_quantiles)
        self.static_vsn_weights: Optional[torch.Tensor]=None; self.past_vsn_weights: Optional[torch.Tensor]=None
        self.future_vsn_weights: Optional[torch.Tensor]=None; self.attention_matrices: Optional[torch.Tensor]=None

    def _project_and_stack_inputs_for_vsn(self,
                                        target_tensor: Optional[torch.Tensor], cat_cov_tensor: Optional[torch.Tensor],
                                        cont_cov_tensor: Optional[torch.Tensor], target_proj_layer: Optional[nn.Linear],
                                        cat_embed_layers_list: nn.ModuleList, cont_proj_layers_list: nn.ModuleList,
                                        batch_size: int, seq_len: int,
                                        is_static: bool = False
                                        ) -> Optional[torch.Tensor]:
        projected_vars_list = []
        if target_tensor is not None and target_proj_layer is not None:
            projected_vars_list.append(target_proj_layer(target_tensor).unsqueeze(2))
        if cat_cov_tensor is not None and cat_embed_layers_list:
            for i, embed_layer in enumerate(cat_embed_layers_list):
                current_cat_input = cat_cov_tensor[:, i] if is_static else cat_cov_tensor[:, :, i]
                embedded_var = embed_layer(current_cat_input)
                if is_static: embedded_var = embedded_var.unsqueeze(1)
                projected_vars_list.append(embedded_var.unsqueeze(2))
        if cont_cov_tensor is not None and cont_proj_layers_list:
            if is_static and len(cont_proj_layers_list) == 1 and self.static_cont_linear_layer is not None :
                    projected_vars_list.append(self.static_cont_linear_layer(cont_cov_tensor).unsqueeze(1).unsqueeze(2))
            else:
                for i, proj_layer in enumerate(cont_proj_layers_list):
                    current_cont_input = cont_cov_tensor[:, i:i+1] if is_static else cont_cov_tensor[:, :, i:i+1]
                    projected_var = proj_layer(current_cont_input)
                    if is_static: projected_var = projected_var.unsqueeze(1)
                    projected_vars_list.append(projected_var.unsqueeze(2))
        if not projected_vars_list: return None
        return torch.cat(projected_vars_list, dim=2)

    def forward(self,
                s_cat: Optional[torch.Tensor], s_cont: Optional[torch.Tensor],
                p_target: torch.Tensor, p_cat_cov: Optional[torch.Tensor], p_cont_cov: Optional[torch.Tensor],
                f_cat_cov: Optional[torch.Tensor], f_cont_cov: Optional[torch.Tensor]) -> torch.Tensor:
        batch_size, k_lookback, _ = p_target.shape; h_forecast = self.forecast_horizon_size; device = p_target.device
        selected_static_context: torch.Tensor
        if self.static_vsn is not None:
            all_static_vars_projected = self._project_and_stack_inputs_for_vsn(
                None, s_cat, s_cont, None, self.static_cat_embed_layers,
                nn.ModuleList([self.static_cont_linear_layer]) if self.static_cont_linear_layer else nn.ModuleList(),
                batch_size=batch_size, seq_len=1, is_static=True)
            if all_static_vars_projected is None: raise ValueError("Static VSN error.")
            selected_static_features_context, static_vsn_w = self.static_vsn(all_static_vars_projected)
            self.static_vsn_weights = static_vsn_w.squeeze(1); selected_static_context = selected_static_features_context.squeeze(1)
        else: selected_static_context = torch.zeros((batch_size, self.d_model), device=device)
        c_s=self.grn_c_s(selected_static_context); c_c=self.grn_c_c(selected_static_context); c_h=self.grn_c_h(selected_static_context); c_e=self.grn_c_e(selected_static_context)
        all_past_vars_projected = self._project_and_stack_inputs_for_vsn(
            p_target, p_cat_cov, p_cont_cov, self.past_target_projection,
            self.past_cat_cov_embed_layers, self.past_cont_cov_projection_layers,
            batch_size=batch_size, seq_len=k_lookback, is_static=False)
        if all_past_vars_projected is None: raise ValueError("Past VSN error.")
        selected_past_features, past_vsn_w = self.past_inputs_vsn(all_past_vars_projected, c_s)
        self.past_vsn_weights = past_vsn_w
        selected_future_features: torch.Tensor
        if self.future_inputs_vsn is not None:
            all_future_vars_projected = self._project_and_stack_inputs_for_vsn(
                None, f_cat_cov, f_cont_cov, None, self.future_cat_cov_embed_layers,
                self.future_cont_cov_projection_layers, batch_size=batch_size, seq_len=h_forecast, is_static=False)
            if all_future_vars_projected is None: raise ValueError("Future VSN error.")
            selected_future_features, future_vsn_w = self.future_inputs_vsn(all_future_vars_projected, c_s)
            self.future_vsn_weights = future_vsn_w
        else: selected_future_features = torch.zeros((batch_size, h_forecast, self.d_model), device=device)
        h_0_past = c_h.unsqueeze(0).repeat(self.lstm_hidden_layers,1,1); c_0_past = c_c.unsqueeze(0).repeat(self.lstm_hidden_layers,1,1)
        past_lstm_out, (h_n_past, c_n_past) = self.past_lstm(selected_past_features, (h_0_past, c_0_past))
        future_lstm_out: torch.Tensor
        if self.future_lstm is not None: future_lstm_out, _ = self.future_lstm(selected_future_features, (h_n_past, c_n_past))
        else: future_lstm_out = torch.zeros_like(selected_future_features)
        phi_past = self.locality_enhancement_norm(selected_past_features + self.locality_enhancement_glu(past_lstm_out))
        phi_future = self.locality_enhancement_norm(selected_future_features + self.locality_enhancement_glu(future_lstm_out))
        phi_combined = torch.cat([phi_past, phi_future], dim=1)
        expanded_c_e = c_e.unsqueeze(1).expand(-1, k_lookback + h_forecast, -1)
        theta = self.static_enrichment_grn(phi_combined, expanded_c_e)
        total_seq_len = k_lookback + h_forecast
        attention_mask_tril = torch.tril(torch.ones(total_seq_len, total_seq_len, device=device, dtype=torch.bool))
        decoder_mask = attention_mask_tril.unsqueeze(0)
        beta_intermediate = self.multihead_attention(theta, theta, theta, mask=decoder_mask)
        self.attention_matrices = self.multihead_attention.attention_scores
        beta = self.attention_norm(theta + self.attention_gated_skip(beta_intermediate))
        psi_tilde_intermediate = self.position_wise_ff_grn(beta)
        psi_tilde_final = self.decoder_block_norm(phi_combined + self.decoder_block_glu(psi_tilde_intermediate))
        psi_tilde_final_future = psi_tilde_final[:, k_lookback:, :]
        return self.output_projection(psi_tilde_final_future)

def quantile_loss(predictions: torch.Tensor, targets: torch.Tensor, quantiles: List[float]) -> torch.Tensor:
    if targets.ndim == 2: targets = targets.unsqueeze(-1)
    targets_expanded = targets.expand_as(predictions)
    errors = targets_expanded - predictions
    loss_sum = torch.tensor(0.0, device=predictions.device)
    for i, q_val in enumerate(quantiles):
        q = torch.tensor(q_val, device=predictions.device)
        loss_q_per_element = torch.max((q - 1) * errors[..., i], q * errors[..., i])
        loss_sum += loss_q_per_element.mean()
    return loss_sum / len(quantiles)

def calculate_rmse(predictions: torch.Tensor, targets: torch.Tensor, scaler: Optional[StandardScaler] = None) -> float:
    if predictions.shape != targets.shape:
        if targets.ndim == predictions.ndim + 1 and targets.shape[-1] == 1:
            targets = targets.squeeze(-1)
        if predictions.shape != targets.shape:
                raise ValueError(f"Shape mismatch: preds {predictions.shape}, targets {targets.shape}")

    preds_np = predictions.detach().cpu().numpy()
    targets_np = targets.detach().cpu().numpy()

    if scaler:
        preds_np = scaler.inverse_transform(preds_np.reshape(-1,1)).reshape(preds_np.shape)
        targets_np = scaler.inverse_transform(targets_np.reshape(-1,1)).reshape(targets_np.shape)

    mse = np.mean((preds_np - targets_np)**2)
    rmse = np.sqrt(mse)
    return float(rmse)

def create_dataloader_from_sequences(sequences_dict: Dict[str, np.ndarray],
                                    batch_size: int,
                                    expected_params: Dict,
                                    shuffle: bool = False):
    tensor_keys_in_order_of_fwd_args = [
        ('x_static_cat', 'static_categorical_input_dims', torch.long),
        ('s_cont_placeholder', 'static_continuous_input_dim', torch.float),
        ('x_past_target', 'target_dim', torch.float),
        ('x_past_cov_cat', 'obs_categorical_input_dims', torch.long),
        ('x_past_cov_cont', 'obs_continuous_input_dim', torch.float),
        ('x_future_known_cat', 'known_categorical_input_dims', torch.long),
        ('x_future_known_cont', 'known_continuous_input_dim', torch.float),
        ('y_target', 'target_dim', torch.float)
    ]

    num_samples = -1
    for key, _, _ in tensor_keys_in_order_of_fwd_args:
        if sequences_dict.get(key) is not None and sequences_dict[key].size > 0:
            num_samples = sequences_dict[key].shape[0]
            break
    if num_samples == -1:
        print("Warning: No data found to create DataLoader.")
        return None, []

    tensors_for_model_forward = {}

    s_cat_data = sequences_dict.get('x_static_cat')
    if s_cat_data is not None and s_cat_data.size > 0:
        tensors_for_model_forward['s_cat'] = torch.from_numpy(s_cat_data).long()
    else:
        tensors_for_model_forward['s_cat'] = None

    tensors_for_model_forward['s_cont'] = None

    p_target_data = sequences_dict.get('x_past_target')
    if p_target_data is not None and p_target_data.size > 0:
        tensors_for_model_forward['p_target'] = torch.from_numpy(p_target_data).float()
    else: raise ValueError("x_past_target (p_target) must not be empty.")

    tensors_for_model_forward['p_cat_cov'] = None

    p_cont_cov_data = sequences_dict.get('x_past_cov_cont')
    if p_cont_cov_data is not None and p_cont_cov_data.size > 0 and p_cont_cov_data.shape[-1] > 0:
        tensors_for_model_forward['p_cont_cov'] = torch.from_numpy(p_cont_cov_data).float()
    else:
        tensors_for_model_forward['p_cont_cov'] = None

    f_cat_cov_data = sequences_dict.get('x_future_known_cat')
    if f_cat_cov_data is not None and f_cat_cov_data.size > 0 and f_cat_cov_data.shape[-1] > 0:
        tensors_for_model_forward['f_cat_cov'] = torch.from_numpy(f_cat_cov_data).long()
    else:
        tensors_for_model_forward['f_cat_cov'] = None

    f_cont_cov_data = sequences_dict.get('x_future_known_cont')
    if f_cont_cov_data is not None and f_cont_cov_data.size > 0 and f_cont_cov_data.shape[-1] > 0:
        tensors_for_model_forward['f_cont_cov'] = torch.from_numpy(f_cont_cov_data).float()
    else:
        tensors_for_model_forward['f_cont_cov'] = None

    y_target_data = sequences_dict.get('y_target')
    if y_target_data is not None and y_target_data.size > 0:
        y_target_tensor = torch.from_numpy(y_target_data).float()
    else: raise ValueError("y_target must not be empty.")

    dataset_tensors = []

    if tensors_for_model_forward['s_cat'] is not None: dataset_tensors.append(tensors_for_model_forward['s_cat'])
    else: dataset_tensors.append(torch.empty(num_samples, 0, dtype=torch.long))

    if tensors_for_model_forward['s_cont'] is not None: dataset_tensors.append(tensors_for_model_forward['s_cont'])
    else: dataset_tensors.append(torch.empty(num_samples, 0, dtype=torch.float))

    dataset_tensors.append(tensors_for_model_forward['p_target'])

    if tensors_for_model_forward['p_cat_cov'] is not None: dataset_tensors.append(tensors_for_model_forward['p_cat_cov'])
    else: dataset_tensors.append(torch.empty(num_samples, _LW, 0, dtype=torch.long))

    if tensors_for_model_forward['p_cont_cov'] is not None: dataset_tensors.append(tensors_for_model_forward['p_cont_cov'])
    else: dataset_tensors.append(torch.empty(num_samples, _LW, 0, dtype=torch.float))

    if tensors_for_model_forward['f_cat_cov'] is not None: dataset_tensors.append(tensors_for_model_forward['f_cat_cov'])
    else: dataset_tensors.append(torch.empty(num_samples, _FH, 0, dtype=torch.long))

    if tensors_for_model_forward['f_cont_cov'] is not None: dataset_tensors.append(tensors_for_model_forward['f_cont_cov'])
    else: dataset_tensors.append(torch.empty(num_samples, _FH, 0, dtype=torch.float))

    dataset_tensors.append(y_target_tensor)

    if not any(t.shape[0] == num_samples for t in dataset_tensors if t is not None and hasattr(t, 'shape')):
        print(f"Error: Sample count mismatch or no valid tensor. Num_samples: {num_samples}")
        return None

    dataset = TensorDataset(*dataset_tensors)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
    return dataloader

if __name__ == '__main__':
    _LW = LOOKBACK_WINDOW
    _FH = FORECAST_HORIZON

    model_params = {
        'target_dim': 1, 'static_categorical_input_dims': [13], 'static_continuous_input_dim': 0,
        'obs_categorical_input_dims': [], 'obs_continuous_input_dim': 9,
        'known_categorical_input_dims': [12, 31, 7], 'known_continuous_input_dim': 3,
    }
    num_train_samples = 73761; num_val_samples = 10000

    def create_dummy_sequences(num_samples, params, lw, fh):
        seq = {}
        if params['static_categorical_input_dims']:
            seq['x_static_cat'] = np.random.randint(0, params['static_categorical_input_dims'][0], (num_samples, 1)).astype(np.int64)
        seq['x_past_target'] = np.random.randn(num_samples, lw, params['target_dim']).astype(np.float32)
        if params['obs_continuous_input_dim'] > 0:
            seq['x_past_cov_cont'] = np.random.randn(num_samples, lw, params['obs_continuous_input_dim']).astype(np.float32)
        else:
            seq['x_past_cov_cont'] = np.empty((num_samples, lw, 0), dtype=np.float32)

        if params['known_categorical_input_dims']:
            fcc_list = []
            for card in params['known_categorical_input_dims']:
                fcc_list.append(np.random.randint(0, card, (num_samples, fh, 1)).astype(np.int64))
            seq['x_future_known_cat'] = np.concatenate(fcc_list, axis=2)
        else:
            seq['x_future_known_cat'] = np.empty((num_samples, fh, 0), dtype=np.int64)

        if params['known_continuous_input_dim'] > 0:
            seq['x_future_known_cont'] = np.random.randn(num_samples, fh, params['known_continuous_input_dim']).astype(np.float32)
        else:
            seq['x_future_known_cont'] = np.empty((num_samples, fh, 0), dtype=np.float32)

        seq['y_target'] = np.random.randn(num_samples, fh, params['target_dim']).astype(np.float32)
        return seq

    train_sequences = create_dummy_sequences(num_train_samples, model_params, _LW, _FH)
    val_sequences = create_dummy_sequences(num_val_samples, model_params, _LW, _FH)

    dummy_target_data_for_scaler = np.random.randn(100,1)
    target_scaler = StandardScaler().fit(dummy_target_data_for_scaler)

    d_model_hyperparam = 64
    num_heads_hyperparam = 4
    lstm_layers_hyperparam = 1
    dropout_rate_hyperparam = 0.1
    output_quantiles_param = [0.1, 0.5, 0.9]

    print("--- Starting TFT model instantiation ---")
    tft_model_instance = TemporalFusionTransformer(
        target_dim=model_params['target_dim'],
        static_categorical_input_dims=model_params['static_categorical_input_dims'],
        static_continuous_input_dim=model_params['static_continuous_input_dim'],
        obs_categorical_input_dims=model_params['obs_categorical_input_dims'],
        obs_continuous_input_dim=model_params['obs_continuous_input_dim'],
        known_categorical_input_dims=model_params['known_categorical_input_dims'],
        known_continuous_input_dim=model_params['known_continuous_input_dim'],
        d_model=d_model_hyperparam, num_attention_heads=num_heads_hyperparam,
        lstm_hidden_layers=lstm_layers_hyperparam, dropout_rate=dropout_rate_hyperparam,
        lookback_window_size=_LW, forecast_horizon_size=_FH,
        output_quantiles=output_quantiles_param
    )
    print("\nTFT model instance with your parameters created successfully.")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tft_model_instance.to(device)

    batch_size_training = 256
    print("\nCreating DataLoader for training set...")
    train_dataloader = create_dataloader_from_sequences(train_sequences, batch_size_training, model_params, shuffle=True)
    print("Creating DataLoader for validation set...")
    val_dataloader = create_dataloader_from_sequences(val_sequences, batch_size_training, model_params, shuffle=False)

    if train_dataloader is None or val_dataloader is None:
        print("Error: Training or validation DataLoader not created. Program will stop.")
        exit()

    num_epochs = 10
    learning_rate = 0.0001
    optimizer = optim.Adam(tft_model_instance.parameters(), lr=learning_rate)

    median_quantile_idx = -1
    if 0.5 in output_quantiles_param:
            median_quantile_idx = output_quantiles_param.index(0.5)
    else:
        print("Warning: Quantile 0.5 for RMSE calculation is not in the output quantiles list.")

    print(f"\n--- Starting training for {num_epochs} epochs on device {device} ---")
    best_val_rmse = float('inf')

    for epoch in range(num_epochs):
        tft_model_instance.train()
        total_train_loss = 0
        for batch_idx, batch_tensors_list in enumerate(train_dataloader):
            s_cat_b, s_cont_b, p_target_b, p_cat_cov_b, p_cont_cov_b, \
            f_cat_cov_b, f_cont_cov_b, y_true_b = [t.to(device) if t is not None and t.numel() > 0 else None for t in batch_tensors_list]

            optimizer.zero_grad()
            predictions = tft_model_instance(
                s_cat=s_cat_b if model_params['static_categorical_input_dims'] else None,
                s_cont=s_cont_b if model_params['static_continuous_input_dim'] > 0 else None,
                p_target=p_target_b,
                p_cat_cov=p_cat_cov_b if model_params['obs_categorical_input_dims'] else None,
                p_cont_cov=p_cont_cov_b if model_params['obs_continuous_input_dim'] > 0 else None,
                f_cat_cov=f_cat_cov_b if model_params['known_categorical_input_dims'] else None,
                f_cont_cov=f_cont_cov_b if model_params['known_continuous_input_dim'] > 0 else None
            )

            loss = quantile_loss(predictions, y_true_b, output_quantiles_param)
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()

            if batch_idx > 0 and batch_idx % (len(train_dataloader)//5) == 0:
                print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(train_dataloader)}, Train Loss: {loss.item():.4f}")

        avg_train_loss = total_train_loss / len(train_dataloader)
        print(f"Epoch {epoch+1} >> Average training loss: {avg_train_loss:.4f}")

        if val_dataloader and median_quantile_idx != -1:
            tft_model_instance.eval()
            total_val_loss = 0
            all_val_preds_median_list = []
            all_val_targets_list = []
            with torch.no_grad():
                for batch_tensors_val in val_dataloader:
                    s_cat_v, s_cont_v, p_target_v, p_cat_cov_v, p_cont_cov_v, \
                    f_cat_cov_v, f_cont_cov_v, y_true_v = [t.to(device) if t is not None and t.numel() > 0 else None for t in batch_tensors_val]

                    val_predictions = tft_model_instance(
                        s_cat_v if model_params['static_categorical_input_dims'] else None,
                        s_cont_v if model_params['static_continuous_input_dim'] > 0 else None,
                        p_target_v,
                        p_cat_cov_v if model_params['obs_categorical_input_dims'] else None,
                        p_cont_cov_v if model_params['obs_continuous_input_dim'] > 0 else None,
                        f_cat_cov_v if model_params['known_categorical_input_dims'] else None,
                        f_cont_cov_v if model_params['known_continuous_input_dim'] > 0 else None
                    )
                    val_loss = quantile_loss(val_predictions, y_true_v, output_quantiles_param)
                    total_val_loss += val_loss.item()

                    all_val_preds_median_list.append(val_predictions[:, :, median_quantile_idx].cpu())
                    all_val_targets_list.append(y_true_v.cpu())

            avg_val_loss = total_val_loss / len(val_dataloader)
            val_preds_median_all = torch.cat(all_val_preds_median_list, dim=0)
            val_targets_all = torch.cat(all_val_targets_list, dim=0)

            current_val_rmse = calculate_rmse(val_preds_median_all, val_targets_all, scaler=target_scaler if 'target_scaler' in locals() else None)

            print(f"Epoch {epoch+1} >> Average validation loss: {avg_val_loss:.4f}, Validation RMSE: {current_val_rmse:.4f}")

            if current_val_rmse < best_val_rmse:
                best_val_rmse = current_val_rmse
                # torch.save(tft_model_instance.state_dict(), "best_tft_model_SAVES.pth")
                print(f"*** Better model with RMSE {best_val_rmse:.4f} saved at epoch {epoch+1} (hypothetical). ***")

    print("--- Training finished ---")

Initial data loaded. Shape: (105979, 24)
Initial columns: ['ایستگاه', 'روزها', 'ژولیوسی شمسی', 'تاریخ شمسی', 'سال شمسی', 'ماه شمسی', 'روز شمسی', 'ژولیوسی میلادی', 'تاریخ میلادی', 'سال میلادی', 'ماه میلادی', 'روز میلادی', 'ماكزيمم دما', 'مينيمم دما', 'ميانگين دما', 'ماكزيمم رطوبت', 'مينيمم رطوبت', 'ميانگين رطوبت', 'بارندگي', 'تبخير', 'ساعات آفتابي', 'سمت باد', 'حداكثر سرعت باد', 'ميانگين فشار QFF']
Columns after initial drop (9 columns): ['ایستگاه', 'ژولیوسی شمسی', 'سال شمسی', 'ماه شمسی', 'روز شمسی', 'ماكزيمم دما', 'مينيمم دما', 'ميانگين دما', 'ماكزيمم رطوبت', 'مينيمم رطوبت', 'ميانگين رطوبت', 'بارندگي', 'ساعات آفتابي', 'سمت باد', 'حداكثر سرعت باد']
NaN values in column 'سمت باد' were replaced with 0 (before sin/cos).
Data shape after removing initial NaNs in critical columns: (105979, 15)
Data shape after creating and cleaning derived_day_of_week: (105975, 16)
Number of unique stations encoded: 13
Column 'سمت باد' was cyclically transformed.
Data shape after final sorting and date clean