# 🌍 SSE Renewables — BDH Climate Risk Pipeline  (Final Version)

## Architecture
```
NASA POWER CSV (hourly, 10 years)
        │
        ▼
[CELL 3] Feature Engineering ──► 84 features (wind, solar, meteo, temporal)
        │
        ▼
[CELL 4] BDH Model Training ───► Trained on 8 years of hourly data
        │                         Saves checkpoint: /content/bdh_checkpoint.pt
        ▼
[CELL 5] RAG Vector Index ──────► SSE PDFs embedded → ChromaDB
        │
        ▼
[CELL 6] CORE PIPELINE LOOP
  Every HOUR:
    └─► BDH inference(window) ──► predicted + actual per feature
    └─► Accumulate: wind stats, power proxy, memory_norm, per-feature MAE
  Every MONTH (end of month):
    └─► Build monthly_summary   (wind stats from BDH actuals)
    └─► Build raw_predictions   (predicted vs actual, MAE, bias per feature)
    └─► Build financials        (derived from BDH wind output)
    └─► Package bdh_data_for_llm
    └─► ask_analyst(bdh_data=bdh_data_for_llm)
          ├─► RAG retrieval (SSE docs) ─────────────────────────────┐
          ├─► BDH context built from bdh_data (real values, no N/A) │
          └─► Groq LLM (llama-3.3-70b) ◄────────────────────────────┘
                └─► Structured 6-section monthly report
    └─► JSON record saved (key_numbers + llm_conclusion sections)
        │
        ▼
[CELL 7] REST API (optional)
  POST /ask              → query LLM analyst with any question
  GET  /monthly-reports  → download full 2-year JSON
  GET  /live-state       → current BDH state
        │
        ▼
[CELL 8] Interactive Analyst (CLI menu in Colab)
```

## Execution Order  ← DO NOT SKIP ANY CELL
```
Cell 1 → Install dependencies
Cell 2 → Global config & API key
Cell 3 → Feature engineering (NASA CSV → 84 features)
Cell 4 → Train BDH model (8 years, ~10 epochs)
Cell 5 → Upload SSE PDFs → build RAG index
Cell 6 → Run full 2-year pipeline → monthly JSON output  ← CORE
Cell 7 → (Optional) Launch REST API
Cell 8 → (Optional) Interactive CLI analyst
```

## Key Fixes in This Version
- ✅ BDH output (all 84 features) accumulated per-hour and passed directly into LLM
- ✅ Per-feature predicted vs actual means, MAE, bias sent to LLM context
- ✅ LLM grounding instruction: every number must trace back to BDH output
- ✅ JSON output: key_numbers + LLM narrative split into 6 structured sections
- ✅ ask_analyst() accepts bdh_data= directly (no more N/A from LIVE_STATE mismatch)


In [None]:
# ╔══════════════════════════════════════════════════════════════╗
# ║  CELL 1 — Install All Dependencies                          ║
# ║  Run this first. Runtime restart may be required after.     ║
# ╚══════════════════════════════════════════════════════════════╝

!pip install -q pathway
!pip install -q groq
!pip install -q -U langchain langchain-community langchain-text-splitters langchain-huggingface
!pip install -q chromadb
!pip install -q sentence-transformers
!pip install -q pypdf
!pip install -q windpowerlib pvlib

print("✅ All dependencies installed. Proceed to Cell 2.")

In [None]:
# ╔══════════════════════════════════════════════════════════════╗
# ║  CELL 2 — Global Config & Constants                         ║
# ║  All settings in one place. Edit here, not in later cells.  ║
# ╚══════════════════════════════════════════════════════════════╝

import os
import getpass
import torch

# ── API Key (secure — never hardcode the actual key) ──────────────────────────
if "GROQ_API_KEY" not in os.environ:
    os.environ["GROQ_API_KEY"] = getpass.getpass("Enter your GROQ_API_KEY: ")

# ── LLM settings ──────────────────────────────────────────────────────────────
LLM_MODEL  = "llama-3.3-70b-versatile"
MAX_TOKENS = 1024

# ── RAG settings ──────────────────────────────────────────────────────────────
EMBEDDING_MODEL  = "sentence-transformers/all-MiniLM-L6-v2"
CHUNK_SIZE       = 1000
CHUNK_OVERLAP    = 150
RETRIEVAL_TOP_K  = 5
REPORTS_DIR      = "/content/reports"
CHROMA_DIR       = "/content/chroma_db"

# ── BDH / Training settings ───────────────────────────────────────────────────
FEATURE_DIM = 84      # number of input features after engineering
SEQ_LEN     = 32      # sliding-window length (hours)
BATCH_SIZE  = 64      # A100-safe
EPOCHS      = 10
LR          = 3e-4
DEVICE      = "cuda" if torch.cuda.is_available() else "cpu"

# ── LLM call frequency: once per month (after all hourly BDH rows processed) ──
# No number needed — LLM fires at the END of each monthly batch automatically

# ── Fallback demo docs if no PDFs are uploaded ────────────────────────────────
from langchain_core.documents import Document
DEMO_DOCS = [
    Document(
        page_content="SSE targets net zero by 2050 with an 80% emissions reduction by 2030. "
                     "The company has committed to investing £18bn in low-carbon infrastructure "
                     "over the next 5 years, focused on wind, solar and electricity networks.",
        metadata={"source": "demo", "source_file": "demo_sse_strategy.txt",
                  "report_year": "2023", "page": 1}
    ),
    Document(
        page_content="Physical climate risks for SSE include increased storm frequency, "
                     "wind variability across UK and Irish assets, and rising sea levels "
                     "affecting coastal infrastructure. These are classified as high-likelihood "
                     "medium-impact risks under the TCFD framework.",
        metadata={"source": "demo", "source_file": "demo_sse_tcfd.txt",
                  "report_year": "2023", "page": 2}
    ),
    Document(
        page_content="SSE ESG metrics include carbon intensity (tCO2e/GWh), renewable capacity "
                     "(GW), total energy generated (TWh), and percentage of capital expenditure "
                     "aligned to EU Taxonomy. Current renewable capacity stands at 4.7 GW.",
        metadata={"source": "demo", "source_file": "demo_sse_esg.txt",
                  "report_year": "2023", "page": 3}
    ),
    Document(
        page_content="Transition risks for SSE include carbon pricing mechanisms, policy changes "
                     "to Contracts for Difference (CfD) strike prices, and changing grid "
                     "balancing requirements as renewable penetration increases in GB and Ireland.",
        metadata={"source": "demo", "source_file": "demo_sse_transition.txt",
                  "report_year": "2023", "page": 4}
    ),
]

print(f"✅ Config loaded")
print(f"   Device   : {DEVICE}")
print(f"   LLM Model: {LLM_MODEL}")
print(f"   Embeddings: {EMBEDDING_MODEL}")

In [None]:
# ╔══════════════════════════════════════════════════════════════╗
# ║  CELL 3 — Feature Engineering (NASA Data → 84 Features)     ║
# ║  Loads NASA POWER CSV, cleans it, builds all feature groups  ║
# ╚══════════════════════════════════════════════════════════════╝

import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")

# ── Load NASA POWER data ───────────────────────────────────────────────────────
# Replace path below with your actual file
df = pd.read_csv("/content/nasa_power_2015to2025_hourly_data_ENGLAND.csv")
print("Raw data shape:", df.shape)
print(df.head(3))

# Build datetime from YEAR, MO, DY, HR columns
df["datetime"] = pd.to_datetime(
    df[["YEAR","MO","DY","HR"]].rename(
        columns={"MO":"month","DY":"day","HR":"hour","YEAR":"year"}
    )
)

# Add spatial columns (replace with your actual site coordinates)
df["lat"]       = 51.5    # England site latitude
df["lon"]       = -1.8    # England site longitude
df["elevation"] = 75.0    # metres

df.drop(columns=["YEAR","MO","DY","HR"], inplace=True)

# ── 1. Data Cleaning ───────────────────────────────────────────────────────────
def clean_data(df):
    df = df.copy()
    df["datetime"] = pd.to_datetime(df["datetime"])
    df = df.sort_values("datetime").reset_index(drop=True)
    num_cols = df.select_dtypes(include=np.number).columns.tolist()
    missing = df[num_cols].isnull().sum()
    if missing.any():
        print("Missing values found — interpolating:\n", missing[missing > 0])
    df[num_cols] = df[num_cols].interpolate(method="linear", limit_direction="both")
    for col in ["T2M", "RH2M", "PS", "WS10M", "WS50M", "ALLSKY_SFC_SW_DWN"]:
        if col in df.columns:
            q1, q3 = df[col].quantile(0.01), df[col].quantile(0.99)
            iqr = q3 - q1
            df[col] = df[col].clip(q1 - 1.5*iqr, q3 + 1.5*iqr)
    df["RH2M"]              = df["RH2M"].clip(0, 100)
    df["CLOUD_AMT"]         = df["CLOUD_AMT"].clip(0, 100)
    df["ALLSKY_SFC_SW_DWN"] = df["ALLSKY_SFC_SW_DWN"].clip(0)
    df["WS10M"]             = df["WS10M"].clip(0)
    df["WS50M"]             = df["WS50M"].clip(0)
    df["WD10M"]             = df["WD10M"].clip(0, 360)
    df["WD50M"]             = df["WD50M"].clip(0, 360)
    return df

df = clean_data(df)
print("After cleaning:", df.shape)

# ── 2. Temporal Features ───────────────────────────────────────────────────────
def add_temporal_features(df):
    df = df.copy()
    dt = df["datetime"]
    df["hour"]        = dt.dt.hour
    df["day_of_year"] = dt.dt.dayofyear
    df["month"]       = dt.dt.month
    df["weekday"]     = dt.dt.weekday
    df["is_weekend"]  = (df["weekday"] >= 5).astype(int)
    df["hour_sin"]    = np.sin(2*np.pi*df["hour"]/24)
    df["hour_cos"]    = np.cos(2*np.pi*df["hour"]/24)
    df["doy_sin"]     = np.sin(2*np.pi*df["day_of_year"]/365)
    df["doy_cos"]     = np.cos(2*np.pi*df["day_of_year"]/365)
    df["month_sin"]   = np.sin(2*np.pi*df["month"]/12)
    df["month_cos"]   = np.cos(2*np.pi*df["month"]/12)
    def get_season(m):
        if m in [12,1,2]:  return "winter"
        if m in [3,4,5]:   return "spring"
        if m in [6,7,8]:   return "summer"
        return "autumn"
    seasons = df["month"].map(get_season)
    df = pd.get_dummies(df, columns=["month"], prefix="month", drop_first=False)
    for s in ["winter","spring","summer","autumn"]:
        df[f"season_{s}"] = (seasons == s).astype(int)
    for lag in [1,3,6,24]:
        df[f"T2M_lag{lag}h"]               = df["T2M"].shift(lag)
        df[f"WS10M_lag{lag}h"]             = df["WS10M"].shift(lag)
        df[f"ALLSKY_SFC_SW_DWN_lag{lag}h"] = df["ALLSKY_SFC_SW_DWN"].shift(lag)
    for win in [3,6,24]:
        df[f"T2M_roll_mean{win}h"]               = df["T2M"].rolling(win, min_periods=1).mean()
        df[f"T2M_roll_std{win}h"]                = df["T2M"].rolling(win, min_periods=1).std().fillna(0)
        df[f"WS10M_roll_mean{win}h"]             = df["WS10M"].rolling(win, min_periods=1).mean()
        df[f"WS10M_roll_std{win}h"]              = df["WS10M"].rolling(win, min_periods=1).std().fillna(0)
        df[f"ALLSKY_SFC_SW_DWN_roll_mean{win}h"] = df["ALLSKY_SFC_SW_DWN"].rolling(win, min_periods=1).mean()
    return df

