In [None]:
import sys
sys.path.append("/content/PatchTST/PatchTST_supervised")

import os
import os.path as osp
import uuid
import random
import argparse
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import matplotlib.pyplot as plt
import psutil
from typing import Optional, Tuple, Dict, List
from models.PatchTST import Model as PatchTST
from datetime import datetime

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

class Config:
    """
    Configuration holder for time series forecasting models.

    This class initializes and manages model configuration parameters,
    allowing overrides via keyword arguments. It also determines
    the computation device (CPU or CUDA).

    Attributes:
        seq_len (int): Length of the input sequence.
        pred_len (int): Length of the prediction sequence.
        enc_in (int): Number of input features.
        device (torch.device): Computation device, automatically set to CUDA if available.

    Methods:
        default_params():
            Returns a dictionary of default parameters.
            Must be implemented by subclasses.
    """
    def __init__(self, seq_len, pred_len, enc_in, **kwargs):
        self.seq_len = seq_len
        self.pred_len = pred_len
        self.enc_in = enc_in

        defaults = self.default_params()
        for k, v in defaults.items():
            setattr(self, k, kwargs.get(k, v))

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def default_params(self):
        raise NotImplementedError

class HourlyConfig(Config):
    """
    Configuration for hourly time series forecasting tasks.

    This subclass of `Config` defines default hyperparameters and model settings
    specifically optimized for hourly-level data input.

    Default parameters include settings for:
    - Transformer model architecture (layers, heads, dimensions)
    - Data preprocessing and input slicing (patch length, stride)
    - Training parameters (batch size, learning rate, optimizer, scheduler)
    - Regularization and dropout rates
    - Optional techniques (Reversible Instance Normalization, decomposition, etc.)

    Attributes inherited from Config:
        seq_len (int): Length of the input sequence.
        pred_len (int): Length of the prediction sequence.
        enc_in (int): Number of input features.
        device (torch.device): Automatically determined computing device.
    """
    def default_params(self):
        """
        Returns a dictionary of default configuration parameters for hourly data.

        Returns:
            dict: Mapping of parameter names to their default values. Includes model
            architecture, training setup, and normalization strategies.
        """
        return {
            'e_layers': 4, 'd_model': 256, 'n_heads': 4, 'd_ff': 128,
            'patch_len': min(16, self.seq_len), 'stride': min(16, self.seq_len),
            'dropout': 0.1, 'attn_dropout': 0.1, 'fc_dropout': 0.1, 'head_dropout': 0.0,
            'kernel_size': 25, 'padding_patch': 'end',
            'individual': False, 'head_type': 'prediction', 'revin': True, 'affine': True,
            'subtract_last': False, 'decomposition': False,
            'batch_size': 16, 'epochs': 4, 'lr': 1e-3, 'weight_decay': 1e-2,
            'optimizer': 'adamw', 'scheduler': 'cosine', 'T_0': 5, 'T_mult': 2,
            'grad_clip': 0.1, 'patience': 10, 'warmup_epochs': 0
        }

class DailyConfig(Config):
    """
    Configuration for daily time series forecasting tasks.

    This subclass of `Config` provides default settings tailored for
    daily-level data frequency. It includes model architecture, training
    parameters, and data processing options.

    Default parameters emphasize:
    - Moderate sequence segmentation (e.g., 7-day patch)
    - Standard transformer architecture
    - Conservative dropout settings
    - Early stopping and learning rate warmup

    Attributes inherited from Config:
        seq_len (int): Input sequence length.
        pred_len (int): Prediction horizon.
        enc_in (int): Number of input features.
        device (torch.device): CUDA if available, else CPU.
    """
    def default_params(self):
        """
        Returns a dictionary of default configuration values specific to daily data.

        Returns:
            dict: Default parameters including model size, patching, dropout rates,
                  training settings, and normalization strategies.
        """
        return {
            'e_layers': 4, 'd_model': 256, 'n_heads': 4, 'd_ff': 128,
            'patch_len': min(7, self.seq_len), 'stride': min(7, self.seq_len),
            'dropout': 0.2, 'attn_dropout': 0.1, 'fc_dropout': 0.1, 'head_dropout': 0.1,
            'kernel_size': 7, 'padding_patch': 'end',
            'individual': False, 'head_type': 'prediction', 'revin': True, 'affine': True,
            'subtract_last': False, 'decomposition': False,
            'batch_size': 16, 'epochs': 10, 'lr': 1e-3, 'weight_decay': 1e-3,
            'optimizer': 'adamw', 'scheduler': 'cosine', 'T_0': 3, 'T_mult': 1,
            'grad_clip': 0.1, 'patience': 5, 'warmup_epochs': 3
        }

