# starting up

In [1]:
!pip install pytorch-lightning==2.2.0
!pip install rasterio
!pip install lightning

Collecting pytorch-lightning==2.2.0
  Downloading pytorch_lightning-2.2.0-py3-none-any.whl.metadata (21 kB)
Collecting torchmetrics>=0.7.0 (from pytorch-lightning==2.2.0)
  Downloading torchmetrics-1.7.1-py3-none-any.whl.metadata (21 kB)
Collecting lightning-utilities>=0.8.0 (from pytorch-lightning==2.2.0)
  Downloading lightning_utilities-0.14.3-py3-none-any.whl.metadata (5.6 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.13.0->pytorch-lightning==2.2.0)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.13.0->pytorch-lightning==2.2.0)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.13.0->pytorch-lightning==2.2.0)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from 

In [2]:
# --- 1. environment flags (must be set *before* importing PL) -----------
import os
os.environ["PL_DISABLE_MIXED_IMPORTS"] = "1"   # use *only* pytorch_lightning
os.environ["TORCH_NAN_INF_CHECK"]    = "1"     # raise if any NaN/Inf in fwd/bwd

# --- 2. standard & utility packages ------------------------------------
import sys, shutil, zipfile, csv
from datetime import datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import json
import rasterio
from skimage.metrics import peak_signal_noise_ratio as psnr

# --- 3. PyTorch core ----------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split

# --- 4. PyTorch Lightning (legacy namespace only) ----------------------
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback

# --- 5. Colab conveniences (only if you’re in Colab) -------------------
try:
    from google.colab import drive, files
except ImportError:
    drive = files = None        # not running in Colab → ignore


In [3]:
# Mount Google Drive
drive.mount('/content/drive')

# Paths
drive_tif_path = "/content/drive/MyDrive/final_project/ONLY_TIF"
local_tif_path = "/content/ONLY_TIF"

# If the folder already exists in local runtime, remove it first
if os.path.exists(local_tif_path):
    shutil.rmtree(local_tif_path)

# Copy entire folder from Drive to local runtime
shutil.copytree(drive_tif_path, local_tif_path)

print(f"Copied entire ONLY_TIF folder to: {local_tif_path}")

Mounted at /content/drive
Copied entire ONLY_TIF folder to: /content/ONLY_TIF


In [4]:
def make_triplet_csv(source_dir, output_csv_path):
    data = []
    id_counter = 1

    for root, _, files in os.walk(source_dir):
        files = [f for f in files if f.lower().endswith('.tif')]
        if not files:
            continue

        goes1 = goes2 = viirs = None
        for f in files:
            f_lower = f.lower()
            full_path = os.path.join(root, f)
            rel_path = os.path.relpath(full_path, source_dir)

            if 'geo16' in f_lower:
                goes1 = rel_path
            elif 'geo17' in f_lower:
                goes2 = rel_path
            elif 'geo18' in f_lower and goes2 is None:
                goes2 = rel_path
            elif 'combined' in f_lower:
                viirs = rel_path

        if goes1 and goes2 and viirs:
            data.append({
                'id': id_counter,
                'goes1_path': os.path.join(source_dir, goes1),
                'goes2_path': os.path.join(source_dir, goes2),
                'viirs_path': os.path.join(source_dir, viirs),
            })
            id_counter += 1

    df = pd.DataFrame(data)
    df.to_csv(output_csv_path, index=False)
    print(f"✅ CSV saved to {output_csv_path} with {len(df)} records.")

# Define paths
source_dir = "/content/ONLY_TIF"
output_csv_path = "/content/superres_triplets.csv"

make_triplet_csv(source_dir, output_csv_path)


✅ CSV saved to /content/superres_triplets.csv with 1260 records.


# clone

In [5]:
# sys.path.append("C:/Users/97254/OneDrive - post.bgu.ac.il/Desktop/code4finalproj/SwinIR-main")
# Clone the repo
!git clone https://github.com/JingyunLiang/SwinIR.git

# Manually install necessary dependencies
!pip install timm einops

