In [None]:
# ==============================================================================
#  STEP 1: V_mpp Predictor Artifact Generator
# ==============================================================================
#  This script's ONLY purpose is to train the best V_mpp predictor and the
#  associated feature transformer, then save them to disk for use by gold.py.
#  RUN THIS SCRIPT ONCE to create the necessary .joblib files.
# ==============================================================================
import numpy as np
import pandas as pd
import joblib
from scipy.interpolate import PchipInterpolator
from sklearn.preprocessing import RobustScaler, FunctionTransformer, MinMaxScaler
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
import lightgbm as lgb
from tqdm.auto import tqdm
from pathlib import Path

print("--- Starting V_mpp Predictor Artifact Generation ---")
# --- Configuration ---

DRIVE_PATH = "/content/drive/MyDrive/Colab Notebooks/Data_100k"

INPUT_FILE_IV = Path(DRIVE_PATH) / "iV_m.txt"
INPUT_FILE_PARAMS = Path(DRIVE_PATH) / "LHS_parameters_m.txt"
# --- DEFINE OUTPUT PATHS FOR ARTIFACTS ---
OUTPUT_DIR = Path("./") # Save in current directory
PREDICTOR_MODEL_PATH = OUTPUT_DIR / "v_mpp_predictor.joblib"
FEATURE_TRANSFORMER_PATH = OUTPUT_DIR / "feature_transformer.joblib"

COLNAMES = [
    'lH','lP','lE', 'muHh','muPh','muPe','muEe','NvH','NcH','NvE','NcE','NvP','NcP',
    'chiHh','chiHe','chiPh','chiPe','chiEh','chiEe', 'Wlm','Whm', 'epsH','epsP','epsE',
    'Gavg','Aug','Brad','Taue','Tauh','vII','vIII'
]
params_df_raw = pd.read_csv(INPUT_FILE_PARAMS, header=None, names=COLNAMES)

# --- Preprocessing Functions ---
def get_param_transformer(colnames: list[str]) -> ColumnTransformer:
    param_defs = {
        'layer_thickness': ['lH', 'lP', 'lE'],
        'material_properties': ['muHh', 'muPh', 'muPe', 'muEe', 'NvH', 'NcH', 'NvE', 'NcE', 'NvP', 'NcP','chiHh', 'chiHe', 'chiPh', 'chiPe', 'chiEh', 'chiEe', 'epsH', 'epsP', 'epsE','Gavg'],
        'contacts': ['Wlm', 'Whm'],
        'recombination_gen': ['Aug', 'Brad', 'Taue', 'Tauh', 'vII', 'vIII']
    }
    transformers = []
    for group, cols in param_defs.items():
        actual_cols = [c for c in cols if c in colnames]
        if not actual_cols: continue
        steps = [('robust', RobustScaler()), ('minmax', MinMaxScaler(feature_range=(-1, 1)))]
        if group == 'material_properties':
            steps.insert(0, ('log1p', FunctionTransformer(func=np.log1p, feature_names_out='one-to-one')))
        transformers.append((group, Pipeline(steps), actual_cols))
    return ColumnTransformer(transformers, remainder='passthrough')

def calculate_pv_params(voltage_grid: np.ndarray, current_curve: np.ndarray) -> dict:
    try:
        if np.count_nonzero(~np.isnan(current_curve)) < 4: return None
        isc = current_curve[0]
        if isc <= 1e-9: return None
        interpolator = PchipInterpolator(voltage_grid, current_curve, extrapolate=False)
        v_fine = np.linspace(voltage_grid[0], voltage_grid[-1], 2000)
        i_fine = interpolator(v_fine)
        zero_cross_indices = np.where(i_fine <= 0)[0]
        if len(zero_cross_indices) == 0: return None
        voc = v_fine[zero_cross_indices[0]]
        search_mask = v_fine <= voc
        v_search, i_search = v_fine[search_mask], i_fine[search_mask]
        if len(v_search) == 0: return None
        power = v_search * i_search
        mpp_index = np.argmax(power)
        v_mpp, i_mpp = v_search[mpp_index], i_search[mpp_index]
        if voc * isc < 1e-9: return None
        fill_factor = (v_mpp * i_mpp) / (voc * isc)
        return {"Isc": isc, "Voc": voc, "Vmpp": v_mpp, "FF": fill_factor}
    except (ValueError, IndexError):
        return None

# --- Data Preparation ---
print("\n--- Preparing Full Dataset for Predictor Training ---")
full_v_grid = np.concatenate([np.arange(0, 0.4 + 1e-8, 0.1), np.arange(0.425, 1.4 + 1e-8, 0.025)]).astype(np.float32)
iv_data_raw = np.loadtxt(INPUT_FILE_IV, delimiter=',')
results, valid_indices = [], []
for i in tqdm(range(len(iv_data_raw)), desc="Processing Curves"):
    params = calculate_pv_params(full_v_grid, iv_data_raw[i])
    if params:
        results.append(params)
        valid_indices.append(i)

iv_summary_df = pd.DataFrame(results)
params_df_aligned = params_df_raw.iloc[valid_indices].reset_index(drop=True)