def scale_data(df: pd.DataFrame) -> Tuple[pd.DataFrame, Dict[str, StandardScaler]]:
    """
    Applies standard scaling (zero mean, unit variance) to numerical columns in a DataFrame,
    excluding known time-based cyclic features.

    Args:
        df (pd.DataFrame): Input DataFrame containing numeric and possibly time-related features.

    Returns:
        Tuple[pd.DataFrame, Dict[str, StandardScaler]]:
            - Scaled DataFrame with specified columns standardized.
            - Dictionary mapping column names to their fitted StandardScaler instances
              for potential inverse transformation or reuse.

    Notes:
        The following time-based or cyclic features are excluded from scaling:
        'dayofweek', 'month', 'is_weekend', 'dayofyear',
        'dayofweek_sin', 'dayofweek_cos',
        'month_sin', 'month_cos',
        'dayofyear_sin', 'dayofyear_cos'.
    """
    scalers = {}
    # Time-related features to exclude from scaling
    skip_cols = {
        'dayofweek', 'month', 'is_weekend', 'dayofyear',
        'dayofweek_sin', 'dayofweek_cos',
        'month_sin', 'month_cos',
        'dayofyear_sin', 'dayofyear_cos'
    }

    for col in df.select_dtypes('number').columns:
        if col not in skip_cols:
            sc = StandardScaler()
            df[[col]] = sc.fit_transform(df[[col]])
            scalers[col] = sc
    return df, scalers

def split_and_window(df: pd.DataFrame, seq_len, pred_len) -> Tuple:
    """
    Splits a time series DataFrame into training, validation, and test sets,
    then applies sliding window segmentation to create input-output pairs.

    The function removes the 'date' column, splits the remaining data
    60%/20%/20% into train/val/test, and generates windows of past `seq_len`
    timesteps as inputs and future `pred_len` timesteps as targets.

    Args:
        df (pd.DataFrame): Input DataFrame with a 'date' column and numerical features.
        seq_len (int): Length of input sequences (historical window).
        pred_len (int): Length of prediction sequences (forecast horizon).

    Returns:
        Tuple:
            - (X_tr, Y_tr): Training input-output pairs as NumPy arrays.
            - (X_va, Y_va): Validation input-output pairs as NumPy arrays.
            - (X_te, Y_te): Test input-output pairs as NumPy arrays.
            - List[str]: Names of feature columns used (excluding 'date').

    Notes:
        Data is split based on row indices, not shuffling.
        Input and output arrays are returned with dtype float32.
    """
    arr = df.drop(columns=['date']).values
    n = len(arr)
    t1, t2 = int(0.6 * n), int(0.8 * n)

    print(f"[Split] Total: {n}, Train: {t1}, Val: {t2 - t1}, Test: {n - t2}")

    def window(data):
        X, Y = [], []
        for i in range(len(data) - seq_len - pred_len):
            X.append(data[i:i+seq_len])
            Y.append(data[i+seq_len:i+seq_len+pred_len])
        return np.array(X, np.float32), np.array(Y, np.float32)

    X_tr, Y_tr = window(arr[:t1])
    X_va, Y_va = window(arr[t1:t2])
    X_te, Y_te = window(arr[t2:])


    print(f"[Windowed] Train X: {X_tr.shape}, Y: {Y_tr.shape}")
    print(f"[Windowed] Val   X: {X_va.shape}, Y: {Y_va.shape}")
    print(f"[Windowed] Test  X: {X_te.shape}, Y: {Y_te.shape}")

    return (X_tr, Y_tr), (X_va, Y_va), (X_te, Y_te), df.drop(columns='date').columns.tolist()


