# Cohesive Weather Forecasting & Satellite Analysis Framework

This notebook implements a comprehensive framework for processing satellite imagery for various weather forecasting tasks. It includes:

1.  **Unified Configuration**: Centralized management of hyperparameters and paths, including a **Smart Product Catalog**.
2.  **Flexible Data Pipeline**: 
    *   **Smart Parsing**: Selects semantically relevant channels (e.g., just Visual + Lightning).
    *   **All-Channel Mode**: Supports consumption of the full 162-channel tensor.
3.  **Model Zoo**: Implementation of 6 specific deep learning architectures from academic literature.
4.  **Robust Training**: Unified trainer with **Verbose Logging** and strict GPU utilization checks.

In [None]:
try:
    from google.colab import drive
    drive.mount('/content/drive')
    print("Google Drive mounted successfully.")
except ImportError:
    print("Not running in Google Colab, skipping Drive mount.")

In [None]:
import torch
import os
import sys

# === GPU/Performance Optimization ===
REQUIRE_GPU = False  # Set to True to raise error if no GPU found

if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
    # Optimizations for fixed-size inputs
    torch.backends.cudnn.benchmark = True 
    torch.backends.cudnn.enabled = True
elif torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
else:
    DEVICE = torch.device("cpu")

print(f"=== SYSTEM DIAGNOSTICS ===")
print(f"Python Version: {sys.version.split()[0]}")
print(f"PyTorch Version: {torch.__version__}")
print(f"Selected Device: {DEVICE}")