# --- Feature Engineering & Assembly ---
iv_summary_df['Voc_x_FF'] = iv_summary_df['Voc'] * iv_summary_df['FF']
y_target = iv_summary_df['Vmpp']
iv_features = iv_summary_df[['Isc', 'Voc', 'FF', 'Voc_x_FF']]
param_transformer = get_param_transformer(COLNAMES)
params_scaled = pd.DataFrame(param_transformer.fit_transform(params_df_aligned), columns=param_transformer.get_feature_names_out())
X_hybrid = pd.concat([params_scaled, iv_features.reset_index(drop=True)], axis=1)

# --- Train the Final Model ---
print("\n--- Training the Definitive V_mpp Predictor ---")
# Use the best hyperparameters found from the Bayesian search
best_params = {
    'n_estimators': 634, 'num_leaves': 100, 'max_depth': 11,
    'learning_rate': 0.0407, 'subsample': 1.0, 'colsample_bytree': 1.0,
    'min_child_samples': 20, 'reg_alpha': 1e-09, 'reg_lambda': 1e-09
}
model = lgb.LGBMRegressor(random_state=42, n_jobs=-1, verbose=-1, **best_params)
model.fit(X_hybrid, y_target)
print("Model training complete.")

# --- Save Artifacts to Disk ---
joblib.dump(model, PREDICTOR_MODEL_PATH)
joblib.dump(param_transformer, FEATURE_TRANSFORMER_PATH)
print(f"\nSUCCESS! Artifacts saved:")
print(f"  - Predictor Model: {PREDICTOR_MODEL_PATH}")
print(f"  - Feature Transformer: {FEATURE_TRANSFORMER_PATH}")

--- Starting V_mpp Predictor Artifact Generation ---

--- Preparing Full Dataset for Predictor Training ---


Processing Curves:   0%|          | 0/66026 [00:00<?, ?it/s]


--- Training the Definitive V_mpp Predictor ---
Model training complete.

SUCCESS! Artifacts saved:
  - Predictor Model: v_mpp_predictor.joblib
  - Feature Transformer: feature_transformer.joblib


In [None]:
# ==============================================================================
#  Golden Child v2.4 - Dilated TCN Architecture
# ==============================================================================
#
#  This version replaces the previous TCN+Attention hybrid model with a pure
#  Temporal Convolutional Network (TCN) using dilated convolutions.
#
#  --- KEY ARCHITECTURE CHANGES (v2.4) ---
#   1. Attention Block Removed: The `SelfAttentionBlock` has been completely
#      removed from the model.
#   2. Dilated Convolutions: The `TemporalBlock` now accepts a `dilation`
#      parameter, which is passed to its internal Conv1d layers.
#   3. TCN Stack: The `PhysicsIVSystem` now builds a sequential stack of
#      TemporalBlocks with exponentially increasing dilation rates (e.g., 1, 2,
#      4, 8), allowing for a large receptive field efficiently.
#
# ==============================================================================
!pip install pytorch_lightning -q
import os
import logging
from pathlib import Path
from datetime import datetime
import math
import typing

import numpy as np
import pandas as pd
import joblib
import lightgbm as lgb

import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl

from torch.utils.data import Dataset, DataLoader
from scipy.interpolate import PchipInterpolator
from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import RobustScaler, FunctionTransformer, MinMaxScaler
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import RichProgressBar
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from PIL import Image


# ──────────────────────────────────────────────────────────────────────────────
#   CONFIGURATIONS & CONSTANTS
# ──────────────────────────────────────────────────────────────────────────────

VMPP_PREDICTOR_CONFIG = {
    "model_path": "./v_mpp_predictor.joblib",
    "feature_transformer_path": "./feature_transformer.joblib"
}

INPUT_FILE_PARAMS = "/content/drive/MyDrive/Colab Notebooks/Data_100k/LHS_parameters_m.txt"
INPUT_FILE_IV = "/content/drive/MyDrive/Colab Notebooks/Data_100k/iV_m.txt"
OUTPUT_DIR = Path("./lightning_output")

OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
Path("./data/processed").mkdir(parents=True, exist_ok=True)

