In [1]:
import os
# os.environ["OMP_NUM_THREADS"] = "1"
# os.environ["MKL_NUM_THREADS"] = "1"
import time
import threading
import time
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
# from unetlite import UNetLite
import fucs  # 你原来的工具函数库（cal_acc 等）

In [2]:
class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.GroupNorm(num_groups=8, num_channels=out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.GroupNorm(num_groups=8, num_channels=out_channels)
        self.shortcut = nn.Sequential()
        if in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1),
                nn.GroupNorm(num_groups=8, num_channels=out_channels))
    
    def forward(self, x):
        identity = self.shortcut(x)
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += identity
        return F.relu(out)

class ClimateEncoder(nn.Module):
    def __init__(self, input_channels=10):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(input_channels, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU()
        )
        self.down1 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Dropout2d(p=0.5)
        )
        self.res_blocks = nn.Sequential(
            ResBlock(64, 64),
            ResBlock(64, 64)
        )
        self.upsample = nn.Sequential(
            nn.Upsample(size=(60, 70), mode='bilinear', align_corners=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Dropout2d(p=0.5),
            nn.Upsample(size=(120, 140), mode='bilinear', align_corners=True)
        )
    
    def forward(self, x):
        x = self.initial(x)
        x = self.down1(x)
        x = self.res_blocks(x)
        return self.upsample(x)

class SSTEncoder(nn.Module):
    def __init__(self, input_channels=3):
        super().__init__()
        self.upsample_init = nn.Upsample(size=(100, 180), mode='bilinear', align_corners=True)
        self.conv1 = nn.Sequential(
            nn.Conv2d(input_channels, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Dropout2d(p=0.5)
        )
        self.down1 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Dropout2d(p=0.5)
        )
        self.res_block = ResBlock(64, 64)
        self.upsample = nn.Sequential(
            nn.Upsample(size=(60, 70), mode='bilinear', align_corners=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Dropout2d(p=0.5),
            nn.Upsample(size=(120, 140), mode='bilinear', align_corners=True))
    
    def forward(self, x):
        x = self.upsample_init(x)
        x = self.conv1(x)
        x = self.down1(x)
        x = self.res_block(x)
        return self.upsample(x)





In [3]:
class ClimateDataset(Dataset):
    def __init__(self, climate_path, sst_path, precip_path, transform=None):
        # ---- Load with mmap to avoid full copy first ----
        climate_raw = np.load(climate_path, mmap_mode="r")    # [time, vars, H, W]
        sst_raw = np.load(sst_path, mmap_mode="r")            # [time, months, H, W]
        precip_raw = np.load(precip_path, mmap_mode="r")      # [time, H, W]

        # ---- Standardize climate (per-variable channel) ----
        climate_mean = np.nanmean(climate_raw, axis=(0,2,3), keepdims=True)
        climate_std  = np.nanstd(climate_raw, axis=(0,2,3), keepdims=True) + 1e-6
        climate_stdzd = (climate_raw - climate_mean) / climate_std
        # print("Standardized climate:", climate_stdzd.shape, climate_stdzd.dtype, climate_stdzd.nbytes/1e9, "GB")


        # ---- Standardize sst (per-month channel) ----
        sst_filled = np.nan_to_num(sst_raw, nan=0.0).astype(np.float64)
        T, M, H, W = sst_filled.shape

        sst_mean = np.zeros((1, M, 1, 1), dtype=np.float64)
        sst_std  = np.zeros((1, M, 1, 1), dtype=np.float64)

        for m in range(M):
            # print(f"Processing month {m}/{M} ...", flush=True)
            channel = sst_filled[:, m, :, :]  # [T,H,W]

            # --- 检查异常值 ---
            if not np.isfinite(channel).all():
                # print(f"Warning: channel {m} contains NaN/Inf")
                channel = np.nan_to_num(channel, nan=0.0, posinf=0.0, neginf=0.0)

            # --- 手写 mean/std 避免 MKL bug ---
            count = channel.size
            mean_val = channel.sum() / count
            var_val = ((channel - mean_val) ** 2).sum() / count
            std_val = np.sqrt(var_val) + 1e-6

            sst_mean[0, m, 0, 0] = mean_val
            sst_std[0, m, 0, 0]  = std_val

            # print(f"month {m}: mean={mean_val:.4f}, std={std_val:.4f}", flush=True)

        # 标准化
        sst_stdzd = (sst_filled - sst_mean) / sst_std
        # print("Standardized sst:", sst_stdzd.shape, sst_stdzd.dtype, sst_stdzd.nbytes/1e9, "GB")




        # ---- Standardize precip (global mean/std) ----
        precip_mean = np.nanmean(precip_raw)
        precip_std  = np.nanstd(precip_raw) + 1e-6
        precip_stdzd = (precip_raw - precip_mean) / precip_std
        precip_stdzd = np.clip(precip_stdzd, -3.0, 3.0)

        # ---- Convert to torch tensors (float32) ----
        self.climate_arr = torch.from_numpy(climate_stdzd.astype(np.float32))   # [T, vars, H, W]
        self.sst_arr     = torch.from_numpy(sst_stdzd.astype(np.float32))       # [T, months, H, W]
        self.precip_arr  = torch.from_numpy(precip_stdzd.astype(np.float32))    # [T, H, W]

        # ---- Mask from original precip NaN ----
        self.data_mask = torch.from_numpy(np.isnan(precip_raw)).unsqueeze(1)    # [T,1,H,W]
        print("valid_pixels per sample:", (~self.data_mask).sum(dim=[1,2,3]).cpu().numpy())

        self.transform = transform

    def __len__(self):
        return self.precip_arr.shape[0]

    def __getitem__(self, idx):
        climate_t = self.climate_arr[idx]
        sst_t     = self.sst_arr[idx]
        precip_t  = self.precip_arr[idx].reshape(1, *self.precip_arr.shape[1:])
        mask_t    = self.data_mask[idx]
        if self.transform:
            climate_t = self.transform(climate_t)
            sst_t     = self.transform(sst_t)
        return climate_t, sst_t, precip_t, mask_t

# 将 AugmentedSubset 移到模块顶层，避免 Windows multiprocess 无法 pickle 局部类的问题
class AugmentedSubset(Dataset):
    """
    对已有 Dataset 的 subset 进行包装，在 __getitem__ 时对 climate/sst 应用 transform（训练增强）。
    定义在模块顶层以便 DataLoader 的 worker 可序列化（Windows spawn 模式）。
    """
    def __init__(self, base_ds, indices, transform=None):
        super().__init__()
        self.base = base_ds
        self.indices = list(indices)
        self.transform = transform

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        idx = self.indices[i]
        climate_t, sst_t, precip_t, mask_t = self.base[idx]
        if self.transform is not None:
            climate_t = self.transform(climate_t)
            sst_t = self.transform(sst_t)
        return climate_t, sst_t, precip_t, mask_t

In [4]:
climate_path = "E:/D1/diffusion/my_models/mulNet_data/lr.npy"
sst_path = "E:/D1/diffusion/my_models/mulNet_data/sst.npy"
precip_path = "E:/D1/diffusion/my_models/mulNet_data/hr.npy"
print("Loading data...")

# 先创建一个不带增强的基础数据集（用于验证 & 统计）
base_dataset = ClimateDataset(climate_path, sst_path, precip_path, transform=None)
n_total = len(base_dataset)
n_val = 70
n_train = n_total - n_val

# 使用模块顶层的 AugmentedSubset（避免在主函数内定义类导致 worker 无法序列化）
# 训练集使用增强（GaussianNoise），验证集不使用增强
train_dataset = AugmentedSubset(base_dataset, range(0, n_train))
val_dataset = torch.utils.data.Subset(base_dataset, range(n_train, n_total))

# # 统计仍然基于基础数据集（标准化后的目标）
# target_stats(base_dataset, list(range(0, n_train)))
# inspect_target_distribution(base_dataset, list(range(0, n_train)), "train")
# inspect_target_distribution(base_dataset, list(range(n_train, len(base_dataset))), "val")

config = {
    'climate_channels': 10,
    'sst_channels': 3,
    'latent_channels': 16,
    'batch_size': 8,
    'epochs': 500,
    'lr': 5e-4,
    'weight_decay': 1e-3,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'resume_checkpoint': None,
    'train_num_workers': 4,
    'val_num_workers': 2,
    'prefetch_factor': 2,
    'acc_frequency': 10,
    'ckpt_dir': './weights'
    }

train_loader_kwargs = {
    'batch_size': config['batch_size'],
    'shuffle': True,
    'pin_memory': True,
    'num_workers': config['train_num_workers'],
    'prefetch_factor': config['prefetch_factor']
}
val_loader_kwargs = {
    'batch_size': config['batch_size'],
    'shuffle': False,
    'pin_memory': True,
    'num_workers': config['val_num_workers'],
    'prefetch_factor': config['prefetch_factor']
}
if config['train_num_workers'] > 0:
    train_loader_kwargs['persistent_workers'] = True
if config['val_num_workers'] > 0:
    val_loader_kwargs['persistent_workers'] = True

train_loader = DataLoader(train_dataset, **train_loader_kwargs)
val_loader = DataLoader(val_dataset, **val_loader_kwargs)

Loading data...
valid_pixels per sample: [5301 5301 5301 5301 5301 5301 5301 5301 5301    0    0 5301 5295 5295
 5295 5295 5295 5295 5295 5295 5295    0    0 5295 5295 5295 5295 5295
 5295 5295 5295 5295 5295    0    0 5295 5237 5237 5237 5237 5237 5237
 5237 5237 5237 5237    0 5237 5237 5237 5237 5237 5237 5237 5237 5237
 5237 5237    0 5237 5237 5237 5237 5237 5237 5237 5237 5237 5237 5237
 5237 5237 5237 5237 5237 5237 5237 5237 5237 5237 5237 5237 5237 5237
 5237 5237 5237 5237 5237 5237 5237 5237 5237 5237 5237 5237 5295 5295
 5295 5295 5295 5295 5295 5295 5295 5295 5295 5295 5295 5295 5295 5295
 5295 5295 5295 5295 5295 5295 5295 5295 5295 5295 5295 5295 5295 5295
 5295 5295 5295 5295 5295 5295 5295 5295 5295 5295 5295 5295 5295 5295
 5295 5295 5295 5295 5295 5295 5295 5295 5295 5295 5295 5295 5295 5295
 5295 5295 5295 5295 5295 5295 5295 5295 5295 5295 5295 5295 5295 5295
 5295 5295 5295 5295 5295 5295 5295 5295 5295 5295 5295 5295 5295 5295
 5295 5295 5295 5295 5295 5295 5295 

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np

# -----------------------------
# 1. ResBlock (保持不变)
# -----------------------------
class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.GroupNorm(num_groups=8, num_channels=out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.GroupNorm(num_groups=8, num_channels=out_channels)
        self.shortcut = nn.Sequential()
        if in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1),
                nn.GroupNorm(num_groups=8, num_channels=out_channels))
    
    def forward(self, x):
        identity = self.shortcut(x)
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += identity
        return F.relu(out)

# -----------------------------
# 2. Encoder & Decoder
# -----------------------------
class ClimateEncoder(nn.Module):
    def __init__(self, input_channels=10, latent_dim=64):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(input_channels, 32, kernel_size=4, stride=2, padding=1),  # 120x140 -> 60x70
            nn.ReLU()
        )
        self.down1 = nn.Sequential(
            nn.Conv2d(32, latent_dim, kernel_size=4, stride=2, padding=1),  # 60x70 -> 30x35
            nn.ReLU(),
            nn.Dropout2d(p=0.5)
        )
        self.res_blocks = nn.Sequential(
            ResBlock(latent_dim, latent_dim),
            ResBlock(latent_dim, latent_dim)
        )

    def forward(self, x):
        x = self.initial(x)
        x = self.down1(x)
        x = self.res_blocks(x)
        return x  # [B, latent_dim, 30, 35]


class ClimateDecoder(nn.Module):
    def __init__(self, output_channels=10, latent_dim=64):
        super().__init__()
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 32, kernel_size=4, stride=2, padding=1),  # 30x35 -> 60x70
            nn.ReLU()
        )
        self.up2 = nn.Sequential(
            nn.ConvTranspose2d(32, output_channels, kernel_size=4, stride=2, padding=1),  # 60x70 -> 120x140
        )

    def forward(self, x):
        x = self.up1(x)
        x = self.up2(x)
        return x