if DEVICE.type == "cuda":
    print(f"CUDA Device Name: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Capability: {torch.cuda.get_device_capability(0)}")
    print(f"Memory Allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
    print(f"Memory Cached: {torch.cuda.memory_reserved(0) / 1024**3:.2f} GB")
    torch.cuda.empty_cache()
    print("Cache cleared.")
elif DEVICE.type == "mps":
    print("Using Apple Metal Performance Shaders (MPS).")
else:
    print("WARNING: Running on CPU. Training will be slow.")

if REQUIRE_GPU and DEVICE.type == "cpu":
    raise RuntimeError("No GPU found! Check Runtime > Change runtime type in Colab.")

In [None]:
import glob
import random
from datetime import datetime, timedelta
from pathlib import Path
from collections import defaultdict
import math
import time

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# Ensure reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)

In [None]:
class Config:
    # Paths
    if Path("/content/drive/MyDrive/DL/").exists():
        DATA_ROOT = Path("/content/drive/MyDrive/DL/")
    else:
        DATA_ROOT = Path("./data") 
        
    DAILIES_PATH = DATA_ROOT / "dailies.csv"
    PROCESSED_ROOT = DATA_ROOT / "processed_tensors"
    
    # Data Parameters
    LOCATION = "new_york_ny"
    HISTORY_DAYS = 30
    IMAGE_SIZE = (128, 128)
    TARGET_COLUMN = "target_tmax_next_day"
    
    # Training Hyperparameters
    BATCH_SIZE = 32
    LEARNING_RATE = 1e-4
    NUM_EPOCHS = 10
    
    # --- SMART PRODUCT CATALOG ---
    PRODUCT_CATALOG = [
        "AMSRU2_Sea_Ice_Brightness_Temp_6km_89H",
        "AMSRU2_Sea_Ice_Brightness_Temp_6km_89V",
        "AMSRU2_Sea_Ice_Concentration_12km",
        "AMSUA_NOAA15_Brightness_Temp_Channel_1",
        "AMSUA_NOAA15_Brightness_Temp_Channel_10",
        "AMSUA_NOAA15_Brightness_Temp_Channel_12",
        "AMSUA_NOAA15_Brightness_Temp_Channel_13",
        "AMSUA_NOAA15_Brightness_Temp_Channel_15",
        "AMSUA_NOAA15_Brightness_Temp_Channel_2",
        "AMSUA_NOAA15_Brightness_Temp_Channel_3",
        "AMSUA_NOAA15_Brightness_Temp_Channel_4",
        "AMSUA_NOAA15_Brightness_Temp_Channel_5",
        "AMSUA_NOAA15_Brightness_Temp_Channel_6",
        "AMSUA_NOAA15_Brightness_Temp_Channel_7",
        "AMSUA_NOAA15_Brightness_Temp_Channel_8",
        "AMSUA_NOAA15_Brightness_Temp_Channel_9",
        "GHRSST_L4_AVHRR-OI_Sea_Surface_Temperature",
        "GHRSST_L4_MUR_Sea_Ice_Concentration",
        "GOES-East_ABI_GeoColor",
        "GOES-West_ABI_GeoColor",
        "IMERG_Precipitation_Rate_30min",
        "LIS_TRMM_Flash_Radiance",
        "MERRA2_ISCCP_Cloud_Albedo_Monthly",
        "MODIS_Aqua_Cloud_Optical_Thickness_16",
        "MODIS_Aqua_Cloud_Top_Height_Day",
        "MODIS_Aqua_Cloud_Top_Temp_Day",
        "MODIS_Aqua_Cloud_Top_Temp_Night",
        "MODIS_Aqua_CorrectedReflectance_Bands721",
        "MODIS_Aqua_CorrectedReflectance_TrueColor",
        "MODIS_Aqua_L2_Chlorophyll_A",
        "MODIS_Aqua_L3_NDSI_Snow_Cover_Daily",
        "MODIS_Aqua_SurfaceReflectance_Bands121",
        "MODIS_Terra_Cloud_Optical_Thickness_16",
        "MODIS_Terra_Cloud_Top_Height_Day",
        "MODIS_Terra_Cloud_Top_Temp_Day",
        "MODIS_Terra_Cloud_Top_Temp_Night",
        "MODIS_Terra_CorrectedReflectance_Bands367",
        "MODIS_Terra_CorrectedReflectance_Bands721",
        "MODIS_Terra_CorrectedReflectance_TrueColor",
        "MODIS_Terra_L2_Chlorophyll_A",
        "MODIS_Terra_L3_NDSI_Snow_Cover_Daily",
        "MODIS_Terra_SurfaceReflectance_Bands121",
        "OSCAR_Sea_Surface_Currents_Meridional",
        "OSCAR_Sea_Surface_Currents_Zonal",
        "SSMI_DMSP_F11_Cloud_Liquid_Water_Over_Oceans_Ascending",
        "SSMI_DMSP_F11_Cloud_Liquid_Water_Over_Oceans_Descending",
        "TOPEX-Poseidon_JASON_Sea_Surface_Height_Anomalies_GDR_Cycles",
        "TRMM_Brightness_Temp_Asc",
        "TRMM_Brightness_Temp_Dsc",
        "TRMM_Precipitation_Rate_Asc",
        "TRMM_Precipitation_Rate_Dsc",
        "VIIRS_NOAA20_CorrectedReflectance_TrueColor_Granule",
        "VIIRS_SNPP_CorrectedReflectance_TrueColor_Granule",
        "VIIRS_SNPP_L2_Chlorophyll_A"
    ]
    
    @staticmethod
    def get_indices(product_name_substring):
        for i, name in enumerate(Config.PRODUCT_CATALOG):
            if product_name_substring in name:
                return i * 3, i * 3 + 3
        print(f"WARNING: Product matching '{product_name_substring}' not found. Using indices 0-3.")
        return 0, 3
        
    @staticmethod
    def select_channels(tensor, product_names):
        chunks = []
        for name in product_names:
            start, end = Config.get_indices(name)
            if start >= tensor.shape[0]:
                chunks.append(torch.zeros((3, tensor.shape[1], tensor.shape[2])))
            else:
                chunks.append(tensor[start:end])
        return torch.cat(chunks, dim=0)

config = Config()
if config.PROCESSED_ROOT.exists():
    config.PROCESSED_ROOT.mkdir(parents=True, exist_ok=True)

## Smart Data Loading with All-Channel Support

Datasets now accept `use_all_channels=True/False`.
*   `False`: Use semantic selection (e.g., 3 channels).
*   `True`: Use 162 channels.

In [None]:
class BaseSlicingDataset(Dataset):
    def __init__(self, file_paths, use_all_channels=False):
        self.files = sorted(file_paths)
        self.use_all_channels = use_all_channels
    def __len__(self):
        return len(self.files)
    def load_tensor(self, path):
        return torch.load(path, weights_only=True)
    def _get_input(self, data, semantic_selection_func):
        if self.use_all_channels:
            return data # All 162 channels
        else:
            return semantic_selection_func(data)
    def __getitem__(self, idx):
        raise NotImplementedError()

# --- 1. Existing TMAX Forecasting Dataset ---
class GIBSSequenceDataset(Dataset):
    def __init__(self, end_dates, processed_dir, target_df, history_days=30, num_channels_expected=162):
        self.end_dates = end_dates
        self.processed_dir = processed_dir
        self.target_df = target_df
        self.history_days = history_days
        self.num_channels_expected = num_channels_expected

    def __len__(self):
        return len(self.end_dates)

    def __getitem__(self, idx):
        end_date = self.end_dates[idx]
        window = [end_date - timedelta(days=i) for i in range(self.history_days - 1, -1, -1)]
        frames = []
        for d in window:
            tensor_path = self.processed_dir / f"{d}.pt"
            if tensor_path.exists():
                frames.append(torch.load(tensor_path, weights_only=True))
            else:
                frames.append(torch.zeros((self.num_channels_expected, config.IMAGE_SIZE[0], config.IMAGE_SIZE[1])))
        x = torch.stack(frames, dim=0)
        try:
            val = self.target_df.loc[pd.Timestamp(end_date), config.TARGET_COLUMN]
            y = torch.tensor(val, dtype=torch.float32)
        except KeyError:
            y = torch.tensor(0.0, dtype=torch.float32)
        return x, y

# --- 2. Model 1: Cloud Cover Dataset ---
class CloudCoverDataset(BaseSlicingDataset):
    def __getitem__(self, idx):
        data = self.load_tensor(self.files[idx])
        
        def select_semantic(d):
             x_sel = Config.select_channels(d, ["GOES-East_ABI_GeoColor"])
             return torch.cat([x_sel, x_sel[:1]], dim=0) # Pad to 4

        x = self._get_input(data, select_semantic)
        y = torch.cat([x, x], dim=0)[:6] # Dummy target of size 6
        if y.shape[0] < 6: y = torch.cat([y, y], dim=0)[:6] # Ensure size
        return x, y

# --- 3. Model 2: SIANet Dataset ---
class SIANetDataset(Dataset):
    def __init__(self, valid_dates, processed_dir, use_all_channels=False):
        self.dates = sorted(valid_dates)
        self.dir = processed_dir
        self.use_all_channels = use_all_channels
    
    def __len__(self):
        return len(self.dates)

    def __getitem__(self, idx):
        end_d = self.dates[idx]
        window = [end_d - timedelta(days=i) for i in range(3, -1, -1)]
        frames = []
        for d in window:
            p = self.dir / f"{d}.pt"
            if p.exists():
                d_tensor = torch.load(p, weights_only=True)
                if self.use_all_channels:
                    frames.append(d_tensor)
                else:
                    goes = Config.select_channels(d_tensor, ["GOES-East_ABI_GeoColor"])
                    water = Config.select_channels(d_tensor, ["SSMI_DMSP_F11_Cloud_Liquid_Water_Over_Oceans_Descending"])
                    frames.append(torch.cat([goes, water[:1]], dim=0))
            else:
                C = 162 if self.use_all_channels else 4
                frames.append(torch.zeros((C, config.IMAGE_SIZE[0], config.IMAGE_SIZE[1])))
        
        x_seq = torch.stack(frames, dim=1) 
        y = torch.zeros((1, 32, x_seq.shape[-2], x_seq.shape[-1]))
        return x_seq, y

# --- 4. Model 3: Pix2Pix Dataset ---
class Pix2PixDataset(BaseSlicingDataset):
    def __getitem__(self, idx):
        data = self.load_tensor(self.files[idx])
        
        def select_semantic(d):
            return Config.select_channels(d, ["GOES-East_ABI_GeoColor"])

        x = self._get_input(data, select_semantic)
        # Target remains fixed (LST)
        therm = Config.select_channels(data, ["MODIS_Aqua_Cloud_Top_Temp_Day", "MODIS_Aqua_Cloud_Top_Temp_Night"])
        y = therm[:2]
        return x, y

# --- 5. Model 4: Microwave Dataset ---
class MicrowaveDataset(BaseSlicingDataset):
    def __getitem__(self, idx):
        data = self.load_tensor(self.files[idx])
        
        def select_semantic(d):
            inputs = [
                "AMSUA_NOAA15_Brightness_Temp_Channel_1", "AMSUA_NOAA15_Brightness_Temp_Channel_2",
                "AMSUA_NOAA15_Brightness_Temp_Channel_3", "AMSUA_NOAA15_Brightness_Temp_Channel_4"
            ]
            raw = Config.select_channels(d, inputs)
            return raw[:10] if raw.shape[0] >= 10 else F.pad(raw, (0,0,0,0,0,10-raw.shape[0]))

        x = self._get_input(data, select_semantic)
        y = Config.select_channels(data, ["TRMM_Brightness_Temp_Asc"])[:1]
        return x, y

# --- 6. Model 5: Convection Dataset ---
class ConvectionDataset(BaseSlicingDataset):
    def __getitem__(self, idx):
        data = self.load_tensor(self.files[idx])
        
        def select_semantic(d):
            vis = Config.select_channels(d, ["GOES-East_ABI_GeoColor"])[:1]
            ir = Config.select_channels(d, ["MODIS_Aqua_Cloud_Top_Temp_Day"])[:1]
            ltg = Config.select_channels(d, ["LIS_TRMM_Flash_Radiance"])[:1]
            return torch.cat([vis, ir, ltg], dim=0)

        x = self._get_input(data, select_semantic)
        y = torch.tensor([0.0])
        return x, y

In [None]:
def discover_files(config):
    """Shared discovery of .pt files and dates."""
    if not config.PROCESSED_ROOT.exists():
        return [], []
        
    files = sorted(list(config.PROCESSED_ROOT.glob("*.pt")))
    dates = []
    for pf in files:
        try:
            dates.append(datetime.strptime(pf.stem, "%Y-%m-%d").date())
        except ValueError:
            pass
    return files, dates

def prepare_tmax_data(config): 
    files, dates = discover_files(config)
    if not files:
        print("No files found for TMAX training.")
        return None, None, 162
    
    # Load targets
    if not config.DAILIES_PATH.exists():
        return None, None, 162
        
    df = pd.read_csv(config.DAILIES_PATH)
    df["DATE"] = pd.to_datetime(df["DATE"])
    df = df.sort_values("DATE").set_index("DATE")
    target = df[["TMAX"]].rename(columns={"TMAX": config.TARGET_COLUMN})
    target = target.shift(-1)
    target = target.dropna()
    
    # Filter valid dates
    valid_dates = []
    target_dates = set(target.index.date)
    for d in dates:
        if d in target_dates:
             valid_dates.append(d)
    
    # Check num channels
    try:
        sample = torch.load(files[0], map_location="cpu")
        num_channels = sample.shape[0]
    except:
        num_channels = 162

    split_idx = int(len(valid_dates) * 0.8)
    train_ds = GIBSSequenceDataset(valid_dates[:split_idx], config.PROCESSED_ROOT, target, config.HISTORY_DAYS, num_channels)
    test_ds = GIBSSequenceDataset(valid_dates[split_idx:], config.PROCESSED_ROOT, target, config.HISTORY_DAYS, num_channels)
    return train_ds, test_ds, num_channels

# Model Zoo
Models defined as before.

In [None]:
# ... (Model classes code as defined previously, omitted here for brevity since they are unchanged, except we must ensure they accept variable input channels)
# Note: I am rewriting them here to be self-contained in the file artifact.

class DayEncoder(nn.Module):
    def __init__(self, in_channels, hidden_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, 32, 3, 1, 1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, 1, 1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(), nn.AdaptiveAvgPool2d((1,1)),
        )
        self.proj = nn.Linear(128, hidden_dim)
    def forward(self, x): return self.proj(self.net(x).view(x.size(0), -1))

class GIBSForecaster(nn.Module):
    def __init__(self, in_channels, history_days, hidden_dim=128, rnn_hidden=256):
        super().__init__()
        self.day_encoder = DayEncoder(in_channels, hidden_dim)
        self.rnn = nn.GRU(hidden_dim, rnn_hidden, batch_first=True)
        self.head = nn.Sequential(nn.Linear(rnn_hidden, 128), nn.ReLU(), nn.Linear(128, 1))
    def forward(self, x):
        b, t, c, h, w = x.shape
        x = x.view(b * t, c, h, w)
        out, _ = self.rnn(self.day_encoder(x).view(b, t, -1))
        return self.head(out[:, -1, :]).squeeze(1)

class CloudCoverNowcaster(nn.Module):
    def __init__(self, in_channels=4, out_timesteps=6):
        super().__init__()
        self.enc1 = nn.Sequential(nn.Conv2d(in_channels, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU())
        self.enc2 = nn.Sequential(nn.Conv2d(64, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU())
        self.enc3 = nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU())
        self.pool = nn.MaxPool2d(2, 2)
        self.bottleneck = nn.Sequential(nn.Conv2d(256, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU())
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.dec3 = nn.Sequential(nn.Conv2d(512 + 256, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU())
        self.dec2 = nn.Sequential(nn.Conv2d(256 + 128, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU())
        self.dec1 = nn.Sequential(nn.Conv2d(128 + 64, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU())
        self.head = nn.Conv2d(64, out_timesteps, 1)
    def forward(self, x):
        e1 = self.pool(self.enc1(x))
        e2 = self.pool(self.enc2(e1))
        e3 = self.pool(self.enc3(e2))
        b = self.bottleneck(e3)
        d3 = self.dec3(torch.cat([self.up(b), e3], dim=1))
        d2 = self.dec2(torch.cat([self.up(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up(d2), self.enc1(x)], dim=1))
        return torch.sigmoid(self.head(d1))

class SIANet(nn.Module):
    def __init__(self, in_channels=4, input_frames=4, output_frames=32):
        super().__init__()
        # Adjust to 3D
        self.conv1 = nn.Conv3d(in_channels, 32, (3,3,3), padding=(1,1,1))
        self.bn1 = nn.BatchNorm3d(32)
        self.context = nn.Conv3d(32, 64, (3,5,5), padding=(1,2,2))
        self.bn2 = nn.BatchNorm3d(64)
        self.res_conv = nn.Conv3d(64, 64, 3, padding=1)
        self.temporal_expand = nn.ConvTranspose3d(64, 32, (output_frames//input_frames, 1, 1), stride=(output_frames//input_frames, 1, 1))
        self.out = nn.Conv3d(32, 1, 1)
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.context(x)))
        x = F.relu(x + self.res_conv(x))
        return self.out(self.temporal_expand(x))

class Pix2PixGenerator(nn.Module):
    def __init__(self, in_channels=3, out_channels=2):
        super().__init__()
        self.down1 = nn.Conv2d(in_channels, 64, 4, 2, 1)
        self.down2 = nn.Conv2d(64, 128, 4, 2, 1)
        self.down3 = nn.Conv2d(128, 256, 4, 2, 1)
        self.bottleneck = nn.Conv2d(256, 256, 4, 2, 1)
        self.up1 = nn.ConvTranspose2d(256, 256, 4, 2, 1)
        self.up2 = nn.ConvTranspose2d(512, 128, 4, 2, 1)
        self.up3 = nn.ConvTranspose2d(256, 64, 4, 2, 1)
        self.final = nn.ConvTranspose2d(128, out_channels, 4, 2, 1)
    def forward(self, x):
        d1 = F.leaky_relu(self.down1(x), 0.2)
        d2 = F.leaky_relu(self.down2(d1), 0.2)
        d3 = F.leaky_relu(self.down3(d2), 0.2)
        b = F.relu(self.bottleneck(d3))
        u1 = F.relu(self.up1(b))
        u2 = F.relu(self.up2(torch.cat([u1, d3], dim=1)))
        u3 = F.relu(self.up3(torch.cat([u2, d2], dim=1)))
        return torch.tanh(self.final(torch.cat([u3, d1], dim=1)))

class MicrowaveLSTNet(nn.Module):
    def __init__(self, in_channels=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(in_channels, 32, 3, 1, 1), nn.ReLU(),
            nn.Conv2d(32, 64, 3, 1, 1), nn.ReLU(),
            nn.Conv2d(64, 32, 3, 1, 1), nn.ReLU(),
        )
        self.regressor = nn.Conv2d(32, 1, 1)
    def forward(self, x): return self.regressor(self.features(x))

class ConvectionCNN(nn.Module):
    def __init__(self, in_channels=3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, 32, 3, 1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, 1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, 1), nn.ReLU(), nn.AdaptiveAvgPool2d((1,1)),
            nn.Flatten(),
            nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 1)
        )
    def forward(self, x): return torch.sigmoid(self.net(x))

In [None]:
class UnifiedTrainer:
    def __init__(self, model, device, optimizer, criterion, task_type="regression"):
        self.model = model
        self.device = device
        self.optimizer = optimizer
        self.criterion = criterion
        self.task_type = task_type

    def train_epoch(self, loader, desc="Task"):
        self.model.train()
        total_loss = 0
        start_time = time.time()
        
        print(f"[{desc}] Starting Epoch on {self.device}. Batch Count: {len(loader)}")
        
        for i, (x, y) in enumerate(loader):
            # === CRITICAL: Ensure tensors are on device ===
            x = x.to(self.device, non_blocking=True)
            y = y.to(self.device, non_blocking=True)
            
            if i == 0:
                print(f"VERBOSE DEBUG: Batch 0 | Input Shape: {x.shape} | Target Shape: {y.shape} | Device: {x.device}")
            
            self.optimizer.zero_grad()
            pred = self.model(x)
            loss = self.criterion(pred, y)
            loss.backward()
            self.optimizer.step()
            
            total_loss += loss.item()
            
            # Log every 10% progress
            if i > 0 and i % max(1, len(loader)//10) == 0:
                print(f"[{desc}] Step {i}/{len(loader)} | Current Loss: {loss.item():.4f}")
            
        avg_loss = total_loss / len(loader)
        duration = time.time() - start_time
        print(f"[{desc}] Epoch Complete in {duration:.2f}s | Avg Loss: {avg_loss:.4f}")
        return avg_loss

In [None]:
# === MAIN EXECUTION ===

files, dates = discover_files(config)
print(f"Found {len(files)} total tensor files in {config.PROCESSED_ROOT}")

if len(files) > 0:
    # 1. TMAX Training (Still runs once)
    print("\n=== TMAX FORECAST ===")
    train_ds, test_ds, num_channels = prepare_tmax_data(config)
    if train_ds:
        loader_kwargs = {'num_workers': 2, 'pin_memory': True} if DEVICE.type == 'cuda' else {}
        dl = DataLoader(train_ds, batch_size=config.BATCH_SIZE, shuffle=True, **loader_kwargs)
        model = GIBSForecaster(in_channels=num_channels, history_days=config.HISTORY_DAYS).to(DEVICE)
        opt = torch.optim.Adam(model.parameters(), lr=config.LEARNING_RATE)
        trainer = UnifiedTrainer(model, DEVICE, opt, nn.L1Loss())
        trainer.train_epoch(dl, desc="TMAX")

    # Function to run both variants
    def run_variants(name, dataset_cls, model_cls, in_ch_smart, in_ch_all, **ds_kwargs):
        print(f"\n=== {name} COMPARISON ===")
        loader_kwargs = {'num_workers': 2, 'pin_memory': True} if DEVICE.type == 'cuda' else {}
        
        # 1. Smart Selection
        print(f"--- {name} (Smart Selection: {in_ch_smart} ch) ---")
        ds_smart = dataset_cls(files if 'files' not in ds_kwargs else ds_kwargs['files'], use_all_channels=False)
        dl_smart = DataLoader(ds_smart, batch_size=config.BATCH_SIZE, shuffle=True, **loader_kwargs)
        model_smart = model_cls(in_channels=in_ch_smart).to(DEVICE)
        opt = torch.optim.Adam(model_smart.parameters(), lr=config.LEARNING_RATE)
        crit = nn.BCEWithLogitsLoss() if name == "Convection" else nn.MSELoss()
        trainer = UnifiedTrainer(model_smart, DEVICE, opt, crit)
        trainer.train_epoch(dl_smart, desc=f"{name}-Smart")
        
        # 2. All Channels
        print(f"--- {name} (All Channels: {in_ch_all} ch) ---")
        ds_all = dataset_cls(files if 'files' not in ds_kwargs else ds_kwargs['files'], use_all_channels=True)
        dl_all = DataLoader(ds_all, batch_size=config.BATCH_SIZE, shuffle=True, **loader_kwargs)
        model_all = model_cls(in_channels=in_ch_all).to(DEVICE)
        opt = torch.optim.Adam(model_all.parameters(), lr=config.LEARNING_RATE)
        trainer = UnifiedTrainer(model_all, DEVICE, opt, crit)
        trainer.train_epoch(dl_all, desc=f"{name}-All162")
        
    # Run all models
    
    # Cloud Cover: Smart=4, All=162
    run_variants("CloudCover", CloudCoverDataset, CloudCoverNowcaster, 4, 162)

    # SIANet: Smart=4, All=162 (Pass specialized dates)
    # Note: SIANetDataset constructor args differ slightly, so we adapt manual call
    print("\n=== SIANet COMPARISON ===")
    sianet_ds_smart = SIANetDataset(dates, config.PROCESSED_ROOT, use_all_channels=False)
    if len(sianet_ds_smart) > 0:
        dl_s = DataLoader(sianet_ds_smart, batch_size=8, shuffle=True)
        m_s = SIANet(in_channels=4).to(DEVICE)
        t_s = UnifiedTrainer(m_s, DEVICE, torch.optim.Adam(m_s.parameters(), lr=config.LEARNING_RATE), nn.MSELoss())
        t_s.train_epoch(dl_s, desc="SIANet-Smart")
        
        sianet_ds_all = SIANetDataset(dates, config.PROCESSED_ROOT, use_all_channels=True)
        dl_a = DataLoader(sianet_ds_all, batch_size=8, shuffle=True)
        m_a = SIANet(in_channels=162).to(DEVICE)
        t_a = UnifiedTrainer(m_a, DEVICE, torch.optim.Adam(m_a.parameters(), lr=config.LEARNING_RATE), nn.MSELoss())
        t_a.train_epoch(dl_a, desc="SIANet-All162")

    # Pix2Pix: Smart=3, All=162
    run_variants("Pix2Pix", Pix2PixDataset, Pix2PixGenerator, 3, 162)
    
    # Microwave: Smart=10, All=162
    run_variants("Microwave", MicrowaveDataset, MicrowaveLSTNet, 10, 162)
    
    # Convection: Smart=3, All=162
    run_variants("Convection", ConvectionDataset, ConvectionCNN, 3, 162)

else:
    print("No .pt files found.")