In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!pip install torch torchvision einops tqdm pytorch-lightning scikit-learn
!git clone https://github.com/ucaswangls/EfficientSCI.git

import sys
sys.path.append("/content/EfficientSCI")

fatal: destination path 'EfficientSCI' already exists and is not an empty directory.


In [3]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64,expandable_segments:True"

import re
import glob
import torch
import torch.nn as nn
import numpy as np
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset, Subset
from tqdm import tqdm
import matplotlib
import matplotlib.pyplot as plt
import gc
import random
from torch.utils.data._utils.collate import default_collate
import os
import imageio.v2 as imageio
from datetime import datetime, timedelta

from cacti.models.efficientsci import EfficientSCI

In [4]:
# Configuration
do_train = True
do_eval = True
retrain = False
band = 1
model_ckpt_path = f"/content/drive/MyDrive/efficientSCI_finetuned_multiband.pth"

train_locs = ['bos', 'bozeman', 'carlinNev', 'la', 'dallas']
test_loc = 'miami'
patch_size = 100
grid_size = 1
batch_size = 1
epochs = 15


In [5]:
class MultiBandCaseADataset(Dataset):
    def __init__(self, samples, patch_size=100, grid_size=1, band_transform=None, use_blocks_for_training=True):
        """
        samples: list of (loc, path)
        band_transform: optional callable(b_idx, x_norm_band) -> x_norm_band (for band-specific tweaks)
        """
        self.samples = samples
        self.patch_size = patch_size
        self.grid_size = grid_size
        self.patches_per_block = grid_size ** 2
        self.band_transform = band_transform
        self.use_blocks_for_training = use_blocks_for_training

    def __len__(self):
        return len(self.samples) * self.patches_per_block

    def __getitem__(self, idx):
        block_idx = idx // self.patches_per_block
        patch_idx = idx % self.patches_per_block

        _, path = self.samples[block_idx]
        d = np.load(path)

        y_full = d["compressed"].astype(np.float32)
        phi_s_f = d["counts"].astype(np.float32)
        B, H, W = y_full.shape
        T = int(len(d["target_times"]))


        if "masks" not in d: #make sure masks are there
            raise KeyError(f"{os.path.basename(path)} missing 'masks' ([B,T,H,W]). Re-save with save_masks=True.")
        phi_full = d["masks"].astype(np.float32)


        has_x = ("blocks" in d) and self.use_blocks_for_training
        x_full = d["blocks"].astype(np.float32) if has_x else np.zeros((B, T, H, W), dtype=np.float32)

        #Make sure same patch across bands
        ps  = self.patch_size
        row = (patch_idx // self.grid_size) * ps
        col = (patch_idx %  self.grid_size) * ps

        y_patch= y_full[:,row:row+ps,col:col+ps]
        phi_s= phi_s_f[:,row:row+ps,col:col+ps]
        phi_patch = phi_full[:,:,row:row+ps,col:col+ps].copy()
        x_patch= x_full[:,:,row:row+ps,col:col+ps]

        #Measurement as forward-SUM (no NaNs)
        y_patch_f = np.nan_to_num(y_patch, nan=0.0)
        y_sum = y_patch_f * phi_s

        #Reference from SUM (no nanmax on avg!)
        ref= np.max(y_sum, axis=(1, 2))
        ref= np.where(np.isfinite(ref) & (ref > 0), ref, 1e-8)
        ref_bc = ref[:, None, None, None]

        #Normalise
        x_norm = np.nan_to_num(x_patch/ref_bc, nan=0.0, posinf=0.0, neginf=0.0)
        x_norm = np.log1p(np.clip(x_norm, 0, None))
        y_norm = np.nan_to_num(y_sum / ref[:, None, None], nan=0.0)

        phi_raw   = phi_patch.astype(np.float32)
        phi_s_raw = phi_s.astype(np.float32)

        #Tensors
        y_t= torch.tensor(y_norm).unsqueeze(1).float()
        phi_t = torch.tensor(phi_raw).float()
        phi_s_t= torch.tensor(phi_s_raw).unsqueeze(1).float()
        x_t = torch.tensor(x_norm).float()
        max_val = torch.tensor(ref).float()

        return y_t, phi_t, phi_s_t, x_t, max_val


In [6]:
root = "/content/drive/MyDrive/Folder/multiband_compressed_blocks_12hrs"
train_locs = ['bos', 'bozeman', 'carlinNev', 'la', 'dallas']
test_loc   = 'miami'

def list_npz(loc, recursive=True):
    pattern = os.path.join(root, loc, "**", "*.npz") if recursive else os.path.join(root, loc, "*.npz")
    files = sorted(glob.glob(pattern, recursive=recursive))
    print(f"[scan] {loc}: {len(files)} npz files (pattern: {pattern})")
    return files

# Show subfolders present
print("Subfolders under root:", sorted([d for d in os.listdir(root) if os.path.isdir(os.path.join(root,d))]))

trainval_files = []
for loc in train_locs:
    trainval_files += [(loc, p) for p in list_npz(loc)]

test_files = [(test_loc, p) for p in list_npz(test_loc)]

train_files, val_files = train_test_split(trainval_files, test_size=0.2, random_state=42)
print(f"train:{len(train_files)} val:{len(val_files)} test:{len(test_files)}")

def skip_none_collate(batch):
    batch = [b for b in batch if b is not None]
    if not batch:
        return None
    return default_collate(batch)

train_dataset = MultiBandCaseADataset(train_files, patch_size=100, grid_size=1, band_transform=None)
val_dataset   = MultiBandCaseADataset(val_files,   patch_size=100, grid_size=1, band_transform=None)
test_dataset  = MultiBandCaseADataset(test_files,  patch_size=100, grid_size=1, band_transform=None)

train_loader = DataLoader(train_dataset, batch_size=1, num_workers=16, shuffle=True,  pin_memory=True)
val_loader   = DataLoader(val_dataset,   batch_size=1, num_workers=16, shuffle=False, pin_memory=True)
test_loader  = DataLoader(test_dataset,  batch_size=1, num_workers=16, shuffle=False, pin_memory=True)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = EfficientSCI().to(device)
model.eval()

Subfolders under root: ['bos', 'bozeman', 'carlinNev', 'dallas', 'la', 'miami']
[scan] bos: 200 npz files (pattern: /content/drive/MyDrive/Folder/multiband_compressed_blocks_12hrs/bos/**/*.npz)
[scan] bozeman: 200 npz files (pattern: /content/drive/MyDrive/Folder/multiband_compressed_blocks_12hrs/bozeman/**/*.npz)
[scan] carlinNev: 200 npz files (pattern: /content/drive/MyDrive/Folder/multiband_compressed_blocks_12hrs/carlinNev/**/*.npz)
[scan] la: 200 npz files (pattern: /content/drive/MyDrive/Folder/multiband_compressed_blocks_12hrs/la/**/*.npz)
[scan] dallas: 200 npz files (pattern: /content/drive/MyDrive/Folder/multiband_compressed_blocks_12hrs/dallas/**/*.npz)
[scan] miami: 200 npz files (pattern: /content/drive/MyDrive/Folder/multiband_compressed_blocks_12hrs/miami/**/*.npz)
train:800 val:200 test:200

FILE: bos_block_000.npz
  keys: ['compressed', 'band_ids', 'block_time', 'target_times', 'counts', 'masks', 'blocks']
  compressed      shape=(5, 50, 50) dtype=float32
  band_ids  



y any NaN?: False
phi_s min/max: 15.0 43.0
y min/max: 0.03003046 1.0
compressed min/max: 36.21957778930664 107.09478759765625
counts min/max: 15.0 43.0
MSE vs SUM: nan
MSE vs AVG: nan
compressed min/max: 9.679545402526855 75.48373413085938
counts min/max: 15.0 40.0
MSE vs SUM: nan
MSE vs AVG: nan


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


In [7]:
def normalized_mse(x_pred, x_true, phi_s):
    return torch.sum((x_pred - x_true)**2) / (torch.sum(x_true**2) + 1e-8)

def mixed_loss(pred, target, phi_s):
    nmse = normalized_mse(pred, target, phi_s)
    peak_weight = (target > 0.6).float() * 5.0 + 1.0  # Boost high-radiance pixels
    weighted_mse = torch.mean(peak_weight * (pred - target)**2)
    return nmse + weighted_mse  #weighted_mse for balance


if do_train:
  train_losses, val_losses = [], []
  scaler = torch.cuda.amp.GradScaler()

  def run_training(epochs, lr, start_epoch=0):
      best_val = float('inf')
      best_state = None
      lr_decay_factor = 0.5

      optimizer = torch.optim.Adam(model.parameters(), lr=lr)

      early_stop_patience = 30
      no_improve_epochs = 0


      for epoch in range(start_epoch, start_epoch + epochs):

          torch.cuda.empty_cache()
          gc.collect()
          model.train()
          total_train_loss = 0

          for y, phi, phi_s, x, max_val in tqdm(train_loader, desc=f"Train Epoch {epoch+1}"):
            y, phi, phi_s, x, max_val = [t.squeeze(0).to(device) for t in (y, phi, phi_s, x, max_val)]
            optimizer.zero_grad()
            with torch.cuda.amp.autocast():
                output = model(y, phi, phi_s)[0]
                output = torch.nan_to_num(output, nan=0.0)
                loss = mixed_loss(output, x, phi_s)

            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()

            total_train_loss += loss.item()

          train_losses.append(total_train_loss / len(train_loader))

          # Validation
          model.eval()
          total_val_loss = 0
          with torch.no_grad():
              for y, phi, phi_s, x, max_val in val_loader:
                  y, phi, phi_s, x, max_val = [t.squeeze(0).to(device) for t in (y, phi, phi_s, x, max_val)]
                  output = model(y, phi, phi_s)[0]
                  output = torch.nan_to_num(output, nan=0.0)
                  val_loss = mixed_loss(output, x, phi_s)
                  total_val_loss += val_loss.item()

                  if torch.isnan(output).any():
                      print("NaNs in model output!")
                  if torch.isnan(x).any():
                      print("NaNs in target x")
          val_losses.append(total_val_loss / len(val_loader))
          # Save losses
          np.save("train_losses.npy", np.array(train_losses))
          np.save("val_losses.npy", np.array(val_losses))

          if val_losses[-1] < best_val:
            best_val = val_losses[-1]
            best_state = model.state_dict().copy()
            print(f'New Best Model at Epoch {epoch+1}')
            torch.save(best_state, model_ckpt_path)
            no_improve_epochs = 0

          else:
            no_improve_epochs += 1
            print(f"No improvement for {no_improve_epochs} epochs.")

          if (no_improve_epochs +1)% 20 ==0:
            for param_goup in optimizer.param_groups:
              param_group['lr'] *= lr_decay_factor

          if no_improve_epochs >= early_stop_patience:
            print('Early stopping occurred')
            break

          for param_group in optimizer.param_groups:
            print(f"[{epoch+1}] Train: {train_losses[-1]:.4f} | Val: {val_losses[-1]:.4f}| Learning rate: {param_group['lr']:.2e}")

          torch.cuda.empty_cache()
          gc.collect()

      return best_state, train_losses, val_losses

  #Main training
  if retrain:
    model.load_state_dict(torch.load(model_ckpt_path))
  best_model_state, train_losses, val_losses = run_training(epochs=300, lr=1e-4)

  # Save best model
  if best_model_state is not None:
    torch.save(best_model_state, model_ckpt_path)
    np.save("train_losses.npy", np.array(train_losses))
    np.save("val_losses.npy", np.array(val_losses))



  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():
Train Epoch 1: 100%|██████████| 800/800 [06:29<00:00,  2.05it/s]


New Best Model at Epoch 1
[1] Train: 0.4087 | Val: 0.1481| Learning rate: 1.00e-04


Train Epoch 2: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


New Best Model at Epoch 2
[2] Train: 0.1057 | Val: 0.0607| Learning rate: 1.00e-04


Train Epoch 3: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


New Best Model at Epoch 3
[3] Train: 0.0650 | Val: 0.0492| Learning rate: 1.00e-04


Train Epoch 4: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


No improvement for 1 epochs.
[4] Train: 0.0565 | Val: 0.0538| Learning rate: 1.00e-04


Train Epoch 5: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


No improvement for 2 epochs.
[5] Train: 0.0529 | Val: 0.0614| Learning rate: 1.00e-04


Train Epoch 6: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 3 epochs.
[6] Train: 0.0450 | Val: 0.0500| Learning rate: 1.00e-04


Train Epoch 7: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


New Best Model at Epoch 7
[7] Train: 0.0458 | Val: 0.0417| Learning rate: 1.00e-04


Train Epoch 8: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


New Best Model at Epoch 8
[8] Train: 0.0432 | Val: 0.0414| Learning rate: 1.00e-04


Train Epoch 9: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


New Best Model at Epoch 9
[9] Train: 0.0411 | Val: 0.0397| Learning rate: 1.00e-04


Train Epoch 10: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


No improvement for 1 epochs.
[10] Train: 0.0418 | Val: 0.0450| Learning rate: 1.00e-04


Train Epoch 11: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 2 epochs.
[11] Train: 0.0388 | Val: 0.0399| Learning rate: 1.00e-04


Train Epoch 12: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


No improvement for 3 epochs.
[12] Train: 0.0406 | Val: 0.0546| Learning rate: 1.00e-04


Train Epoch 13: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 4 epochs.
[13] Train: 0.0443 | Val: 0.0565| Learning rate: 1.00e-04


Train Epoch 14: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 5 epochs.
[14] Train: 0.0431 | Val: 0.0415| Learning rate: 1.00e-04


Train Epoch 15: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


New Best Model at Epoch 15
[15] Train: 0.0406 | Val: 0.0384| Learning rate: 1.00e-04


Train Epoch 16: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 1 epochs.
[16] Train: 0.0420 | Val: 0.0387| Learning rate: 1.00e-04


Train Epoch 17: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


New Best Model at Epoch 17
[17] Train: 0.0399 | Val: 0.0373| Learning rate: 1.00e-04


Train Epoch 18: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


New Best Model at Epoch 18
[18] Train: 0.0354 | Val: 0.0361| Learning rate: 1.00e-04


Train Epoch 19: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


New Best Model at Epoch 19
[19] Train: 0.0358 | Val: 0.0355| Learning rate: 1.00e-04


Train Epoch 20: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


No improvement for 1 epochs.
[20] Train: 0.0393 | Val: 0.0396| Learning rate: 1.00e-04


Train Epoch 21: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


New Best Model at Epoch 21
[21] Train: 0.0391 | Val: 0.0354| Learning rate: 1.00e-04


Train Epoch 22: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


New Best Model at Epoch 22
[22] Train: 0.0371 | Val: 0.0347| Learning rate: 1.00e-04


Train Epoch 23: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 1 epochs.
[23] Train: 0.0340 | Val: 0.0371| Learning rate: 1.00e-04


Train Epoch 24: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 2 epochs.
[24] Train: 0.0358 | Val: 0.0383| Learning rate: 1.00e-04


Train Epoch 25: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


No improvement for 3 epochs.
[25] Train: 0.0388 | Val: 0.0362| Learning rate: 1.00e-04


Train Epoch 26: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 4 epochs.
[26] Train: 0.0361 | Val: 0.0363| Learning rate: 1.00e-04


Train Epoch 27: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 5 epochs.
[27] Train: 0.0337 | Val: 0.0368| Learning rate: 1.00e-04


Train Epoch 28: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


No improvement for 6 epochs.
[28] Train: 0.0331 | Val: 0.0374| Learning rate: 1.00e-04


Train Epoch 29: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


New Best Model at Epoch 29
[29] Train: 0.0328 | Val: 0.0341| Learning rate: 1.00e-04


Train Epoch 30: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


No improvement for 1 epochs.
[30] Train: 0.0366 | Val: 0.0360| Learning rate: 1.00e-04


Train Epoch 31: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 2 epochs.
[31] Train: 0.0353 | Val: 0.0352| Learning rate: 1.00e-04


Train Epoch 32: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


New Best Model at Epoch 32
[32] Train: 0.0342 | Val: 0.0325| Learning rate: 1.00e-04


Train Epoch 33: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


No improvement for 1 epochs.
[33] Train: 0.0320 | Val: 0.0386| Learning rate: 1.00e-04


Train Epoch 34: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 2 epochs.
[34] Train: 0.0322 | Val: 0.0327| Learning rate: 1.00e-04


Train Epoch 35: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


New Best Model at Epoch 35
[35] Train: 0.0311 | Val: 0.0325| Learning rate: 1.00e-04


Train Epoch 36: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


New Best Model at Epoch 36
[36] Train: 0.0327 | Val: 0.0319| Learning rate: 1.00e-04


Train Epoch 37: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 1 epochs.
[37] Train: 0.0366 | Val: 0.0328| Learning rate: 1.00e-04


Train Epoch 38: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


No improvement for 2 epochs.
[38] Train: 0.0344 | Val: 0.0346| Learning rate: 1.00e-04


Train Epoch 39: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


No improvement for 3 epochs.
[39] Train: 0.0326 | Val: 0.0324| Learning rate: 1.00e-04


Train Epoch 40: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


No improvement for 4 epochs.
[40] Train: 0.0309 | Val: 0.0333| Learning rate: 1.00e-04


Train Epoch 41: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


New Best Model at Epoch 41
[41] Train: 0.0310 | Val: 0.0316| Learning rate: 1.00e-04


Train Epoch 42: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


New Best Model at Epoch 42
[42] Train: 0.0302 | Val: 0.0312| Learning rate: 1.00e-04


Train Epoch 43: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 1 epochs.
[43] Train: 0.0324 | Val: 0.0372| Learning rate: 1.00e-04


Train Epoch 44: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 2 epochs.
[44] Train: 0.0337 | Val: 0.0365| Learning rate: 1.00e-04


Train Epoch 45: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


No improvement for 3 epochs.
[45] Train: 0.0327 | Val: 0.0329| Learning rate: 1.00e-04


Train Epoch 46: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 4 epochs.
[46] Train: 0.0309 | Val: 0.0324| Learning rate: 1.00e-04


Train Epoch 47: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


No improvement for 5 epochs.
[47] Train: 0.0305 | Val: 0.0315| Learning rate: 1.00e-04


Train Epoch 48: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


New Best Model at Epoch 48
[48] Train: 0.0305 | Val: 0.0307| Learning rate: 1.00e-04


Train Epoch 49: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 1 epochs.
[49] Train: 0.0299 | Val: 0.0357| Learning rate: 1.00e-04


Train Epoch 50: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 2 epochs.
[50] Train: 0.0351 | Val: 0.0316| Learning rate: 1.00e-04


Train Epoch 51: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 3 epochs.
[51] Train: 0.0328 | Val: 0.0315| Learning rate: 1.00e-04


Train Epoch 52: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 4 epochs.
[52] Train: 0.0307 | Val: 0.0312| Learning rate: 1.00e-04


Train Epoch 53: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


No improvement for 5 epochs.
[53] Train: 0.0297 | Val: 0.0321| Learning rate: 1.00e-04


Train Epoch 54: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 6 epochs.
[54] Train: 0.0332 | Val: 0.0320| Learning rate: 1.00e-04


Train Epoch 55: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 7 epochs.
[55] Train: 0.0328 | Val: 0.0337| Learning rate: 1.00e-04


Train Epoch 56: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 8 epochs.
[56] Train: 0.0316 | Val: 0.0308| Learning rate: 1.00e-04


Train Epoch 57: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


New Best Model at Epoch 57
[57] Train: 0.0294 | Val: 0.0303| Learning rate: 1.00e-04


Train Epoch 58: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


No improvement for 1 epochs.
[58] Train: 0.0294 | Val: 0.0307| Learning rate: 1.00e-04


Train Epoch 59: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 2 epochs.
[59] Train: 0.0297 | Val: 0.0304| Learning rate: 1.00e-04


Train Epoch 60: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


No improvement for 3 epochs.
[60] Train: 0.0298 | Val: 0.0311| Learning rate: 1.00e-04


Train Epoch 61: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


New Best Model at Epoch 61
[61] Train: 0.0298 | Val: 0.0299| Learning rate: 1.00e-04


Train Epoch 62: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 1 epochs.
[62] Train: 0.0293 | Val: 0.0304| Learning rate: 1.00e-04


Train Epoch 63: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


No improvement for 2 epochs.
[63] Train: 0.0321 | Val: 0.0325| Learning rate: 1.00e-04


Train Epoch 64: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 3 epochs.
[64] Train: 0.0318 | Val: 0.0355| Learning rate: 1.00e-04


Train Epoch 65: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


No improvement for 4 epochs.
[65] Train: 0.0310 | Val: 0.0300| Learning rate: 1.00e-04


Train Epoch 66: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


New Best Model at Epoch 66
[66] Train: 0.0287 | Val: 0.0295| Learning rate: 1.00e-04


Train Epoch 67: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


No improvement for 1 epochs.
[67] Train: 0.0286 | Val: 0.0299| Learning rate: 1.00e-04


Train Epoch 68: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


No improvement for 2 epochs.
[68] Train: 0.0282 | Val: 0.0296| Learning rate: 1.00e-04


Train Epoch 69: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 3 epochs.
[69] Train: 0.0288 | Val: 0.0299| Learning rate: 1.00e-04


Train Epoch 70: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


No improvement for 4 epochs.
[70] Train: 0.0296 | Val: 0.0298| Learning rate: 1.00e-04


Train Epoch 71: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


No improvement for 5 epochs.
[71] Train: 0.0291 | Val: 0.0310| Learning rate: 1.00e-04


Train Epoch 72: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


No improvement for 6 epochs.
[72] Train: 0.0296 | Val: 0.0321| Learning rate: 1.00e-04


Train Epoch 73: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


No improvement for 7 epochs.
[73] Train: 0.0308 | Val: 0.0340| Learning rate: 1.00e-04


Train Epoch 74: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 8 epochs.
[74] Train: 0.0297 | Val: 0.0309| Learning rate: 1.00e-04


Train Epoch 75: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 9 epochs.
[75] Train: 0.0290 | Val: 0.0295| Learning rate: 1.00e-04


Train Epoch 76: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 10 epochs.
[76] Train: 0.0277 | Val: 0.0310| Learning rate: 1.00e-04


Train Epoch 77: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


New Best Model at Epoch 77
[77] Train: 0.0279 | Val: 0.0286| Learning rate: 1.00e-04


Train Epoch 78: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


New Best Model at Epoch 78
[78] Train: 0.0269 | Val: 0.0286| Learning rate: 1.00e-04


Train Epoch 79: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


No improvement for 1 epochs.
[79] Train: 0.0269 | Val: 0.0290| Learning rate: 1.00e-04


Train Epoch 80: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 2 epochs.
[80] Train: 0.0268 | Val: 0.0322| Learning rate: 1.00e-04


Train Epoch 81: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 3 epochs.
[81] Train: 0.0282 | Val: 0.0308| Learning rate: 1.00e-04


Train Epoch 82: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 4 epochs.
[82] Train: 0.0314 | Val: 0.0549| Learning rate: 1.00e-04


Train Epoch 83: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 5 epochs.
[83] Train: 0.0303 | Val: 0.0291| Learning rate: 1.00e-04


Train Epoch 84: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


No improvement for 6 epochs.
[84] Train: 0.0281 | Val: 0.0286| Learning rate: 1.00e-04


Train Epoch 85: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 7 epochs.
[85] Train: 0.0272 | Val: 0.0295| Learning rate: 1.00e-04


Train Epoch 86: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 8 epochs.
[86] Train: 0.0272 | Val: 0.0287| Learning rate: 1.00e-04


Train Epoch 87: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


No improvement for 9 epochs.
[87] Train: 0.0270 | Val: 0.0309| Learning rate: 1.00e-04


Train Epoch 88: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 10 epochs.
[88] Train: 0.0278 | Val: 0.0289| Learning rate: 1.00e-04


Train Epoch 89: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 11 epochs.
[89] Train: 0.0275 | Val: 0.0304| Learning rate: 1.00e-04


Train Epoch 90: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 12 epochs.
[90] Train: 0.0270 | Val: 0.0301| Learning rate: 1.00e-04


Train Epoch 91: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 13 epochs.
[91] Train: 0.0273 | Val: 0.0290| Learning rate: 1.00e-04


Train Epoch 92: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


New Best Model at Epoch 92
[92] Train: 0.0268 | Val: 0.0281| Learning rate: 1.00e-04


Train Epoch 93: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


No improvement for 1 epochs.
[93] Train: 0.0261 | Val: 0.0311| Learning rate: 1.00e-04


Train Epoch 94: 100%|██████████| 800/800 [06:29<00:00,  2.05it/s]


No improvement for 2 epochs.
[94] Train: 0.0276 | Val: 0.0299| Learning rate: 1.00e-04


Train Epoch 95: 100%|██████████| 800/800 [06:29<00:00,  2.05it/s]


No improvement for 3 epochs.
[95] Train: 0.0272 | Val: 0.0285| Learning rate: 1.00e-04


Train Epoch 96: 100%|██████████| 800/800 [06:29<00:00,  2.05it/s]


New Best Model at Epoch 96
[96] Train: 0.0266 | Val: 0.0281| Learning rate: 1.00e-04


Train Epoch 97: 100%|██████████| 800/800 [06:29<00:00,  2.05it/s]


No improvement for 1 epochs.
[97] Train: 0.0262 | Val: 0.0297| Learning rate: 1.00e-04


Train Epoch 98: 100%|██████████| 800/800 [06:29<00:00,  2.05it/s]


No improvement for 2 epochs.
[98] Train: 0.0274 | Val: 0.0294| Learning rate: 1.00e-04


Train Epoch 99: 100%|██████████| 800/800 [06:29<00:00,  2.05it/s]


No improvement for 3 epochs.
[99] Train: 0.0280 | Val: 0.0426| Learning rate: 1.00e-04


Train Epoch 100: 100%|██████████| 800/800 [06:29<00:00,  2.05it/s]


No improvement for 4 epochs.
[100] Train: 0.0287 | Val: 0.0303| Learning rate: 1.00e-04


Train Epoch 101: 100%|██████████| 800/800 [06:29<00:00,  2.05it/s]


No improvement for 5 epochs.
[101] Train: 0.0279 | Val: 0.0298| Learning rate: 1.00e-04


Train Epoch 102: 100%|██████████| 800/800 [06:29<00:00,  2.05it/s]


No improvement for 6 epochs.
[102] Train: 0.0264 | Val: 0.0291| Learning rate: 1.00e-04


Train Epoch 103: 100%|██████████| 800/800 [06:29<00:00,  2.05it/s]


New Best Model at Epoch 103
[103] Train: 0.0259 | Val: 0.0277| Learning rate: 1.00e-04


Train Epoch 104: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 1 epochs.
[104] Train: 0.0258 | Val: 0.0279| Learning rate: 1.00e-04


Train Epoch 105: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 2 epochs.
[105] Train: 0.0260 | Val: 0.0282| Learning rate: 1.00e-04


Train Epoch 106: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


No improvement for 3 epochs.
[106] Train: 0.0262 | Val: 0.0321| Learning rate: 1.00e-04


Train Epoch 107: 100%|██████████| 800/800 [06:29<00:00,  2.05it/s]


No improvement for 4 epochs.
[107] Train: 0.0264 | Val: 0.0294| Learning rate: 1.00e-04


Train Epoch 108: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


No improvement for 5 epochs.
[108] Train: 0.0263 | Val: 0.0293| Learning rate: 1.00e-04


Train Epoch 109: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 6 epochs.
[109] Train: 0.0264 | Val: 0.0285| Learning rate: 1.00e-04


Train Epoch 110: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


New Best Model at Epoch 110
[110] Train: 0.0251 | Val: 0.0276| Learning rate: 1.00e-04


Train Epoch 111: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 1 epochs.
[111] Train: 0.0249 | Val: 0.0278| Learning rate: 1.00e-04


Train Epoch 112: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 2 epochs.
[112] Train: 0.0270 | Val: 0.0305| Learning rate: 1.00e-04


Train Epoch 113: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 3 epochs.
[113] Train: 0.0284 | Val: 0.0296| Learning rate: 1.00e-04


Train Epoch 114: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


No improvement for 4 epochs.
[114] Train: 0.0273 | Val: 0.0304| Learning rate: 1.00e-04


Train Epoch 115: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 5 epochs.
[115] Train: 0.0256 | Val: 0.0287| Learning rate: 1.00e-04


Train Epoch 116: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 6 epochs.
[116] Train: 0.0253 | Val: 0.0281| Learning rate: 1.00e-04


Train Epoch 117: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


New Best Model at Epoch 117
[117] Train: 0.0250 | Val: 0.0275| Learning rate: 1.00e-04


Train Epoch 118: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 1 epochs.
[118] Train: 0.0244 | Val: 0.0275| Learning rate: 1.00e-04


Train Epoch 119: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 2 epochs.
[119] Train: 0.0253 | Val: 0.0286| Learning rate: 1.00e-04


Train Epoch 120: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 3 epochs.
[120] Train: 0.0259 | Val: 0.0279| Learning rate: 1.00e-04


Train Epoch 121: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 4 epochs.
[121] Train: 0.0255 | Val: 0.0284| Learning rate: 1.00e-04


Train Epoch 122: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


No improvement for 5 epochs.
[122] Train: 0.0244 | Val: 0.0278| Learning rate: 1.00e-04


Train Epoch 123: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 6 epochs.
[123] Train: 0.0244 | Val: 0.0275| Learning rate: 1.00e-04


Train Epoch 124: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


New Best Model at Epoch 124
[124] Train: 0.0240 | Val: 0.0273| Learning rate: 1.00e-04


Train Epoch 125: 100%|██████████| 800/800 [06:29<00:00,  2.05it/s]


No improvement for 1 epochs.
[125] Train: 0.0249 | Val: 0.0283| Learning rate: 1.00e-04


Train Epoch 126: 100%|██████████| 800/800 [06:29<00:00,  2.05it/s]


No improvement for 2 epochs.
[126] Train: 0.0258 | Val: 0.0300| Learning rate: 1.00e-04


Train Epoch 127: 100%|██████████| 800/800 [06:29<00:00,  2.05it/s]


No improvement for 3 epochs.
[127] Train: 0.0282 | Val: 0.0290| Learning rate: 1.00e-04


Train Epoch 128: 100%|██████████| 800/800 [06:29<00:00,  2.05it/s]


No improvement for 4 epochs.
[128] Train: 0.0266 | Val: 0.0318| Learning rate: 1.00e-04


Train Epoch 129: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


No improvement for 5 epochs.
[129] Train: 0.0258 | Val: 0.0276| Learning rate: 1.00e-04


Train Epoch 130: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


No improvement for 6 epochs.
[130] Train: 0.0244 | Val: 0.0289| Learning rate: 1.00e-04


Train Epoch 131: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 7 epochs.
[131] Train: 0.0245 | Val: 0.0278| Learning rate: 1.00e-04


Train Epoch 132: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


New Best Model at Epoch 132
[132] Train: 0.0238 | Val: 0.0273| Learning rate: 1.00e-04


Train Epoch 133: 100%|██████████| 800/800 [06:29<00:00,  2.06it/s]


No improvement for 1 epochs.
[133] Train: 0.0243 | Val: 0.0303| Learning rate: 1.00e-04


Train Epoch 134: 100%|██████████| 800/800 [06:28<00:00,  2.06it/s]


No improvement for 2 epochs.
[134] Train: 0.0269 | Val: 0.0286| Learning rate: 1.00e-04


Train Epoch 135:  20%|██        | 161/800 [01:19<05:15,  2.03it/s]


KeyboardInterrupt: 

In [None]:
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss (MSE)')
plt.title('Training and Validation Loss per Epoch')
plt.grid(False)
plt.legend()
plt.tight_layout()
plt.show()

np.save("/content/drive/MyDrive/train_losses.npy", np.array(train_losses))
np.save("/content/drive/MyDrive/val_losses.npy", np.array(val_losses))

In [None]:
if do_eval:
    assert getattr(test_loader, 'batch_size', 1) == 1, "Use batch_size=1 for path to batch zip."
    SAVE_DIR = "./recons"; MAKE_GIFS = False
    os.makedirs(SAVE_DIR, exist_ok=True)

    print(model_ckpt_path)
    model.load_state_dict(torch.load(model_ckpt_path, map_location="cpu"))
    model.to(device).eval()

    true_ts_all  = recon_ts_all = mse_lists = psnr_lists = None

    with torch.no_grad():
        for (y, phi, phi_s, x_true, max_val), (loc, path) in zip(test_loader, test_files):
            y, phi, phi_s, x_true, max_val = [t.squeeze(0).to(device) for t in (y, phi, phi_s, x_true, max_val)]

            with np.load(path, allow_pickle=True) as dmeta:
              band_ids_raw = dmeta["band_ids"]
              band_ids = [str(b) for b in band_ids_raw.ravel()]
              target_times = dmeta["target_times"]
            src_name = os.path.splitext(os.path.basename(path))[0]

            # forward
            x_pred = model(y, phi, phi_s)[0]

            # invert log1p + per-band scaling
            mv = max_val.view(-1, 1, 1, 1)
            print("x_true raw range:", float(x_true.min()), float(x_true.max()))
            x_pred = torch.expm1(x_pred) * mv
            x_true = torch.expm1(x_true) * mv

            B, T, H, W = x_true.shape
            h, w = H // 2, W // 2

            if true_ts_all is None:
                true_ts_all  = [[] for _ in range(B)]
                recon_ts_all = [[] for _ in range(B)]
                mse_lists = [[] for _ in range(B)]
                psnr_lists = [[] for _ in range(B)]

            mse_b = torch.mean((x_pred - x_true) ** 2, dim=(1, 2, 3))
            max_true_b = x_true.view(B, -1).max(dim=1).values
            psnr_b = 10.0 * torch.log10((max_true_b ** 2) / (mse_b + 1e-8))

            for b in range(B):
                mse_lists[b].append(mse_b[b].item())
                psnr_lists[b].append(psnr_b[b].item())
                true_ts_all[b].append(x_true[b, :, h, w].detach().cpu().numpy())
                recon_ts_all[b].append(x_pred[b, :, h, w].detach().cpu().numpy())

            # save recon + gt for this block
            out_npz = os.path.join(SAVE_DIR, f"{src_name}_recon.npz")
            np.savez_compressed(out_npz,x_pred=x_pred.detach().cpu().numpy().astype(np.float32),x_true=x_true.detach().cpu().numpy().astype(np.float32),band_ids=np.array(band_ids, dtype="U16"),target_times=target_times)

            torch.cuda.empty_cache(); gc.collect()

    # concat series & plot
    true_ts_concat  = [np.concatenate(chunks) for chunks in true_ts_all]
    recon_ts_concat = [np.concatenate(chunks) for chunks in recon_ts_all]
    start_time = datetime(2024, 4, 1)
    n_frames = len(true_ts_concat[0])
    frame_times = [start_time + timedelta(minutes=10 * i) for i in range(n_frames)]

    # stacked per-band
    import matplotlib.pyplot as plt
    fig, axes = plt.subplots(len(true_ts_concat), 1, figsize=(14, 2.8 * len(true_ts_concat)), sharex=True)
    if len(true_ts_concat) == 1: axes = [axes]
    for b, ax in enumerate(axes):
        ax.plot(frame_times, true_ts_concat[b], label=f"Band {b+1} • True")
        ax.plot(frame_times, recon_ts_concat[b], '--', label="Recon")
        ax.grid(True, alpha=0.4); ax.legend(loc="upper right"); ax.set_ylabel("Radiance")
    axes[-1].set_xlabel("Time")
    fig.suptitle(f"Center-Pixel Time Series — stacked by band • Location: {test_loc}")
    fig.autofmt_xdate(); plt.tight_layout(); plt.show()

    # single-axis “everything” plot
    plt.figure(figsize=(14, 5))
    for b in range(len(true_ts_concat)):
        plt.plot(frame_times, true_ts_concat[b],  label=f"B{b+1} True")
        plt.plot(frame_times, recon_ts_concat[b], '--', label=f"B{b+1} Recon")
    plt.grid(True, alpha=0.4); plt.xlabel("Time"); plt.ylabel("Radiance")
    plt.title(f"All bands — True (solid) vs Recon (dashed) • {test_loc}")
    plt.gcf().autofmt_xdate(); plt.legend(ncol=2); plt.tight_layout(); plt.show()

    # metrics
    per_band_mse  = [float(np.mean(m)) for m in mse_lists]
    per_band_psnr = [float(np.mean(p)) for p in psnr_lists]
    print("\nPer-band metrics:")
    for b in range(len(per_band_mse)):
        print(f"Band {b+1}: MSE={per_band_mse[b]:.6f} | PSNR={per_band_psnr[b]:.2f} dB")
    print(f"Overall: MSE={np.mean(per_band_mse):.6f} | PSNR={np.mean(per_band_psnr):.2f} dB")