In [None]:
# dlinear_pipeline.py
import numpy as np
import pandas as pd
import argparse
from typing import Tuple, Dict, Optional, List
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import r2_score, mean_absolute_error
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import os
import os.path as osp
from datetime import datetime
from dataclasses import dataclass, field
from enum import Enum

class DatasetType(str, Enum):
    HOURLY = "hourly"
    DAILY = "daily"

@dataclass
class Config:
    dataset: str = "daily"
    data_path: str = None
    lookback: int = None
    horizon: int = None
    epochs: int = 20
    lr: float = 1e-3
    batch_size: int = 32
    output_dir: str = "logs"
    timestamp: str = field(default_factory=lambda: datetime.now().strftime("%Y%m%d_%H%M%S"))
    threshold: int = None  # Optional

    def __post_init__(self):
        if self.dataset == "hourly":
            self.data_path = self.data_path or "/content/PatchTST/PatchTST_supervised/Dataset/electricity_hourly_transformed_5.csv"
            self.lookback = self.lookback or 256
            self.horizon = self.horizon or 24
            self.threshold = self.threshold
        elif self.dataset == "daily":
            self.data_path = self.data_path or "/content/PatchTST/PatchTST_supervised/Dataset/electricity_daily_transformed_5.csv"
            self.lookback = self.lookback or 120
            self.horizon = self.horizon or 7
            self.threshold = self.threshold

        self.basename = f"{osp.splitext(osp.basename(self.data_path))[0]}_{self.timestamp}"
        self.log_path = osp.join(self.output_dir, f"{self.basename}_train_log.txt")
        self.metrics_path = osp.join(self.output_dir, f"{self.basename}_metrics.txt")
        self.plot_path = osp.join(self.output_dir, f"{self.basename}_forecast.png")
        self.preds_path = osp.join(self.output_dir, f"{self.basename}_preds.npy")
        self.trues_path = osp.join(self.output_dir, f"{self.basename}_trues.npy")



# Dataset
class TimeSeriesDataset(Dataset):
    def __init__(self, data, lookback, horizon):
        self.X, self.y = [], []
        for i in range(len(data) - lookback - horizon + 1):
            self.X.append(data[i:i + lookback])
            self.y.append(data[i + lookback:i + lookback + horizon])
        self.X = torch.tensor(np.array(self.X), dtype=torch.float32)
        self.y = torch.tensor(np.array(self.y), dtype=torch.float32)
        print(f"[Dataset Init] X: {self.X.shape}, Y: {self.y.shape}")

    def __len__(self): return len(self.X)
    def __getitem__(self, idx): return self.X[idx], self.y[idx]


class DLinear(nn.Module):
    def __init__(self, seq_len, pred_len, input_dim, hidden_dim=64):
        super(DLinear, self).__init__()
        self.model = nn.Sequential(
        nn.Linear(seq_len, hidden_dim),
        nn.ReLU(),
        nn.Dropout(0.2),
        nn.Linear(hidden_dim, pred_len)
        )

    def forward(self, x):
        x = x.permute(0, 2, 1)  # (B, F, T)
        out = self.model(x)     # Apply per feature
        return out.permute(0, 2, 1)  # (B, T, F)


# Scaling
def scale_data(df: pd.DataFrame) -> Tuple[np.ndarray, Dict[str, StandardScaler], list]:
    scalers = {}
    skip_cols = {
        'dayofweek', 'month', 'is_weekend', 'dayofyear',
        'dayofweek_sin', 'dayofweek_cos',
        'month_sin', 'month_cos',
        'dayofyear_sin', 'dayofyear_cos'
    }

    numeric_cols = df.select_dtypes(include=['number']).columns
    scaled_columns = []

    for col in numeric_cols:
        if col not in skip_cols:
            sc = StandardScaler()
            df[[col]] = sc.fit_transform(df[[col]])
            scalers[col] = sc
            scaled_columns.append(col)

    return df[numeric_cols].astype(np.float32).values, scalers, scaled_columns


# Utils
def load_data(filepath):
    df = pd.read_csv(filepath)
    df = df.sort_values("date")
    return df


