!pip install lightning pytorch-forecasting --quiet

# ==============================
# 1) Imports and seed
# ==============================
import os, glob, math
import numpy as np
import pandas as pd
import torch
from datetime import datetime, timedelta

import lightning.pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint

from pytorch_forecasting import TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.metrics import QuantileLoss

pl.seed_everything(42, workers=True)

# ==============================
# 2) Locate data and model files
# ==============================
base_dir = "/kaggle/input/iess-demand-2002-2025"
csv_candidates = []
for ext in ("*.csv", "*.CSV"):
    csv_candidates.extend(glob.glob(os.path.join(base_dir, "**", ext), recursive=True))
assert len(csv_candidates) > 0, "No CSV files found under /kaggle/input/iess-demand-2002-2025"
csv_path = max(csv_candidates, key=os.path.getsize)
print(f"Using historical file: {csv_path}")

model_input_slug = "results2"
model_dir = f"/kaggle/input/{model_input_slug}"
best_checkpoint = os.path.join(model_dir, "tft_best_model.ckpt")

assert os.path.exists(best_checkpoint), f"Checkpoint not found at {best_checkpoint}"
print(f"Using checkpoint: {best_checkpoint}")

# ==============================
# 3) Load & preprocess historical data
# ==============================
df = pd.read_csv(csv_path)
df.columns = [c.strip() for c in df.columns]

assert "Ontario Demand" in df.columns, "Expected 'Ontario Demand' column"
df = df.drop(columns=["Market Demand"], errors="ignore")

df["Hour"] = pd.to_numeric(df["Hour"], errors="coerce").astype("Int64")
df["Date"] = pd.to_datetime(df["Date"], errors="coerce")
df = df.dropna(subset=["Date", "Hour"]).copy()
df["Hour"] = df["Hour"].astype(int).clip(1, 24)

df["time"] = df["Date"] + pd.to_timedelta(df["Hour"] - 1, unit="h")
df["Ontario Demand"] = pd.to_numeric(df["Ontario Demand"], errors="coerce")

df = df.dropna(subset=["Ontario Demand", "time"]).sort_values("time").reset_index(drop=True)
df = df.drop_duplicates(subset=["time"], keep="first").sort_values("time")

full_range = pd.date_range(df["time"].min(), df["time"].max(), freq="h")
df = df.set_index("time").reindex(full_range).rename_axis("time").reset_index()

# Forward-fill only (match training preprocessing)
df["Ontario Demand"] = df["Ontario Demand"].astype(float).ffill()