def prepare_loaders(Xs, Ys, cfg) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """
    Converts split input-output NumPy arrays into PyTorch DataLoaders
    for training, validation, and testing.

    Args:
        Xs (Tuple[np.ndarray, np.ndarray, np.ndarray]): Tuple containing input arrays for train, val, and test.
        Ys (Tuple[np.ndarray, np.ndarray, np.ndarray]): Tuple containing target arrays for train, val, and test.
        cfg (Config): Configuration object containing at least `batch_size`.

    Returns:
        Tuple[DataLoader, DataLoader, DataLoader]:
            - train_loader: DataLoader for training set (with shuffling).
            - val_loader: DataLoader for validation set (with shuffling).
            - test_loader: DataLoader for test set (no shuffling).

    Notes:
        Inputs and targets are converted to `torch.tensor` from NumPy arrays.
        Only training and validation loaders have shuffling enabled.
    """
    to_loader = lambda x, y: DataLoader(TensorDataset(torch.tensor(x), torch.tensor(y)), batch_size=cfg.batch_size, shuffle=True)
    train_loader = to_loader(Xs[0], Ys[0])
    val_loader = to_loader(Xs[1], Ys[1])
    test_loader = DataLoader(TensorDataset(torch.tensor(Xs[2]), torch.tensor(Ys[2])), batch_size=cfg.batch_size)

    print(f"[Loaders] Train batches: {len(train_loader)}, Batch size: {cfg.batch_size}")
    print(f"[Loaders] Val   batches: {len(val_loader)}")
    print(f"[Loaders] Test  batches: {len(test_loader)}")

    return train_loader, val_loader, test_loader

def compute_channel_weights_from_val(model, val_loader, device) -> torch.Tensor:
    """
    Computes normalized per-channel weights based on RMSE over the validation set.

    This function evaluates the model on the validation data, computes root mean
    squared error (RMSE) per channel, and assigns higher weights to channels
    with lower error. The weights are then normalized to have a mean of 1.

    Args:
        model (torch.nn.Module): Trained model to evaluate.
        val_loader (DataLoader): Validation set DataLoader.
        device (torch.device): Computation device (CPU or CUDA).

    Returns:
        torch.Tensor: 1D tensor of shape (num_channels,) containing normalized weights.

    Notes:
        - The function disables gradient computation (`torch.no_grad()`).
        - Channels with lower RMSE receive higher weights.
        - `1e-6` is added to avoid division by zero.
    """
    model.eval()
    preds_val, trues_val = [], []
    with torch.no_grad():
        for xb, yb in val_loader:
            xb = xb.to(device)
            out = model(xb).cpu().numpy()
            preds_val.append(out)
            trues_val.append(yb.numpy())
    preds_val = np.concatenate(preds_val, axis=0)
    trues_val = np.concatenate(trues_val, axis=0)
    rmse = np.sqrt(((trues_val - preds_val) ** 2).mean(axis=(0, 1)))
    weights = 1.0 / (rmse + 1e-6)
    return torch.tensor(weights / weights.mean(), dtype=torch.float32)