df = add_temporal_features(df)
print("After temporal features:", df.shape)

# ── 3. Spatial Features ────────────────────────────────────────────────────────
def add_spatial_features(df):
    df = df.copy()
    df["lat_norm"] = (df["lat"] - df["lat"].mean()) / (df["lat"].std() + 1e-9)
    df["lon_norm"] = (df["lon"] - df["lon"].mean()) / (df["lon"].std() + 1e-9)
    lat_rad = np.radians(df["lat"])
    dec_rad = np.radians(23.45 * np.sin(2*np.pi*(df["day_of_year"]-81)/365))
    df["solar_noon_angle"] = np.degrees(
        np.arcsin(np.sin(lat_rad)*np.sin(dec_rad) + np.cos(lat_rad)*np.cos(dec_rad))
    )
    return df

df = add_spatial_features(df)
print("After spatial features:", df.shape)

# ── 4. Solar Features ──────────────────────────────────────────────────────────
def add_solar_features(df):
    df = df.copy()
    df["date"] = df["datetime"].dt.date
    daily_irr = df.groupby("date")["ALLSKY_SFC_SW_DWN"].sum().rename("daily_irr_sum_Wh")
    df = df.merge(daily_irr, on="date", how="left")
    df["ym"] = df["datetime"].dt.to_period("M")
    monthly_irr = df.groupby("ym")["ALLSKY_SFC_SW_DWN"].mean().rename("monthly_irr_mean")
    df = df.merge(monthly_irr, on="ym", how="left")
    df.drop(columns=["date","ym"], inplace=True)
    CLEAR_SKY_MAX = 1000.0
    df["clearness_index"]  = (df["ALLSKY_SFC_SW_DWN"] / CLEAR_SKY_MAX).clip(0,1)
    lat_rad = np.radians(df["lat"])
    dec_rad = np.radians(23.45 * np.sin(2*np.pi*(df["day_of_year"]-81)/365))
    cos_ha  = (-np.tan(lat_rad) * np.tan(dec_rad)).clip(-1,1)
    df["daylight_hours"]   = (2/15) * np.degrees(np.arccos(cos_ha))
    df["daylight_fraction"]= df["daylight_hours"] / 24
    df["T2M_sq"]           = df["T2M"]**2
    df["T2M_x_cloud"]      = df["T2M"] * df["CLOUD_AMT"]
    df["irr_x_clearness"]  = df["ALLSKY_SFC_SW_DWN"] * df["clearness_index"]
    tilt_rad = lat_rad
    df["tilt_adjusted_irr"] = df["ALLSKY_SFC_SW_DWN"] * (
        np.sin(np.radians(df["solar_noon_angle"]) + tilt_rad) /
        np.maximum(np.sin(np.radians(df["solar_noon_angle"])), 0.01)
    ).clip(0, 2)
    return df

df = add_solar_features(df)
print("After solar features:", df.shape)

# ── 5. Wind Features ───────────────────────────────────────────────────────────
def add_wind_features(df):
    df = df.copy()
    df["wind_power_10m"]        = 0.5 * df["RHOA"] * df["WS10M"]**3
    df["wind_power_50m"]        = 0.5 * df["RHOA"] * df["WS50M"]**3
    df["wind_shear_ratio"]      = df["WS50M"] / (df["WS10M"] + 1e-9)
    df["wind_shear_diff"]       = df["WS50M"] - df["WS10M"]
    df["WD10M_sin"]             = np.sin(np.radians(df["WD10M"]))
    df["WD10M_cos"]             = np.cos(np.radians(df["WD10M"]))
    df["WD50M_sin"]             = np.sin(np.radians(df["WD50M"]))
    df["WD50M_cos"]             = np.cos(np.radians(df["WD50M"]))
    df["WS10M_std_24h"]         = df["WS10M"].rolling(24, min_periods=1).std().fillna(0)
    df["WS10M_above6_frac24"]   = (df["WS10M"] > 6).rolling(24, min_periods=1).mean()
    T  = df["T2M"]
    RH = df["RH2M"].clip(1,100)
    df["dew_point"]             = T - ((100 - RH) / 5.0)
    df["air_density_T_corrected"] = df["RHOA"] * (273.15 / (273.15 + df["T2M"]))
    return df

df = add_wind_features(df)
print("After wind features:", df.shape)

# ── 6. Meteo Features ──────────────────────────────────────────────────────────
def add_meteo_features(df):
    df = df.copy()
    df["PS_trend_1h"]      = df["PS"].diff(1).fillna(0)
    df["PS_trend_3h"]      = df["PS"].diff(3).fillna(0)
    df["QV10M_roll_mean6h"] = df["QV10M"].rolling(6, min_periods=1).mean()
    return df

df = add_meteo_features(df)
print("After meteo features:", df.shape)

# ── 7. Scaling ─────────────────────────────────────────────────────────────────
# NOTE: month_sin / month_cos are the actual column names (not month_x)
EXCLUDE = ["lat","lon","elevation","is_weekend",
           "hour","weekday","day_of_year","month_sin","month_cos"]

def scale_features(df, exclude_cols=None):
    if exclude_cols is None:
        exclude_cols = []
    dt_cols  = df.select_dtypes(include=["datetime64","object","bool"]).columns.tolist()
    skip     = list(set(dt_cols + exclude_cols))
    num_cols = [c for c in df.select_dtypes(include=np.number).columns if c not in skip]
    scaler   = StandardScaler()
    df_scaled = df.copy()
    df_scaled[num_cols] = scaler.fit_transform(df[num_cols])
    return df_scaled, scaler, num_cols

df_scaled, scaler, scaled_cols = scale_features(df, exclude_cols=EXCLUDE)
print(f"Scaled {len(scaled_cols)} numeric columns.")

# ── 8. Prepare train/test splits for BDH ───────────────────────────────────────
# Drop non-numeric columns (datetime, object) before passing to BDH
feature_cols = df_scaled.select_dtypes(include=np.number).columns.tolist()
df_numeric   = df_scaled[feature_cols].fillna(0)

# Trim to exact FEATURE_DIM columns if needed
if len(feature_cols) > FEATURE_DIM:
    feature_cols = feature_cols[:FEATURE_DIM]
    df_numeric   = df_numeric[feature_cols]
    print(f"⚠️  Trimmed to first {FEATURE_DIM} features")
elif len(feature_cols) < FEATURE_DIM:
    # Pad with zeros if we have fewer features than expected
    pad_cols = [f"pad_{i}" for i in range(FEATURE_DIM - len(feature_cols))]
    for c in pad_cols:
        df_numeric[c] = 0.0
    feature_cols = list(df_numeric.columns)
    print(f"⚠️  Padded to {FEATURE_DIM} features with zeros")

# 8-year train / 2-year test split
split_idx  = int(len(df_numeric) * 0.8)
df_train   = df_numeric.iloc[:split_idx].reset_index(drop=True)
df_test    = df_numeric.iloc[split_idx:].reset_index(drop=True)

# Keep datetime index for the test set (used during streaming)
test_datetimes = df["datetime"].iloc[split_idx:].reset_index(drop=True)

print(f"\n✅ Feature engineering complete")
print(f"   Total features : {len(feature_cols)}")
print(f"   Train rows     : {len(df_train):,}")
print(f"   Test rows      : {len(df_test):,}")
print(f"   Feature names  : {feature_cols[:5]} ... (first 5)")

In [None]:
# ╔══════════════════════════════════════════════════════════════╗
# ║  CELL 4 — BDH Model Definition + Training on 8 Years        ║
# ║  MUST complete before Cell 5 or Cell 6.                     ║
# ╚══════════════════════════════════════════════════════════════╝

import math
import dataclasses
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import Dataset, DataLoader

# ── BDH Config ─────────────────────────────────────────────────────────────────
@dataclasses.dataclass
class BDHConfig:
    n_layer: int = 4
    n_embd:  int = 128
    dropout: float = 0.1
    n_head:  int = 4
    mlp_internal_dim_multiplier: int = 16
    vocab_size: int = 256

# ── Attention (RoPE-based) ─────────────────────────────────────────────────────
def get_freqs(n, theta, dtype):
    def quantize(t, q=2):
        return (t / q).floor() * q
    return (
        1.0 / (theta ** (quantize(torch.arange(0, n, 1, dtype=dtype)) / n))
        / (2 * math.pi)
    )

class Attention(nn.Module):
    def __init__(self, config):
        super().__init__()
        nh = config.n_head
        D  = config.n_embd
        N  = config.mlp_internal_dim_multiplier * D // nh
        self.freqs = torch.nn.Buffer(
            get_freqs(N, theta=2**16, dtype=torch.float32).view(1,1,1,N)
        )

    @staticmethod
    def phases_cos_sin(phases):
        phases = (phases % 1) * (2 * math.pi)
        return torch.cos(phases), torch.sin(phases)

    @staticmethod
    def rope(phases, v):
        v_rot = torch.stack((-v[...,1::2], v[...,::2]), dim=-1).view(*v.size())
        pc, ps = Attention.phases_cos_sin(phases)
        return (v*pc).to(v.dtype) + (v_rot*ps).to(v.dtype)

    def forward(self, Q, K, V):
        assert self.freqs.dtype == torch.float32
        assert K is Q
        _, _, T, _ = Q.size()
        r_phases = (
            torch.arange(0, T, device=self.freqs.device, dtype=self.freqs.dtype)
            .view(1,1,-1,1)
        ) * self.freqs
        QR     = self.rope(r_phases, Q)
        scores = (QR @ QR.mT).tril(diagonal=-1)
        return scores @ V