def split_data(data, lookback, horizon):
    train_size = int(len(data) * 0.6)
    val_size = int(len(data) * 0.2)
    train = data[:train_size]
    val = data[train_size:train_size + val_size]
    test = data[train_size + val_size:]
    return train, val, test


def create_dataloaders(train, val, test, lookback, horizon, batch_size):
    train_ds = TimeSeriesDataset(train, lookback, horizon)
    val_ds = TimeSeriesDataset(val, lookback, horizon)
    test_ds = TimeSeriesDataset(test, lookback, horizon)
    return (
        DataLoader(train_ds, batch_size=batch_size, shuffle=True),
        DataLoader(val_ds, batch_size=batch_size),
        DataLoader(test_ds, batch_size=batch_size),
    )

def compute_channel_weights_from_val(model, val_loader, device, threshold=1.0) -> torch.Tensor:
    model.eval()
    preds_val, trues_val = [], []

    with torch.no_grad():
        for xb, yb in val_loader:
            xb = xb.to(device)
            yb = yb.to(device)
            out = model(xb).cpu().numpy()
            preds_val.append(out)
            trues_val.append(yb.cpu().numpy())

    preds_val = np.concatenate(preds_val, axis=0)
    trues_val = np.concatenate(trues_val, axis=0)

    rmse = np.sqrt(np.mean((trues_val - preds_val) ** 2, axis=(0, 1)))
    mask = rmse < threshold
    weights = np.zeros_like(rmse)
    weights[mask] = 1.0 / (rmse[mask] + 1e-6)

    if weights[mask].mean() > 0:
        weights[mask] /= weights[mask].mean()

    return torch.tensor(weights, dtype=torch.float32)

def weighted_mse_loss(pred, target, weights):
    return ((weights * (pred - target) ** 2).mean())

def smape(y_true, y_pred):
    denominator = (np.abs(y_true) + np.abs(y_pred)) + 1e-8
    return 100 * np.mean(2 * np.abs(y_pred - y_true) / denominator)


def train_model(model, train_loader, val_loader, epochs=20, lr=1e-3, log_path="train_log.txt", device='cpu'):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    weights = None

    with open(log_path, "w") as f:
        for epoch in range(epochs):
            model.train()
            train_loss_total = 0
            for x_batch, y_batch in train_loader:
                x_batch = x_batch.to(device)
                y_batch = y_batch.to(device)

                optimizer.zero_grad()
                output = model(x_batch)

                if weights is None:
                    loss = nn.functional.mse_loss(output, y_batch)
                else:
                    loss = weighted_mse_loss(output, y_batch, weights.to(device))

                loss.backward()
                optimizer.step()
                train_loss_total += loss.item()

            train_loss_avg = train_loss_total / len(train_loader)

            model.eval()
            val_loss_total = 0
            with torch.no_grad():
                for x_val, y_val in val_loader:
                    x_val = x_val.to(device)
                    y_val = y_val.to(device)
                    output = model(x_val)
                    if weights is None:
                        val_loss = nn.functional.mse_loss(output, y_val)
                    else:
                        val_loss = weighted_mse_loss(output, y_val, weights.to(device))
                    val_loss_total += val_loss.item()

            val_loss_avg = val_loss_total / len(val_loader)

            log_msg = f"Epoch {epoch+1}/{epochs} | Train Loss: {train_loss_avg:.4f} | Val Loss: {val_loss_avg:.4f}"
            print(log_msg)
            f.write(log_msg + "\n")

            weights = compute_channel_weights_from_val(model, val_loader, device)