def train_one_epoch(model, loader, opt, sched, channel_weights, cfg):
    """
    Trains the model for a single epoch using Smooth L1 loss with channel-wise weighting.

    This function performs forward and backward passes on the training data,
    applies gradient clipping, updates the optimizer, and steps the learning rate scheduler.

    Args:
        model (torch.nn.Module): The model to train.
        loader (DataLoader): DataLoader providing training batches.
        opt (torch.optim.Optimizer): Optimizer instance (e.g., Adam, AdamW).
        sched (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.
        channel_weights (torch.Tensor): 1D tensor of weights to apply to each channel's loss.
        cfg (Config): Configuration object with `.device` and `.grad_clip`.

    Returns:
        float: Average training loss over all batches in the epoch.

    Notes:
        - Uses `Smooth L1` loss (Huber loss) with `beta=1.0`.
        - Applies per-channel weighting and global averaging.
        - Gradients are clipped to `cfg.grad_clip` for stability.
    """
    model.train()
    total_loss = 0
    for batch_idx, (xb, yb) in enumerate(loader):
        if batch_idx == 0:
            print(f"[Train Batch] xb: {xb.shape}, yb: {yb.shape}")
        xb, yb = xb.to(cfg.device), yb.to(cfg.device)
        opt.zero_grad()
        loss = torch.nn.functional.smooth_l1_loss(model(xb), yb, reduction='none', beta=1.0)
        loss = (loss * channel_weights.to(cfg.device).view(1, 1, -1)).mean()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
        opt.step()
        sched.step()
        total_loss += loss.item()
    return total_loss / len(loader)


def validate(model, loader, channel_weights, cfg):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for xb, yb in loader:
            xb, yb = xb.to(cfg.device), yb.to(cfg.device)
            loss = torch.nn.functional.smooth_l1_loss(model(xb), yb, reduction='none', beta=1.0)
            loss = (loss * channel_weights.to(cfg.device).view(1, 1, -1)).mean()
            total_loss += loss.item()
    return total_loss / len(loader)

def plot_prediction_sequence(
    raw_preds, raw_trues, columns,
    channel_idx: int, t0: int,
    fname: Optional[str] = None,
    align_future: bool = False
):
    """
    Plot test prediction vs. truth (first 7 days), and future forecast (next 7 days) as a dotted line.
    Optionally aligns the future forecast for smoother visual continuity.

    Parameters:
        raw_preds: np.ndarray, shape (T, H, C) - model predictions
        raw_trues: np.ndarray, shape (T, H, C) - ground truth values
        columns: List[str] - feature/column names
        channel_idx: int - index of the channel to plot
        t0: int - start time window index
        fname: Optional[str] - optional file path to save the figure
        align_future: bool - whether to visually align the forecast to test prediction
    """
    H = raw_preds.shape[1]

    y_pred_test = raw_preds[t0, :, channel_idx]
    y_true_test = raw_trues[t0, :, channel_idx]
    x_test = np.arange(0, H)

    # Forecast for next horizon
    if t0 + 1 < raw_preds.shape[0]:
        y_pred_future = raw_preds[t0 + 1, :, channel_idx]
        x_future = np.arange(H, H * 2)

        if align_future:
            offset = y_pred_test[-1] - y_pred_future[0]
            y_pred_future = y_pred_future + offset
    else:
        y_pred_future = None

    # Metrics
    rmse = np.sqrt(np.mean((y_pred_test - y_true_test) ** 2))
    mean_val = np.mean(y_true_test)

    # Plot
    plt.figure(figsize=(12, 4))
    plt.plot(x_test, y_true_test, label='True (Test)', color='blue')
    plt.plot(x_test, y_pred_test, label='Predicted (Test)', color='orange')
    if y_pred_future is not None:
        plt.plot(x_future, y_pred_future, label='Forecast (Future)', color='green', linestyle='dotted')

    plt.title(f"{channel_idx} — RMSE: {rmse:.2f}, Mean: {mean_val:.2f}")
    plt.xlabel('Time Index (Days)')
    plt.ylabel('Value')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()

    if fname:
        plt.savefig(fname)
    plt.show()


