
# üß™ Time-Series Forecasting Demo ‚Äî Dashboard Style

This notebook lets you:
1. üîº Upload **or** automatically load the dataset  
2. üìå Select a **model** via dropdown  
3. üìä Re-run **forecasting**  
4. üìà See results (**Actual vs Predicted**, **Error plot**, **Model metrics**)  
5. üîç Add **Explainable AI (SHAP + LIME)**  
6. üèÜ Final output summary with model & top features  


In [None]:

# === Setup & Imports ===
import sys, io, os, json, warnings
warnings.filterwarnings('ignore')

# Ensure our modular package is importable
sys.path.append('/mnt/data')
sys.path.append('/mnt/data/xai_pipeline')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from IPython.display import display, Markdown, HTML, clear_output
import ipywidgets as widgets

# Pipeline modules
from xai_pipeline.data_loader import load_covid_aggregated
from xai_pipeline.preprocessing import scale_features, create_sequences, train_test_sequence_split
from xai_pipeline.metrics import regression_metrics
from xai_pipeline.plots import plot_predictions
from xai_pipeline.shap import explain_model as shap_explain

# Models
from xai_pipeline.models.lstm import build_lstm
from xai_pipeline.models.bilstm import build_bilstm
from xai_pipeline.models.cnn import build_cnn
from xai_pipeline.models.hybrid_cnn_lstm import build_hybrid_cnn_lstm
from xai_pipeline.models.rnn import build_rnn
from xai_pipeline.models.mlp import build_mlp
from xai_pipeline.models.hybrid_mlp_cnn_lstm import build_hybrid_mlp_cnn_lstm
from xai_pipeline.models.hybrid_cnn_bilstm import build_hybrid_cnn_bilstm
from xai_pipeline.models.hybrid_cnn_dense import build_hybrid_cnn_dense

# TF callbacks
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau

# Optional: LIME (graceful fallback if not installed)
try:
    from xai_pipeline.lime import explain_instance as lime_explain
    _lime_ok = True
except Exception as _e:
    _lime_ok = False


In [None]:

# === UI Widgets ===

# 1. Data source
uploader = widgets.FileUpload(accept='.csv', multiple=False, description='Upload CSV')
auto_load_btn = widgets.Button(description='Use Sample COVID CSV', tooltip='Auto-load cached Kaggle COVID dataset (if present)')

# 2. Model selection
model_options = {
    "LSTM": "lstm",
    "BiLSTM": "bilstm",
    "CNN (1D)": "cnn",
    "Hybrid: CNN + LSTM": "hybrid_cnn_lstm",
    "RNN (SimpleRNN)": "rnn",
    "MLP (Flattened T*F)": "mlp",
    "Hybrid: Conv -> LSTM -> MLP": "hybrid_mlp_cnn_lstm",
    "Hybrid: Conv -> BiLSTM": "hybrid_cnn_bilstm",
    "Hybrid: CNN + Dense (baseline)": "hybrid_cnn_dense"
}
model_dd = widgets.Dropdown(options=list(model_options.keys()), value="LSTM", description='Model')

# Controls
time_steps = widgets.IntSlider(value=14, min=3, max=60, step=1, description='Time steps')
epochs = widgets.IntSlider(value=30, min=5, max=200, step=5, description='Epochs')
batch = widgets.IntSlider(value=32, min=8, max=256, step=8, description='Batch')
run_btn = widgets.Button(description='Run Forecast', button_style='success')
explain_btn = widgets.Button(description='Explain (SHAP + LIME)', button_style='info')
out = widgets.Output()

controls = widgets.VBox([
    widgets.HTML("<h3>1) Data</h3>"),
    widgets.HBox([uploader, auto_load_btn]),
    widgets.HTML("<h3>2) Model</h3>"),
    model_dd,
    widgets.HTML("<h3>3) Train Settings</h3>"),
    widgets.HBox([time_steps, epochs, batch]),
    widgets.HBox([run_btn, explain_btn]),
])
display(controls, out)


In [None]:

# === Helpers ===

def _read_uploaded_csv(uploader):
    if len(uploader.value) == 0:
        return None
    key = list(uploader.value.keys())[0]
    content = uploader.value[key]['content']
    return pd.read_csv(io.BytesIO(content))

def _auto_sample_csv_df():
    # Default path used in your notebook
    path = "/root/.cache/kagglehub/datasets/sudalairajkumar/novel-corona-virus-2019-dataset/versions/151/covid_19_data.csv"
    if os.path.exists(path):
        return pd.read_csv(path)
    return None