CONFIG = {
    "train": {
        "seed": 42,
        "run_name": f"DilatedTCN-{datetime.now().strftime('%Y%m%d_%H%M%S')}",
    },
    "model": {
        "param_dim": 33,
        "dense_units": [256, 128, 128],
        "filters": [128], # Defines the channel count for the TCN stack
        "kernel": 5,
        "dilations": [1, 2, 4, 8], # <-- NEW: Dilation rates for the TCN stack
        "dropout": 0.0363, "embedding_type": 'gaussian', "gaussian_bands": 18,
        "gaussian_sigma": 0.0775,
        "loss_weights": {
            "mse": 0.98, "mono": 0.005, "excurv": 0.01, "excess_threshold": 0.8,
        },
    },
    "optimizer": {
        "lr": 0.0055, "weight_decay": 5.4e-05,
        "final_lr_ratio": 0.0066, "warmup_epochs": 7,
    },
    "dataset": {
        "paths": {
            "params_csv": INPUT_FILE_PARAMS, "iv_raw_txt": INPUT_FILE_IV,
            "output_dir": "./data/processed", "preprocessed_npz": "./data/processed/preprocessed_data.npz",
            "param_transformer": "./data/processed/param_transformer.joblib",
            "scalar_transformer": "./data/processed/scalar_transformer.joblib",
            "v_fine_memmap": "./data/processed/v_fine_curves.mmap",
            "i_fine_memmap": "./data/processed/i_fine_curves.mmap",
        },
        "pchip": {
            "v_max": 1.4, "n_fine": 2000, "n_pre_mpp": 3, "n_post_mpp": 4, "seq_len": 8,
        },
        "dataloader": {
            "batch_size": 128, "num_workers": os.cpu_count() // 2, "pin_memory": True,
        },
        "curvature_weighting": {"alpha": 4.0, "power": 1.5,},
    },
    "trainer": {
        "max_epochs": 12, "accelerator": "auto", "devices": "auto", "precision": "32-true",
        "gradient_clip_val": 1.0, "log_every_n_steps": 25,
    },
}

COLNAMES = [
    'lH','lP','lE', 'muHh','muPh','muPe','muEe','NvH','NcH','NvE','NcE','NvP','NcP',
    'chiHh','chiHe','chiPh','chiPe','chiEh','chiEe', 'Wlm','Whm', 'epsH','epsP','epsE',
    'Gavg','Aug','Brad','Taue','Tauh','vII','vIII'
]
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
log = logging.getLogger(__name__)

# ──────────────────────────────────────────────────────────────────────────────
#   UTILITY & PREPROCESSING FUNCTIONS
# ──────────────────────────────────────────────────────────────────────────────

def seed_everything(seed: int):
    pl.seed_everything(seed, workers=True)

def calculate_predictor_features(iv_raw, full_v_grid):
    try:
        if np.count_nonzero(~np.isnan(iv_raw)) < 4: return None
        isc = iv_raw[0]
        if isc <= 1e-9: return None
        pi = PchipInterpolator(full_v_grid, iv_raw, extrapolate=False)
        v_fine = np.linspace(0, 1.4, 2000)
        i_fine = pi(v_fine)
        valid_mask = ~np.isnan(i_fine)
        v_fine, i_fine = v_fine[valid_mask], i_fine[valid_mask]
        if v_fine.size < 2: return None
        zero_cross_idx = np.where(i_fine <= 0)[0]
        voc = v_fine[zero_cross_idx[0]] if len(zero_cross_idx) > 0 else v_fine[-1]
        search_mask = v_fine <= voc
        v_search, i_search = v_fine[search_mask], i_fine[search_mask]
        if len(v_search) == 0: return None
        power = v_search * i_search
        mpp_index = np.argmax(power)
        v_mpp, i_mpp = v_search[mpp_index], i_search[mpp_index]
        if voc * isc < 1e-9: return None
        fill_factor = (v_mpp * i_mpp) / (voc * isc)
        return {"Isc": isc, "Voc": voc, "FF": fill_factor, "Vmpp": v_mpp, "pi": pi}
    except (ValueError, IndexError):
        return None

def create_v_slice_from_prediction(pi: PchipInterpolator, voc_v: float, predicted_v_mpp: float, seq_len: int, n_pre: int, n_post: int) -> typing.Optional[tuple]:
    safe_v_mpp = np.clip(predicted_v_mpp, 1e-6, voc_v * 0.99)
    v_pre_mpp = np.linspace(0, safe_v_mpp, n_pre + 2, endpoint=True)[:-1]
    v_post_mpp = np.linspace(safe_v_mpp, voc_v, n_post + 2, endpoint=True)[1:]
    v_mpp_grid = np.unique(np.concatenate([v_pre_mpp, v_post_mpp]))
    v_slice = np.interp(np.linspace(0, 1, seq_len), np.linspace(0, 1, len(v_mpp_grid)), v_mpp_grid)
    i_slice = pi(v_slice)
    if np.any(np.isnan(i_slice)) or i_slice.shape[0] != seq_len:
        return None, None
    return v_slice.astype(np.float32), i_slice.astype(np.float32)

def normalize_and_scale_by_isc(curve: np.ndarray) -> tuple[float, np.ndarray]:
    isc_val = float(curve[0])
    return isc_val, (2.0 * (curve / isc_val) - 1.0).astype(np.float32)

def compute_curvature_weights(y_curves: np.ndarray, alpha: float, power: float) -> np.ndarray:
    padded = np.pad(y_curves, ((0, 0), (1, 1)), mode='edge')
    kappa = np.abs(padded[:, 2:] - 2 * padded[:, 1:-1] + padded[:, :-2])
    max_kappa = np.max(kappa, axis=1, keepdims=True)
    max_kappa[max_kappa < 1e-9] = 1.0
    return (1.0 + alpha * np.power(kappa / max_kappa, power)).astype(np.float32)