# ── BDH Model ──────────────────────────────────────────────────────────────────
class BDH(nn.Module):
    def __init__(self, config, input_dim, output_dim):
        super().__init__()
        self.config = config
        nh = config.n_head
        D  = config.n_embd
        N  = D * config.mlp_internal_dim_multiplier // nh
        self.input_proj = nn.Linear(input_dim, D)
        self.decoder    = nn.Parameter(torch.zeros(nh*N, D).normal_(std=0.02))
        self.encoder    = nn.Parameter(torch.zeros(nh, D, N).normal_(std=0.02))
        self.encoder_v  = nn.Parameter(torch.zeros(nh, D, N).normal_(std=0.02))
        self.attn = Attention(config)
        self.ln   = nn.LayerNorm(D, elementwise_affine=False, bias=False)
        self.drop = nn.Dropout(config.dropout)
        self.head = nn.Linear(D, output_dim)
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)

    def forward(self, x, targets=None):
        B, T, _ = x.size()
        D  = self.config.n_embd
        nh = self.config.n_head
        N  = D * self.config.mlp_internal_dim_multiplier // nh
        x = self.input_proj(x)
        x = self.ln(x).unsqueeze(1)
        for _ in range(self.config.n_layer):
            x_res    = x
            x_latent = x @ self.encoder
            x_sparse = F.relu(x_latent)
            yKV      = self.attn(Q=x_sparse, K=x_sparse, V=x)
            yKV      = self.ln(yKV)
            y_latent  = yKV @ self.encoder_v
            y_sparse  = F.relu(y_latent)
            xy_sparse = self.drop(x_sparse * y_sparse)
            yMLP = (
                xy_sparse.transpose(1,2).reshape(B,1,T,N*nh) @ self.decoder
            )
            x = self.ln(x_res + self.ln(yMLP))
        out    = x.view(B, T, D)
        logits = self.head(out)
        loss   = None
        if targets is not None:
            loss = F.mse_loss(logits, targets)
        return logits, loss

# ── Dataset ────────────────────────────────────────────────────────────────────
class TimeSeriesDataset(Dataset):
    def __init__(self, data, seq_len):
        self.data    = torch.tensor(data, dtype=torch.float32)
        self.seq_len = seq_len

    def __len__(self):
        return max(0, len(self.data) - self.seq_len)

    def __getitem__(self, idx):
        x = self.data[idx : idx + self.seq_len]
        y = self.data[idx + 1 : idx + self.seq_len + 1]
        return x, y

