# init

In [1]:
!pip install pytorch-lightning
!pip install rasterio

Collecting pytorch-lightning
  Downloading pytorch_lightning-2.5.1.post0-py3-none-any.whl.metadata (20 kB)
Collecting torchmetrics>=0.7.0 (from pytorch-lightning)
  Downloading torchmetrics-1.7.2-py3-none-any.whl.metadata (21 kB)
Collecting lightning-utilities>=0.10.0 (from pytorch-lightning)
  Downloading lightning_utilities-0.14.3-py3-none-any.whl.metadata (5.6 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.1.0->pytorch-lightning)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.1.0->pytorch-lightning)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=2.1.0->pytorch-lightning)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=2.1.0->pytorch-lightning)
  Dow

In [2]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import numpy as np
import torch
import torch.nn.functional as F
import rasterio
import matplotlib.pyplot as plt
from pytorch_lightning.callbacks import Callback
from google.colab import drive
import pandas as pd
from torch.utils.data import Dataset
from skimage.metrics import peak_signal_noise_ratio as psnr
import json
from datetime import datetime
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import CSVLogger
from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl
import sys
import torch.nn as nn
from torch.optim import Adam
import torchvision.transforms.functional as TF

In [3]:
drive.mount('/content/drive', force_remount=True)
base_folder = '/content/drive/My Drive/ONLY_TIF'
csv_path = '/content/drive/My Drive/superres_triplets.csv'

Mounted at /content/drive


In [4]:
df = pd.read_csv(csv_path)
for col in ['goes1_path', 'goes2_path', 'viirs_path']:
    df[col] = df[col].apply(lambda x: os.path.join(base_folder, x))

df.to_csv(csv_path, index=False)

# dataset+loader

In [5]:
class SatelliteTripletDataset(Dataset):
    def __init__(self, csv_file, json_range_path="radiance_visualization_ranges.json"):
        self.data_info = pd.read_csv(csv_file)

        with open(json_range_path, "r") as f:
            self.global_ranges = json.load(f)

    def __len__(self):
        return len(self.data_info)

    def __getitem__(self, idx):
        record = self.data_info.iloc[idx]
        goes1 = self.load_raw_radiance(record['goes1_path'])
        goes2 = self.load_raw_radiance(record['goes2_path'])
        viirs = self.load_raw_radiance(record['viirs_path'])
        return (goes1, goes2), viirs

    def load_raw_radiance(self, path):
        filename = os.path.basename(path).lower()
        is_viirs = "viirs" in filename or "combined_clip" in filename
        sensor_type = "VIIRS" if is_viirs else "GOES"
        band_index = 1 if is_viirs else 7

        with rasterio.open(path) as src:
            if src.count < band_index:
                raise ValueError(f"{path} does not contain band {band_index}")
            image = src.read(band_index).astype(np.float32)

        # Handle NaN/Inf values
        mask = ~(np.isnan(image) | np.isinf(image))
        if not np.any(mask):
            image = np.zeros_like(image)
        else:
            mean_val = image[mask].mean()
            image = np.where(mask, image, mean_val)

        # 🔁 Use global clipping range from JSON
        p_low = self.global_ranges[sensor_type]["p2"]
        p_high = self.global_ranges[sensor_type]["p98"]
        image = np.clip(image, p_low, p_high)

        return torch.from_numpy(image).unsqueeze(0)

In [6]:
class SatelliteDataModule(pl.LightningDataModule):
    def __init__(self, csv_file, batch_size=4, num_workers=0, percentile_range=(0.5, 99.5)):
        super().__init__()
        self.val_split  = 0.1
        self.csv_file = csv_file
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.percentile_range = percentile_range
        self.seed = 42

    def setup(self, stage=None):
        # self.train_dataset = SatelliteTripletDataset(
        #     csv_file=self.csv_file,
        #     json_range_path="/content/radiance_visualization_ranges.json"
        # )
        full = SatelliteTripletDataset(self.csv_file,
                                       json_range_path="radiance_visualization_ranges.json")
        val_size   = int(len(full) * self.val_split)
        train_size = len(full) - val_size

        self.train_dataset, self.val_dataset = random_split(
            full,
            [train_size, val_size],
            generator=torch.Generator().manual_seed(self.seed)
        )


    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True
        )

    def val_dataloader(self):
        return DataLoader(self.val_dataset,
                          batch_size=self.batch_size,
                          shuffle=False,
                          num_workers=self.num_workers,
                          pin_memory=True)