def get_param_transformer(colnames: list[str]) -> ColumnTransformer:
    param_defs = {
        'layer_thickness': ['lH', 'lP', 'lE'],
        'material_properties': ['muHh', 'muPh', 'muPe', 'muEe', 'NvH', 'NcH', 'NvE', 'NcE', 'NvP', 'NcP','chiHh', 'chiHe', 'chiPh', 'chiPe', 'chiEh', 'chiEe', 'epsH', 'epsP', 'epsE','Gavg'],
        'contacts': ['Wlm', 'Whm'],
        'recombination_gen': ['Aug', 'Brad', 'Taue', 'Tauh', 'vII', 'vIII']
    }
    transformers = []
    for group, cols in param_defs.items():
        actual_cols = [c for c in cols if c in colnames]
        if not actual_cols: continue
        steps = [('robust', RobustScaler()), ('minmax', MinMaxScaler(feature_range=(-1, 1)))]
        if group == 'material_properties':
            steps.insert(0, ('log1p', FunctionTransformer(func=np.log1p)))
        transformers.append((group, Pipeline(steps), actual_cols))
    return ColumnTransformer(transformers, remainder='passthrough')

def denormalize(scaled_current, isc):
    is_tensor = isinstance(scaled_current, torch.Tensor)
    if is_tensor: isc = isc.unsqueeze(1)
    else: isc = isc[:, np.newaxis]
    return (scaled_current + 1.0) / 2.0 * isc

def physics_loss(y_pred, y_true, sample_w, loss_w):
    mse_loss = (((y_true - y_pred)**2) * sample_w).mean()
    mono_loss = torch.relu(y_pred[:, 1:] - y_pred[:, :-1]).pow(2).mean()
    curvature = torch.abs(y_pred[:, :-2] - 2 * y_pred[:, 1:-1] + y_pred[:, 2:])
    excurv_loss = torch.relu(curvature - loss_w['excess_threshold']).pow(2).mean()
    total_loss = (loss_w['mse'] * mse_loss +
                  loss_w['mono'] * mono_loss +
                  loss_w['excurv'] * excurv_loss)
    return total_loss, {'mse': mse_loss, 'mono': mono_loss, 'excurv': excurv_loss}

class GaussianRBFFeatures(nn.Module):
    def __init__(self, num_bands: int, sigma: float, v_max: float):
        super().__init__()
        self.v_max, self.sigma = v_max, sigma
        self.register_buffer('mu', torch.linspace(0, 1, num_bands), persistent=False)
        self.out_dim = num_bands
    def forward(self, v: torch.Tensor) -> torch.Tensor:
        v_norm = v / self.v_max
        diff = v_norm.unsqueeze(-1) - self.mu
        return torch.exp(-0.5 * (diff / self.sigma)**2)

def make_positional_embedding(cfg: dict) -> nn.Module:
    return GaussianRBFFeatures(cfg['model']['gaussian_bands'], cfg['model']['gaussian_sigma'], cfg['dataset']['pchip']['v_max'])

class ChannelLayerNorm(nn.Module):
    def __init__(self, num_channels):
        super().__init__()
        self.norm = nn.LayerNorm(num_channels)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.norm(x.transpose(1, 2)).transpose(1, 2)

class TemporalBlock(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, kernel_size: int, dropout: float, dilation: int):
        super().__init__()
        # Causal padding is calculated based on kernel size and dilation
        self.padding = (kernel_size - 1) * dilation
        self.conv1 = nn.Conv1d(in_ch, out_ch, kernel_size, dilation=dilation)
        self.act1, self.norm1, self.drop1 = nn.GELU(), ChannelLayerNorm(out_ch), nn.Dropout(dropout)
        self.conv2 = nn.Conv1d(out_ch, out_ch, kernel_size, dilation=dilation)
        self.act2, self.norm2, self.drop2 = nn.GELU(), ChannelLayerNorm(out_ch), nn.Dropout(dropout)
        self.downsample = nn.Conv1d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        res = self.downsample(x)
        out = F.pad(x, (self.padding, 0)) # Apply causal padding
        out = self.drop1(self.norm1(self.act1(self.conv1(out))))
        out = F.pad(out, (self.padding, 0)) # Apply causal padding
        out = self.drop2(self.norm2(self.act2(self.conv2(out))))
        return out + res

# ──────────────────────────────────────────────────────────────────────────────
#   PYTORCH LIGHTNING DATA & MODEL MODULES
# ──────────────────────────────────────────────────────────────────────────────