def filter_high_rmse_channels(raw_preds, raw_trues, columns, threshold=None):
    if threshold is not None:
        per_series_rmse = np.sqrt(((raw_trues - raw_preds) ** 2).mean(axis=(0, 1)))
        bad_idx = [i for i, rmse in enumerate(per_series_rmse) if rmse > threshold]

        if bad_idx:
            print(f"Dropping {len(bad_idx)} channels with RMSE > {threshold}")
            print(f"Dropped columns: {[columns[i] for i in bad_idx]}")

            keep_idx = [i for i in range(len(columns)) if i not in bad_idx]
            raw_preds = raw_preds[:, :, keep_idx]
            raw_trues = raw_trues[:, :, keep_idx]
            columns = [columns[i] for i in keep_idx]

    flat_df = pd.DataFrame(raw_trues.reshape(-1, len(columns)), columns=columns)
    print(f"\nFiltered dataframe shape: {flat_df.shape} (rows x columns)")
    print("Column means after filtering:")
    print(flat_df.mean())

    mae = mean_absolute_error(raw_trues.flatten(), raw_preds.flatten())
    rmse = np.sqrt(((raw_trues - raw_preds) ** 2).mean())
    r2 = r2_score(raw_trues.flatten(), raw_preds.flatten())
    smape_score = smape(raw_trues.flatten(), raw_preds.flatten())
    print(f"Global → MAE={mae:.3f}, RMSE={rmse:.3f}, R²={r2:.4f}, SMAPE={smape_score:.2f}%")

    return raw_preds, raw_trues, columns, mae, rmse, r2

# --- New Plot Function ---
def plot_forecast(preds, trues, feature_idxs=[2], save_path="/content/logs", num_samples=3):
    num_samples = min(num_samples, preds.shape[0])

    for feature_idx in feature_idxs:
        plt.figure(figsize=(10, 6))

        for i in range(num_samples):
            true_series = trues[i, :, feature_idx]
            pred_series = preds[i, :, feature_idx]

            rmse = np.sqrt(np.mean((pred_series - true_series) ** 2))
            mean_val = np.mean(true_series)

            plt.subplot(num_samples, 1, i + 1)
            plt.plot(true_series, label="True")
            plt.plot(pred_series, label="Pred")
            plt.title(f"Sample {i} | Feature {feature_idx} | RMSE: {rmse:.2f}, Mean: {mean_val:.2f}")
            plt.legend(loc='upper right')

        plt.tight_layout()

        if save_path:
            fname = f"{osp.splitext(save_path)[0]}_feature{feature_idx}_multi.png"
            plt.savefig(fname)
        else:
            plt.show()




def create_scheduler_with_warmup(optimizer, total_steps, warmup_steps):
    def lr_lambda(current_step):
        if current_step < warmup_steps:
            return float(current_step) / float(max(1, warmup_steps))
        progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        return 0.5 * (1.0 + np.cos(np.pi * progress))

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

def smape(y_true, y_pred):
    denominator = (np.abs(y_true) + np.abs(y_pred)) / 2.0
    return np.mean(np.abs(y_pred - y_true) / np.clip(denominator, 1e-8, None)) * 100

def compute_global_metrics(raw_preds, raw_trues):
    mae = mean_absolute_error(raw_trues.flatten(), raw_preds.flatten())
    rmse = np.sqrt(((raw_trues - raw_preds) ** 2).mean())
    r2 = r2_score(raw_trues.flatten(), raw_preds.flatten())
    smape_score = smape(raw_trues.flatten(), raw_preds.flatten())
    return mae, rmse, r2, smape_score

def compute_per_channel_metrics(raw_preds, raw_trues):
    mean_true_per_channel = raw_trues.mean(axis=(0, 1))
    rmse_per_channel = np.sqrt(((raw_preds - raw_trues) ** 2).mean(axis=(0, 1)))
    nrmse_per_channel = np.divide(
        rmse_per_channel,
        np.abs(mean_true_per_channel),
        out=np.full_like(rmse_per_channel, np.inf),
        where=mean_true_per_channel != 0
    )
    return mean_true_per_channel, rmse_per_channel, nrmse_per_channel

def categorize_nrmse(nrmse):
    if nrmse < 0.1: return 'Excellent'
    elif nrmse < 0.3: return 'Good'
    elif nrmse < 0.5: return 'Fair'
    elif nrmse < 1.0: return 'Poor'
    else: return 'Bad'

def generate_stats_df(columns, mean_true, rmse, nrmse):
    categories = [categorize_nrmse(val) for val in nrmse]
    return pd.DataFrame({
        "column": columns,
        "mean": mean_true,
        "rmse": rmse,
        "nrmse": nrmse,
        "category": categories
    })