def _prepare_aggregated_df(df):
    # Harmonize to required columns
    date_col = None
    for c in ["ObservationDate", "Date", "date"]:
        if c in df.columns:
            date_col = c
            break
    if date_col is None:
        raise ValueError("No date-like column found. Expected 'ObservationDate', 'Date', or 'date'.")
    df[date_col] = pd.to_datetime(df[date_col])
    pick = lambda names: next((c for c in names if c in df.columns), None)
    cC = pick(["Confirmed","confirmed"]); cD = pick(["Deaths","deaths"]); cR = pick(["Recovered","recovered"])
    if not all([cC, cD, cR]):
        raise ValueError("Dataset must include Confirmed/Deaths/Recovered columns.")
    grouped = (df.groupby(df[date_col].dt.date)[[cC, cD, cR]].sum()
                 .reset_index().rename(columns={cC:"Confirmed", cD:"Deaths", cR:"Recovered", date_col:"ObservationDate"}))
    grouped["ObservationDate"] = pd.to_datetime(grouped["ObservationDate"])
    grouped["Active"] = grouped["Confirmed"] - grouped["Deaths"] - grouped["Recovered"]
    grouped = grouped.sort_values("ObservationDate").reset_index(drop=True)
    return grouped

def _build_model(kind, input_shape, X_seq=None):
    # kind is the value from model_options map
    if kind == "lstm":
        return build_lstm(input_shape)
    if kind == "bilstm":
        return build_bilstm(input_shape)
    if kind == "cnn":
        return build_cnn(input_shape)
    if kind == "hybrid_cnn_lstm":
        return build_hybrid_cnn_lstm(input_shape)
    if kind == "rnn":
        return build_rnn(input_shape)
    if kind == "hybrid_mlp_cnn_lstm":
        return build_hybrid_mlp_cnn_lstm(input_shape)
    if kind == "hybrid_cnn_bilstm":
        return build_hybrid_cnn_bilstm(input_shape)
    if kind == "hybrid_cnn_dense":
        return build_hybrid_cnn_dense(input_shape)
    if kind == "mlp":
        # For MLP we need flattened dim
        assert X_seq is not None, "X_seq is required for MLP to determine input dim."
        input_dim = X_seq.shape[1] * X_seq.shape[2]
        return build_mlp(input_dim=input_dim)
    raise ValueError(f"Unknown model kind: {kind}")


In [None]:

# State to carry across buttons
STATE = dict(
    df=None,
    X_seq=None, y_seq=None,
    X_train=None, X_test=None, y_train=None, y_test=None,
    model=None, model_name=None,
    y_pred=None,
    features=['Confirmed','Deaths','Recovered','Active'],
    tsteps=14
)

def _load_data():
    if len(uploader.value):
        df_raw = _read_uploaded_csv(uploader)
    else:
        df_raw = _auto_sample_csv_df()
        if df_raw is None:
            raise RuntimeError("No dataset uploaded and sample path not found.")
    df = _prepare_aggregated_df(df_raw)
    return df

def _train_and_forecast():
    df = _load_data()
    STATE['df'] = df
    features = STATE['features']
    tsteps = time_steps.value
    
    X_scaled, scaler = scale_features(df, features, scaler_type='minmax')
    y = df['Confirmed'].values
    X_seq, y_seq = create_sequences(X_scaled, y, time_steps=tsteps)
    X_train, X_test, y_train, y_test = train_test_sequence_split(X_seq, y_seq, test_size=0.2)
    
    model_key = model_options[model_dd.value]
    model = _build_model(model_key, X_train.shape[1:], X_seq=X_seq)
    callbacks = [EarlyStopping(patience=10, restore_best_weights=True), ReduceLROnPlateau(patience=5)]
    model.fit(X_train, y_train, validation_split=0.1, epochs=epochs.value, batch_size=batch.value, verbose=0, callbacks=callbacks)
    y_pred = model.predict(X_test)
    
    STATE.update(dict(
        X_seq=X_seq, y_seq=y_seq,
        X_train=X_train, X_test=X_test, y_train=y_train, y_test=y_test,
        model=model, model_name=model_dd.value, y_pred=y_pred, tsteps=tsteps
    ))

def _make_plots_and_metrics():
    y_true = STATE['y_test'].reshape(-1)
    y_pred = STATE['y_pred'].reshape(-1)
    # 1) Actual vs Predicted
    plt.figure(figsize=(10,5))
    plt.plot(y_true, label='Actual')
    plt.plot(y_pred, label='Predicted')
    plt.title(f"Actual vs Predicted ‚Äî {STATE['model_name']}")
    plt.xlabel("Samples")
    plt.ylabel("Value")
    plt.legend()
    plt.tight_layout()
    plt.show()
    
    # 2) Error plot
    err = y_true - y_pred
    plt.figure(figsize=(10,4))
    plt.plot(err, label='Error')
    plt.title("Prediction Error")
    plt.xlabel("Samples")
    plt.ylabel("Error")
    plt.legend()
    plt.tight_layout()
    plt.show()
    
    # 3) Metrics
    mets = regression_metrics(y_true, y_pred)
    # Add MAPE
    mape = float(np.mean(np.abs((y_true - y_pred) / np.maximum(1e-8, y_true))) * 100.0)
    mets['mape'] = mape
    dfm = pd.DataFrame([mets])
    display(Markdown("### Model Metrics"))
    display(dfm)
    return mets