class IVDataModule(pl.LightningDataModule):
    def __init__(self, cfg: dict):
        super().__init__()
        self.cfg = cfg
        self.param_tf, self.scalar_tf = None, None
    def prepare_data(self):
        if not Path(self.cfg['dataset']['paths']['preprocessed_npz']).exists():
            log.info("Preprocessed data not found. Running preprocessing...")
            self._preprocess_and_save()
        else:
            log.info("Found preprocessed data. Skipping preprocessing.")
    def setup(self, stage: str | None = None):
        if self.param_tf is None:
            self.param_tf = joblib.load(self.cfg['dataset']['paths']['param_transformer'])
            self.scalar_tf = joblib.load(self.cfg['dataset']['paths']['scalar_transformer'])
        if stage in ("fit", None):
            self.train_dataset = IVDataset(self.cfg, 'train', self.param_tf, self.scalar_tf)
            self.val_dataset = IVDataset(self.cfg, 'val', self.param_tf, self.scalar_tf)
        if stage in ("test", None):
            self.test_dataset = IVDataset(self.cfg, 'test', self.param_tf, self.scalar_tf)
    def train_dataloader(self): return DataLoader(self.train_dataset, **self.cfg['dataset']['dataloader'], shuffle=True)
    def val_dataloader(self): return DataLoader(self.val_dataset, **self.cfg['dataset']['dataloader'])
    def test_dataloader(self): return DataLoader(self.test_dataset, **self.cfg['dataset']['dataloader'])

    def _preprocess_and_save(self):
        log.info("--- Starting V_mpp-PREDICTED Data Preprocessing ---")
        cfg = self.cfg
        paths, pchip_cfg = cfg['dataset']['paths'], cfg['dataset']['pchip']

        log.info("Loading V_mpp predictor artifacts...")
        try:
            v_mpp_predictor = joblib.load(VMPP_PREDICTOR_CONFIG['model_path'])
            feature_transformer_pred = joblib.load(VMPP_PREDICTOR_CONFIG['feature_transformer_path'])
        except FileNotFoundError as e:
            log.error(f"Error loading artifact: {e}")
            log.error("Please run the 'VMPP Predictor Generation Script' first.")
            raise

        params_df = pd.read_csv(paths['params_csv'], header=None, names=COLNAMES)
        iv_data_raw = np.loadtxt(paths['iv_raw_txt'], delimiter=',')
        full_v_grid = np.concatenate([np.arange(0, 0.4 + 1e-8, 0.1), np.arange(0.425, 1.4 + 1e-8, 0.025)]).astype(np.float32)

        N_raw, n_fine = len(iv_data_raw), pchip_cfg['n_fine']
        log.info(f"Opening memmaps for {N_raw} full fine curves…")
        v_fine_mm = np.memmap(paths['v_fine_memmap'], dtype=np.float16, mode='w+', shape=(N_raw, n_fine))
        i_fine_mm = np.memmap(paths['i_fine_memmap'], dtype=np.float16, mode='w+', shape=(N_raw, n_fine))
        v_fine_mm[:], i_fine_mm[:] = np.nan, np.nan

        log.info("Pass 1: Calculating predictor features for all curves...")
        predictor_features, valid_indices, interpolators = [], [], []
        for i in tqdm(range(len(iv_data_raw)), desc="Feature Calculation"):
            res = calculate_predictor_features(iv_data_raw[i], full_v_grid)
            if res:
                interpolators.append(res.pop("pi"))
                predictor_features.append(res)
                valid_indices.append(i)

        iv_summary_df = pd.DataFrame(predictor_features)
        params_df_aligned = params_df.iloc[valid_indices].reset_index(drop=True)
        log.info(f"Found {len(valid_indices)} valid curves for prediction.")

        log.info("Assembling hybrid feature matrix for V_mpp prediction...")
        iv_summary_df['Voc_x_FF'] = iv_summary_df['Voc'] * iv_summary_df['FF']
        iv_features = iv_summary_df[['Isc', 'Voc', 'FF', 'Voc_x_FF']]
        params_scaled = pd.DataFrame(feature_transformer_pred.transform(params_df_aligned), columns=COLNAMES)
        X_hybrid = pd.concat([params_scaled, iv_features.reset_index(drop=True)], axis=1)

        log.info("Predicting V_mpp for all curves...")
        predicted_vmpps = v_mpp_predictor.predict(X_hybrid)

        log.info("Pass 2: Generating training slices and populating memmaps...")
        v_slices, i_slices, final_valid_indices_mask = [], [], []
        for i, raw_idx in enumerate(tqdm(valid_indices, desc="Slice Generation")):
            v_slice, i_slice = create_v_slice_from_prediction(
                interpolators[i], iv_summary_df.loc[i, 'Voc'], predicted_vmpps[i],
                pchip_cfg['seq_len'], pchip_cfg['n_pre_mpp'], pchip_cfg['n_post_mpp']
            )
            if v_slice is not None:
                v_slices.append(v_slice)
                i_slices.append(i_slice)
                final_valid_indices_mask.append(True)

                v_full = np.linspace(0, pchip_cfg['v_max'], n_fine)
                i_full = interpolators[i](v_full)
                mask = ~np.isnan(i_full)
                v_fine_mm[raw_idx, :mask.sum()] = v_full[mask]
                i_fine_mm[raw_idx, :mask.sum()] = i_full[mask]
            else:
                final_valid_indices_mask.append(False)

        v_fine_mm.flush(); i_fine_mm.flush()
        del v_fine_mm, i_fine_mm

        v_slices, i_slices = np.array(v_slices), np.array(i_slices)
        final_valid_indices = np.array(valid_indices)[final_valid_indices_mask]
        predicted_vmpps = predicted_vmpps[final_valid_indices_mask]
        params_df_final = params_df.iloc[final_valid_indices].reset_index(drop=True)

        log.info(f"Retained {len(v_slices)} final valid curves after prediction and slicing.")
        isc_vals, i_slices_scaled = zip(*[normalize_and_scale_by_isc(c) for c in i_slices])
        isc_vals, i_slices_scaled = np.array(isc_vals), np.array(i_slices_scaled)
        sample_weights = compute_curvature_weights(i_slices_scaled, **cfg['dataset']['curvature_weighting'])

        param_transformer_tcn = get_param_transformer(COLNAMES)
        param_transformer_tcn.fit(params_df_final)
        joblib.dump(param_transformer_tcn, paths['param_transformer'])

        scalar_df = pd.DataFrame({'I_ref': i_slices[:, 0], 'Predicted_V_mpp': predicted_vmpps})
        scalar_transformer = Pipeline([('scaler', MinMaxScaler(feature_range=(-1, 1)))])
        scalar_transformer.fit(scalar_df)
        joblib.dump(scalar_transformer, paths['scalar_transformer'])

        param_dim, scalar_dim = param_transformer_tcn.transform(params_df_final).shape[1], scalar_transformer.transform(scalar_df).shape[1]
        self.cfg['model']['param_dim'] = param_dim + scalar_dim
        log.info(f"TCN Model total parameter dimension calculated: {self.cfg['model']['param_dim']}")

        all_indices = np.arange(len(final_valid_indices))
        train_val_idx, test_idx = train_test_split(all_indices, test_size=0.2, random_state=cfg['train']['seed'])
        train_idx, val_idx = train_test_split(train_val_idx, test_size=0.15, random_state=cfg['train']['seed'])
        split_labels = np.array([''] * len(all_indices), dtype=object)
        split_labels[train_idx], split_labels[val_idx], split_labels[test_idx] = 'train', 'val', 'test'

        np.savez(paths['preprocessed_npz'],
                 v_slices=v_slices, i_slices=i_slices, i_slices_scaled=i_slices_scaled,
                 sample_weights=sample_weights, isc_vals=isc_vals,
                 valid_indices=final_valid_indices, split_labels=split_labels,
                 predicted_vmpps=predicted_vmpps)
        log.info(f"Saved final preprocessed data to {paths['preprocessed_npz']}")