class SSTEncoder(nn.Module):
    def __init__(self, input_channels=3):
        super().__init__()
        self.upsample_init = nn.Upsample(size=(100, 180), mode='bilinear', align_corners=True)
        self.conv1 = nn.Sequential(
            nn.Conv2d(input_channels, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Dropout2d(p=0.5)
        )
        self.down1 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Dropout2d(p=0.5)
        )
        self.res_block = ResBlock(64, 64)
        self.upsample = nn.Sequential(
            nn.Upsample(size=(60, 70), mode='bilinear', align_corners=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Dropout2d(p=0.5),
            nn.Upsample(size=(120, 140), mode='bilinear', align_corners=True))
    
    def forward(self, x):
        x = self.upsample_init(x)
        x = self.conv1(x)
        x = self.down1(x)
        x = self.res_block(x)
        return self.upsample(x)


class SSTDecoder(nn.Module):
    def __init__(self, output_channels=3, latent_dim=64):
        super().__init__()
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 32, kernel_size=4, stride=2, padding=1),  # 50x90 -> 100x180
            nn.ReLU()
        )
        self.out = nn.Conv2d(32, output_channels, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.up1(x)
        x = self.out(x)
        return x

# -----------------------------
# 3. Autoencoder Wrappers
# -----------------------------
class ClimateAE(nn.Module):
    def __init__(self, input_channels=10, latent_dim=64):
        super().__init__()
        self.encoder = ClimateEncoder(input_channels, latent_dim)
        self.decoder = ClimateDecoder(input_channels, latent_dim)

    def forward(self, x):
        z = self.encoder(x)
        out = self.decoder(z)
        return out


class SSTAe(nn.Module):
    def __init__(self, input_channels=3, latent_dim=64):
        super().__init__()
        self.encoder = SSTEncoder(input_channels)
        self.decoder = SSTDecoder(input_channels, latent_dim)

    def forward(self, x):
        z = self.encoder(x)
        out = self.decoder(z)
        return out

# -----------------------------
# 4. Dataset & DataLoader
# -----------------------------
class ClimateDataset(Dataset):
    def __init__(self, climate_path, sst_path, mode="climate"):
        self.mode = mode
        if mode == "climate":
            arr = np.load(climate_path, mmap_mode="r")  # [T, vars, H, W]
        elif mode == "sst":
            arr = np.load(sst_path, mmap_mode="r")  # [T, months, H, W]
        else:
            raise ValueError("mode must be 'climate' or 'sst'")
        arr = np.nan_to_num(arr).astype(np.float32)
        self.arr = torch.from_numpy(arr)

    def __len__(self):
        return self.arr.shape[0]

    def __getitem__(self, idx):
        x = self.arr[idx]
        return x, x  # AE: input = target

# -----------------------------
# 5. Training Loop
# -----------------------------
def train_autoencoder(model, dataloader, device, epochs=5):
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.MSELoss()

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            pred = model(x)
            loss = criterion(pred, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}/{epochs}, Loss={total_loss/len(dataloader):.4f}")

# -----------------------------
# 6. Example Usage
# -----------------------------
if __name__ == "__main__":
    climate_path = "E:/D1/diffusion/my_models/mulNet_data/lr.npy"
    sst_path = "E:/D1/diffusion/my_models/mulNet_data/sst.npy"

    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Train Climate AE
    ds = ClimateDataset(climate_path, sst_path, mode="climate")
    dl = DataLoader(ds, batch_size=8, shuffle=True)
    model = ClimateAE(input_channels=10)
    train_autoencoder(model, dl, device, epochs=10)

    # Train SST AE
    ds_sst = ClimateDataset(climate_path, sst_path, mode="sst")
    dl_sst = DataLoader(ds_sst, batch_size=8, shuffle=True)
    model_sst = SSTAe(input_channels=3)
    train_autoencoder(model_sst, dl_sst, device, epochs=10)


In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

# ==== 你给的 ResBlock / ClimateEncoder 代码在这里 ====
# (略去重复，假设你已经定义好 ClimateEncoder 和 SSTEncoder)

class ClimateModel(nn.Module):
    """ 在 ClimateEncoder 后加一个 1x1 卷积作为预测头 """
    def __init__(self, input_channels=10):
        super().__init__()
        self.encoder = ClimateEncoder(input_channels)
        self.head = nn.Conv2d(64, 1, kernel_size=1)

    def forward(self, x):
        feat = self.encoder(x)  # [B,64,120,140]
        out = self.head(feat)   # [B,1,120,140]
        return out

# ==== 假数据 (替换成你的 DataLoader) ====
# B, T, H, W = 16, 10, 120, 140
# climate_x = torch.randn(100, T, H, W)
# target_y = torch.randn(100, 1, H, W)

train_loader = DataLoader(train_dataset, **train_loader_kwargs)

# ==== 训练 ====
device = "cuda" if torch.cuda.is_available() else "cpu"
model = ClimateModel(input_channels=10).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

for epoch in range(5):
    model.train()
    total_loss = 0.0
    for climate, sst, precip, mask in train_loader:
        climate, sst, precip, mask = climate.to(device), sst.to(device), precip.to(device), mask.to(device)
        pred = model(climate)
        loss = criterion(pred, precip)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"[ClimateEncoder] Epoch {epoch+1}, Loss={total_loss/len(train_loader):.4f}")


RuntimeError: DataLoader worker (pid(s) 25356, 45244, 40456, 41192) exited unexpectedly