Cloning into 'SwinIR'...
remote: Enumerating objects: 333, done.[K
remote: Counting objects: 100% (10/10), done.[K
remote: Compressing objects: 100% (8/8), done.[K
remote: Total 333 (delta 6), reused 2 (delta 2), pack-reused 323 (from 2)[K
Receiving objects: 100% (333/333), 29.84 MiB | 22.43 MiB/s, done.
Resolving deltas: 100% (119/119), done.


In [6]:
%cd SwinIR

# Download the grayscale denoising pre-trained weights
!mkdir -p experiments/pretrained_models
!wget https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/004_grayDN_DFWB_s128w8_SwinIR-M_noise15.pth -P experiments/pretrained_models

/content/SwinIR
--2025-05-16 09:24:54--  https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/004_grayDN_DFWB_s128w8_SwinIR-M_noise15.pth
Resolving github.com (github.com)... 140.82.113.4
Connecting to github.com (github.com)|140.82.113.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/396770997/44b18cfe-3817-49c6-aed0-9f8912acb152?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=releaseassetproduction%2F20250516%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250516T092454Z&X-Amz-Expires=300&X-Amz-Signature=ede8d4d67f1255b53b400af12cd5c60357ad7be8afdb1cbea8393d1ff4cabcba&X-Amz-SignedHeaders=host&response-content-disposition=attachment%3B%20filename%3D004_grayDN_DFWB_s128w8_SwinIR-M_noise15.pth&response-content-type=application%2Foctet-stream [following]
--2025-05-16 09:24:54--  https://objects.githubusercontent.com/github-production-release-asset-2e65be/396770997/44b18cfe

# dataset

In [7]:
class SatelliteImageDataset(Dataset):
    def __init__(self, csv_path, json_path, transform=None):
        """
        Dataset for satellite image super-resolution

        Args:
            csv_path: Path to CSV file with image triplets
            json_path: Path to JSON file with normalization parameters
            transform: Optional transforms to apply to images
        """
        self.df = pd.read_csv(csv_path)

        # Load normalization parameters from JSON
        with open(json_path, 'r') as f:
            self.norm_params = json.load(f)

        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # Get paths for this sample
        sample = self.df.iloc[idx]
        goes1_path = sample['goes1_path']
        goes2_path = sample['goes2_path']
        viirs_path = sample['viirs_path']

        # Load images using rasterio
        goes1 = self._load_and_process_image(goes1_path, 'goes')
        goes2 = self._load_and_process_image(goes2_path, 'goes')
        viirs = self._load_and_process_image(viirs_path, 'viirs')

        # Apply additional transforms if specified
        if self.transform:
            goes1 = self.transform(goes1)
            goes2 = self.transform(goes2)
            viirs = self.transform(viirs)

        # Return (goes1, goes2), viirs format
        return (goes1, goes2), viirs

    def _load_and_process_image(self, path, img_type):
        """Load, clip and normalize an image"""
        with rasterio.open(path) as src:
            # Use band 7 for GOES and band 1 for VIIRS
            band_idx = 7 if img_type == 'goes' else 1
            img = src.read(band_idx)

        # Handle NaN and Inf values
        mask = ~(np.isnan(img) | np.isinf(img))
        if np.any(mask):
            mean_val = img[mask].mean()
            img = np.where(mask, img, mean_val)
        else:
            img = np.zeros_like(img)

        # Get correct normalization values from JSON
        # JSON has percentile values for GOES and VIIRS
        sat_type = "GOES" if img_type == "goes" else "VIIRS"
        min_val = self.norm_params[sat_type]["p2"]
        max_val = self.norm_params[sat_type]["p98"]

        # Clip values based on image type
        img = np.clip(img, min_val, max_val)

        # # Normalize to [0, 1]
        img = (img - min_val) / (max_val - min_val)

        # Convert to PyTorch tensor and add channel dimension
        img_tensor = torch.from_numpy(img).float().unsqueeze(0)

        return img_tensor



class SatelliteDataModule(pl.LightningDataModule):
    def __init__(
        self,
        csv_path,
        json_path,
        batch_size=8,
        num_workers=0,
        transform=None
    ):
        """
        PyTorch Lightning DataModule for satellite super-resolution
        Uses all available data for training only.

        Args:
            csv_path: Path to CSV file with image triplets
            json_path: Path to JSON file with normalization parameters
            batch_size: Batch size for dataloader
            num_workers: Number of worker processes for data loading
            transform: Optional transforms to apply
        """
        super().__init__()
        self.csv_path = csv_path
        self.json_path = json_path
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.transform = transform
        self.save_hyperparameters(ignore=['transform'])

    def setup(self, stage=None):
        # Create dataset with all available images
        self.train_dataset = SatelliteImageDataset(
            csv_path=self.csv_path,
            json_path=self.json_path,
            transform=self.transform
        )

        print(f"Training dataset ready with {len(self.train_dataset)} samples")

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True
        )

    def val_dataloader(self):
        # No validation data
        return None

    def test_dataloader(self):
        # No test data
        return None


# callbacks

In [27]:
class VisualizePredictionCallback(Callback):
    def __init__(self, goes1_path, goes2_path, viirs_path, every_n_epochs=1):
        super().__init__()
        self.goes1_path = goes1_path
        self.goes2_path = goes2_path
        self.viirs_path = viirs_path
        self.every_n_epochs = every_n_epochs

        # Load visualization scaling values from JSON
        with open("/content/radiance_visualization_ranges.json", "r") as f:
            self.ranges = json.load(f)

        # Make output directory with timestamp
        timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        self.output_dir = os.path.join("/content/checkpoints", f"visual_{timestamp}")
        os.makedirs(self.output_dir, exist_ok=True)

    def _load_and_normalize_image(self, path, band=1, is_viirs=False):
        """Load and normalize an image using same method as dataset"""
        with rasterio.open(path) as src:
            image = src.read(band).astype(np.float32)

        # Handle NaN/Inf values
        mask = ~(np.isnan(image) | np.isinf(image))
        if np.any(mask):
            mean_val = image[mask].mean()
            image = np.where(mask, image, mean_val)
        else:
            image = np.zeros_like(image)

        # Normalize using same method as SatelliteImageDataset
        sat_type = "VIIRS" if is_viirs else "GOES"
        min_val = self.ranges[sat_type]["p2"]
        max_val = self.ranges[sat_type]["p98"]

        # Clip and normalize to [0,1]
        image = np.clip(image, min_val, max_val)
        image = (image - min_val) / (max_val - min_val)

        return image

    def on_train_epoch_end(self, trainer, pl_module):
        epoch = trainer.current_epoch
        if epoch % self.every_n_epochs != 0:
            return

        # Load and normalize images the same way as dataset
        goes1_img = self._load_and_normalize_image(self.goes1_path, band=7)
        goes2_img = self._load_and_normalize_image(self.goes2_path, band=7)
        viirs_img = self._load_and_normalize_image(self.viirs_path, band=1, is_viirs=True)

        # Convert to tensors
        goes1_tensor = torch.from_numpy(goes1_img).unsqueeze(0).unsqueeze(0).to(pl_module.device)
        goes2_tensor = torch.from_numpy(goes2_img).unsqueeze(0).unsqueeze(0).to(pl_module.device)

        # Run prediction
        with torch.no_grad():
            predicted = pl_module(goes1_tensor, goes2_tensor).squeeze().cpu().numpy()

            # Debug information
            print(f"Prediction stats - min: {predicted.min():.6f}, max: {predicted.max():.6f}, mean: {predicted.mean():.6f}")
            print(f"Prediction shape: {predicted.shape}, VIIRS shape: {viirs_img.shape}")

            # If prediction is all near zero, amplify for visualization
            if predicted.max() < 0.1:
                print("Warning: Prediction values are very small - amplifying for visualization")
                # Try to amplify signal for visualization without changing actual model output
                viz_predicted = predicted.copy()
                if viz_predicted.max() > 0:
                    viz_predicted = viz_predicted / viz_predicted.max()  # Normalize to [0,1] for visibility
                else:
                    viz_predicted = predicted  # If all zeros, don't change
            else:
                viz_predicted = predicted

            # Handle NaN in prediction
            viz_predicted = np.nan_to_num(viz_predicted, nan=0.0, posinf=1.0, neginf=0.0)
            # Ensure prediction is in [0,1] range
            viz_predicted = np.clip(viz_predicted, 0, 1)

        # Plot all 4 images side by side
        fig, axs = plt.subplots(1, 4, figsize=(20, 5))
        titles = ["GOES-1", "GOES-2", "VIIRS (GT)", "Predicted"]
        images = [goes1_img, goes2_img, viirs_img, viz_predicted]

        for ax, img, title in zip(axs, images, titles):
            im = ax.imshow(img, cmap="gray", vmin=0, vmax=1.0)
            ax.set_title(title)
            ax.axis("off")
            plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

        plt.suptitle(f"Epoch {epoch}", fontsize=16)
        plt.tight_layout()
        save_path = os.path.join(self.output_dir, f"epoch_{epoch:03d}.png")
        plt.savefig(save_path)
        plt.close()
        print(f"✅ Saved visualization to {save_path}")






In [28]:
class PSNRValidationCallback(Callback):
    """
    Compute corrected PSNR (Kelvin metric) on a fixed month of triplets.
    Logs cPSNR scores per epoch into a CSV in vis_callback.output_dir
    """

    def __init__(self, vis_callback, val_month_dir, every_n_epochs=1):
        super().__init__()
        self.vis_callback   = vis_callback
        self.val_month_dir  = val_month_dir
        self.every_n_epochs = every_n_epochs

        # Read radiometric range for VIIRS
        with open("/content/radiance_visualization_ranges.json") as f:
            rng = json.load(f)
        self.vi_min = rng["VIIRS"]["p2"]
        self.vi_rng = rng["VIIRS"]["p98"] - self.vi_min

        self.goes_min = rng["GOES"]["p2"]
        self.goes_rng = rng["GOES"]["p98"] - self.goes_min

        # Collect validation triplets
        self.triplets = self._collect_triplets(val_month_dir)
        if not self.triplets:
            print("[WARNING] No validation triplets found — skipping PSNR computation.")

        # CSV output path
        self.csv_path = os.path.join(self.vis_callback.output_dir, "cpsnr_log.csv")
        # Write header if file doesn't exist yet
        if not os.path.exists(self.csv_path):
            with open(self.csv_path, "w", newline="") as f:
                writer = csv.writer(f)
                writer.writerow(["epoch", "cpsnr"])

    @staticmethod
    def _collect_triplets(root):
        out = []
        for cur, _, files in os.walk(root):
            files = [f for f in files if f.lower().endswith(".tif")]
            if not files:
                continue
            g1 = g2 = v = None
            for f in files:
                p = os.path.join(cur, f)
                lf = f.lower()
                if "geo16" in lf:
                    g1 = p
                elif "geo17" in lf or ("geo18" in lf and g2 is None):
                    g2 = p
                elif "viirs" in lf or "combined" in lf:
                    v = p
            if g1 and g2 and v:
                out.append((g1, g2, v))
            else:
                print(f"[WARNING] Incomplete triplet in {cur} → g1: {bool(g1)}, g2: {bool(g2)}, v: {bool(v)}")
        return out

    @staticmethod
    def _load_band(path):
        is_viirs   = "viirs" in path.lower() or "combined" in path.lower()
        band_index = 1 if is_viirs else 7
        with rasterio.open(path) as src:
            img = src.read(band_index).astype(np.float32)
        m = ~(np.isnan(img) | np.isinf(img))
        return np.where(m, img, img[m].mean() if m.any() else 0.0)

    @staticmethod
    def cpsnr(gt: np.ndarray, pred: np.ndarray, mask: np.ndarray) -> float:
        diff = (gt - pred) * mask
        b    = diff.sum() / (mask.sum() + 1e-8)
        cmse = ((gt - pred + b) ** 2 * mask).sum() / (mask.sum() + 1e-8)
        # print(f"cMSE: {cmse:.4f}, bias: {b:.4f}, cPSNR: {-10.0 * np.log10(cmse + 1e-8):.2f}")
        return -10.0 * np.log10(cmse + 1e-8)

    def on_train_epoch_end(self, trainer, pl_module):
        epoch = trainer.current_epoch
        if (epoch + 1) % self.every_n_epochs:
            return

        psnrs = []

        for g1_path, g2_path, v_path in self.triplets:
            g1 = self._load_band(g1_path)
            g2 = self._load_band(g2_path)
            vi = self._load_band(v_path)

            # Normalize ground truth VIIRS and geo pictures
            vi_scaled = np.clip((vi - self.vi_min) / self.vi_rng, 0, 1)
            g1 = np.clip((g1 - self.goes_min) / self.goes_rng, 0, 1)
            g2 = np.clip((g2 - self.goes_min) / self.goes_rng, 0, 1)


            # Predict using scaled GOES inputs
            pred = pl_module(
                torch.from_numpy(g1)[None, None].to(pl_module.device),
                torch.from_numpy(g2)[None, None].to(pl_module.device)
            ).squeeze().detach().cpu().numpy()

            # print("GT min/max:", vi_scaled.min(), vi_scaled.max())
            # print("Pred min/max:", pred.min(), pred.max())
            # print("GOES1 min/max:", g1.min(), g1.max())
            # print("GOES2 min/max:", g2.min(), g2.max())


            psnrs.append(self.cpsnr(vi_scaled, pred, np.ones_like(vi_scaled)))

        mean_psnr = float(np.mean(psnrs))
        if not hasattr(pl_module, "psnr_scores"):
            pl_module.psnr_scores = []
        pl_module.psnr_scores.append(mean_psnr)

        # Save PSNR plot
        plot_path = os.path.join(self.vis_callback.output_dir, f"psnr_curve_epoch_{epoch:03d}.png")
        plt.figure()
        plt.plot(pl_module.psnr_scores, marker='o')
        plt.title("Validation cPSNR")
        plt.xlabel("epoch")
        plt.ylabel("dB")
        plt.grid(True)
        plt.tight_layout()
        plt.savefig(plot_path)
        plt.close()

        # Append to CSV log
        with open(self.csv_path, "a", newline="") as f:
            writer = csv.writer(f)
            writer.writerow([epoch, round(mean_psnr, 4)])

        print(f"📈  epoch {epoch:03d}  mean cPSNR: {mean_psnr:.2f} dB")

In [29]:
class LossTrackingCallback(pl.Callback):
    def __init__(self, vis_callback):
        super().__init__()
        self.vis_callback = vis_callback
        self.losses = []
        self.csv_path = os.path.join(vis_callback.output_dir, "loss_log.csv")

        # Write CSV header once
        if not os.path.exists(self.csv_path):
            with open(self.csv_path, "w", newline="") as f:
                writer = csv.writer(f)
                writer.writerow(["epoch", "loss"])

    def on_train_epoch_end(self, trainer, pl_module):
        epoch = trainer.current_epoch
        loss = float(trainer.callback_metrics["train_loss"])

        self.losses.append(loss)

        # Save curve
        plt.figure()
        plt.plot(self.losses, marker='o')
        plt.title("Training Loss per Epoch")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.grid(True)
        plt.tight_layout()
        plot_path = os.path.join(self.vis_callback.output_dir, f"loss_curve_epoch_{epoch:03d}.png")
        plt.savefig(plot_path)
        plt.close()

        # Save to CSV
        with open(self.csv_path, "a", newline="") as f:
            writer = csv.writer(f)
            writer.writerow([epoch, round(loss, 6)])

        print(f"📉  epoch {epoch:03d}  loss: {loss:.6f}")


# modeling

In [30]:

sys.path.append('/content/SwinIR')  # Add the repository to path
from models.network_swinir import SwinIR

def create_swinir_grayscale(
    pretrained_path,
    img_size=100,
    upscale=4
):
    # Find suitable window size
    window_size = 10  # Use fixed window size for now

    # Create model specifically for grayscale SR
    model = SwinIR(
        upscale=upscale,      # 4x upscaling
        in_chans=1,           # 1 channel input
        img_size=img_size,    # 100x100 input
        window_size=window_size,
        img_range=1.,         # Normalized range [0,1]
        depths=[6, 6, 6, 6, 6, 6],
        embed_dim=180,
        num_heads=[6, 6, 6, 6, 6, 6],
        mlp_ratio=2,
        upsampler='nearest+conv',  # Critical for SR
        resi_connection='1conv'
    )

    # Load pretrained weights but DON'T try strict loading
    pretrained = torch.load(pretrained_path, map_location='cpu')
    if 'params' in pretrained:
        pretrained = pretrained['params']

    # Only load compatible parameters
    model_dict = model.state_dict()
    pretrained_dict = {k: v for k, v in pretrained.items()
                      if k in model_dict and v.shape == model_dict[k].shape}

    # Load matching parameters
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict, strict=False)

    print(f"Loaded {len(pretrained_dict)}/{len(model_dict)} parameters from {pretrained_path}")

    return model

In [31]:
class SwinIRGrayscaleLightningModule(pl.LightningModule):
    def __init__(self, lr=1e-4):
        super().__init__()

        # Create a single SwinIR model with upscaling (upscale=4)
        self.swinir_model = create_swinir_grayscale(
            pretrained_path='experiments/pretrained_models/004_grayDN_DFWB_s128w8_SwinIR-M_noise15.pth',
            img_size=100,
            upscale=4  # Keep upscaling in the model
        )

        # Create a deeper fusion network that works on upscaled inputs (400x400)
        self.fusion = nn.Sequential(
            # Initial fusion of both high-res inputs
            nn.Conv2d(2, 64, kernel_size=3, padding=1),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),

            # Deeper processing
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),

            # Final output layers
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv2d(64, 1, kernel_size=3, padding=1),
            nn.Sigmoid()
        )

        self.criterion1 = nn.MSELoss()
        self.criterion2 = nn.L1Loss()
        self.lr = lr

    def forward(self, goes1, goes2):
        # Process both inputs through the same SwinIR model (with upscaling)
        goes1_upscaled = self.swinir_model(goes1)  # [B, 1, 400, 400]
        goes2_upscaled = self.swinir_model(goes2)  # [B, 1, 400, 400]

        # Concatenate the upscaled features and apply fusion
        x = torch.cat([goes1_upscaled, goes2_upscaled], dim=1)  # [B, 2, 400, 400]
        output = self.fusion(x)  # [B, 1, 400, 400]

        return output

    def training_step(self, batch, batch_idx):
        (goes1, goes2), viirs = batch

        # Forward pass
        output = self(goes1, goes2)

        # Ensure dimensions match
        if output.shape[2:] != viirs.shape[2:]:
            output = F.interpolate(output, size=viirs.shape[2:],
                                   mode='bilinear', align_corners=False)

        loss = self.criterion1(output, viirs)
        loss += self.criterion2(output, viirs)
        # Add out-of-range penalty
        # penalty = torch.mean(torch.relu(output - 1.0) ** 2 + torch.relu(-output) ** 2)
        # loss += 0.1 * penalty  # Tune the 0.1 weight
        self.log('train_loss', loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

# experiment

In [32]:
# Create visualization callback
vis_callback = VisualizePredictionCallback(
    goes1_path="/content/ONLY_TIF/2020-11/2020-11-01_20-12/clipped_geo16.tif",
    goes2_path="/content/ONLY_TIF/2020-11/2020-11-01_20-12/clipped_geo17.tif",
    viirs_path="/content/ONLY_TIF/2020-11/2020-11-01_20-12/combined_clip.tif",
    every_n_epochs=1
)

# Create PSNR validation callback
psnr_callback = PSNRValidationCallback(
    vis_callback=vis_callback,
    val_month_dir="/content/ONLY_TIF/2023-02",
    every_n_epochs=1
)

loss_cb = LossTrackingCallback(vis_callback)



# Data module
datamodule = SatelliteDataModule(
    csv_path="/content/superres_triplets.csv",
        json_path="/content/radiance_visualization_ranges.json",
        batch_size=2,
        num_workers=3
)


In [33]:
# Create the model - using grayscale denoising weights with correct dimensions
pl_model = SwinIRGrayscaleLightningModule()

# Create trainer with callbacks
trainer = Trainer(
    max_epochs=15,
    accelerator='gpu',
    devices=1,
    precision=32,
    log_every_n_steps=10,
    callbacks=[psnr_callback, vis_callback, loss_cb]
)

# Train model
trainer.fit(pl_model, datamodule=datamodule)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/configuration_validator.py:72: You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Loaded 453/552 parameters from experiments/pretrained_models/004_grayDN_DFWB_s128w8_SwinIR-M_noise15.pth
Training dataset ready with 1260 samples


INFO:pytorch_lightning.callbacks.model_summary:
  | Name         | Type       | Params
--------------------------------------------
0 | swinir_model | SwinIR     | 11.7 M
1 | fusion       | Sequential | 333 K 
2 | criterion1   | MSELoss    | 0     
3 | criterion2   | L1Loss     | 0     
--------------------------------------------
12.1 M    Trainable params
0         Non-trainable params
12.1 M    Total params
48.298    Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

📈  epoch 000  mean cPSNR: 7.38 dB
Prediction stats - min: 0.441323, max: 0.677713, mean: 0.631059
Prediction shape: (400, 400), VIIRS shape: (400, 400)
✅ Saved visualization to /content/checkpoints/visual_2025-05-16_09-51-35/epoch_000.png
📉  epoch 000  loss: 0.243662
📈  epoch 001  mean cPSNR: 9.46 dB
Prediction stats - min: 0.419201, max: 0.709854, mean: 0.556045
Prediction shape: (400, 400), VIIRS shape: (400, 400)
✅ Saved visualization to /content/checkpoints/visual_2025-05-16_09-51-35/epoch_001.png
📉  epoch 001  loss: 0.299726
📈  epoch 002  mean cPSNR: 10.67 dB
Prediction stats - min: 0.366161, max: 0.760957, mean: 0.612440
Prediction shape: (400, 400), VIIRS shape: (400, 400)
✅ Saved visualization to /content/checkpoints/visual_2025-05-16_09-51-35/epoch_002.png
📉  epoch 002  loss: 0.174907
📈  epoch 003  mean cPSNR: 11.19 dB
Prediction stats - min: 0.420831, max: 0.671676, mean: 0.564205
Prediction shape: (400, 400), VIIRS shape: (400, 400)
✅ Saved visualization to /content/checkpoi

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=15` reached.


In [34]:
# prompt: make a script that will zip the checkpoints folder and then download the folder

# Zip the checkpoints folder
zip_filename = "checkpoints.zip"
!zip -r {zip_filename} /content/checkpoints

# Download the zip file
files.download(zip_filename)


updating: content/checkpoints/ (stored 0%)
updating: content/checkpoints/visual_2025-05-16_09-25-01/ (stored 0%)
updating: content/checkpoints/visual_2025-05-16_09-25-01/loss_log.csv (stored 0%)
updating: content/checkpoints/visual_2025-05-16_09-25-01/psnr_curve_epoch_000.png (deflated 24%)
updating: content/checkpoints/visual_2025-05-16_09-25-01/loss_curve_epoch_000.png (deflated 23%)
updating: content/checkpoints/visual_2025-05-16_09-25-01/cpsnr_log.csv (stored 0%)
updating: content/checkpoints/visual_2025-05-16_09-25-01/epoch_000.png (deflated 4%)
  adding: content/checkpoints/visual_2025-05-16_09-51-35/ (stored 0%)
  adding: content/checkpoints/visual_2025-05-16_09-51-35/loss_curve_epoch_006.png (deflated 8%)
  adding: content/checkpoints/visual_2025-05-16_09-51-35/psnr_curve_epoch_006.png (deflated 10%)
  adding: content/checkpoints/visual_2025-05-16_09-51-35/loss_log.csv (deflated 39%)
  adding: content/checkpoints/visual_2025-05-16_09-51-35/epoch_001.png (deflated 4%)
  adding: 

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>