class IVDataset(Dataset):
    def __init__(self, cfg: dict, split: str, param_tf, scalar_tf):
        super().__init__()
        data = np.load(cfg['dataset']['paths']['preprocessed_npz'], allow_pickle=True)
        indices = np.where(data['split_labels'] == split)[0]
        self.v_slices = torch.from_numpy(data['v_slices'][indices])
        self.i_slices_scaled = torch.from_numpy(data['i_slices_scaled'][indices])
        self.sample_weights = torch.from_numpy(data['sample_weights'][indices])
        self.isc_vals = torch.from_numpy(data['isc_vals'][indices])

        params_df_all = pd.read_csv(cfg['dataset']['paths']['params_csv'], header=None, names=COLNAMES)
        params_df_valid = params_df_all.iloc[data['valid_indices']].reset_index(drop=True)

        scalar_df = pd.DataFrame({
            'I_ref': data['i_slices'][:, 0],
            'Predicted_V_mpp': data['predicted_vmpps']
        })

        X_params_full = param_tf.transform(params_df_valid).astype(np.float32)
        X_scalar_full = scalar_tf.transform(scalar_df).astype(np.float32)
        X_combined = np.concatenate([X_params_full, X_scalar_full], axis=1)
        self.X = torch.from_numpy(X_combined[indices])

    def __len__(self): return len(self.v_slices)
    def __getitem__(self, idx):
        return {'X_combined': self.X[idx], 'voltage': self.v_slices[idx], 'current_scaled': self.i_slices_scaled[idx], 'sample_w': self.sample_weights[idx], 'isc': self.isc_vals[idx]}