df["series"] = "ON"
df["time_idx"] = ((df["time"] - df["time"].min()).dt.total_seconds() // 3600).astype(int)
df["hour"] = df["time"].dt.hour.astype("int16")
df["day_of_week"] = df["time"].dt.dayofweek.astype("int8")
df["month"] = df["time"].dt.month.astype("int8")

import warnings
warnings.filterwarnings('ignore', category=FutureWarning)
pd.set_option('display.float_format', lambda x: '%.3f' % x)

print("\nHistorical data loaded and preprocessed.")
print(f"Data shape: {df.shape}")
print(f"Last timestamp in database: {df['time'].max()}")
print(f"Last known demand: {df['Ontario Demand'].iloc[-1]:.1f} MW")
print(f"Ontario Demand range: {df['Ontario Demand'].min():.1f} to {df['Ontario Demand'].max():.1f} MW")

# ==============================
# 4) Checkpoint diagnostics
# ==============================
print("\n=== CHECKPOINT DIAGNOSTICS ===")
checkpoint = torch.load(best_checkpoint, map_location='cpu', weights_only=False)

if 'hyper_parameters' in checkpoint:
    hparams = checkpoint['hyper_parameters']
    print(f"Model was trained with:")
    print(f"  - max_encoder_length: {hparams.get('max_encoder_length', 'N/A')}")
    print(f"  - max_prediction_length: {hparams.get('max_prediction_length', 'N/A')}")
    
    if 'target_normalizer' in hparams:
        normalizer = hparams['target_normalizer']
        print(f"  - target_normalizer: {type(normalizer).__name__}")
        if hasattr(normalizer, 'transformation'):
            print(f"    transformation: {normalizer.transformation}")
        if hasattr(normalizer, 'center'):
            print(f"    center: {normalizer.center}")

# ==============================
# 5) Create reference training dataset - MATCH TRAINING EXACTLY
# ==============================
max_encoder_length = 168
max_prediction_length = 24

# CRITICAL: Use EXACT same normalizer as training (no softplus, no center=False)
training = TimeSeriesDataSet(
    df,
    time_idx="time_idx",
    target="Ontario Demand",
    group_ids=["series"],
    max_encoder_length=max_encoder_length,
    max_prediction_length=max_prediction_length,
    
    time_varying_known_reals=["time_idx", "hour", "day_of_week", "month"],
    time_varying_unknown_reals=["Ontario Demand"],
    
    # CRITICAL: Match training configuration exactly
    target_normalizer=GroupNormalizer(groups=["series"]),
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
    allow_missing_timesteps=True,
)

print(f"\nReference training dataset created with {len(training)} samples")

# ==============================
# 6) Prepare data for NEXT DAY prediction
# ==============================
last_time = df["time"].max()
last_time_idx = df["time_idx"].max()
last_demand = df["Ontario Demand"].iloc[-1]

# Create future timestamps
future_times = pd.date_range(
    start=last_time + pd.Timedelta(hours=1),
    periods=max_prediction_length, 
    freq="h"
)

print(f"\n=== PREDICTION SETUP ===")
print(f"Last database entry: {last_time} ({last_demand:.1f} MW)")
print(f"Prediction period: {future_times[0]} to {future_times[-1]}")
print(f"Prediction day: {future_times[0].strftime('%A, %B %d, %Y')}")

# Create prediction dataframe with FUTURE data
# Use last known demand as placeholder (better than 0)
future_rows = pd.DataFrame({
    "time": future_times,
    "series": "ON",
    "Ontario Demand": last_demand,  # Use realistic placeholder instead of 0
    "hour": future_times.hour.astype("int16"),
    "day_of_week": future_times.dayofweek.astype("int8"),
    "month": future_times.month.astype("int8"),
})

# Recalculate time_idx for future rows
future_rows["time_idx"] = ((future_rows["time"] - df["time"].min()).dt.total_seconds() // 3600).astype(int)

# Combine historical + future
prediction_df = pd.concat([df, future_rows], ignore_index=True)

# Verify encoder data
encoder_data = prediction_df[prediction_df["time_idx"] <= last_time_idx].tail(max_encoder_length)
print(f"\nEncoder data verification:")
print(f"  Length: {len(encoder_data)} hours")
print(f"  Time range: {encoder_data['time'].min()} to {encoder_data['time'].max()}")
print(f"  Demand range: {encoder_data['Ontario Demand'].min():.1f} to {encoder_data['Ontario Demand'].max():.1f} MW")
print(f"  Last encoder demand: {encoder_data['Ontario Demand'].iloc[-1]:.1f} MW")

# ==============================
# 7) Create prediction dataset using from_dataset
# ==============================
prediction_dataset = TimeSeriesDataSet.from_dataset(
    training,
    prediction_df,
    predict=True,
    stop_randomization=True
)

predict_loader = prediction_dataset.to_dataloader(
    train=False, 
    batch_size=1,
    num_workers=0
)

print(f"\nPrediction dataset created with {len(prediction_dataset)} samples")

# ==============================
# 8) Load model from checkpoint
# ==============================
print("\n=== LOADING MODEL ===")

try:
    tft = TemporalFusionTransformer.load_from_checkpoint(best_checkpoint)
    print("✓ Model loaded using load_from_checkpoint()")
except Exception as e:
    print(f"Primary loading failed: {e}")
    print("Attempting alternative method...")
    
    tft = TemporalFusionTransformer.from_dataset(
        training,
        learning_rate=1e-3,
        hidden_size=32,
        attention_head_size=4,
        dropout=0.1,
        hidden_continuous_size=16,
        loss=QuantileLoss(),
        optimizer="Adam",
    )
    
    checkpoint = torch.load(best_checkpoint, map_location='cpu', weights_only=False)
    tft.load_state_dict(checkpoint['state_dict'], strict=False)
    print("✓ Model loaded using manual state_dict")

tft.eval()
print("✓ Model set to evaluation mode")

# Verify model normalizer matches (access from dataset instead)
try:
    if hasattr(tft, 'dataset_parameters'):
        normalizer = tft.dataset_parameters.get('target_normalizer')
    elif hasattr(training, 'target_normalizer'):
        normalizer = training.target_normalizer
    else:
        normalizer = None
    
    if normalizer:
        print(f"\nDataset normalizer type: {type(normalizer).__name__}")
        if hasattr(normalizer, 'transformation'):
            print(f"  transformation: {normalizer.transformation}")
except Exception as e:
    print(f"\nNote: Could not inspect normalizer details: {e}")

# ==============================
# 9) Generate predictions
# ==============================
print("\n=== GENERATING PREDICTIONS ===")

with torch.no_grad():
    raw_predictions = tft.predict(
        predict_loader, 
        mode="prediction",
        return_x=False,
        return_index=False
    )

# Extract predictions
if isinstance(raw_predictions, torch.Tensor):
    predictions = raw_predictions.detach().cpu().numpy()
else:
    predictions = raw_predictions

# Handle different output shapes
if predictions.ndim == 3:
    predictions = predictions[0, :, 0]
elif predictions.ndim == 2:
    predictions = predictions[0, :]
else:
    predictions = predictions.flatten()

# Ensure we have exactly 24 predictions
predictions = predictions[:max_prediction_length]

print(f"✓ Predictions generated")
print(f"  Shape: {predictions.shape}")
print(f"  Range: {predictions.min():.1f} to {predictions.max():.1f} MW")
print(f"  Mean: {predictions.mean():.1f} MW")

# Sanity check: Compare to historical scale
historical_mean = df["Ontario Demand"].tail(168).mean()
print(f"\nScale validation:")
print(f"  Historical mean (last 7 days): {historical_mean:.1f} MW")
print(f"  Prediction mean: {predictions.mean():.1f} MW")
print(f"  Difference: {abs(predictions.mean() - historical_mean):.1f} MW")

if abs(predictions.mean() - historical_mean) > 5000:
    print("  ⚠️ WARNING: Predictions significantly different from historical scale!")
else:
    print("  ✓ Predictions within expected scale")

# ==============================
# 10) Create output dataframe
# ==============================
output_df = pd.DataFrame({
    "time": future_times[:len(predictions)],
    "predicted_ontario_demand": predictions,
    "horizon_hour_ahead": np.arange(1, len(predictions) + 1),
    "hour_of_day": future_times[:len(predictions)].hour,
    "day_of_week": future_times[:len(predictions)].dayofweek,
})

# Add day part labels
output_df['day_part'] = output_df['hour_of_day'].apply(
    lambda h: 'Night' if h < 6 or h >= 22 else 'Morning' if h < 12 else 'Afternoon' if h < 18 else 'Evening'
)

print("\n" + "="*80)
print("ONTARIO DEMAND FORECAST: NEXT 24 HOURS")
print("="*80)
print(f"Forecast Date: {future_times[0].strftime('%A, %B %d, %Y')}")
print(f"Last Known: {last_demand:.1f} MW at {last_time.strftime('%Y-%m-%d %H:%M')}")
print("="*80)
print("\nHour-by-Hour Predictions:")
print(output_df[['time', 'predicted_ontario_demand', 'hour_of_day', 'day_part']].to_string(
    index=False, 
    float_format='%.1f',
    formatters={'time': lambda x: x.strftime('%Y-%m-%d %H:%M')}
))

# ==============================
# 11) Statistical analysis
# ==============================
daytime_hours = output_df[output_df['hour_of_day'].isin(range(6, 20))]
night_hours = output_df[~output_df['hour_of_day'].isin(range(6, 20))]

print("\n" + "="*80)
print("STATISTICAL SUMMARY")
print("="*80)
print(f"Overall Statistics:")
print(f"  Mean:   {predictions.mean():.1f} MW")
print(f"  Min:    {predictions.min():.1f} MW (Hour {predictions.argmin()+1})")
print(f"  Max:    {predictions.max():.1f} MW (Hour {predictions.argmax()+1})")
print(f"  Std:    {predictions.std():.1f} MW")

print(f"\nDaytime (6am-8pm):")
print(f"  Average: {daytime_hours['predicted_ontario_demand'].mean():.1f} MW")
print(f"  Range:   {daytime_hours['predicted_ontario_demand'].min():.1f} - {daytime_hours['predicted_ontario_demand'].max():.1f} MW")

print(f"\nNighttime (8pm-6am):")
print(f"  Average: {night_hours['predicted_ontario_demand'].mean():.1f} MW") 
print(f"  Range:   {night_hours['predicted_ontario_demand'].min():.1f} - {night_hours['predicted_ontario_demand'].max():.1f} MW")

print(f"\nPeak Hours:")
peak_3 = output_df.nlargest(3, 'predicted_ontario_demand')[['time', 'predicted_ontario_demand', 'hour_of_day']]
print(peak_3.to_string(index=False, float_format='%.1f'))

print(f"\nLowest Hours:")
low_3 = output_df.nsmallest(3, 'predicted_ontario_demand')[['time', 'predicted_ontario_demand', 'hour_of_day']]
print(low_3.to_string(index=False, float_format='%.1f'))

# ==============================
# 12) Quality validation
# ==============================
print("\n" + "="*80)
print("QUALITY CHECKS")
print("="*80)

# Check for negative values
has_negative = (predictions < 0).any()
print(f"Negative values: {'✗ FOUND' if has_negative else '✓ None'}")
if has_negative:
    neg_hours = np.where(predictions < 0)[0] + 1
    print(f"  Negative at hours: {neg_hours.tolist()}")

# Check for unrealistic values
has_unrealistic = ((predictions < 5000) | (predictions > 30000)).any()
print(f"Unrealistic values (< 5000 or > 30000 MW): {'✗ FOUND' if has_unrealistic else '✓ None'}")

# Check daytime/nighttime patterns
daytime_ok = 15000 <= daytime_hours['predicted_ontario_demand'].mean() <= 28000
nighttime_ok = 8000 <= night_hours['predicted_ontario_demand'].mean() <= 18000

print(f"Daytime pattern (expect 15,000-28,000 MW): {'✓ Pass' if daytime_ok else '✗ Fail'}")
print(f"Nighttime pattern (expect 8,000-18,000 MW): {'✓ Pass' if nighttime_ok else '✗ Fail'}")

# Overall assessment
all_checks = not has_negative and not has_unrealistic and daytime_ok and nighttime_ok
print(f"\n{'✓ ALL CHECKS PASSED' if all_checks else '✗ SOME CHECKS FAILED'}")

# ==============================
# 13) Save outputs
# ==============================
output_df.to_csv("/kaggle/working/next_day_ontario_forecast.csv", index=False)
print(f"\n✓ Predictions saved to /kaggle/working/next_day_ontario_forecast.csv")

# ==============================
# 14) Visualization
# ==============================
import matplotlib.pyplot as plt

recent_history = df.tail(168)[["time", "Ontario Demand"]]

fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(16, 10))

# Top plot: Last 7 days + forecast
ax1.plot(recent_history["time"], recent_history["Ontario Demand"], 
         label="Historical Demand (Last 7 Days)", color='green', linewidth=2, alpha=0.8)
ax1.plot(output_df["time"], output_df["predicted_ontario_demand"], 
         marker='o', label=f"24h Forecast ({future_times[0].strftime('%b %d')})", 
         color='blue', linewidth=2.5, markersize=5)
ax1.axvline(x=last_time, color='red', linestyle='--', alpha=0.8, linewidth=2,
            label='Forecast Start')
ax1.set_xlabel("Time", fontsize=12)
ax1.set_ylabel("Ontario Demand (MW)", fontsize=12)
ax1.set_title("Ontario Energy Demand: Historical + Next Day Forecast (TFT Model)", fontsize=14, fontweight='bold')
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)
ax1.tick_params(axis='x', rotation=45)