# ── Training function ──────────────────────────────────────────────────────────
def train_bdh(model, df_train, epochs=EPOCHS):
    data    = df_train.values.astype(np.float32)
    dataset = TimeSeriesDataset(data, SEQ_LEN)
    loader  = DataLoader(
        dataset, batch_size=BATCH_SIZE, shuffle=True,
        num_workers=2, pin_memory=True,
    )
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-2)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    model.train()
    for epoch in range(epochs):
        total_loss = 0.0
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            _, loss = model(x, targets=y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            total_loss += loss.item()
        scheduler.step()
        avg = total_loss / len(loader)
        print(f"  Epoch {epoch+1:02d}/{epochs} | Loss: {avg:.6f}")
    print("✅ Training complete.")
    return model

# ── Inference function ─────────────────────────────────────────────────────────
@torch.no_grad()
def run_inference(model, window):
    """
    window : (SEQ_LEN, FEATURE_DIM) numpy array
    Returns predictions (SEQ_LEN, output_dim) and memory_norm (float)
    """
    model.eval()
    B  = 1
    x  = torch.tensor(window, dtype=torch.float32).unsqueeze(0).to(DEVICE)
    T  = x.size(1)
    D  = model.config.n_embd
    nh = model.config.n_head
    N  = D * model.config.mlp_internal_dim_multiplier // nh
    h  = model.input_proj(x)
    h  = model.ln(h).unsqueeze(1)
    for _ in range(model.config.n_layer):
        h_res    = h
        x_latent = h @ model.encoder
        x_sparse = F.relu(x_latent)
        yKV      = model.attn(Q=x_sparse, K=x_sparse, V=h)
        yKV      = model.ln(yKV)
        y_sparse = F.relu(yKV @ model.encoder_v)
        xy       = model.drop(x_sparse * y_sparse)
        yMLP     = xy.transpose(1,2).reshape(B,1,T,N*nh) @ model.decoder
        h        = model.ln(h_res + model.ln(yMLP))
    out         = h.view(B, T, D)
    logits      = model.head(out)
    memory_norm = float(out[0,-1].norm().item())
    return logits.squeeze(0).cpu().numpy(), memory_norm

# ── Build & train the model ────────────────────────────────────────────────────
print(f"Building BDH model on {DEVICE}...")
bdh_config = BDHConfig()
bdh_model  = BDH(bdh_config, input_dim=FEATURE_DIM, output_dim=FEATURE_DIM).to(DEVICE)
print(f"BDH parameters: {sum(p.numel() for p in bdh_model.parameters()):,}")

print(f"\nTraining BDH on {len(df_train):,} rows ({EPOCHS} epochs)...")
bdh_model = train_bdh(bdh_model, df_train, epochs=EPOCHS)

# Save checkpoint
torch.save(bdh_model.state_dict(), "/content/bdh_checkpoint.pt")
print("\n✅ Checkpoint saved to /content/bdh_checkpoint.pt")
print("   Proceed to Cell 5 to build the RAG index.")

In [None]:
# ╔══════════════════════════════════════════════════════════════╗
# ║  CELL 5 — Upload SSE PDFs & Build RAG Vector Index          ║
# ║  MUST run after Cell 4 and before Cell 6.                   ║
# ╚══════════════════════════════════════════════════════════════╝

import os
import re
import time
import shutil
from pathlib import Path

from langchain_community.document_loaders import PyPDFLoader, TextLoader
from langchain_community.vectorstores import Chroma
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.documents import Document
from google.colab import files

# ── Step 1: Upload PDFs ────────────────────────────────────────────────────────
os.makedirs(REPORTS_DIR, exist_ok=True)

print("📂 Upload your SSE PDF reports now...")
print("   (annual reports, TCFD reports, ESG disclosures, investor presentations)")
print("   Press Cancel/Skip if you want to use built-in demo documents instead.\n")

try:
    uploaded = files.upload()
    for filename in uploaded.keys():
        src = f"/content/{filename}"
        dst = f"{REPORTS_DIR}/{filename}"
        if os.path.exists(src):
            shutil.move(src, dst)
    report_files = os.listdir(REPORTS_DIR)
    print(f"\n✅ {len(report_files)} file(s) in reports directory:")
    for f in report_files:
        size_mb = os.path.getsize(f"{REPORTS_DIR}/{f}") / 1024 / 1024
        print(f"   📄 {f}  ({size_mb:.1f} MB)")
except Exception:
    print("⚠️  Upload skipped — will use demo documents.")

# ── Step 2: Load documents ─────────────────────────────────────────────────────
def load_documents(reports_dir):
    docs      = []
    pdf_files = list(Path(reports_dir).glob("**/*.pdf"))
    csv_files = list(Path(reports_dir).glob("**/*.csv"))
    all_files = pdf_files + csv_files

    if not all_files:
        print("⚠️  No files found — loading DEMO documents.")
        return DEMO_DOCS

    print(f"📂 Found {len(all_files)} file(s) ({len(pdf_files)} PDFs, {len(csv_files)} CSVs)")
    for file_path in all_files:
        try:
            if file_path.suffix.lower() == ".pdf":
                loader = PyPDFLoader(str(file_path))
            else:
                loader = TextLoader(str(file_path), encoding="utf-8")
            file_docs = loader.load()
            for doc in file_docs:
                doc.metadata["source_file"] = file_path.name
                doc.metadata["file_type"]   = file_path.suffix.lower()
                match = re.search(r"(20\d{2})", file_path.name)
                doc.metadata["report_year"] = match.group(1) if match else "unknown"
            docs.extend(file_docs)
            print(f"   ✅ {file_path.name} — {len(file_docs)} pages")
        except Exception as e:
            print(f"   ⚠️  Could not load {file_path.name}: {e}")
    print(f"\n📄 Total pages loaded: {len(docs)}")
    return docs

# ── Step 3: Build vector store ─────────────────────────────────────────────────
def build_vectorstore(reports_dir=REPORTS_DIR, db_dir=CHROMA_DIR, force_rebuild=False):
    print(f"\n🔧 Loading embedding model: {EMBEDDING_MODEL}")
    print("   (First run downloads ~90MB — takes 1-2 minutes)")
    embeddings = HuggingFaceEmbeddings(
        model_name=EMBEDDING_MODEL,
        model_kwargs={"device": "cpu"},
        encode_kwargs={"normalize_embeddings": True},
    )
    print("   ✅ Embedding model ready")

    if os.path.exists(db_dir) and not force_rebuild:
        print(f"\n📦 Loading existing vector store from {db_dir}")
        vs    = Chroma(persist_directory=db_dir, embedding_function=embeddings)
        count = vs._collection.count()
        print(f"   ✅ Loaded {count} chunks")
        return vs

    docs     = load_documents(reports_dir)
    splitter = RecursiveCharacterTextSplitter(
        chunk_size=CHUNK_SIZE,
        chunk_overlap=CHUNK_OVERLAP,
        separators=["\n\n","\n",". "," ",""],
    )
    chunks = splitter.split_documents(docs)
    print(f"   ✅ Created {len(chunks)} chunks")

    print("\n🔢 Embedding and indexing...")
    start = time.time()
    vs = Chroma.from_documents(
        documents=chunks,
        embedding=embeddings,
        persist_directory=db_dir,
    )
    # ChromaDB >= 0.4 persists automatically — no vs.persist() needed
    elapsed = time.time() - start
    print(f"   ✅ Indexed {len(chunks)} chunks in {elapsed:.1f}s")
    print(f"   📦 Saved to {db_dir}")
    return vs

print("🚀 Building RAG vector index...")
vectorstore = build_vectorstore()
retriever   = vectorstore.as_retriever(
    search_type="mmr",
    search_kwargs={"k": RETRIEVAL_TOP_K, "fetch_k": 20},
)

print("\n✅ RAG index ready. Proceed to Cell 6 to start the live pipeline.")

In [None]:
# ╔══════════════════════════════════════════════════════════════╗
# ║  CELL 6 — LLM Config + Core Pipeline                        ║
# ║  Defines prompts, ask_analyst(), and the BDH stream loop.   ║
# ║  This is the CORE cell that connects everything together:   ║
# ║    BDH output → LIVE_STATE → ask_analyst() → Groq LLM       ║
# ║                                          ↗                  ║
# ║                              retriever (RAG / SSE docs)     ║
# ╚══════════════════════════════════════════════════════════════╝

import json
from groq import Groq

# ── Groq client ────────────────────────────────────────────────────────────────
groq_client = Groq(api_key=os.environ["GROQ_API_KEY"])

# ── System prompt — SSE-specific, financial-aware ────────────────────────────
SYSTEM_PROMPT = """You are a senior climate, energy, and financial risk analyst
embedded within SSE Renewables (UK & Ireland), one of Europe's leading
renewable energy companies.

COMPANY CONTEXT — SSE Renewables:
• Listed on London Stock Exchange (SSE.L), FTSE 100 component
• Owns and operates wind, hydro and solar assets across UK and Ireland
• Flagship assets: Dogger Bank (world's largest offshore wind farm, 3.6GW),
  Seagreen (1.075GW), Viking (443MW onshore), Gordonbush, Bhlaraidh
• Net zero target: 2050 (80% reduction by 2030 vs 2018 baseline)
• Committed £18bn capital investment in low-carbon over 5 years
• Revenue streams: Contracts for Difference (CfD), Renewable Obligation
  Certificates (ROC), merchant power, capacity market payments
• Turbine fleet: primarily Siemens Gamesa SG 14-236 DD (offshore),
  Enercon E-126 / Vestas V136 (onshore)
• Regulatory exposure: Ofgem, NESO, UK CCC climate targets

FINANCIAL PARAMETERS (use these for all calculations):
• Nominal turbine capacity    : 4.2 MW (E-126 class, onshore reference)
• Assumed fleet size          : 50 turbines (210 MW total installed capacity)
• CfD strike price            : £98/MWh (2023 AR5 reference)
• Merchant power price        : £85/MWh (UK day-ahead average reference)
• Annual O&M cost             : £120,000 per turbine (£6M fleet total)
• Availability factor         : 97% (industry standard onshore)
• Transmission loss factor    : 2%
• Carbon intensity avoided    : 0.233 tCO2e/MWh (UK grid average)

BDH MODEL RULES:
• Wind speed is extrapolated to 135m hub height via log-law
• Power output is from physical turbine modelling — do NOT recalculate
• memory_norm = BDH latent stability (higher = more predictable conditions)
• All BDH numerical values are authoritative

Your role IS to:
• Write a professional monthly report with general weather/wind summary
• Calculate and present key financial metrics using the parameters above
• Assess physical and transition climate risks aligned with TCFD
• Reference SSE corporate strategy and targets from the RAG documents
• Provide actionable recommendations specific to SSE operations

NEVER invent data, recalculate BDH physics, or speculate beyond provided values.
Always show your financial calculations step by step.
"""

# ── Task prompts (ALL 5 modes defined) ────────────────────────────────────────
TASK_PROMPTS = {
    "qa": """Answer the question clearly and concisely using the SSE report context
and live BDH wind data provided. Cite specific sources where possible.""",

    "risk_analysis": """You are writing the MONTHLY OPERATIONAL REPORT for SSE Renewables.
Structure your response EXACTLY as follows:

## 1. MONTHLY WEATHER & WIND SUMMARY
- Summarise the month's wind conditions in plain English
- Comment on seasonal norms for UK/Ireland and how this month compares
- Note any notable weather events (storms, calm periods, high variability)
- Interpret BDH memory_norm: what does the stability score mean for this month

## 2. ENERGY GENERATION ESTIMATE
Using the provided avg wind speed and the fleet parameters below, calculate:
- Estimated capacity factor (%) = (avg power output / rated capacity) × 100
- Estimated monthly energy (MWh) = capacity factor × total capacity (MW) × hours
- Show your calculation steps clearly
- Compare to SSE's typical annual capacity factor targets (~35-40% onshore)

## 3. FINANCIAL PERFORMANCE ESTIMATE
Using CfD strike price £98/MWh and fleet parameters, calculate:
- Estimated monthly revenue (£) = MWh generated × £98
- Estimated monthly O&M cost (£) = annual O&M / 12
- Estimated monthly gross profit (£) = revenue − O&M
- Lost revenue from low-wind hours = low-wind hours × fleet capacity × £98
- Carbon avoided (tCO2e) = MWh × 0.233
- Show all calculations step by step

## 4. PHYSICAL CLIMATE RISK ASSESSMENT (TCFD)
- Assess risks from this month's BDH data: low-wind exposure, storm risk
- Reference SSE's TCFD disclosures from the provided documents
- Risk rating: HIGH / MEDIUM / LOW with specific justification from data

## 5. SSE STRATEGIC ALIGNMENT
- How does this month's performance align with SSE's net zero 2050 target?
- Reference specific SSE commitments from the RAG documents
- Comment on CfD exposure vs merchant price risk

## 6. RECOMMENDED ACTIONS
- 3 specific, actionable recommendations for SSE operations team
- Prioritise by financial impact

Use professional financial reporting language throughout.
Show ALL numerical calculations explicitly.""",

    "recommendation": """Based on the BDH predictions and SSE report context,
provide 3-5 actionable recommendations. Prioritise by impact and feasibility.
Reference specific SSE strategic targets where relevant.""",

    "scenario": """Analyse the climate scenario implications using the BDH data
and SSE report context. Consider 1.5°C, 2°C, and 3°C+ warming pathways.
Focus on wind resource changes, operational risks, and portfolio resilience.""",

    "esg": """Provide a structured ESG/TCFD analysis using the SSE report context
and live BDH wind data:
1. **Governance** — oversight structures for climate risk
2. **Strategy** — climate risk integration into business strategy
3. **Risk Management** — identification and management processes
4. **Metrics & Targets** — KPIs, net zero commitments, progress""",
}

# ── Global Live State — updated by the BDH stream loop ────────────────────────
# This dict is the bridge between BDH outputs and the LLM
LIVE_STATE = {
    "hour":            0,
    "timestamp":       None,
    "features":        {},      # {feature_name: {predicted, actual, error}}
    "memory_norm":     None,    # BDH latent stability proxy
    "wind_metrics":    {},      # rolling wind generation summary
    "recent_errors":   [],      # last N prediction errors
}

# ── Core analyst function — receives BDH state + queries RAG ──────────────────
def ask_analyst(question, task="qa", temperature=0.3, chat_history=None,
                bdh_data=None):
    """
    The unified function that combines:
      1. bdh_data    — direct BDH monthly output (predictions, actuals, stats)
                       passed explicitly from stream_and_infer(); falls back to
                       LIVE_STATE when called interactively (Cell 8 / API).
      2. retriever   — RAG over SSE corporate documents
    And sends both to Groq LLM for analysis.
    """

    # 1️⃣ Retrieve relevant SSE report chunks via RAG
    docs          = retriever.invoke(question)
    context_parts = []
    sources       = []
    for i, doc in enumerate(docs, 1):
        meta = doc.metadata
        src  = meta.get("source_file", meta.get("source", "Unknown"))
        page = meta.get("page", "?")
        year = meta.get("report_year", "?")
        context_parts.append(
            f"[Excerpt {i} | {src} | Year: {year} | Page: {page}]\n"
            f"{doc.page_content.strip()}"
        )
        sources.append({"file": src, "year": year, "page": page,
                        "excerpt": doc.page_content[:200]})
    report_context = "\n\n".join(context_parts)

    # 2️⃣ Resolve BDH data source
    # Priority: explicit bdh_data arg (from stream loop) > LIVE_STATE (interactive)
    if bdh_data is not None:
        # Direct pass-in from stream_and_infer() — all fields guaranteed present
        ms  = bdh_data["monthly_summary"]    # full monthly stats dict
        fin = bdh_data["financials"]          # pre-computed financials dict
        raw_preds = bdh_data.get("raw_predictions", {})   # hourly BDH arrays
    else:
        # Fallback for interactive / API use after pipeline has run
        ms  = LIVE_STATE.get("wind_metrics", {})
        fin = {}
        raw_preds = {}

    # Fleet constants (always needed)
    TURBINE_CAPACITY_MW = 4.2
    FLEET_SIZE          = 50
    TOTAL_CAPACITY_MW   = TURBINE_CAPACITY_MW * FLEET_SIZE
    CfD_PRICE           = 98.0
    MERCHANT_PRICE      = 85.0
    OM_ANNUAL_PER_TURB  = 120000
    OM_MONTHLY          = (OM_ANNUAL_PER_TURB * FLEET_SIZE) / 12
    AVAILABILITY        = 0.97
    CARBON_FACTOR       = 0.233

    # Pull values — works whether ms came from bdh_data or LIVE_STATE
    hours      = int(ms.get("total_hours_processed", ms.get("total_hours_processed", 720)))
    avg_ws     = float(ms.get("wind_speed_avg_ms",   ms.get("avg_ws_24h", 0)) or 0)
    max_ws     = float(ms.get("wind_speed_max_ms",   0) or 0)
    min_ws     = float(ms.get("wind_speed_min_ms",   0) or 0)
    std_ws     = float(ms.get("wind_speed_std_ms",   0) or 0)
    wp_proxy   = float(ms.get("wind_power_proxy_avg", ms.get("avg_power_proxy_24h", 0)) or 0)
    hi_wind    = int(ms.get("high_wind_hours_gt12ms", ms.get("high_wind_hours_24h", 0)) or 0)
    lo_wind    = int(ms.get("low_wind_hours_lt4ms",   ms.get("low_wind_hours_24h", 0)) or 0)
    calm_frac  = float(ms.get("calm_fraction_pct",   0) or 0)
    mem_avg    = ms.get("memory_norm_avg",  LIVE_STATE.get("memory_norm", "N/A"))
    mem_std    = ms.get("memory_norm_std",  "N/A")
    pred_err   = ms.get("mean_bdh_prediction_error", ms.get("mean_pred_error", "N/A"))
    month_lbl  = ms.get("month", LIVE_STATE.get("timestamp", "N/A"))

    # Financials — use pre-computed if available, else re-derive
    rated_ws        = 12.0
    cf_approx       = min(0.45, 0.45 * (avg_ws / rated_ws) ** 3) * AVAILABILITY
    est_energy_mwh  = fin.get("est_energy_mwh",      round(cf_approx * TOTAL_CAPACITY_MW * hours, 1))
    est_revenue_cfd = fin.get("est_revenue_cfd_gbp",  round(est_energy_mwh * CfD_PRICE, 0))
    est_gross_profit= fin.get("est_gross_profit_gbp", round(est_revenue_cfd - OM_MONTHLY, 0))
    lost_revenue    = fin.get("lost_revenue_gbp",     round(lo_wind * TOTAL_CAPACITY_MW * CfD_PRICE, 0))
    carbon_avoided  = fin.get("carbon_avoided_tco2e", round(est_energy_mwh * CARBON_FACTOR, 1))

    # ── Raw BDH predictions block (actual vs predicted per feature) ────────────
    raw_pred_block = ""
    if raw_preds:
        # Top features most relevant to wind energy — show predicted vs actual
        key_features = [
            "WS50M", "WS10M", "wind_power_50m", "wind_power_10m",
            "T2M", "ALLSKY_SFC_SW_DWN", "RHOA", "RH2M", "PS",
            "wind_shear_ratio", "memory_norm_last_hour"
        ]
        lines = []
        for feat in key_features:
            if feat in raw_preds:
                p = raw_preds[feat]
                lines.append(
                    f"  {feat:<30} predicted={p['predicted_mean']:>8.4f}  "
                    f"actual={p['actual_mean']:>8.4f}  "
                    f"MAE={p['mae']:>8.4f}  "
                    f"bias={p['bias']:>+8.4f}"
                )
        if lines:
            raw_pred_block = (
                "\nBDH RAW PREDICTIONS vs ACTUALS (monthly means, key features):\n"
                + "\n".join(lines)
            )

    bdh_context = f"""
═══════════════════════════════════════════════════════════════
BDH PHYSICS MODEL OUTPUT — {month_lbl}
(All values are the BDH model's own outputs — treat as ground truth)
═══════════════════════════════════════════════════════════════

MODEL PERFORMANCE:
• Hours processed          : {hours}
• Mean prediction error    : {pred_err}   (lower = better BDH accuracy)
• BDH memory norm (avg)    : {mem_avg}    (higher = more stable/predictable regime)
• BDH memory norm (std)    : {mem_std}    (lower = consistent stability)

WIND RESOURCE (BDH actual outputs):
• Average wind speed       : {avg_ws:.3f} m/s
• Maximum wind speed       : {max_ws:.3f} m/s
• Minimum wind speed       : {min_ws:.3f} m/s
• Wind speed std deviation : {std_ws:.3f} m/s   (variability index)
• Wind power proxy (avg)   : {wp_proxy:.3f}       (proportional to wind³)
• High-wind hours >12 m/s  : {hi_wind}            (storm / high-output risk)
• Low-wind  hours  <4 m/s  : {lo_wind}            (cut-in threshold, near-zero output)
• Calm fraction            : {calm_frac:.1f} %     (proportion of month below cut-in)
{raw_pred_block}

═══════════════════════════════════════════════════════════════
FLEET PARAMETERS & PRE-COMPUTED FINANCIALS
(Derived from BDH wind output — use these exact values)
═══════════════════════════════════════════════════════════════
Fleet size           : {FLEET_SIZE} turbines × {TURBINE_CAPACITY_MW} MW = {TOTAL_CAPACITY_MW:.0f} MW total
CfD strike price     : £{CfD_PRICE}/MWh
Merchant power price : £{MERCHANT_PRICE}/MWh
Monthly O&M cost     : £{OM_MONTHLY:,.0f}
Availability factor  : {AVAILABILITY*100:.0f}%
Carbon intensity     : {CARBON_FACTOR} tCO2e/MWh

FINANCIAL OUTPUTS FROM BDH WIND DATA:
• Capacity factor (BDH)    : {cf_approx*100:.2f}%
• Est. energy generated    : {est_energy_mwh:,.1f} MWh
• Est. CfD revenue         : £{est_revenue_cfd:,.0f}
• Monthly O&M cost         : £{OM_MONTHLY:,.0f}
• Est. gross profit        : £{est_gross_profit:,.0f}
• Lost revenue (low-wind)  : £{lost_revenue:,.0f}   ({lo_wind} hrs × {TOTAL_CAPACITY_MW:.0f} MW × £{CfD_PRICE})
• Carbon avoided           : {carbon_avoided:,.1f} tCO2e

INSTRUCTION: Your entire analysis MUST be grounded in the BDH numbers above.
Do not assume or invent wind speeds, energy figures, or financial values.
Every number in your report must trace back to BDH output or the fleet parameters.
"""

    # 3️⃣ Build message list
    task_instruction = TASK_PROMPTS.get(task, TASK_PROMPTS["qa"])
    messages = [{"role": "system", "content": SYSTEM_PROMPT}]
    if chat_history:
        messages.extend(chat_history)
    messages.append({
        "role": "user",
        "content": f"""
{task_instruction}

{'='*60}
SSE CORPORATE REPORT CONTEXT (RAG — retrieved documents)
{'='*60}
{report_context}

{'='*60}
LIVE BDH MODEL OUTPUT (continuous physics-based stream)
{'='*60}
{bdh_context}

QUESTION: {question}
"""
    })

    # 4️⃣ Call Groq with streaming
    stream        = groq_client.chat.completions.create(
        model=LLM_MODEL,
        messages=messages,
        max_tokens=MAX_TOKENS,
        temperature=temperature,
        stream=True,
    )
    full_response = ""
    for chunk in stream:
        delta = chunk.choices[0].delta.content
        if delta:
            full_response += delta

    return full_response, sources

# ── Pathway streaming simulation — monthly batches, LLM once per month ─────────
def stream_and_infer(model, df_test, feature_names):
    """
    Pathway streaming simulation:
    - BDH inference runs EVERY HOUR (hourly predictions + LIVE_STATE update)
    - All hourly BDH outputs are accumulated across the month
    - At the END of each month: LLM is called ONCE with the full monthly
      summary (BDH stats) + RAG (SSE docs) → produces one risk analysis per month
    """
    all_data    = df_test.values.astype(np.float32)
    n_total     = len(all_data)
    global_hour = 0

    # Build DatetimeIndex for monthly resampling
    dt_index   = pd.date_range(start=test_datetimes.iloc[0],
                               periods=len(df_test), freq="h")
    df_indexed = df_test.copy()
    df_indexed.index = dt_index

    # Feature index helpers
    ws_idx = feature_names.index("WS50M")         if "WS50M"         in feature_names else 0
    wp_idx = feature_names.index("wind_power_50m") if "wind_power_50m" in feature_names else 0

    print("\n" + "="*65)
    print("Pathway streaming — hourly BDH inference, monthly LLM call")
    print("="*65)

    for month_end, month_df in df_indexed.resample("ME"):
        if month_df.empty:
            continue

        month_label     = month_end.strftime("%Y-%m")
        n_hours         = len(month_df)
        month_start_pos = df_indexed.index.get_indexer(
            [month_df.index[0]], method="nearest"
        )[0]

        print(f"\n[STREAM] Month: {month_label}  ({n_hours} hourly BDH rows)")

        # ── Accumulators for this month ────────────────────────────────────────
        month_ws_actuals   = []   # actual wind speeds each hour
        month_ws_preds     = []   # predicted wind speeds each hour
        month_wp_actuals   = []   # actual wind power proxy each hour
        month_memory_norms = []   # BDH memory_norm each hour
        month_errors       = []   # prediction errors each hour
        last_features      = {}   # features dict from the last hour of the month
        # Per-feature accumulation for BDH predicted vs actual arrays
        feat_pred_sums     = {}   # sum of predicted values per feature
        feat_act_sums      = {}   # sum of actual values per feature
        feat_abs_err_sums  = {}   # sum of |pred - actual| per feature
        feat_bias_sums     = {}   # sum of (pred - actual) per feature (signed)
        feat_counts        = {}   # number of valid hours per feature

        # ── Hourly BDH loop (NO LLM call here) ────────────────────────────────
        for local_hour in range(n_hours):
            abs_pos = month_start_pos + local_hour
            if abs_pos >= n_total:
                break

            # Build SEQ_LEN sliding window
            win_start = max(0, abs_pos - SEQ_LEN + 1)
            window    = all_data[win_start : abs_pos + 1]
            if len(window) < SEQ_LEN:
                pad    = np.zeros((SEQ_LEN - len(window), window.shape[1]),
                                  dtype=np.float32)
                window = np.vstack([pad, window])
            window = window[-SEQ_LEN:]

            # BDH inference
            predictions, memory_norm = run_inference(model, window)
            pred_last   = predictions[-1]
            actual_last = all_data[abs_pos]

            # Accumulate monthly stats
            month_ws_actuals.append(float(actual_last[ws_idx]))
            month_ws_preds.append(float(pred_last[ws_idx]))
            month_wp_actuals.append(float(actual_last[wp_idx]))
            month_memory_norms.append(memory_norm)
            month_errors.append(abs(float(pred_last[ws_idx] - actual_last[ws_idx])))

            # Build per-feature dict (kept from last hour for LIVE_STATE)
            last_features = {
                name: {
                    "predicted": round(float(pred_last[i]),   4),
                    "actual":    round(float(actual_last[i]), 4),
                    "error":     round(float(pred_last[i] - actual_last[i]), 4),
                }
                for i, name in enumerate(feature_names)
            }

            # Accumulate predicted vs actual per feature across the month
            for i, name in enumerate(feature_names):
                p_val = float(pred_last[i])
                a_val = float(actual_last[i])
                feat_pred_sums[name]    = feat_pred_sums.get(name, 0.0)    + p_val
                feat_act_sums[name]     = feat_act_sums.get(name, 0.0)     + a_val
                feat_abs_err_sums[name] = feat_abs_err_sums.get(name, 0.0) + abs(p_val - a_val)
                feat_bias_sums[name]    = feat_bias_sums.get(name, 0.0)    + (p_val - a_val)
                feat_counts[name]       = feat_counts.get(name, 0)         + 1

            # Update LIVE_STATE every hour so UI stays current
            LIVE_STATE.update({
                "hour":        global_hour,
                "timestamp":   str(month_df.index[local_hour]),
                "features":    last_features,
                "memory_norm": round(memory_norm, 4),
                "recent_errors": month_errors[-24:],
                "wind_metrics": {
                    "avg_ws_24h":          round(float(np.mean(month_ws_actuals[-24:])), 4),
                    "avg_power_proxy_24h": round(float(np.mean(month_wp_actuals[-24:])), 4),
                    "high_wind_hours_24h": int(sum(w > 12 for w in month_ws_actuals[-24:])),
                    "low_wind_hours_24h":  int(sum(w <  4 for w in month_ws_actuals[-24:])),
                    "mean_pred_error":     round(float(np.mean(month_errors)), 6),
                },
            })

            global_hour += 1

        # ── END OF MONTH — build full monthly summary and call LLM ONCE ────────
        ws_arr = np.array(month_ws_actuals)
        wp_arr = np.array(month_wp_actuals)
        mn_arr = np.array(month_memory_norms)
        er_arr = np.array(month_errors)

        monthly_summary = {
            "month":                   month_label,
            "total_hours_processed":   len(month_ws_actuals),
            "wind_speed_avg_ms":       round(float(ws_arr.mean()), 3),
            "wind_speed_max_ms":       round(float(ws_arr.max()),  3),
            "wind_speed_min_ms":       round(float(ws_arr.min()),  3),
            "wind_speed_std_ms":       round(float(ws_arr.std()),  3),
            "wind_power_proxy_avg":    round(float(wp_arr.mean()), 3),
            "high_wind_hours_gt12ms":  int((ws_arr > 12).sum()),
            "low_wind_hours_lt4ms":    int((ws_arr <  4).sum()),
            "calm_fraction_pct":       round(float((ws_arr < 4).mean() * 100), 1),
            "memory_norm_avg":         round(float(mn_arr.mean()), 4),
            "memory_norm_std":         round(float(mn_arr.std()),  4),
            "mean_bdh_prediction_error": round(float(er_arr.mean()), 6),
            "last_hour_features_sample": {
                k: v for k, v in list(last_features.items())[:10]
            },
        }

        # ── Build raw BDH predictions summary (monthly means per feature) ────────
        raw_predictions = {}
        for name in feature_names:
            n = feat_counts.get(name, 0)
            if n > 0:
                raw_predictions[name] = {
                    "predicted_mean": round(feat_pred_sums[name] / n, 6),
                    "actual_mean":    round(feat_act_sums[name]  / n, 6),
                    "mae":            round(feat_abs_err_sums[name] / n, 6),
                    "bias":           round(feat_bias_sums[name]    / n, 6),
                    "n_hours":        n,
                }
        # Also inject memory_norm as a pseudo-feature for the LLM
        raw_predictions["memory_norm_last_hour"] = {
            "predicted_mean": round(float(mn_arr.mean()), 6),
            "actual_mean":    round(float(mn_arr.mean()), 6),
            "mae":            round(float(mn_arr.std()),  6),
            "bias":           0.0,
            "n_hours":        len(mn_arr),
        }

        # Update LIVE_STATE with full monthly summary before LLM call
        LIVE_STATE["wind_metrics"] = monthly_summary

        # ── Pretty progress bar for the month ────────────────────────────────
        calm_bar_filled = int(monthly_summary['calm_fraction_pct'] / 5)
        calm_bar = '█' * calm_bar_filled + '░' * (20 - calm_bar_filled)

        print(f"\n  ╔{'═'*63}╗")
        print(f"  ║  📅  BDH MONTHLY SUMMARY — {month_label:<34}║")
        print(f"  ╠{'═'*63}╣")
        print(f"  ║  ⏱  Hours processed     : {len(month_ws_actuals):<36}║")
        print(f"  ║  💨  Avg wind speed      : {monthly_summary['wind_speed_avg_ms']:<33} m/s║")
        print(f"  ║  📈  Max wind speed      : {monthly_summary['wind_speed_max_ms']:<33} m/s║")
        print(f"  ║  📉  Min wind speed      : {monthly_summary['wind_speed_min_ms']:<33} m/s║")
        print(f"  ║  〰  Wind variability    : {monthly_summary['wind_speed_std_ms']:<33} m/s║")
        print(f"  ║  ⚡  Wind power proxy    : {monthly_summary['wind_power_proxy_avg']:<36}║")
        print(f"  ║  🔴  High-wind hrs >12ms : {monthly_summary['high_wind_hours_gt12ms']:<36}║")
        print(f"  ║  🔵  Low-wind  hrs <4ms  : {monthly_summary['low_wind_hours_lt4ms']:<36}║")
        print(f"  ║  😶  Calm fraction       : {monthly_summary['calm_fraction_pct']:<33} %  ║")
        print(f"  ║     [{calm_bar}] {monthly_summary['calm_fraction_pct']}%{'':>10}║")
        print(f"  ║  🧠  BDH memory norm     : {monthly_summary['memory_norm_avg']:<36}║")
        print(f"  ║  🎯  Mean predict error  : {monthly_summary['mean_bdh_prediction_error']:<36}║")
        print(f"  ╚{'═'*63}╝")
        print(f"\n  🤖 Sending to LLM (Groq · {LLM_MODEL})...\n")

        # ── Single LLM call for the whole month ───────────────────────────────
        # Build the bdh_data package — everything the LLM needs, directly from BDH
        bdh_data_for_llm = {
            "monthly_summary":  monthly_summary,   # full wind stats from BDH actuals
            "raw_predictions":  raw_predictions,   # per-feature predicted vs actual means
            "financials": {                         # pre-computed from BDH wind output
                "fleet_capacity_mw":      TOTAL_CAPACITY_MW,
                "capacity_factor_pct":    round(cf_approx * 100, 2),
                "est_energy_mwh":         est_energy_mwh,
                "est_revenue_cfd_gbp":    est_revenue_cfd,
                "monthly_om_cost_gbp":    OM_MONTHLY,
                "est_gross_profit_gbp":   est_gross_profit,
                "lost_revenue_gbp":       lost_revenue,
                "carbon_avoided_tco2e":   carbon_avoided,
                "cfd_strike_price_gbp":   CfD_PRICE,
            },
        }

        monthly_question = (
            f"Provide a full climate and operational risk analysis for {month_label}. "
            f"Ground every finding in the BDH model output provided — "
            f"wind speeds, predictions, memory norm, and financial figures "
            f"must all come from the BDH data. "
            f"Supplement with SSE corporate strategy from the RAG documents."
        )
        answer, sources = ask_analyst(
            question    = monthly_question,
            task        = "risk_analysis",
            temperature = 0.2,
            bdh_data    = bdh_data_for_llm,   # direct BDH output → LLM
        )

        # ── Save result to JSON record ─────────────────────────────────────────
        TURBINE_CAPACITY_MW = 4.2
        FLEET_SIZE          = 50
        TOTAL_CAPACITY_MW   = TURBINE_CAPACITY_MW * FLEET_SIZE
        CfD_PRICE           = 98.0
        OM_MONTHLY          = (120000 * FLEET_SIZE) / 12
        AVAILABILITY        = 0.97
        CARBON_FACTOR       = 0.233
        hours_proc          = int(monthly_summary['total_hours_processed'])
        avg_ws              = float(monthly_summary['wind_speed_avg_ms'])
        rated_ws            = 12.0
        cf_approx           = min(0.45, 0.45*(avg_ws/rated_ws)**3) * AVAILABILITY
        est_energy_mwh      = round(cf_approx * TOTAL_CAPACITY_MW * hours_proc, 1)
        est_revenue_cfd     = round(est_energy_mwh * CfD_PRICE, 0)
        est_gross_profit    = round(est_revenue_cfd - OM_MONTHLY, 0)
        low_wind_hrs        = int(monthly_summary['low_wind_hours_lt4ms'])
        lost_revenue        = round(low_wind_hrs * TOTAL_CAPACITY_MW * CfD_PRICE, 0)
        carbon_avoided      = round(est_energy_mwh * CARBON_FACTOR, 1)

        # ── Parse LLM answer into structured sections ─────────────────────────
        def parse_llm_sections(text):
            """
            Splits the LLM risk_analysis response into its 6 named sections.
            Returns a dict with section names as keys and the text content as values.
            """
            import re
            section_map = {
                "weather_wind_summary":       r"##\s*1\.",
                "energy_generation":          r"##\s*2\.",
                "financial_performance":      r"##\s*3\.",
                "climate_risk_tcfd":          r"##\s*4\.",
                "strategic_alignment":        r"##\s*5\.",
                "recommended_actions":        r"##\s*6\.",
            }
            keys   = list(section_map.keys())
            pats   = list(section_map.values())
            splits = [re.search(p, text) for p in pats]

            sections = {}
            for i, key in enumerate(keys):
                if splits[i] is None:
                    sections[key] = ""
                    continue
                start = splits[i].start()
                # End = start of next section, or end of string
                end = splits[i+1].start() if i+1 < len(keys) and splits[i+1] else len(text)
                # Strip the heading line itself, keep the body
                body = text[start:end]
                body = re.sub(r"^##\s*\d+\.\s*[^
]*
", "", body, count=1)
                sections[key] = body.strip()
            return sections

        def extract_key_numbers(answer_text, fin):
            """
            Pull the most important numbers from the LLM's own text
            (to verify it used the BDH data correctly) plus the pre-computed
            financials so the JSON is self-contained.
            """
            import re

            def find_first_number(pattern, text, default=None):
                m = re.search(pattern, text, re.IGNORECASE)
                if m:
                    try:
                        return float(m.group(1).replace(",", ""))
                    except Exception:
                        return default
                return default

            return {
                # From pre-computed financials (authoritative)
                "capacity_factor_pct":       round(fin["capacity_factor_pct"], 2),
                "est_energy_mwh":            fin["est_energy_mwh"],
                "est_revenue_cfd_gbp":       fin["est_revenue_cfd_gbp"],
                "est_gross_profit_gbp":      fin["est_gross_profit_gbp"],
                "monthly_om_cost_gbp":       fin["monthly_om_cost_gbp"],
                "lost_revenue_low_wind_gbp": fin["lost_revenue_gbp"],
                "carbon_avoided_tco2e":      fin["carbon_avoided_tco2e"],
                # Extracted from LLM text (what the LLM actually stated)
                "llm_stated_capacity_factor_pct": find_first_number(
                    r"capacity factor[^\d]*?([\d]+\.?\d*)\s*%", answer_text),
                "llm_stated_energy_mwh": find_first_number(
                    r"([\d,]+)\s*MWh", answer_text),
                "llm_stated_revenue_gbp": find_first_number(
                    r"£\s*([\d,]+).*?revenue", answer_text),
                "llm_stated_risk_rating": (
                    re.search(r"Risk rating[:\s]*(HIGH|MEDIUM|LOW)", answer_text, re.IGNORECASE)
                    or re.search(r"(HIGH|MEDIUM|LOW).*?risk", answer_text, re.IGNORECASE)
                ),
            }

        llm_sections   = parse_llm_sections(answer)
        fin_snapshot   = {
            "fleet_capacity_mw":      TOTAL_CAPACITY_MW,
            "capacity_factor_pct":    round(cf_approx * 100, 2),
            "est_energy_mwh":         est_energy_mwh,
            "est_revenue_cfd_gbp":    est_revenue_cfd,
            "monthly_om_cost_gbp":    OM_MONTHLY,
            "est_gross_profit_gbp":   est_gross_profit,
            "lost_revenue_gbp":       lost_revenue,
            "carbon_avoided_tco2e":   carbon_avoided,
            "cfd_strike_price_gbp":   CfD_PRICE,
        }
        key_numbers    = extract_key_numbers(answer, fin_snapshot)

        # Risk rating — pull from LLM text
        import re as _re
        _risk_match = (
            _re.search(r"Risk rating[:\s]*(HIGH|MEDIUM|LOW)", answer, _re.IGNORECASE) or
            _re.search(r"\*\*(HIGH|MEDIUM|LOW)\*\*", answer, _re.IGNORECASE)
        )
        risk_rating = _risk_match.group(1).upper() if _risk_match else "UNKNOWN"

        monthly_record = {
            # ── Identity ────────────────────────────────────────────────────────
            "month":     month_label,
            "timestamp": str(LIVE_STATE.get("timestamp", "N/A")),
            "bdh_hour":  global_hour,

            # ── Key numbers at a glance ─────────────────────────────────────────
            # Most important figures from both the model and the LLM — designed
            # so you can read the JSON without opening the full analysis text.
            "key_numbers": {
                "wind_speed_avg_ms":         monthly_summary["wind_speed_avg_ms"],
                "wind_speed_max_ms":         monthly_summary["wind_speed_max_ms"],
                "calm_fraction_pct":         monthly_summary["calm_fraction_pct"],
                "high_wind_hours_gt12ms":    monthly_summary["high_wind_hours_gt12ms"],
                "low_wind_hours_lt4ms":      monthly_summary["low_wind_hours_lt4ms"],
                "bdh_memory_norm_avg":       monthly_summary["memory_norm_avg"],
                "capacity_factor_pct":       round(cf_approx * 100, 2),
                "est_energy_mwh":            est_energy_mwh,
                "est_revenue_cfd_gbp":       est_revenue_cfd,
                "est_gross_profit_gbp":      est_gross_profit,
                "lost_revenue_low_wind_gbp": lost_revenue,
                "carbon_avoided_tco2e":      carbon_avoided,
                "overall_risk_rating":       risk_rating,
            },

            # ── LLM conclusion & structured analysis ────────────────────────────
            # Each section is the LLM's own explanation, not just numbers.
            "llm_conclusion": {
                # One-sentence headline drawn from the LLM's weather summary
                "headline": (
                    llm_sections.get("weather_wind_summary", "")
                    .split("\n")[0].strip(" -•*") or "See full analysis below."
                ),
                # The LLM's overall risk verdict
                "overall_risk_rating": risk_rating,

                # Full narrative sections — each is the LLM's own explanation
                "sections": {
                    "1_weather_and_wind_summary":  llm_sections.get("weather_wind_summary", ""),
                    "2_energy_generation":         llm_sections.get("energy_generation", ""),
                    "3_financial_performance":     llm_sections.get("financial_performance", ""),
                    "4_climate_risk_tcfd":         llm_sections.get("climate_risk_tcfd", ""),
                    "5_strategic_alignment":       llm_sections.get("strategic_alignment", ""),
                    "6_recommended_actions":       llm_sections.get("recommended_actions", ""),
                },

                # The complete unmodified LLM response (fallback / audit trail)
                "full_analysis_text": answer,
            },

            # ── BDH physics summary ─────────────────────────────────────────────
            "bdh_summary": {
                "hours_processed":        hours_proc,
                "wind_speed_avg_ms":      monthly_summary["wind_speed_avg_ms"],
                "wind_speed_max_ms":      monthly_summary["wind_speed_max_ms"],
                "wind_speed_min_ms":      monthly_summary["wind_speed_min_ms"],
                "wind_speed_std_ms":      monthly_summary["wind_speed_std_ms"],
                "wind_power_proxy_avg":   monthly_summary["wind_power_proxy_avg"],
                "high_wind_hours_gt12ms": monthly_summary["high_wind_hours_gt12ms"],
                "low_wind_hours_lt4ms":   monthly_summary["low_wind_hours_lt4ms"],
                "calm_fraction_pct":      monthly_summary["calm_fraction_pct"],
                "memory_norm_avg":        monthly_summary["memory_norm_avg"],
                "memory_norm_std":        monthly_summary["memory_norm_std"],
                "mean_prediction_error":  monthly_summary["mean_bdh_prediction_error"],
            },

            # ── Financial snapshot ──────────────────────────────────────────────
            "financials": fin_snapshot,

            # ── RAG sources used ────────────────────────────────────────────────
            "sources": [
                {"file": s["file"], "year": s["year"], "page": s["page"]}
                for s in sources
            ],
        }
        all_monthly_records.append(monthly_record)

        # Save running JSON after every month (so nothing is lost mid-run)
        with open("/content/sse_monthly_analysis.json", "w") as jf:
            json.dump(all_monthly_records, jf, indent=2)

        # ── Print formatted output ─────────────────────────────────────────────
        print(f"  ╔{'═'*63}╗")
        print(f"  ║  🌍  MONTHLY RISK ANALYSIS — {month_label:<33}║")
        print(f"  ║  ⚠️   Mode: Risk Analysis · Model: {LLM_MODEL:<27}║")
        print(f"  ╠{'═'*63}╣")
        print(f"  ║  💰  FINANCIAL SNAPSHOT{'':>39}║")
        print(f"  ╠{'═'*63}╣")
        print(f"  ║  Capacity factor      : {cf_approx*100:>6.1f} %{'':>34}║")
        print(f"  ║  Est. energy generated: {est_energy_mwh:>10,.0f} MWh{'':>28}║")
        print(f"  ║  Est. CfD revenue     : £{est_revenue_cfd:>11,.0f}{'':>28}║")
        print(f"  ║  Monthly O&M cost     : £{OM_MONTHLY:>11,.0f}{'':>28}║")
        print(f"  ║  Est. gross profit    : £{est_gross_profit:>11,.0f}{'':>28}║")
        print(f"  ║  Lost rev (low wind)  : £{lost_revenue:>11,.0f}{'':>28}║")
        print(f"  ║  Carbon avoided       : {carbon_avoided:>10,.1f} tCO2e{'':>26}║")
        print(f"  ╠{'═'*63}╣")
        print(f"  ║  🤖 LLM ANALYSIS{'':>46}║")
        print(f"  ╠{'═'*63}╣")
        print()
        for line in answer.split('\n'):
            print(f"  {line}")
        print()
        print(f"  ╠{'═'*63}╣")
        print(f"  ║  📎  SSE REPORT SOURCES ({len(sources)} chunks retrieved){'':>24}║")
        print(f"  ╠{'═'*63}╣")
        if sources:
            for i, s in enumerate(sources, 1):
                src_line = f"{i}. {s['file']}  (Year: {s['year']}, Page: {s['page']})"
                print(f"  ║  {src_line:<61}║")
        else:
            print(f"  ║  No sources retrieved — using demo documents{'':>17}║")
        print(f"  ╠{'═'*63}╣")
        print(f"  ║  💾  Saved to /content/sse_monthly_analysis.json{'':>13}║")
        print(f"  ║  📡  BDH Hour: {global_hour:<8} Timestamp: {str(LIVE_STATE.get('timestamp','N/A'))[:19]:<19}  ║")
        print(f"  ╚{'═'*63}╝\n")

    # ── Final save + summary ──────────────────────────────────────────────────
    JSON_PATH = "/content/sse_monthly_analysis.json"
    with open(JSON_PATH, "w") as jf:
        json.dump(all_monthly_records, jf, indent=2)

    file_size_kb = os.path.getsize(JSON_PATH) / 1024
    total_energy   = sum(r['financials']['est_energy_mwh']      for r in all_monthly_records)
    total_revenue  = sum(r['financials']['est_revenue_cfd_gbp']  for r in all_monthly_records)
    total_profit   = sum(r['financials']['est_gross_profit_gbp'] for r in all_monthly_records)
    total_carbon   = sum(r['financials']['carbon_avoided_tco2e'] for r in all_monthly_records)
    total_lost_rev = sum(r['financials']['lost_revenue_gbp']     for r in all_monthly_records)

    print(f"  ╔{'═'*63}╗")
    print(f"  ║  ✅  PIPELINE COMPLETE — 2-YEAR SUMMARY{'':>22}║")
    print(f"  ╠{'═'*63}╣")
    print(f"  ║  📊  Months analysed      : {len(all_monthly_records):<35}║")
    print(f"  ║  ⏱   Hours streamed       : {global_hour:<35,}║")
    print(f"  ║  🤖  LLM calls made       : {len(all_monthly_records):<35}║")
    print(f"  ╠{'═'*63}╣")
    print(f"  ║  💰  2-YEAR FINANCIAL TOTALS{'':>34}║")
    print(f"  ╠{'═'*63}╣")
    print(f"  ║  Total energy generated   : {total_energy:>12,.0f} MWh{'':>22}║")
    print(f"  ║  Total CfD revenue        : £{total_revenue:>12,.0f}{'':>22}║")
    print(f"  ║  Total gross profit       : £{total_profit:>12,.0f}{'':>22}║")
    print(f"  ║  Total lost revenue       : £{total_lost_rev:>12,.0f}{'':>22}║")
    print(f"  ║  Total carbon avoided     : {total_carbon:>12,.1f} tCO2e{'':>19}║")
    print(f"  ╠{'═'*63}╣")
    print(f"  ║  💾  JSON saved → {JSON_PATH}{'':>14}║")
    print(f"  ║      {len(all_monthly_records)} monthly records  |  {file_size_kb:.1f} KB{'':>24}║")
    print(f"  ╚{'═'*63}╝\n")

    return all_monthly_records


# ── Run the full pipeline ──────────────────────────────────────────────────────
import os

JSON_OUTPUT_PATH = "/content/sse_monthly_analysis.json"

print("🚀 Starting BDH stream → monthly LLM + RAG pipeline...")
print(f"   BDH    : runs every hour")
print(f"   LLM    : called once per month with full BDH summary + RAG")
print(f"   Output : printed + saved incrementally to {JSON_OUTPUT_PATH}")
print(f"   Test set: {len(df_test):,} hours (~{len(df_test)//8760:.0f} years)")
print()

# Run pipeline — records saved incrementally inside the loop AND returned here
all_monthly_records = stream_and_infer(bdh_model, df_test, feature_names=feature_cols)

# ── Confirm JSON file on disk and trigger download ─────────────────────────────
print(f"\n{'='*60}")
print("💾 JSON FILE STATUS")
print(f"{'='*60}")

if os.path.exists(JSON_OUTPUT_PATH):
    file_size_kb = os.path.getsize(JSON_OUTPUT_PATH) / 1024

    # Verify it is valid JSON
    try:
        with open(JSON_OUTPUT_PATH) as jf:
            saved_records = json.load(jf)
        n = len(saved_records)
        months = [r['month'] for r in saved_records]
        print(f"✅ File exists    : {JSON_OUTPUT_PATH}")
        print(f"✅ File size      : {file_size_kb:.1f} KB")
        print(f"✅ Records inside : {n} monthly analyses")
        print(f"✅ Months covered : {months}")
        print(f"✅ JSON is valid")
    except Exception as e:
        print(f"⚠️  JSON file exists but may be corrupted: {e}")

    # ── Download (Colab) or print path (local) ─────────────────────────────────
    try:
        from google.colab import files
        print(f"\n⬇️  Downloading {JSON_OUTPUT_PATH} ...")
        files.download(JSON_OUTPUT_PATH)
        print("✅ Download started — check your browser downloads folder.")
    except ImportError:
        print(f"\nℹ️  Running outside Colab.")
        print(f"   File is at: {JSON_OUTPUT_PATH}")
        print(f"   Copy it with:  !cp {JSON_OUTPUT_PATH} /your/local/path/")
else:
    print(f"❌ JSON file NOT found at {JSON_OUTPUT_PATH}")
    print("   This means stream_and_infer() did not complete a single month.")
    print("   Check for errors above — likely causes:")
    print("   • GROQ_API_KEY missing or invalid")
    print("   • BDH inference error (feature dimension mismatch)")
    print("   • Runtime crashed mid-run (re-run Cell 6)")
    print()
    # Emergency save from in-memory records if any exist
    if all_monthly_records:
        print(f"   ⚠️  Found {len(all_monthly_records)} in-memory records — saving now...")
        with open(JSON_OUTPUT_PATH, "w") as jf:
            json.dump(all_monthly_records, jf, indent=2)
        print(f"   ✅ Emergency save complete: {JSON_OUTPUT_PATH}")

In [None]:
# ╔══════════════════════════════════════════════════════════════╗
# ║  CELL 7 — Launch REST API (Optional)                        ║
# ║  Run AFTER Cell 6 completes.                               ║
# ║    POST /ask              → query the LLM analyst          ║
# ║    GET  /monthly-reports  → download full 2-year JSON      ║
# ║    GET  /live-state       → current BDH state              ║
# ╚══════════════════════════════════════════════════════════════╝

# ── Step 1: Install ────────────────────────────────────────────────────────────
import subprocess
subprocess.run(["pip", "install", "-q", "fastapi", "uvicorn",
                "nest_asyncio", "pyngrok"], check=True)

# ── Step 2: Upload api.py to /content/ if not already there ───────────────────
# api.py should be in the same folder as this notebook.
# If it's not there yet, upload it via the Colab file panel (left sidebar).
import os
if not os.path.exists("/content/api.py"):
    print("⚠️  api.py not found at /content/api.py")
    print("   Upload it via the Colab file panel (📁 icon) then re-run this cell.")
else:
    print("✅ api.py found at /content/api.py")

# ── Step 3: Launch API using notebook globals ──────────────────────────────────
import sys
sys.path.insert(0, "/content")

from api import launch_from_notebook

# Passes ask_analyst and LIVE_STATE directly from this kernel
# so /ask uses the real trained model and RAG, not a cold start
launch_from_notebook(
    ask_analyst_fn  = ask_analyst,    # defined in Cell 6
    live_state_dict = LIVE_STATE,     # defined in Cell 6
    port            = 8000,
)


In [None]:
# ╔══════════════════════════════════════════════════════════════╗
# ║  CELL 8 — Interactive Analyst (No widgets needed)           ║
# ║  Works in ALL Colab environments using simple input()       ║
# ║  Type your question when prompted, press Enter to submit    ║
# ╚══════════════════════════════════════════════════════════════╝

chat_history = []

QUICK_QUESTIONS = {
    "1": "What are the main physical climate risks?",
    "2": "What is SSE's net zero target?",
    "3": "What transition risks affect SSE?",
    "4": "Summarise TCFD governance disclosures",
    "5": "What ESG metrics does SSE track?",
    "6": "What mitigation strategies are recommended?",
}

TASK_MAP = {
    "1": "qa",
    "2": "risk_analysis",
    "3": "recommendation",
    "4": "scenario",
    "5": "esg",
}

TASK_LABELS = {
    "qa":             "💬 Q&A",
    "risk_analysis":  "⚠️  Risk Analysis",
    "recommendation": "✅ Recommendations",
    "scenario":       "🌡️  Scenario Analysis",
    "esg":            "📊 ESG / TCFD",
}

def print_response(question, answer, sources, task):
    label = TASK_LABELS.get(task, task)
    wm    = LIVE_STATE.get('wind_metrics', {})
    print(f"\n  ╔{'═'*63}╗")
    print(f"  ║  {label:<61}║")
    print(f"  ║  🤖 Model : {LLM_MODEL:<51}║")
    print(f"  ║  📡 BDH   : Hour {str(LIVE_STATE.get('hour','N/A')):<7} | {str(LIVE_STATE.get('timestamp','N/A'))[:19]:<19}  ║")
    print(f"  ╠{'═'*63}╣")
    print(f"  ║  ❓ QUESTION{'':>51}║")
    print(f"  ╠{'═'*63}╣")
    # Wrap question at 59 chars
    words = question.split()
    line  = ''
    for w in words:
        if len(line) + len(w) + 1 > 59:
            print(f"  ║  {line:<61}║")
            line = w
        else:
            line = (line + ' ' + w).strip()
    if line:
        print(f"  ║  {line:<61}║")
    print(f"  ╠{'═'*63}╣")
    print(f"  ║  💨 Wind Avg: {str(wm.get('wind_speed_avg_ms','N/A')):<6} m/s  "
          f"Low-wind hrs: {str(wm.get('low_wind_hours_lt4ms','N/A')):<5}  "
          f"High-wind hrs: {str(wm.get('high_wind_hours_gt12ms','N/A')):<4}║")
    print(f"  ║  🧠 BDH memory norm: {str(wm.get('memory_norm_avg','N/A')):<10}  "
          f"Predict error: {str(wm.get('mean_bdh_prediction_error','N/A')):<14}║")
    print(f"  ╠{'═'*63}╣")
    print(f"  ║  🤖 ANSWER{'':>52}║")
    print(f"  ╠{'═'*63}╣")
    print()
    for line in answer.split('\n'):
        print(f"  {line}")
    print()
    print(f"  ╠{'═'*63}╣")
    print(f"  ║  📎 SOURCES ({len(sources)} chunks retrieved){'':>38}║")
    print(f"  ╠{'═'*63}╣")
    if sources:
        for i, s in enumerate(sources, 1):
            src_line = f"{i}. {s['file']}  (Year: {s['year']}, Page: {s['page']})"
            print(f"  ║  {src_line:<61}║")
    else:
        print(f"  ║  No sources retrieved{'':>41}║")
    print(f"  ╚{'═'*63}╝\n")

def show_menu():
    wm = LIVE_STATE.get('wind_metrics', {})
    print(f"\n  ╔{'═'*63}╗")
    print(f"  ║  🌍  SSE RENEWABLES — CLIMATE RISK ANALYST{'':>19}║")
    print(f"  ║  BDH Physics + RAG (SSE Docs) + Groq LLM{'':>20}║")
    print(f"  ╠{'═'*63}╣")
    print(f"  ║  📡 BDH Hour : {str(LIVE_STATE.get('hour','N/A')):<48}║")
    print(f"  ║  🕐 Timestamp: {str(LIVE_STATE.get('timestamp','N/A'))[:19]:<48}║")
    print(f"  ║  💨 Avg Wind : {str(wm.get('wind_speed_avg_ms','N/A')):<45} m/s║")
    print(f"  ║  🧠 Mem Norm : {str(wm.get('memory_norm_avg','N/A')):<48}║")
    print(f"  ╠{'═'*63}╣")
    print(f"  ║  ANALYSIS MODES{'':>47}║")
    print(f"  ╠{'═'*63}╣")
    for k, v in TASK_MAP.items():
        print(f"  ║  [{k}] {TASK_LABELS[v]:<57}║")
    print(f"  ╠{'═'*63}╣")
    print(f"  ║  QUICK QUESTIONS{'':>46}║")
    print(f"  ╠{'═'*63}╣")
    for k, v in QUICK_QUESTIONS.items():
        print(f"  ║  [q{k}] {v[:56]:<57}║")
    print(f"  ╠{'═'*63}╣")
    print(f"  ║  [cm] Toggle chat memory ON/OFF{'':>31}║")
    print(f"  ║  [c]  Clear chat history{'':>38}║")
    print(f"  ║  [x]  Exit{'':>52}║")
    print(f"  ╚{'═'*63}╝")

# ── Main interactive loop ──────────────────────────────────────────────────────
current_task = "qa"
chat_mode    = False

print("\n✅ Analyst ready. Starting interactive session...")

while True:
    show_menu()

    print(f"  Current mode : {TASK_LABELS[current_task]}")
    print(f"  Chat mode    : {'ON  (history kept)' if chat_mode else 'OFF (single query)'}")
    print()

    user_input = input("  ➤ Enter mode number, q+number for quick question, or type your question: ").strip()

    # ── Exit ──────────────────────────────────────────────────────────────────
    if user_input.lower() == "x":
        print("\n👋 Exiting analyst. Goodbye!")
        break

    # ── Clear chat ────────────────────────────────────────────────────────────
    elif user_input.lower() == "c":
        chat_history.clear()
        print("\n🗑  Chat history cleared.")
        continue

    # ── Toggle chat mode ──────────────────────────────────────────────────────
    elif user_input.lower() == "cm":
        chat_mode = not chat_mode
        print(f"\n💬 Chat mode {'ON' if chat_mode else 'OFF'}")
        continue

    # ── Select analysis mode ──────────────────────────────────────────────────
    elif user_input in TASK_MAP:
        current_task = TASK_MAP[user_input]
        print(f"\n✅ Mode set to: {TASK_LABELS[current_task]}")
        continue

    # ── Quick question ────────────────────────────────────────────────────────
    elif user_input.lower().startswith("q") and user_input[1:] in QUICK_QUESTIONS:
        question = QUICK_QUESTIONS[user_input[1:]]
        print(f"\n  Using quick question: {question}")

    # ── Custom question ───────────────────────────────────────────────────────
    elif len(user_input) > 3:
        question = user_input

    else:
        print("\n⚠️  Not recognised. Type your question, or use the menu options above.")
        continue

    # ── Call analyst ──────────────────────────────────────────────────────────
    print(f"\n  ⏳ Analysing with mode '{TASK_LABELS[current_task]}'...")
    try:
        history = chat_history if chat_mode else None
        answer, sources = ask_analyst(
            question,
            task        = current_task,
            temperature = 0.3,
            chat_history= history,
        )

        if chat_mode:
            chat_history.append({"role": "user",      "content": question})
            chat_history.append({"role": "assistant", "content": answer})
            if len(chat_history) > 20:
                chat_history[:] = chat_history[-20:]

        print_response(question, answer, sources, current_task)

    except Exception as e:
        print(f"\n❌ Error: {e}")
        print("   Check your GROQ_API_KEY and that Cell 5 (RAG) ran successfully.")