class PhysicsIVSystem(pl.LightningModule):
    def __init__(self, cfg: dict, warmup_steps: int, total_steps: int):
        super().__init__()
        self.save_hyperparameters(cfg)
        self.hparams.update({'warmup_steps': warmup_steps, 'total_steps': total_steps})
        mcfg = self.hparams

        # --- Parametric MLP (unchanged) ---
        mlp_layers = []
        in_dim = mcfg['model']['param_dim']
        for units in mcfg['model']['dense_units']:
            mlp_layers.extend([nn.Linear(in_dim, units), nn.BatchNorm1d(units), nn.GELU(), nn.Dropout(mcfg['model']['dropout'])])
            in_dim = units
        self.param_mlp = nn.Sequential(*mlp_layers)

        # --- Positional Embedding (unchanged) ---
        self.pos_embed = make_positional_embedding(mcfg)

        # --- Dilated TCN Stack (NEW) ---
        tcn_layers = []
        seq_input_dim = mcfg['model']['dense_units'][-1] + self.pos_embed.out_dim
        num_filters = mcfg['model']['filters'][0]
        kernel_size = mcfg['model']['kernel']
        dropout = mcfg['model']['dropout']

        for i, dilation in enumerate(mcfg['model']['dilations']):
            input_channels = seq_input_dim if i == 0 else num_filters
            tcn_layers.append(
                TemporalBlock(input_channels, num_filters, kernel_size, dropout, dilation)
            )
        self.tcn_stack = nn.Sequential(*tcn_layers)

        # --- Output Head (takes input from the TCN stack) ---
        self.out_head = nn.Linear(num_filters, 1)

        self.apply(self._init_weights)
        self.test_preds, self.test_trues = [], []
        self.all_test_preds_np, self.all_test_trues_np = None, None

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Conv1d)):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None: nn.init.zeros_(module.bias)

    def forward(self, X_combined: torch.Tensor, voltage: torch.Tensor) -> torch.Tensor:
        p = self.param_mlp(X_combined).unsqueeze(1).expand(-1, voltage.shape[1], -1)
        v_emb = self.pos_embed(voltage)
        x = torch.cat([p, v_emb], dim=-1).transpose(1, 2)
        x = self.tcn_stack(x)
        x = x.transpose(1, 2)
        return self.out_head(x).squeeze(-1)

    def _step(self, batch, stage: str):
        y_pred = self(batch['X_combined'], batch['voltage'])
        loss, comps = physics_loss(y_pred, batch['current_scaled'], batch['sample_w'], self.hparams['model']['loss_weights'])
        self.log_dict({f'{stage}_{k}': v for k, v in comps.items()}, on_step=False, on_epoch=True, batch_size=len(batch['voltage']))
        self.log(f'{stage}_loss', loss, prog_bar=(stage == 'val'), on_step=False, on_epoch=True, batch_size=len(batch['voltage']))
        return loss

    def training_step(self, batch, batch_idx): return self._step(batch, 'train')
    def validation_step(self, batch, batch_idx): return self._step(batch, 'val')
    def test_step(self, batch, batch_idx):
        pred_scaled = self(batch['X_combined'], batch['voltage'])
        self.test_preds.append(denormalize(pred_scaled.cpu(), batch['isc'].cpu()))
        self.test_trues.append(denormalize(batch['current_scaled'].cpu(), batch['isc'].cpu()))

    def on_test_epoch_start(self): self.test_preds.clear(); self.test_trues.clear()
    def on_test_epoch_end(self):
        if not self.test_preds: return
        self.all_test_preds_np = torch.cat(self.test_preds, dim=0).numpy()
        self.all_test_trues_np = torch.cat(self.test_trues, dim=0).numpy()
        preds, trues = self.all_test_preds_np, self.all_test_trues_np
        self.log("test/MAE_denorm", mean_absolute_error(trues.ravel(), preds.ravel()), prog_bar=True)
        self.log("test/RMSE_denorm", np.sqrt(mean_squared_error(trues.ravel(), preds.ravel())), prog_bar=True)
        self.log("test/avg_R2", np.mean([r2_score(trues[i], preds[i]) for i in range(len(trues)) if trues[i].size > 1]), prog_bar=True)

    def configure_optimizers(self):
        opt_cfg = self.hparams['optimizer']
        optimizer = torch.optim.AdamW(self.parameters(), lr=opt_cfg['lr'], weight_decay=opt_cfg['weight_decay'])
        scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=opt_cfg['lr'], total_steps=self.hparams.total_steps, pct_start=self.hparams.warmup_steps/self.hparams.total_steps, final_div_factor=1/opt_cfg['final_lr_ratio'])
        return [optimizer], [{'scheduler': scheduler, 'interval': 'step'}]

