In [None]:
import os
import time
import copy
import datetime
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from scipy.stats import pearsonr, zscore
from sklearn.model_selection import train_test_split

torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

cudnn.benchmark = True
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass
if torch.cuda.is_available():
    try:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
    except Exception:
        pass

def compute_pearson_r2(y_true, y_pred):
    if len(y_true) <= 2:
        return np.nan
    if np.std(y_pred) < 1e-7 or np.std(y_true) < 1e-7:
        return np.nan
    r, _ = pearsonr(y_true, y_pred)
    return r * r

def create_mask(X):
    return np.all(~np.isnan(X), axis=3, keepdims=True).astype(np.float32)

def normalize_data(X_train, X_val, X_test):
    band_mean = np.nanmean(X_train, axis=(0, 1, 2))
    band_std = np.nanstd(X_train, axis=(0, 1, 2))
    band_std[band_std < 1e-7] = 1.0

    def norm(A):
        return np.where(np.isnan(A), np.nan, (A - band_mean) / band_std)

    return norm(X_train), norm(X_val), norm(X_test), band_mean, band_std

def time_split(X, y, meta, cutoff_date="20240101"):
    cutoff = datetime.datetime.strptime(cutoff_date, "%Y%m%d")
    before_idx = [i for i, m in enumerate(meta) if datetime.datetime.strptime(m["date"], "%Y%m%d") < cutoff]
    test_idx = [i for i, m in enumerate(meta) if datetime.datetime.strptime(m["date"], "%Y%m%d") >= cutoff]
    if len(before_idx) == 0 or len(test_idx) == 0:
        X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=0.2, random_state=42)
        meta_tr, meta_te = train_test_split(meta, test_size=0.2, random_state=42)
        return X_tr, y_tr, meta_tr, X_te, y_te, meta_te
    return (X[before_idx], y[before_idx], [meta[i] for i in before_idx],
            X[test_idx], y[test_idx], [meta[i] for i in test_idx])

def filter_by_zscore_and_ratio(X, y, meta, z_threshold=2.5, min_valid_ratio=0.5):
    nan_ratios = np.mean(np.isnan(X), axis=(1, 2, 3))
    valid_ratios = 1 - nan_ratios
    z_scores = zscore(nan_ratios)
    keep_mask = (z_scores <= z_threshold) & (valid_ratios >= min_valid_ratio)
    idx = np.where(keep_mask)[0]
    return X[idx], y[idx], [meta[i] for i in idx]

_ONES_CACHE = {}

def _get_ones(device, dtype, k):
    key = (device, dtype, k)
    if key not in _ONES_CACHE:
        _ONES_CACHE[key] = torch.ones(1, 1, k, k, device=device, dtype=dtype)
    return _ONES_CACHE[key]

def pconv3x3_like(x, mask, conv3x3: nn.Conv2d):
    x_masked = x * mask
    y = conv3x3(x_masked)
    k = conv3x3.kernel_size[0]
    ones = _get_ones(x.device, x.dtype, k)
    cnt = F.conv2d(mask, ones, bias=None, stride=conv3x3.stride, padding=conv3x3.padding, dilation=conv3x3.dilation)
    scale = (k * k) / torch.clamp(cnt, min=1.0)
    y = y * scale
    new_mask = (cnt > 0).to(x.dtype)
    y = y * new_mask
    return y, new_mask

def _kernel_area(k):
    if isinstance(k, tuple):
        return int(k[0]) * int(k[1])
    return int(k) * int(k)

def masked_avgpool2d(x, mask, k=2, s=2):
    area = _kernel_area(k)
    sum_x = F.avg_pool2d(x * mask, k, s) * area
    cnt = F.avg_pool2d(mask, k, s) * area
    out = sum_x / torch.clamp(cnt, min=1.0)
    new_mask = (cnt > 0).to(x.dtype)
    out = out * new_mask
    return out, new_mask