def _explain_models():
    model = STATE['model']
    X = STATE['X_test']
    features = STATE['features']
    # SHAP summary (auto-handles sequences via our helper)
    shap_explain(model, X, feature_names=features, title=f"SHAP Summary ‚Äî {STATE['model_name']}")
    
    # Try LIME on a single instance
    if _lime_ok:
        try:
            exp = lime_explain(model, X, feature_names=None, class_names=None, num_features=10, sample_index=0)
            display(Markdown("### LIME Explanation (sample index 0)"))
            display(exp.as_list())
        except Exception as e:
            display(Markdown(f"**LIME warning:** {str(e)}"))
    else:
        display(Markdown("**LIME not available.** Install with `pip install lime` to enable tabular explanations."))

def _final_summary(mets, top_pos, top_neg):
    name = STATE['model_name']
    lines = [
        f"‚úî **Model:** {name}",
        f"‚úî **RMSE:** {mets['rmse']:.4f}",
        f"‚úî **MAE:** {mets['mae']:.4f}",
        f"‚úî **MAPE:** {mets['mape']:.2f}%",
        f"‚úî **Top positive features (SHAP):** {', '.join(f'{k} ({v:+.3f})' for k,v in top_pos)}",
        f"‚úî **Top negative features (SHAP):** {', '.join(f'{k} ({v:+.3f})' for k,v in top_neg)}",
    ]
    display(Markdown("## üèÜ Final Output"))
    display(Markdown("<br>".join(lines)))

def _compute_shap_top_features():
    # Compute average shap values per feature (using our flattening in shap_explain).
    # We'll re-run a minimal explainer locally to fetch raw values
    import shap
    X = STATE['X_test']
    model = STATE['model']
    features = STATE['features']
    
    # Flatten to match our explain function behavior
    if X.ndim == 3:
        n,t,f = X.shape
        X_flat = X.reshape(n, t*f)
        feat_names = [f"{name}_t-{ti}" for ti in range(t-1, -1, -1) for name in features]
    else:
        X_flat = X
        feat_names = [f"Feature_{i}" for i in range(X.shape[1])]
    
    def model_wrap(x):
        if X.ndim == 3:
            x = x.reshape(-1, t, f)
        return model.predict(x)
    
    bg = X_flat[: min(100, len(X_flat))]
    explainer = shap.Explainer(model_wrap, bg)
    sh = explainer(X_flat[: min(200, len(X_flat))])
    # Mean SHAP over samples
    mean_vals = np.array(sh.values).mean(axis=0)
    # Top positive / negative
    order_pos = np.argsort(-mean_vals)[:5]
    order_neg = np.argsort(mean_vals)[:5]
    top_pos = [(feat_names[i], float(mean_vals[i])) for i in order_pos]
    top_neg = [(feat_names[i], float(mean_vals[i])) for i in order_neg]
    return top_pos, top_neg


In [None]:

# === Button Callbacks ===

def on_auto_load_clicked(b):
    with out:
        clear_output()
        try:
            df_raw = _auto_sample_csv_df()
            if df_raw is None:
                display(Markdown("**Sample CSV not found. Please upload a CSV.**"))
            else:
                display(Markdown("**Loaded sample COVID dataset.**"))
        except Exception as e:
            display(Markdown(f"**Error:** {e}"))

def on_run_clicked(b):
    with out:
        clear_output()
        try:
            _train_and_forecast()
            mets = _make_plots_and_metrics()
            STATE['mets'] = mets
            display(Markdown("> ‚úÖ Training & forecasting complete."))
        except Exception as e:
            display(Markdown(f"**Error:** {e}"))

def on_explain_clicked(b):
    with out:
        try:
            _explain_models()
            # Compute top +/- for final output
            top_pos, top_neg = _compute_shap_top_features()
            STATE['top_pos'] = top_pos
            STATE['top_neg'] = top_neg
            display(Markdown("> üîç Explainability complete."))
        except Exception as e:
            display(Markdown(f"**Explain Error:** {e}"))

auto_load_btn.on_click(on_auto_load_clicked)
run_btn.on_click(on_run_clicked)
explain_btn.on_click(on_explain_clicked)


In [None]:

# === 6) Final Output Cell ===
if 'mets' in STATE and 'top_pos' in STATE and 'top_neg' in STATE:
    _final_summary(STATE['mets'], STATE['top_pos'], STATE['top_neg'])
else:
    display(Markdown("Run **Forecast** and then **Explain** to populate the final summary here."))