class ExamplePlotsCallback(pl.Callback):
    def __init__(self, num_samples: int = 8): super().__init__(); self.num_samples = num_samples
    def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
        log.info("--- Generating fully reconstructed plots ---")
        if pl_module.all_test_preds_np is None: return
        preds, trues = pl_module.all_test_preds_np, pl_module.all_test_trues_np
        metrics_df = pd.DataFrame({'r2': [r2_score(trues[i], preds[i]) for i in range(len(trues))]})
        n_samples = min(self.num_samples, len(trues))
        plot_groups = {"Random": np.random.choice(metrics_df.index, n_samples, replace=False), "Best": metrics_df.nlargest(n_samples, 'r2').index.values, "Worst": metrics_df.nsmallest(n_samples, 'r2').index.values}
        for name, indices in plot_groups.items():
            if not indices.size: continue
            self._generate_and_log_plot(trainer, pl_module, name, indices, preds, trues, metrics_df)

    def _generate_and_log_plot(self, trainer, pl_module, title, indices, preds, trues, metrics_df):
        hparams = pl_module.hparams
        paths = hparams['dataset']['paths']
        n_fine = hparams['dataset']['pchip']['n_fine']
        filename = Path(trainer.logger.log_dir) / f"test_plots_{title.lower()}.png"
        data = np.load(paths['preprocessed_npz'], allow_pickle=True)
        test_indices_in_full_valid_set = np.where(data['split_labels'] == 'test')[0]

        v_fine_mm = np.memmap(paths['v_fine_memmap'], dtype=np.float16, mode='r').reshape(-1, n_fine)
        i_fine_mm = np.memmap(paths['i_fine_memmap'], dtype=np.float16, mode='r').reshape(-1, n_fine)

        n_samples = len(indices)
        fig, axes = plt.subplots((n_samples + 3) // 4, 4, figsize=(20, 5 * ((n_samples + 3) // 4)), squeeze=False, constrained_layout=True)
        axes = axes.flatten()
        fig.suptitle(f"{title} Samples", fontsize=20, weight='bold')
        for i, test_set_idx in enumerate(indices):
            ax = axes[i]
            valid_set_idx = test_indices_in_full_valid_set[test_set_idx]
            raw_data_idx = data['valid_indices'][valid_set_idx]
            v_slice = data['v_slices'][valid_set_idx]
            i_true, i_pred = trues[test_set_idx], preds[test_set_idx]
            v_fine, i_fine = v_fine_mm[raw_data_idx].astype(np.float32), i_fine_mm[raw_data_idx].astype(np.float32)
            mask = ~np.isnan(v_fine)
            ax.plot(v_fine[mask], i_fine[mask], 'k-', alpha=0.7, lw=2, label='Actual (Fine Grid)')
            i_pred_full = PchipInterpolator(v_slice, i_pred, extrapolate=False)(v_fine)
            ax.plot(v_fine, i_pred_full, 'r--', lw=2, label='Predicted (Reconstructed)')
            ax.plot(v_slice, i_true,  'bo', ms=6, label='Actual Points')
            ax.plot(v_slice, i_pred,  'rx', ms=6, mew=2, label='Predicted Points')
            r2 = metrics_df.loc[test_set_idx, 'r2']
            ax.set(title=f"Test Sample #{test_set_idx} (R²={r2:.4f})", xlabel="Voltage (V)", ylabel="Current (mA/cm²)")
            ax.grid(True, linestyle='--', alpha=0.6); ax.legend()
            if mask.any(): ax.set(xlim=(-0.05, max(v_fine[mask].max()*1.05, 0.1)), ylim=(-max(i_fine[mask].max()*0.05, 1), None))
        for j in range(n_samples, len(axes)): fig.delaxes(axes[j])
        plt.savefig(filename, dpi=150, bbox_inches='tight'); plt.close(fig)
        trainer.logger.experiment.add_image(title, np.array(Image.open(filename)), 0, dataformats='HWC')
        log.info(f"Saved and logged reconstructed plot: {filename}")
        del v_fine_mm, i_fine_mm

def run_experiment(cfg: dict):
    log.info(f"Starting run '{cfg['train']['run_name']}'")
    seed_everything(cfg['train']['seed'])
    datamodule = IVDataModule(cfg)
    datamodule.prepare_data()
    datamodule.setup(stage='fit')
    batches_per_epoch = len(datamodule.train_dataloader())
    total_steps = cfg['trainer']['max_epochs'] * batches_per_epoch
    model = PhysicsIVSystem(cfg, warmup_steps=cfg['optimizer']['warmup_epochs'] * batches_per_epoch, total_steps=total_steps)
    log.info(f"Model instantiated with {sum(p.numel() for p in model.parameters())/1e6:.2f}M parameters.")
    trainer = pl.Trainer(
        **cfg['trainer'],
        default_root_dir=OUTPUT_DIR,
        logger=TensorBoardLogger(str(OUTPUT_DIR / "tb_logs"), name=cfg['train']['run_name']),
        callbacks=[ModelCheckpoint(monitor="val_loss", mode="min", save_top_k=1),
                   LearningRateMonitor(logging_interval="step"),
                   EarlyStopping(monitor="val_loss", patience=20, mode="min"),
                   RichProgressBar(),
                   ExamplePlotsCallback(num_samples=8)],
    )
    log.info("--- Starting Training ---")
    trainer.fit(model, datamodule=datamodule)
    log.info("--- Starting Final Testing on Best Checkpoint ---")
    trainer.test(datamodule=datamodule, ckpt_path="best")
    log.info(f"Experiment Finished. Full results in: {trainer.logger.log_dir}")

if __name__ == "__main__":
    if not (Path(VMPP_PREDICTOR_CONFIG['model_path']).exists() and Path(VMPP_PREDICTOR_CONFIG['feature_transformer_path']).exists()):
        log.error("="*80 + "\n! V_mpp predictor artifacts not found! Please run the 'VMPP Predictor Generation Script' first.\n" + "="*80)
    else:
        run_experiment(CONFIG)

INFO:lightning_fabric.utilities.seed:Seed set to 42


Feature Calculation:   0%|          | 0/66026 [00:00<?, ?it/s]

Slice Generation:   0%|          | 0/66026 [00:00<?, ?it/s]

  i_fine_mm[raw_idx, :mask.sum()] = i_full[mask]
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=12` reached.


INFO:pytorch_lightning.utilities.rank_zero:Restoring states from the checkpoint path at lightning_output/tb_logs/DilatedTCN-20250727_000258/version_0/checkpoints/epoch=11-step=4212.ckpt
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.utilities.rank_zero:Loaded model weights from the checkpoint at lightning_output/tb_logs/DilatedTCN-20250727_000258/version_0/checkpoints/epoch=11-step=4212.ckpt


Output()