In [7]:
import torch
torch.cuda.empty_cache()

# external liberery

In [8]:
# sys.path.append("C:/Users/97254/OneDrive - post.bgu.ac.il/Desktop/code4finalproj/SwinIR-main")
# Clone the repo
!git clone https://github.com/JingyunLiang/SwinIR.git
sys.path.append('/content/SwinIR')

# Manually install necessary dependencies
!pip install timm einops

Cloning into 'SwinIR'...
remote: Enumerating objects: 333, done.[K
remote: Counting objects: 100% (10/10), done.[K
remote: Compressing objects: 100% (8/8), done.[K
remote: Total 333 (delta 6), reused 2 (delta 2), pack-reused 323 (from 2)[K
Receiving objects: 100% (333/333), 29.84 MiB | 20.15 MiB/s, done.
Resolving deltas: 100% (119/119), done.


In [9]:
# Now you're safe to import
from models.network_swinir import SwinIR

model = SwinIR(
    upscale=4,            # 100x4 = 400 → desired upscale
    in_chans=1,           # GOES1 + GOES2 = 2 input channels
    img_size=100,         # patch input size
    window_size=25,        # usually 8 or 7 works fine
    img_range=1.,
    depths=[3,3,3,3,3,3],
    # depths=[2,2,2,2,2,2],
    embed_dim=180,
    # embed_dim=64,
    # num_heads=[8,8,8,8,8,8],
    num_heads=[6,6,6,6,6,6],
    mlp_ratio=2,
    upsampler='nearest+conv',
    # upsampler='convtranspose',
    resi_connection='3conv'
    # out_chans=1
    # <<< set to 1 for single-channel output
)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


# model

In [10]:
class SuperResolutionModule(pl.LightningModule):
    def __init__(self, model, lr=1e-4):
        super().__init__()
        self.model = model
        self.criterion = nn.L1Loss()
        self.lr = lr

        # for per‐epoch accumulation
        self._current_train_losses = []
        self._current_val_losses   = []

        # for storing one value *per epoch*
        self.train_losses = []
        self.val_losses   = []


        with open("/content/radiance_visualization_ranges.json", "r") as f:
            self.ranges = json.load(f)

        self.goes1_path = "/content/drive/My Drive/ONLY_TIF/ONLY_TIF/2020-11/2020-11-01_20-12/clipped_geo16.tif"
        self.goes2_path = "/content/drive/My Drive/ONLY_TIF/ONLY_TIF/2020-11/2020-11-01_20-12/clipped_geo17.tif"
        self.viirs_path = "/content/drive/My Drive/ONLY_TIF/ONLY_TIF/2020-11/2020-11-01_20-12/combined_clip.tif"

    def forward(self, goes1, goes2):
        out1 = self.model(goes1)  # Each has shape [B, 1, H, W]
        out2 = self.model(goes2)
        return out1, out2

    def training_step(self, batch, batch_idx):
        (goes1, goes2), viirs = batch
        viirs = viirs.float()
        out1, out2 = self(goes1, goes2)
        loss1 = self.criterion(out1, viirs)
        loss2 = self.criterion(out2, viirs)
        loss  = (loss1 + loss2) * 0.5

        self._current_train_losses.append(loss.item())

        psnr1 = self.compute_psnr(out1, viirs)
        psnr2 = self.compute_psnr(out2, viirs)
        avg_psnr = (psnr1 + psnr2) / 2

        self.log('train_loss', loss, prog_bar=True)
        return loss



    def validation_step(self, batch, batch_idx):
        (goes1, goes2), viirs = batch
        viirs = viirs.float()

        out1, out2 = self(goes1, goes2)
        loss1 = self.criterion(out1, viirs)
        loss2 = self.criterion(out2, viirs)
        val_loss = (loss1 + loss2) * 0.5

        psnr1 = self.compute_psnr(out1, viirs)
        psnr2 = self.compute_psnr(out2, viirs)
        avg_psnr = (psnr1 + psnr2) / 2

        self._current_val_losses.append(val_loss.item())

        self.log('val_loss', val_loss, prog_bar=True)
        self.log('val_psnr', avg_psnr, prog_bar=True)
        return val_loss


    def configure_optimizers(self):
        return Adam(self.parameters(), lr=self.lr)

    def on_train_epoch_end(self) -> None:
        # average all batch losses for this epoch
        avg_train = sum(self._current_train_losses) / len(self._current_train_losses)
        self.train_losses.append(avg_train)
        self.log('epoch_train_loss', avg_train, prog_bar=True)
        # reset for next epoch
        self._current_train_losses.clear()


    def on_validation_epoch_end(self) -> None:
        if not self._current_val_losses:
            return

        avg_val = sum(self._current_val_losses) / len(self._current_val_losses)
        self.val_losses.append(avg_val)
        self.log('epoch_val_loss', avg_val, prog_bar=True)
        self._current_val_losses.clear()

    # def on_train_epoch_end(self):
    #   (goes1, goes2), viirs = self.example_batch

    #   self.eval()
    #   with torch.no_grad():
    #       # inputs = torch.cat((goes1, goes2), dim=1).float().to(self.device)
    #       # output = self(inputs)
    #       output = self(goes1.to(self.device), goes2.to(self.device))
    #       output = F.interpolate(output, size=viirs.shape[2:], mode='bilinear', align_corners=False)

    #   inp1 = goes1[0][0].cpu().numpy()
    #   inp2 = goes2[0][0].cpu().numpy()
    #   tgt = viirs[0][0].cpu().numpy()
    #   out = output[0][0].cpu().numpy()

    #   fig, axs = plt.subplots(1, 4, figsize=(14, 4))
    #   axs[0].imshow(inp1, cmap='gray'); axs[0].set_title('GOES 1')
    #   axs[1].imshow(inp2, cmap='gray'); axs[1].set_title('GOES 2')
    #   axs[2].imshow(tgt, cmap='gray'); axs[2].set_title('VIIRS Target')
    #   axs[3].imshow(out, cmap='gray'); axs[3].set_title('Model Output')
    #   fig.suptitle(f'Epoch {self.current_epoch+1} - Triplet Visualization', fontsize=14)

    #   for ax in axs: ax.axis('off')
    #   plt.tight_layout()
    #   plt.show()



    def on_train_start(self):
      goes1 = self.load_radiance(self.goes1_path)
      goes2 = self.load_radiance(self.goes2_path)
      viirs  = self.load_radiance(self.viirs_path)
      # Add batch dimension [1, 1, H, W]
      self.example_batch = ((goes1.unsqueeze(0), goes2.unsqueeze(0)), viirs.unsqueeze(0))


    def load_radiance(self, path, band=1):
        with rasterio.open(path) as src:
            img = src.read(band).astype(np.float32)
            mask = ~(np.isnan(img) | np.isinf(img))
            if mask.any():
                img = np.where(mask, img, img[mask].mean())
            else:
                img = np.zeros_like(img)
            if "viirs" in path.lower():
                p2, p98 = self.ranges["VIIRS"]["p2"], self.ranges["VIIRS"]["p98"]
            else:
                p2, p98 = self.ranges["GOES"]["p2"], self.ranges["GOES"]["p98"]

            img = np.clip(img, p2, p98)
            # img = (img - p2) / (p98 - p2)
            return torch.from_numpy(img).unsqueeze(0)



    def compute_psnr(self, output, target):
      mse = F.mse_loss(output, target)
      if mse == 0:
          return torch.tensor(100.0)  # Perfect match
      max_val = self.ranges["VIIRS"]["p98"]
      return 20 * torch.log10(max_val / torch.sqrt(mse))


#callbacks

In [11]:
class PSNRValidationCallback(Callback):
    """
    Compute corrected PSNR (Kelvin metric) on a fixed month of triplets.
    * uses the same helpers as the baseline script              (already defined)
    * model gets raw GOES‑16 / GOES‑17(18) images (no bicubic up‑scale)
    * VIIRS GT and model prediction are both high‑res → same shape
    """

    # ------------------------------------------------------------------ #
    def __init__(self, vis_callback, val_month_dir, every_n_epochs=1):
        super().__init__()
        self.vis_callback   = vis_callback
        self.val_month_dir  = val_month_dir
        self.every_n_epochs = every_n_epochs

        # read radiometric ranges once
        with open("/content/radiance_visualization_ranges.json") as f:
            rng = json.load(f)
        self.vi_min = rng["VIIRS"]["p2"]
        self.vi_rng = rng["VIIRS"]["p98"] - self.vi_min

        # collect (goes1, goes2, viirs) triplets once
        self.triplets = self._collect_triplets(val_month_dir)

    # ------------------------------------------------------------------ #
    @staticmethod
    def _collect_triplets(root):
        out = []
        for cur, _, files in os.walk(root):
            files = [f for f in files if f.lower().endswith(".tif")]
            if not files:
                continue
            g1 = g2 = v = None
            for f in files:
                p  = os.path.join(cur, f)
                lf = f.lower()
                if "geo16" in lf:
                    g1 = p
                elif "geo17" in lf or ("geo18" in lf and g2 is None):
                    g2 = p
                elif "combined" in lf or "viirs" in lf:
                    v  = p
            if g1 and g2 and v:
                out.append((g1, g2, v))
        return out

    # ------------------------------------------------------------------ #

    @staticmethod
    def _load_band(path):
        is_viirs   = "viirs" in path.lower() or "combined" in path.lower()
        band_index = 1 if is_viirs else 7           # VIIRS‑I4 or GOES‑7
        with rasterio.open(path) as src:
            img = src.read(band_index).astype(np.float32)
        m = ~(np.isnan(img) | np.isinf(img))
        return np.where(m, img, img[m].mean() if m.any() else 0.0)


    @staticmethod
    def cpsnr(gt: np.ndarray, pred: np.ndarray, mask: np.ndarray = None) -> float:
        """Corrected PSNR (Kelvin): brightness‑bias + clear‑pixel mask."""
        if mask is None:
          mask = np.ones_like(gt)

        diff = (gt - pred) * mask
        b    = diff.sum() / (mask.sum() + 1e-8)                 # brightness bias
        cmse = ((gt - pred + b) ** 2 * mask).sum() / (mask.sum() + 1e-8)
        return -10.0 * np.log10(cmse + 1e-8)

    # ------------------------------------------------------------------ #
    def on_train_epoch_end(self, trainer, pl_module):
        epoch = trainer.current_epoch
        if (epoch + 1) % self.every_n_epochs:
            return

        psnrs = []

        for g1_path, g2_path, v_path in self.triplets:
            g1   = self._load_band(g1_path)
            g2   = self._load_band(g2_path)
            vi   = self._load_band(v_path)              # H×W high‑res

            with torch.no_grad():
                p1_t, p2_t = pl_module(
                    torch.from_numpy(g1)[None,None].to(pl_module.device),
                    torch.from_numpy(g2)[None,None].to(pl_module.device)
                )
            p1 = p1_t.squeeze().cpu().numpy()
            p2 = p2_t.squeeze().cpu().numpy()

            # normalize to [0,1]
            vi_n = np.clip((vi - self.vi_min) / self.vi_rng,  0, 1)
            p1_n = np.clip((p1 - self.vi_min) / self.vi_rng,  0, 1)
            p2_n = np.clip((p2 - self.vi_min) / self.vi_rng,  0, 1)

            psnrs.append(self.cpsnr(vi_n, p1_n))
            psnrs.append(self.cpsnr(vi_n, p2_n))

        mean_psnr = float(np.mean(psnrs))
        if not hasattr(pl_module, "psnr_scores"):
            pl_module.psnr_scores = []
        pl_module.psnr_scores.append(mean_psnr)
        pl_module.log('val_cpsnr', mean_psnr, prog_bar=True)
        print(f"📈 Epoch {epoch+1:03d} — cPSNR: {mean_psnr:.2f} dB")

        # save the cPSNR curve
        save_dir = self.vis_callback.output_dir
        os.makedirs(save_dir, exist_ok=True)
        plt.figure()
        plt.plot(pl_module.psnr_scores, marker='o')
        plt.title("Validation cPSNR"); plt.xlabel("Epoch"); plt.ylabel("dB")
        plt.grid(True); plt.tight_layout()
        curve_path = os.path.join(save_dir, f"psnr_curve_epoch_{epoch:03d}.png")
        plt.savefig(curve_path); plt.close()
        print(f"✅ Saved cPSNR curve to {curve_path}")


In [12]:
from scipy.ndimage import gaussian_filter
from PIL import Image
import matplotlib.pyplot as plt

class VisualizePredictionCallback(Callback):
    def __init__(self, goes1_path, goes2_path, viirs_path, every_n_epochs=1):
        super().__init__()
        self.goes1_path   = goes1_path
        self.goes2_path   = goes2_path
        self.viirs_path   = viirs_path
        self.every_n_epochs = every_n_epochs

        # Load radiance ranges for GOES/VIIRS
        with open("/content/radiance_visualization_ranges.json", "r") as f:
            self.ranges = json.load(f)


        # Make output directory with timestamp
        timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        self.output_dir = os.path.join("/content/checkpoints", f"visual_{timestamp}")
        os.makedirs(self.output_dir, exist_ok=True)


    def scale_for_visualization(self, image, sensor_type=None):
        p1 = np.percentile(image, 2)
        p99 = np.percentile(image, 98)
        return np.clip((image - p1) / (p99 - p1), 0, 1.0)


    def load_image(self, path, band=1):
        with rasterio.open(path) as src:
            image = src.read(band).astype(np.float32)
            mask = ~(np.isnan(image) | np.isinf(image))
            if np.any(mask):
                mean_val = image[mask].mean()
                image = np.where(mask, image, mean_val)
            else:
                image = np.zeros_like(image)
            return image


    def on_train_epoch_end(self, trainer, pl_module):

        # only every N epochs
        if (trainer.current_epoch + 1) % self.every_n_epochs != 0:
            return

        # load & normalize inputs
        g1 = self.load_image(self.goes1_path,  7); g1_v = self.scale_for_visualization(g1)
        g2 = self.load_image(self.goes2_path,  7); g2_v = self.scale_for_visualization(g2)
        vi = self.load_image(self.viirs_path, 1); vi_v = self.scale_for_visualization(vi)

        # forward – unpack two outputs
        t1 = torch.from_numpy(g1).unsqueeze(0).unsqueeze(0).to(pl_module.device)
        t2 = torch.from_numpy(g2).unsqueeze(0).unsqueeze(0).to(pl_module.device)
        with torch.no_grad():
            pred1_t, pred2_t = pl_module(t1, t2)

        # bring back to NumPy + normalize
        pred1 = pred1_t.squeeze().cpu().numpy(); p1_v = self.scale_for_visualization(pred1)
        pred2 = pred2_t.squeeze().cpu().numpy(); p2_v = self.scale_for_visualization(pred2)

        # compute local-variance fusion
        sigma = 3.0
        eps   = 1e-8
        from scipy.ndimage import gaussian_filter

        var1 = gaussian_filter(pred1**2, sigma=sigma) - gaussian_filter(pred1, sigma=sigma)**2
        var2 = gaussian_filter(pred2**2, sigma=sigma) - gaussian_filter(pred2, sigma=sigma)**2
        w1   = var1 / (var1 + var2 + eps)
        fused = w1 * pred1 + (1 - w1) * pred2

        # normalize fused for viz
        fused_v = self.scale_for_visualization(fused, "VIIRS")

        # build a 1×4 grid: GOES1, GOES2, VIIRS GT, FUSED
        imgs   = [g1_v,   g2_v,   vi_v,   fused_v]
        titles = ["GOES-1","GOES-2","VIIRS","Fused"]

        fig, axs = plt.subplots(1, 4, figsize=(20, 5))
        for ax, im_data, ttl in zip(axs, imgs, titles):
            im = ax.imshow(im_data, cmap="gray", vmin=0, vmax=1)
            ax.set_title(ttl)
            ax.axis("off")
            plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

        out = f"epoch_{trainer.current_epoch+1:03d}_fused_grid.png"
        save_path = os.path.join(self.output_dir, out)
        plt.tight_layout()
        plt.savefig(save_path)
        plt.close()
        print(f"✅ Saved fused visualization to {save_path}")

        if len(pl_module.train_losses) > 0 and len(pl_module.val_losses) > 0:
            # separate x-axes so shapes always match
            x_train = list(range(1, len(pl_module.train_losses) + 1))
            x_val   = list(range(1, len(pl_module.val_losses)   + 1))

            plt.figure(figsize=(6,4))
            plt.plot(x_train, pl_module.train_losses, marker='o', label='Train L1')
            plt.plot(x_val,   pl_module.val_losses,   marker='o', label='Val L1')
            plt.title("L1 Loss per Epoch")
            plt.xlabel("Epoch")
            plt.ylabel("Loss")
            plt.legend()
            plt.tight_layout()

            loss_path = os.path.join(
                self.output_dir,
                f"epoch_{trainer.current_epoch+1:03d}_loss_curve.png"
            )
            plt.savefig(loss_path)
            plt.close()
            print(f"✅ Saved loss curve to {loss_path}")
        else:
            print("ℹ️ Skipping loss-curve plot (no data yet)")


# training

In [13]:
import torch
import gc
gc.collect()
torch.cuda.empty_cache()

pl_model = SuperResolutionModule(model)

In [14]:
# === Create dummy inputs ===
goes1 = torch.randn(1, 1, 100, 100)  # GOES-1 image
goes2 = torch.randn(1, 1, 100, 100)  # GOES-2 image

# === Forward pass through the updated SwinIR wrapper ===
pl_model.eval()
with torch.no_grad():
    out1, out2 = pl_model(goes1, goes2)

pred1 = out1.squeeze().cpu().numpy()
pred2 = out2.squeeze().cpu().numpy()

# === Print shapes ===
print("GOES-1 shape:", goes1.shape)
print("GOES-2 shape:", goes2.shape)
print("Output from GOES-1 shape:", out1.shape)
print("Output from GOES-2 shape:", out2.shape)


GOES-1 shape: torch.Size([1, 1, 100, 100])
GOES-2 shape: torch.Size([1, 1, 100, 100])
Output from GOES-1 shape: torch.Size([1, 1, 400, 400])
Output from GOES-2 shape: torch.Size([1, 1, 400, 400])


In [16]:
# === DataModule ===
datamodule = SatelliteDataModule(
    csv_file="/content/drive/My Drive/superres_triplets.csv",
    batch_size=1,
    num_workers=2
)

# === Visualization Callback ===
vis_callback = VisualizePredictionCallback(
    goes1_path="/content/drive/My Drive/ONLY_TIF/ONLY_TIF/2020-11/2020-11-01_20-12/clipped_geo16.tif",
    goes2_path="/content/drive/My Drive/ONLY_TIF/ONLY_TIF/2020-11/2020-11-01_20-12/clipped_geo17.tif",
    viirs_path="/content/drive/My Drive/ONLY_TIF/ONLY_TIF/2020-11/2020-11-01_20-12/combined_clip.tif",
    every_n_epochs=1
)


# === PSNR Callback ===
psnr_callback = PSNRValidationCallback(
    vis_callback=vis_callback,
    val_month_dir="/content/drive/My Drive/ONLY_TIF/ONLY_TIF/2023-02",  # Folder with validation triplets
    every_n_epochs=1
)

# === Logger ===
logger = CSVLogger("logs", name="swinir_superres")

# === Trainer ===
trainer = Trainer(
    max_epochs=20,
    accelerator="auto",
    devices=1,
    precision=16,                # mixed precision
    logger=logger,
    gradient_clip_val=1.0,
    callbacks=[psnr_callback, vis_callback],
    log_every_n_steps=10
)

import gc
gc.collect()
torch.cuda.empty_cache()

# === Fit the model ===
trainer.fit(pl_model, datamodule=datamodule)

/usr/local/lib/python3.11/dist-packages/lightning_fabric/connector.py:571: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
INFO:pytorch_lightning.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name      | Type   | Params | Mode
--------------------------------------------
0 | model     | SwinIR | 6.2 M  | eval
1 | criter

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

📈 Epoch 001 — cPSNR: 10.65 dB
✅ Saved cPSNR curve to /content/checkpoints/visual_2025-06-02_08-20-35/psnr_curve_epoch_000.png
✅ Saved fused visualization to /content/checkpoints/visual_2025-06-02_08-20-35/epoch_001_fused_grid.png
ℹ️ Skipping loss-curve plot (no data yet)


Validation: |          | 0/? [00:00<?, ?it/s]

📈 Epoch 002 — cPSNR: 11.25 dB
✅ Saved cPSNR curve to /content/checkpoints/visual_2025-06-02_08-20-35/psnr_curve_epoch_001.png
✅ Saved fused visualization to /content/checkpoints/visual_2025-06-02_08-20-35/epoch_002_fused_grid.png
✅ Saved loss curve to /content/checkpoints/visual_2025-06-02_08-20-35/epoch_002_loss_curve.png


Validation: |          | 0/? [00:00<?, ?it/s]

📈 Epoch 003 — cPSNR: 4.17 dB
✅ Saved cPSNR curve to /content/checkpoints/visual_2025-06-02_08-20-35/psnr_curve_epoch_002.png
✅ Saved fused visualization to /content/checkpoints/visual_2025-06-02_08-20-35/epoch_003_fused_grid.png
✅ Saved loss curve to /content/checkpoints/visual_2025-06-02_08-20-35/epoch_003_loss_curve.png


Validation: |          | 0/? [00:00<?, ?it/s]

📈 Epoch 004 — cPSNR: 4.02 dB
✅ Saved cPSNR curve to /content/checkpoints/visual_2025-06-02_08-20-35/psnr_curve_epoch_003.png
✅ Saved fused visualization to /content/checkpoints/visual_2025-06-02_08-20-35/epoch_004_fused_grid.png
✅ Saved loss curve to /content/checkpoints/visual_2025-06-02_08-20-35/epoch_004_loss_curve.png


Validation: |          | 0/? [00:00<?, ?it/s]

📈 Epoch 005 — cPSNR: 8.96 dB
✅ Saved cPSNR curve to /content/checkpoints/visual_2025-06-02_08-20-35/psnr_curve_epoch_004.png
✅ Saved fused visualization to /content/checkpoints/visual_2025-06-02_08-20-35/epoch_005_fused_grid.png
✅ Saved loss curve to /content/checkpoints/visual_2025-06-02_08-20-35/epoch_005_loss_curve.png


Validation: |          | 0/? [00:00<?, ?it/s]

📈 Epoch 006 — cPSNR: 2.50 dB
✅ Saved cPSNR curve to /content/checkpoints/visual_2025-06-02_08-20-35/psnr_curve_epoch_005.png
✅ Saved fused visualization to /content/checkpoints/visual_2025-06-02_08-20-35/epoch_006_fused_grid.png
✅ Saved loss curve to /content/checkpoints/visual_2025-06-02_08-20-35/epoch_006_loss_curve.png


Validation: |          | 0/? [00:00<?, ?it/s]

📈 Epoch 007 — cPSNR: 11.28 dB
✅ Saved cPSNR curve to /content/checkpoints/visual_2025-06-02_08-20-35/psnr_curve_epoch_006.png
✅ Saved fused visualization to /content/checkpoints/visual_2025-06-02_08-20-35/epoch_007_fused_grid.png
✅ Saved loss curve to /content/checkpoints/visual_2025-06-02_08-20-35/epoch_007_loss_curve.png


Validation: |          | 0/? [00:00<?, ?it/s]

📈 Epoch 008 — cPSNR: 11.28 dB
✅ Saved cPSNR curve to /content/checkpoints/visual_2025-06-02_08-20-35/psnr_curve_epoch_007.png
✅ Saved fused visualization to /content/checkpoints/visual_2025-06-02_08-20-35/epoch_008_fused_grid.png
✅ Saved loss curve to /content/checkpoints/visual_2025-06-02_08-20-35/epoch_008_loss_curve.png


Validation: |          | 0/? [00:00<?, ?it/s]

📈 Epoch 009 — cPSNR: 11.28 dB
✅ Saved cPSNR curve to /content/checkpoints/visual_2025-06-02_08-20-35/psnr_curve_epoch_008.png
✅ Saved fused visualization to /content/checkpoints/visual_2025-06-02_08-20-35/epoch_009_fused_grid.png
✅ Saved loss curve to /content/checkpoints/visual_2025-06-02_08-20-35/epoch_009_loss_curve.png


Validation: |          | 0/? [00:00<?, ?it/s]

📈 Epoch 010 — cPSNR: 11.28 dB
✅ Saved cPSNR curve to /content/checkpoints/visual_2025-06-02_08-20-35/psnr_curve_epoch_009.png
✅ Saved fused visualization to /content/checkpoints/visual_2025-06-02_08-20-35/epoch_010_fused_grid.png
✅ Saved loss curve to /content/checkpoints/visual_2025-06-02_08-20-35/epoch_010_loss_curve.png


Validation: |          | 0/? [00:00<?, ?it/s]

📈 Epoch 011 — cPSNR: 11.28 dB
✅ Saved cPSNR curve to /content/checkpoints/visual_2025-06-02_08-20-35/psnr_curve_epoch_010.png
✅ Saved fused visualization to /content/checkpoints/visual_2025-06-02_08-20-35/epoch_011_fused_grid.png
✅ Saved loss curve to /content/checkpoints/visual_2025-06-02_08-20-35/epoch_011_loss_curve.png


Validation: |          | 0/? [00:00<?, ?it/s]

📈 Epoch 012 — cPSNR: 11.28 dB
✅ Saved cPSNR curve to /content/checkpoints/visual_2025-06-02_08-20-35/psnr_curve_epoch_011.png
✅ Saved fused visualization to /content/checkpoints/visual_2025-06-02_08-20-35/epoch_012_fused_grid.png
✅ Saved loss curve to /content/checkpoints/visual_2025-06-02_08-20-35/epoch_012_loss_curve.png


Validation: |          | 0/? [00:00<?, ?it/s]

📈 Epoch 013 — cPSNR: 11.28 dB
✅ Saved cPSNR curve to /content/checkpoints/visual_2025-06-02_08-20-35/psnr_curve_epoch_012.png
✅ Saved fused visualization to /content/checkpoints/visual_2025-06-02_08-20-35/epoch_013_fused_grid.png
✅ Saved loss curve to /content/checkpoints/visual_2025-06-02_08-20-35/epoch_013_loss_curve.png


Validation: |          | 0/? [00:00<?, ?it/s]

📈 Epoch 014 — cPSNR: 11.28 dB
✅ Saved cPSNR curve to /content/checkpoints/visual_2025-06-02_08-20-35/psnr_curve_epoch_013.png
✅ Saved fused visualization to /content/checkpoints/visual_2025-06-02_08-20-35/epoch_014_fused_grid.png
✅ Saved loss curve to /content/checkpoints/visual_2025-06-02_08-20-35/epoch_014_loss_curve.png


Validation: |          | 0/? [00:00<?, ?it/s]

📈 Epoch 015 — cPSNR: 11.28 dB
✅ Saved cPSNR curve to /content/checkpoints/visual_2025-06-02_08-20-35/psnr_curve_epoch_014.png
✅ Saved fused visualization to /content/checkpoints/visual_2025-06-02_08-20-35/epoch_015_fused_grid.png
✅ Saved loss curve to /content/checkpoints/visual_2025-06-02_08-20-35/epoch_015_loss_curve.png


Validation: |          | 0/? [00:00<?, ?it/s]

📈 Epoch 016 — cPSNR: 11.28 dB
✅ Saved cPSNR curve to /content/checkpoints/visual_2025-06-02_08-20-35/psnr_curve_epoch_015.png
✅ Saved fused visualization to /content/checkpoints/visual_2025-06-02_08-20-35/epoch_016_fused_grid.png
✅ Saved loss curve to /content/checkpoints/visual_2025-06-02_08-20-35/epoch_016_loss_curve.png


Validation: |          | 0/? [00:00<?, ?it/s]

📈 Epoch 017 — cPSNR: 11.28 dB
✅ Saved cPSNR curve to /content/checkpoints/visual_2025-06-02_08-20-35/psnr_curve_epoch_016.png
✅ Saved fused visualization to /content/checkpoints/visual_2025-06-02_08-20-35/epoch_017_fused_grid.png
✅ Saved loss curve to /content/checkpoints/visual_2025-06-02_08-20-35/epoch_017_loss_curve.png


Validation: |          | 0/? [00:00<?, ?it/s]

📈 Epoch 018 — cPSNR: 11.28 dB
✅ Saved cPSNR curve to /content/checkpoints/visual_2025-06-02_08-20-35/psnr_curve_epoch_017.png
✅ Saved fused visualization to /content/checkpoints/visual_2025-06-02_08-20-35/epoch_018_fused_grid.png
✅ Saved loss curve to /content/checkpoints/visual_2025-06-02_08-20-35/epoch_018_loss_curve.png


Validation: |          | 0/? [00:00<?, ?it/s]

📈 Epoch 019 — cPSNR: 11.28 dB
✅ Saved cPSNR curve to /content/checkpoints/visual_2025-06-02_08-20-35/psnr_curve_epoch_018.png
✅ Saved fused visualization to /content/checkpoints/visual_2025-06-02_08-20-35/epoch_019_fused_grid.png
✅ Saved loss curve to /content/checkpoints/visual_2025-06-02_08-20-35/epoch_019_loss_curve.png


Validation: |          | 0/? [00:00<?, ?it/s]

📈 Epoch 020 — cPSNR: 11.28 dB
✅ Saved cPSNR curve to /content/checkpoints/visual_2025-06-02_08-20-35/psnr_curve_epoch_019.png
✅ Saved fused visualization to /content/checkpoints/visual_2025-06-02_08-20-35/epoch_020_fused_grid.png
✅ Saved loss curve to /content/checkpoints/visual_2025-06-02_08-20-35/epoch_020_loss_curve.png


INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=20` reached.


In [19]:
import subprocess

folder_to_zip = "/content/checkpoints/visual_2025-06-02_08-20-35"
output_zip    = "/content/visual_results_2025-06-02.zip"

# Create the zip file
subprocess.run(["zip", "-r", output_zip, folder_to_zip])

CompletedProcess(args=['zip', '-r', '/content/visual_results_2025-06-02.zip', '/content/checkpoints/visual_2025-06-02_08-20-35'], returncode=0)

In [None]:
# 2) Download it to your local machine
from google.colab import files
files.download('/content/visual_results.zip')