def evaluate_model(model, test_loader, result_dir="results", basename="output", scalers=None, scaled_columns=None):
    os.makedirs(result_dir, exist_ok=True)
    model.eval()
    preds, trues = [], []

    with torch.no_grad():
        for x, y in test_loader:
            pred = model(x)
            preds.append(pred.numpy())
            trues.append(y.numpy())

    preds = np.concatenate(preds, axis=0)
    trues = np.concatenate(trues, axis=0)

    B, T, F = preds.shape
    preds_2d = preds.reshape(B, T, F)
    trues_2d = trues.reshape(B, T, F)

    # Inverse transform each feature
    if scalers and scaled_columns:
        for i, col in enumerate(scaled_columns):
            scaler = scalers[col]
            preds_2d[:, :, i] = scaler.inverse_transform(
                preds_2d[:, :, i].reshape(-1, 1)
            ).reshape(preds_2d.shape[0], preds_2d.shape[1])
            trues_2d[:, :, i] = scaler.inverse_transform(
                trues_2d[:, :, i].reshape(-1, 1)
            ).reshape(trues_2d.shape[0], trues_2d.shape[1])

    # Metric calculations
    mse = np.mean((preds_2d - trues_2d) ** 2)
    mae = np.mean(np.abs(preds_2d - trues_2d))
    rmse = np.sqrt(mse)
    smape_score = smape(trues_2d, preds_2d)

    valid_idx = [
    i for i in range(len(scaled_columns))
    if not np.isclose(np.std(trues_2d[:, :, i]), 0)
    ]
    r2 = r2_score(
        trues_2d[:, :, valid_idx].reshape(-1, len(valid_idx)),
        preds_2d[:, :, valid_idx].reshape(-1, len(valid_idx))
    )

    # Save metrics and arrays
    with open(os.path.join(result_dir, f"{basename}_metrics.txt"), "w") as f:
        f.write(f"MAE: {mae:.4f}\nRMSE: {rmse:.4f}\nR2: {r2:.4f}\nSMAPE:{smape_score:.4f}")

    print(f"MAE: {mae:.4f}, RMSE: {rmse:.4f}, R2: {r2:.4f}, SMAPE: {smape_score:.2f}%")

    np.save(os.path.join(result_dir, f"{basename}_preds.npy"), preds_2d)
    np.save(os.path.join(result_dir, f"{basename}_trues.npy"), trues_2d)

    return preds_2d, trues_2d

def plot_random_sample_forecasts(preds, trues, feature_idx=6, num_samples=3, save_path=None):
    """
    Plot a few random samples from the batch for visual clarity.
    """
    import random
    B, T, F = preds.shape
    indices = random.sample(range(B), min(num_samples, B))

    plt.figure(figsize=(14, num_samples * 2.5))
    for i, idx in enumerate(indices):
        plt.subplot(num_samples, 1, i + 1)
        plt.plot(trues[idx, :, feature_idx], label="True")
        plt.plot(preds[idx, :, feature_idx], label="Pred")
        plt.title(f"Random Sample {idx} - Feature {feature_idx}")
        plt.legend()
        plt.grid(True)

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path)
    else:
        plt.show()

def plot_prediction_sequence(
    raw_preds, raw_trues, columns,
    channel_idx: int, t0: int,
    fname: Optional[str] = f"/content/logs/forecast.png",
    align_future: bool = False
):
    """
    Plot test prediction vs. truth (first 7 days), and future forecast (next 7 days) as a dotted line.
    Optionally aligns the future forecast for smoother visual continuity.
    """
    H = raw_preds.shape[1]

    y_pred_test = raw_preds[t0, :, channel_idx]
    y_true_test = raw_trues[t0, :, channel_idx]
    x_test = np.arange(0, H)

    if t0 + 1 < raw_preds.shape[0]:
        y_pred_future = raw_preds[t0 + 1, :, channel_idx]
        x_future = np.arange(H, H * 2)
        if align_future:
            offset = y_pred_test[-1] - y_pred_future[0]
            y_pred_future = y_pred_future + offset
    else:
        y_pred_future = None

    rmse = np.sqrt(np.mean((y_pred_test - y_true_test) ** 2))
    mean_val = np.mean(y_true_test)

    plt.figure(figsize=(12, 4))
    plt.plot(x_test, y_true_test, label='True (Test)', color='blue')
    plt.plot(x_test, y_pred_test, label='Predicted (Test)', color='orange')
    if y_pred_future is not None:
        plt.plot(x_future, y_pred_future, label='Forecast (Future)', color='green', linestyle='dotted')

    plt.title(f"{channel_idx} — RMSE: {rmse:.2f}, Mean: {mean_val:.2f}")
    plt.xlabel('Time Index (Days)')
    plt.ylabel('Value')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()

    if fname:
        plt.savefig(fname)
    plt.show()