class EnhancedCNNModel(nn.Module):
    def __init__(self, input_channels=41):
        super().__init__()
        self.conv_block1 = nn.Sequential(
            nn.Conv2d(input_channels, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.AvgPool2d(2),
            nn.Dropout(0.2),
        )
        self.conv_block2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.AvgPool2d(2),
            nn.Dropout(0.3),
        )
        self.conv_block3 = nn.Sequential(
            nn.Conv2d(128, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1),
            nn.Dropout(0.4),
        )
        self.fc_block = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(128, 64),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(64, 1),
        )

    def forward(self, x, mask):
        x = torch.nan_to_num(x, nan=0.0)

        x, mask = pconv3x3_like(x, mask, self.conv_block1[0])
        x = self.conv_block1[1](x)
        x = self.conv_block1[2](x)
        x, mask = pconv3x3_like(x, mask, self.conv_block1[3])
        x = self.conv_block1[4](x)
        x = self.conv_block1[5](x)
        x, mask = masked_avgpool2d(x, mask, k=2, s=2)
        x = self.conv_block1[7](x)

        x, mask = pconv3x3_like(x, mask, self.conv_block2[0])
        x = self.conv_block2[1](x)
        x = self.conv_block2[2](x)
        x, mask = pconv3x3_like(x, mask, self.conv_block2[3])
        x = self.conv_block2[4](x)
        x = self.conv_block2[5](x)
        x, mask = masked_avgpool2d(x, mask, k=2, s=2)
        x = self.conv_block2[7](x)

        x, mask = pconv3x3_like(x, mask, self.conv_block3[0])
        x = self.conv_block3[1](x)
        x = self.conv_block3[2](x)
        x, mask = pconv3x3_like(x, mask, self.conv_block3[3])
        x = self.conv_block3[4](x)
        x = self.conv_block3[5](x)
        H, W = x.shape[-2:]
        x, mask = masked_avgpool2d(x, mask, k=(H, W), s=(H, W))
        x = self.conv_block3[7](x)

        x = x.view(x.size(0), -1)
        x = self.fc_block(x)
        return x.squeeze(-1)

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    y_final_tp = y_final[:, 0]
    X_filtered, y_filtered, meta_filtered = filter_by_zscore_and_ratio(X_final, y_final_tp, meta_final)
    X_before, y_before, meta_before, X_test, y_test, meta_test = time_split(X_filtered, y_filtered, meta_filtered)
    X_train, X_val, y_train, y_val, meta_train, meta_val = train_test_split(
        X_before, y_before, meta_before, test_size=0.2, random_state=42
    )

    y_train_log = np.log(np.clip(y_train, 1e-6, np.inf))
    y_val_log = np.log(np.clip(y_val, 1e-6, np.inf))
    y_test_log = np.log(np.clip(y_test, 1e-6, np.inf))

    X_train_norm, X_val_norm, X_test_norm, band_mean, band_std = normalize_data(X_train, X_val, X_test)

    mask_train = create_mask(X_train)
    mask_val = create_mask(X_val)
    mask_test = create_mask(X_test)

    def to_tensor_4d(X):
        return torch.tensor(X, dtype=torch.float32).permute(0, 3, 1, 2).contiguous(memory_format=torch.channels_last)

    def to_tensor_mask(M):
        return torch.tensor(M, dtype=torch.float32).permute(0, 3, 1, 2).contiguous(memory_format=torch.channels_last)

    X_train_t = to_tensor_4d(X_train_norm)
    X_val_t = to_tensor_4d(X_val_norm)
    X_test_t = to_tensor_4d(X_test_norm)
    M_train_t = to_tensor_mask(mask_train)
    M_val_t = to_tensor_mask(mask_val)
    M_test_t = to_tensor_mask(mask_test)
    y_train_t = torch.tensor(y_train_log, dtype=torch.float32)
    y_val_t = torch.tensor(y_val_log, dtype=torch.float32)
    y_test_t = torch.tensor(y_test_log, dtype=torch.float32)

    batch_size = 128 if torch.cuda.is_available() else 32
    num_workers = min(8, os.cpu_count() or 1)
    persistent = num_workers > 0
    prefetch = 4 if num_workers > 0 else None

    def make_loader(dataset, shuffle):
        kwargs = dict(batch_size=batch_size, shuffle=shuffle, pin_memory=True, num_workers=num_workers)
        if persistent:
            kwargs["persistent_workers"] = True
        if prefetch is not None:
            kwargs["prefetch_factor"] = prefetch
        return DataLoader(dataset, **kwargs)

    train_loader = make_loader(TensorDataset(X_train_t, M_train_t, y_train_t), True)
    val_loader = make_loader(TensorDataset(X_val_t, M_val_t, y_val_t), False)
    test_loader = make_loader(TensorDataset(X_test_t, M_test_t, y_test_t), False)

    train_eval_loader = make_loader(TensorDataset(X_train_t, M_train_t, y_train_t), False)

    base_model = EnhancedCNNModel(input_channels=X_train_t.shape[1]).to(device)
    base_model = base_model.to(memory_format=torch.channels_last)

    model = base_model
    if device.type == "cuda" and torch.cuda.device_count() > 1:
        model = nn.DataParallel(base_model).to(device)

    criterion = nn.SmoothL1Loss(beta=0.5)

    try:
        optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4, fused=True)
    except TypeError:
        optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5)
    scaler = torch.amp.GradScaler("cuda", enabled=torch.cuda.is_available())

    BN_FREEZE_EPOCH = 5

    def _freeze_bn(m):
        if isinstance(m, nn.BatchNorm2d):
            m.eval()

    def apply_to_model(m, fn):
        if isinstance(m, nn.DataParallel):
            m.module.apply(fn)
        else:
            m.apply(fn)

    best_val_loss = float("inf")
    best_state = copy.deepcopy(model.state_dict())
    patience = 10
    patience_ctr = 0

    for epoch in range(1, 101):
        if epoch == BN_FREEZE_EPOCH:
            apply_to_model(model, _freeze_bn)

        model.train()
        tr_sum = 0.0
        for Xb, Mb, yb in train_loader:
            Xb = Xb.to(device, non_blocking=True)
            Mb = Mb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True)

            optimizer.zero_grad(set_to_none=True)
            with torch.autocast("cuda", enabled=torch.cuda.is_available()):
                out = model(Xb, Mb)
                loss = criterion(out, yb)

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()

            tr_sum += loss.item() * Xb.size(0)

        model.eval()
        va_sum = 0.0
        with torch.no_grad(), torch.autocast("cuda", enabled=torch.cuda.is_available()):
            for Xb, Mb, yb in val_loader:
                Xb = Xb.to(device, non_blocking=True)
                Mb = Mb.to(device, non_blocking=True)
                yb = yb.to(device, non_blocking=True)
                out = model(Xb, Mb)
                va_sum += criterion(out, yb).item() * Xb.size(0)

        val_loss = va_sum / len(val_loader.dataset)
        scheduler.step(val_loss)

        if val_loss < best_val_loss - 1e-6:
            best_val_loss = val_loss
            best_state = copy.deepcopy(model.state_dict())
            patience_ctr = 0
        else:
            patience_ctr += 1
            if patience_ctr >= patience:
                break

    model.load_state_dict(best_state)

    def predict_on_loader(loader):
        model.eval()
        pred_log, true_log = [], []
        with torch.no_grad(), torch.autocast("cuda", enabled=torch.cuda.is_available()):
            for Xb, Mb, yb_log in loader:
                Xb = Xb.to(device, non_blocking=True)
                Mb = Mb.to(device, non_blocking=True)
                yb_log = yb_log.to(device, non_blocking=True)
                out_log = model(Xb, Mb).float().squeeze(-1)
                pred_log.append(out_log.detach().cpu().numpy())
                true_log.append(yb_log.detach().cpu().numpy())
        y_pred_log = np.concatenate(pred_log)
        y_true_log = np.concatenate(true_log)
        y_pred = np.exp(y_pred_log) - 1e-6
        y_true = np.exp(y_true_log) - 1e-6
        return y_true, y_pred

    y_train_true, y_train_pred = predict_on_loader(train_eval_loader)
    y_val_true, y_val_pred = predict_on_loader(val_loader)
    y_test_true, y_test_pred = predict_on_loader(test_loader)

    def metrics(y_true, y_pred):
        return {
            "PearsonR2": compute_pearson_r2(y_true, y_pred),
            "MAE": float(np.mean(np.abs(y_true - y_pred))),
            "RMSE": float(np.sqrt(np.mean((y_true - y_pred) ** 2))),
            "SD_pred": float(np.std(y_pred)),
            "N": int(len(y_true)),
        }

    m_train = metrics(y_train_true, y_train_pred)
    m_val = metrics(y_val_true, y_val_pred)
    m_test = metrics(y_test_true, y_test_pred)

    def _safe_site_id(m):
        if isinstance(m, dict):
            return m.get("site_id", m.get("index", m.get("id", None)))
        for k in ("site_id", "index", "id"):
            if hasattr(m, k):
                return getattr(m, k)
        return None

    train_dates = np.array([m["date"] for m in meta_train])
    val_dates = np.array([m["date"] for m in meta_val])
    test_dates = np.array([m["date"] for m in meta_test])

    train_site_ids = np.array([_safe_site_id(m) for m in meta_train])
    val_site_ids = np.array([_safe_site_id(m) for m in meta_val])
    test_site_ids = np.array([_safe_site_id(m) for m in meta_test])

    np.savez(
        "tp_predictions_train_val_test.npz",
        y_train_true=y_train_true, y_train_pred=y_train_pred,
        y_val_true=y_val_true, y_val_pred=y_val_pred,
        y_test_true=y_test_true, y_test_pred=y_test_pred,
        train_dates=train_dates, train_site_ids=train_site_ids,
        val_dates=val_dates, val_site_ids=val_site_ids,
        test_dates=test_dates, test_site_ids=test_site_ids,
        meta_train=np.array(meta_train, dtype=object),
        meta_val=np.array(meta_val, dtype=object),
        meta_test=np.array(meta_test, dtype=object),
        metrics_train=m_train, metrics_val=m_val, metrics_test=m_test,
    )

    state_to_save = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()
    ckpt = {
        "state_dict": state_to_save,
        "band_mean": band_mean.astype(np.float32).tolist(),
        "band_std": band_std.astype(np.float32).tolist(),
        "input_channels": int(X_train_t.shape[1]),
        "arch": "EnhancedCNNModel(mask-aware v1)",
    }
    torch.save(ckpt, "tp_maskaware_checkpoint.pt")

if __name__ == "__main__":
    main()