def evaluate_and_plot(model, test_loader, scalers, columns, cfg, basename, rmse_threshold=None):
    model.eval()
    preds, trues = [], []
    with torch.no_grad():
        for xb, yb in test_loader:
            out = model(xb.to(cfg.device)).cpu().numpy()
            preds.append(out)
            trues.append(yb.numpy())
    preds, trues = np.concatenate(preds), np.concatenate(trues)

    print(f"[Evaluation] Preds shape: {preds.shape}")
    print(f"[Evaluation] Trues shape: {trues.shape}")

    # --- Inverse Transform ---
    raw_preds = np.zeros_like(preds)
    raw_trues = np.zeros_like(trues)
    for i, col in enumerate(columns):
        if col in scalers:
            raw_preds[:, :, i] = scalers[col].inverse_transform(preds[:, :, i].reshape(-1, 1)).reshape(preds.shape[0], preds.shape[1])
            raw_trues[:, :, i] = scalers[col].inverse_transform(trues[:, :, i].reshape(-1, 1)).reshape(trues.shape[0], trues.shape[1])
        else:
            raw_preds[:, :, i] = preds[:, :, i]
            raw_trues[:, :, i] = trues[:, :, i]

    # --- Exclude engineered features from metrics ---
    engineered_cols = {
        'dayofweek', 'month', 'is_weekend', 'dayofyear',
        'dayofweek_sin', 'dayofweek_cos',
        'month_sin', 'month_cos',
        'dayofyear_sin', 'dayofyear_cos'
    }

    metric_mask = [i for i, col in enumerate(columns) if col not in engineered_cols]
    raw_preds_metrics = raw_preds[:, :, metric_mask]
    raw_trues_metrics = raw_trues[:, :, metric_mask]
    metric_columns = [columns[i] for i in metric_mask]

    # --- Compute global metrics on unfiltered + unengineered channels ---
    mae, rmse, r2, smape_score = compute_global_metrics(raw_preds_metrics, raw_trues_metrics)
    print(f"[Global Metrics] MAE={mae:.4f}, RMSE={rmse:.4f}, R²={r2:.4f}, SMAPE={smape_score:.2f}%")

    # --- Optional filtering of high-RMSE channels (only for visualization and per-channel) ---
    filtered_preds, filtered_trues, filtered_columns = raw_preds_metrics, raw_trues_metrics, metric_columns
    if rmse_threshold is not None:
        filtered_preds, filtered_trues, filtered_columns = filter_high_rmse_channels(
            raw_preds_metrics, raw_trues_metrics, metric_columns, threshold=rmse_threshold
        )

    # --- Per-channel metrics ---
    mean_vals, rmse_vals, nrmse_vals = compute_per_channel_metrics(filtered_preds, filtered_trues)
    stats_df = generate_stats_df(filtered_columns, mean_vals, rmse_vals, nrmse_vals)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    stats_df.to_csv(f"logs/{basename}_per_channel_stats_{timestamp}.csv", index=False)

    print("Mean RMSE (filtered):", stats_df['rmse'].mean())
    print("Mean NRMSE (filtered):", stats_df['nrmse'].mean())
    print("Summary by Category:")
    print(stats_df['category'].value_counts())

    for cat in ['Poor', 'Fair', 'Bad']:
        filtered = stats_df[stats_df['category'] == cat]
        if not filtered.empty:
            print(f"\n{cat} Channels:")
            print(filtered[['column', 'mean', 'rmse', 'nrmse']].sort_values(by='nrmse').to_string(index=False))

    # --- Plots ---
    t0 = min(100, raw_preds.shape[0] - 1)
    plot_prediction_sequence(
    raw_preds, raw_trues, columns,
    channel_idx=1, t0=8,
    align_future=True
    )

    find_best_prediction_plots(raw_preds_metrics, raw_trues_metrics, metric_columns)

    #plot_forecast(raw_preds, raw_trues, feature_idxs=[1,2], save_path=f"/content/logs/{basename}_forecast{timestamp}.png")
    return raw_preds, raw_trues, columns, {
        'mae': mae,
        'rmse': rmse,
        'r2': r2,
        'smape': smape_score,
        'nrmse_per_channel': nrmse_vals.tolist(),
        'columns': filtered_columns
    }