def filter_high_rmse_channels(raw_preds, raw_trues, columns, threshold=2500, debug=False):
    per_series_rmse = np.sqrt(((raw_trues - raw_preds) ** 2).mean(axis=(0, 1)))
    mean_rmse = np.mean(per_series_rmse)
    print(f"\nMean per-channel RMSE: {mean_rmse:.4f}")

    if threshold is None:
        print("No threshold provided. Skipping channel filtering.")
        mae = mean_absolute_error(raw_trues.flatten(), raw_preds.flatten())
        rmse = np.sqrt(((raw_trues - raw_preds) ** 2).mean())
        r2 = r2_score(raw_trues.flatten(), raw_preds.flatten())
        return raw_preds, raw_trues, columns, mae, rmse, r2

    bad_idx = [i for i, rmse in enumerate(per_series_rmse) if rmse > threshold]

    if bad_idx:
        print(f"Dropping {len(bad_idx)} channels with RMSE > {threshold}")
        print(f"Dropped columns: {[columns[i] for i in bad_idx]}")

    keep_idx = [i for i in range(len(columns)) if i not in bad_idx]
    raw_preds = raw_preds[:, :, keep_idx]
    raw_trues = raw_trues[:, :, keep_idx]
    columns = [columns[i] for i in keep_idx]
    per_series_rmse = per_series_rmse[keep_idx]  # filtered

    if debug:
        print(f"raw_preds shape: {raw_preds.shape}")
        print(f"raw_trues shape: {raw_trues.shape}")
        print(f"Filtered columns count: {len(columns)}")

    try:
        flat_df = pd.DataFrame(raw_trues.reshape(-1, len(columns)), columns=columns)
        print(f"\nFiltered dataframe shape: {flat_df.shape} (rows x columns)")
        print("Column means after filtering:")
        print(flat_df.mean())
    except ValueError as e:
        print(f"Reshape failed: {e}")
        return raw_preds, raw_trues, columns, None, None, None

    mae = mean_absolute_error(raw_trues.flatten(), raw_preds.flatten())
    rmse = np.sqrt(((raw_trues - raw_preds) ** 2).mean())
    r2 = r2_score(raw_trues.flatten(), raw_preds.flatten())
    print(f"Filtered → MAE={mae:.3f}, RMSE={rmse:.3f}, R²={r2:.4f}")

    rmse_df = pd.DataFrame({
        "column": columns,
        "rmse": per_series_rmse
    })
    rmse_csv_path = os.path.join("logs", f"{osp.splitext(config.basename)[0]}_rmse_per_channel.csv")
    rmse_df.to_csv(rmse_csv_path, index=False)
    print(f"Saved per-channel RMSE to: {rmse_csv_path}")

    return raw_preds, raw_trues, columns, mae, rmse, r2

def compute_nrmse_per_series(raw_preds: np.ndarray, raw_trues: np.ndarray, columns: list) -> pd.DataFrame:
    assert raw_preds.shape == raw_trues.shape, "Prediction and ground truth shapes must match"

    rmse_list = []
    mean_list = []
    nrmse_list = []
    category_list = []

    for i in range(raw_preds.shape[2]):
        y_true = raw_trues[:, :, i]
        y_pred = raw_preds[:, :, i]

        rmse = np.sqrt(np.mean((y_true - y_pred) ** 2))
        mean = np.mean(y_true)
        nrmse = rmse / (mean + 1e-8)

        if nrmse < 0.1:
            category = "Excellent"
        elif nrmse < 0.2:
            category = "Good"
        elif nrmse < 0.3:
            category = "Fair"
        else:
            category = "Poor"

        rmse_list.append(rmse)
        mean_list.append(mean)
        nrmse_list.append(nrmse)
        category_list.append(category)

    df = pd.DataFrame({
        "column": columns,
        "rmse": rmse_list,
        "mean": mean_list,
        "nrmse": nrmse_list,
        "category": category_list
    })

    return df.sort_values("nrmse")