# Bottom plot: Forecast only with hour labels
ax2.plot(output_df["hour_of_day"], output_df["predicted_ontario_demand"], 
         marker='o', color='blue', linewidth=2.5, markersize=6)
ax2.fill_between(output_df["hour_of_day"], output_df["predicted_ontario_demand"], 
                  alpha=0.3, color='blue')
ax2.set_xlabel("Hour of Day", fontsize=12)
ax2.set_ylabel("Predicted Demand (MW)", fontsize=12)
ax2.set_title(f"24-Hour Forecast by Hour ({future_times[0].strftime('%A, %B %d, %Y')})", 
              fontsize=14, fontweight='bold')
ax2.set_xticks(range(0, 24, 2))
ax2.grid(True, alpha=0.3)

# Add shading for day/night
ax2.axvspan(0, 6, alpha=0.1, color='gray', label='Night')
ax2.axvspan(22, 24, alpha=0.1, color='gray')
ax2.axvspan(6, 20, alpha=0.1, color='yellow', label='Daytime')
ax2.legend(fontsize=10)

plt.tight_layout()
plt.savefig("/kaggle/working/next_day_demand_forecast.png", dpi=300, bbox_inches='tight')
plt.show()

print(f"✓ Visualization saved to /kaggle/working/next_day_demand_forecast.png")

print("\n" + "="*80)
print("FORECAST COMPLETE")
print("="*80)
print(f"Forecasted {len(predictions)} hours starting {future_times[0].strftime('%Y-%m-%d %H:%M')}")
print(f"Output files saved to /kaggle/working/")