def test_loss(model, loader, channel_weights, cfg):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for xb, yb in loader:
            xb, yb = xb.to(cfg.device), yb.to(cfg.device)
            loss = torch.nn.functional.smooth_l1_loss(model(xb), yb, reduction='none', beta=1.0)
            loss = (loss * channel_weights.to(cfg.device).view(1, 1, -1)).mean()
            total_loss += loss.item()
    return total_loss / len(loader)

# --- Training ---
def run_experiment(csv_path, seq_len, pred_len, config_class):
    basename = os.path.splitext(os.path.basename(csv_path))[0]
    df = pd.read_csv(csv_path)
    df, scalers = scale_data(df)
    (X_tr, Y_tr), (X_va, Y_va), (X_te, Y_te), columns = split_and_window(df, seq_len, pred_len)
    cfg = config_class(seq_len, pred_len, enc_in=X_tr.shape[-1])
    train_loader, val_loader, test_loader = prepare_loaders((X_tr, X_va, X_te), (Y_tr, Y_va, Y_te), cfg)

    model = PatchTST(cfg).to(cfg.device)
    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

    total_steps = cfg.epochs * len(train_loader)
    warmup_steps = cfg.warmup_epochs * len(train_loader)
    sched = create_scheduler_with_warmup(opt, total_steps, warmup_steps)

    channel_weights = torch.ones(cfg.enc_in)
    best_val, trigger = float('inf'), 0
    train_losses, val_losses = [] ,[]

    for epoch in range(cfg.epochs):
        if epoch == 1:
            channel_weights = compute_channel_weights_from_val(model, val_loader, cfg.device)
        train_loss = train_one_epoch(model, train_loader, opt, sched, channel_weights, cfg)
        val_loss = validate(model, val_loader, channel_weights, cfg)
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        print(f"[PatchTST][{basename}] Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
        if val_loss < best_val:
            best_val, trigger = val_loss, 0
        else:
            trigger += 1
            if trigger >= cfg.patience:
                print(f"[EarlyStopping] Epoch {epoch+1}")
                break

    test_l = test_loss(model, test_loader, channel_weights, cfg)
    print(f"[PatchTST][{basename}] Final Test Loss: {test_l:.4f}")
    raw_preds, raw_trues, columns, result = evaluate_and_plot(model, test_loader, scalers, columns, cfg, basename)
    np.savez_compressed(f"logs/patchtst_outputs_{timestamp}_{basename}.npz",
                    preds=raw_preds, trues=raw_trues, columns=np.array(columns))
    result['test_loss'] = test_l
    print(f"[PatchTST][{basename}] Final Test Loss: {test_l:.4f}")
    return result

if __name__ == "__main__":
    # --- COLAB FALLBACK ---
    csv_path_for_hourly = "/content/PatchTST/PatchTST_supervised/Dataset/electricity_hourly_transformed_3.csv"
    csv_path_for_daily = "/content/PatchTST/PatchTST_supervised/Dataset/electricity_daily_transformed_5.csv"
    config_type ='daily'  # Change to 'daily' if needed
    if config_type=='hourly':
      seq_len = 512
      pred_len = 24
    elif config_type=='daily':
      seq_len = 120
      pred_len = 7
    config_class = HourlyConfig if config_type == 'hourly' else DailyConfig
    if config_type=='hourly':
      result = run_experiment(csv_path_for_hourly, seq_len, pred_len, config_class)
      pd.DataFrame([result]).to_csv(f"logs/combined_results_{timestamp}.csv", index=False)
    elif config_type=='daily':
      result = run_experiment(csv_path_for_daily, seq_len, pred_len, config_class)
      pd.DataFrame([result]).to_csv(f"logs/combined_results_{timestamp}.csv", index=False)