# Entry point
if __name__ == '__main__':
    base_config = Config()

    parser = argparse.ArgumentParser()
    parser.add_argument('--data', type=str)
    parser.add_argument('--lookback', type=int)
    parser.add_argument('--horizon', type=int)
    parser.add_argument('--epochs', type=int)
    parser.add_argument('--lr', type=float)
    parser.add_argument('--batch_size', type=int)
    parser.add_argument('--output', type=str)
    parser.add_argument('--rmse_threshold', type=float)
    import sys
    args = parser.parse_args([] if "__file__" not in globals() else sys.argv[1:])
    for key, value in vars(args).items():
        if value is not None and hasattr(base_config, key if key != 'data' else 'data_path'):
            setattr(base_config, key if key != 'data' else 'data_path', value)

    config = base_config
    config.__post_init__()

    df = load_data(config.data_path)
    print(f"[Raw DF] Shape: {df.shape}")

    data, scalers, scaled_columns = scale_data(df)
    print(f"[Scaled Data] Shape: {data.shape}")

    train_data, val_data, test_data = split_data(data, config.lookback, config.horizon)
    print(f"[Split Data] Train: {len(train_data)}, Val: {len(val_data)}, Test: {len(test_data)}")

    train_loader, val_loader, test_loader = create_dataloaders(
        train_data, val_data, test_data,
        config.lookback, config.horizon, config.batch_size
    )

    model = DLinear(
    seq_len=config.lookback,
    pred_len=config.horizon,
    input_dim=data.shape[1],
    hidden_dim=128  # or higher, e.g., 128, 256
    )

    train_model(model, train_loader, val_loader, epochs=config.epochs, lr=config.lr, log_path=config.log_path)

    preds, trues = evaluate_model(model, test_loader, result_dir=config.output_dir, basename=config.basename,
                                  scalers=scalers, scaled_columns=scaled_columns)
    print(f"[Eval Output] preds: {preds.shape}, trues: {trues.shape}")

    preds, trues, filtered_columns, mae, rmse, r2 = filter_high_rmse_channels(
        preds, trues, scaled_columns, threshold=config.threshold, debug=True
    )
    # Ensure columns align with pred/trues third dimension
    aligned_columns = filtered_columns
    if preds.shape[2] != len(filtered_columns):
        print(f"[Warning] Column count mismatch: {preds.shape[2]} ≠ {len(filtered_columns)}. Using fallback.")
        if len(scaled_columns) != preds.shape[2]:
            print(f"[Fixing] scaled_columns length ({len(scaled_columns)}) ≠ preds.shape[2] ({preds.shape[2]})")
            aligned_columns = [f"feat_{i}" for i in range(preds.shape[2])]
        else:
            aligned_columns = scaled_columns

    nrmse_df = compute_nrmse_per_series(preds, trues, aligned_columns)
    print("\nPer-Series NRMSE Summary:")
    print(nrmse_df.to_string(index=False))

    nrmse_report_path = os.path.join(config.output_dir, f"{config.basename}_nrmse_report.csv")
    nrmse_df.to_csv(nrmse_report_path, index=False)
    print(f"\nNRMSE report saved to: {nrmse_report_path}")

    #plot_forecast(preds, trues, feature_idxs=list(range(min(3, len(filtered_columns)))), save_path=config.plot_path)

    plot_prediction_sequence(
    raw_preds=preds,
    raw_trues=trues,
    columns=aligned_columns,
    channel_idx=1,
    t0=4,
    fname="logs/sample_forecast_seq.png",
    align_future=True
)

    top_k_scores = find_best_prediction_plots(
    raw_preds=preds,
    raw_trues=trues,
    columns=aligned_columns,
    save_dir=os.path.join(config.output_dir, "best_plots"),
    top_k=5,
    align_future=True
)
