#initaling things up

In [1]:
!pip install pytorch-lightning==2.2.0
!pip install rasterio
!pip install lightning



In [2]:
# --- 1. environment flags (must be set *before* importing PL) -----------
import os
os.environ["PL_DISABLE_MIXED_IMPORTS"] = "1"   # use *only* pytorch_lightning
os.environ["TORCH_NAN_INF_CHECK"]    = "1"     # raise if any NaN/Inf in fwd/bwd

# --- 2. standard & utility packages ------------------------------------
import sys, shutil, zipfile, csv
from datetime import datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import json
import rasterio
from skimage.metrics import peak_signal_noise_ratio as psnr

# --- 3. PyTorch core ----------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split

# --- 4. PyTorch Lightning (legacy namespace only) ----------------------
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback

# --- 5. Colab conveniences (only if you’re in Colab) -------------------
try:
    from google.colab import drive, files
except ImportError:
    drive = files = None        # not running in Colab → ignore


In [2]:


# Mount Google Drive
drive.mount('/content/drive')

# Paths
drive_tif_path = "/content/drive/MyDrive/final_project/ONLY_TIF"
local_tif_path = "/content/ONLY_TIF"

# If the folder already exists in local runtime, remove it first
if os.path.exists(local_tif_path):
    shutil.rmtree(local_tif_path)

# Copy entire folder from Drive to local runtime
shutil.copytree(drive_tif_path, local_tif_path)

print(f"Copied entire ONLY_TIF folder to: {local_tif_path}")


Mounted at /content/drive
Copied entire ONLY_TIF folder to: /content/ONLY_TIF


In [3]:

def make_triplet_csv(source_dir, output_csv_path):
    data = []
    id_counter = 1

    for root, _, files in os.walk(source_dir):
        files = [f for f in files if f.lower().endswith('.tif')]
        if not files:
            continue

        goes1 = goes2 = viirs = None
        for f in files:
            f_lower = f.lower()
            full_path = os.path.join(root, f)
            rel_path = os.path.relpath(full_path, source_dir)

            if 'geo16' in f_lower:
                goes1 = rel_path
            elif 'geo17' in f_lower:
                goes2 = rel_path
            elif 'geo18' in f_lower and goes2 is None:
                goes2 = rel_path
            elif 'combined' in f_lower:
                viirs = rel_path

        if goes1 and goes2 and viirs:
            data.append({
                'id': id_counter,
                'goes1_path': os.path.join(source_dir, goes1),
                'goes2_path': os.path.join(source_dir, goes2),
                'viirs_path': os.path.join(source_dir, viirs),
            })
            id_counter += 1

    df = pd.DataFrame(data)
    df.to_csv(output_csv_path, index=False)
    print(f"✅ CSV saved to {output_csv_path} with {len(df)} records.")

# Define paths
source_dir = "/content/ONLY_TIF"
output_csv_path = "/content/superres_triplets.csv"

make_triplet_csv(source_dir, output_csv_path)


✅ CSV saved to /content/superres_triplets.csv with 1260 records.


# cloning the repo and downloding the model

clones

In [5]:
# sys.path.append("C:/Users/97254/OneDrive - post.bgu.ac.il/Desktop/code4finalproj/SwinIR-main")
# Clone the repo
!git clone https://github.com/JingyunLiang/SwinIR.git

# Manually install necessary dependencies
!pip install timm einops

fatal: destination path 'SwinIR' already exists and is not an empty directory.


zips

In [6]:
# prompt: i need you to zip SwinIR folder and download it to my pc

!zip -r /content/SwinIR.zip /content/SwinIR

files.download("/content/SwinIR.zip")


  adding: content/SwinIR/ (stored 0%)
  adding: content/SwinIR/models/ (stored 0%)
  adding: content/SwinIR/models/network_swinir.py (deflated 80%)
  adding: content/SwinIR/.git/ (stored 0%)
  adding: content/SwinIR/.git/info/ (stored 0%)
  adding: content/SwinIR/.git/info/exclude (deflated 28%)
  adding: content/SwinIR/.git/config (deflated 30%)
  adding: content/SwinIR/.git/description (deflated 14%)
  adding: content/SwinIR/.git/refs/ (stored 0%)
  adding: content/SwinIR/.git/refs/remotes/ (stored 0%)
  adding: content/SwinIR/.git/refs/remotes/origin/ (stored 0%)
  adding: content/SwinIR/.git/refs/remotes/origin/HEAD (stored 0%)
  adding: content/SwinIR/.git/refs/tags/ (stored 0%)
  adding: content/SwinIR/.git/refs/heads/ (stored 0%)
  adding: content/SwinIR/.git/refs/heads/main (stored 0%)
  adding: content/SwinIR/.git/HEAD (stored 0%)
  adding: content/SwinIR/.git/objects/ (stored 0%)
  adding: content/SwinIR/.git/objects/info/ (stored 0%)
  adding: content/SwinIR/.git/objects/pac

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

unzips

In [None]:
# prompt: i need a code that unzips the SwinIR zip



# Specify the path to the zip file
zip_file_path = "/content/SwinIR.zip"  # Replace with the actual path

# Specify the directory to extract the contents to
extract_dir = "/content/SwinIR_extracted"  # Replace with your desired directory


try:
    with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
        zip_ref.extractall(extract_dir)
    print(f"Successfully extracted '{zip_file_path}' to '{extract_dir}'")
except FileNotFoundError:
    print(f"Error: Zip file '{zip_file_path}' not found.")
except zipfile.BadZipFile:
    print(f"Error: Invalid zip file '{zip_file_path}'.")
except Exception as e:
    print(f"An unexpected error occurred: {e}")


In [5]:
%cd SwinIR

# Download the pre-trained weights
!mkdir -p experiments/pretrained_models
!wget https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.pth -P experiments/pretrained_models

/content/SwinIR
--2025-04-22 13:26:50--  https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.pth
Resolving github.com (github.com)... 140.82.113.3
Connecting to github.com (github.com)|140.82.113.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/396770997/f3c0fbd1-d787-49f1-924a-8939e9a6707c?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=releaseassetproduction%2F20250422%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250422T132650Z&X-Amz-Expires=300&X-Amz-Signature=d92f7624602d2fe9777a776c799998ed47fa1affcd00e0be3d03bd2cff4b5d83&X-Amz-SignedHeaders=host&response-content-disposition=attachment%3B%20filename%3D003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.pth&response-content-type=application%2Foctet-stream [following]
--2025-04-22 13:26:50--  https://objects.githubusercontent.com/github-production-release-asset-2e65be/396770997/

# dataset and loaders

In [6]:


class SatelliteTripletDataset(Dataset):
    def __init__(self, csv_file, json_range_path="radiance_visualization_ranges.json"):
        self.data_info = pd.read_csv(csv_file)

        with open(json_range_path, "r") as f:
            self.global_ranges = json.load(f)

    def __len__(self):
        return len(self.data_info)

    def __getitem__(self, idx):
        record = self.data_info.iloc[idx]
        goes1 = self.load_raw_radiance(record['goes1_path'])
        goes2 = self.load_raw_radiance(record['goes2_path'])
        viirs = self.load_raw_radiance(record['viirs_path'])
        return (goes1, goes2), viirs

    def load_raw_radiance(self, path):
        filename = os.path.basename(path).lower()
        is_viirs = "viirs" in filename or "combined_clip" in filename
        sensor_type = "VIIRS" if is_viirs else "GOES"
        band_index = 1 if is_viirs else 7

        with rasterio.open(path) as src:
            if src.count < band_index:
                raise ValueError(f"{path} does not contain band {band_index}")
            image = src.read(band_index).astype(np.float32)

        # Handle NaN/Inf values
        mask = ~(np.isnan(image) | np.isinf(image))
        if not np.any(mask):
            image = np.zeros_like(image)
        else:
            mean_val = image[mask].mean()
            image = np.where(mask, image, mean_val)

        # 🔁 Use global clipping range from JSON
        p_low = self.global_ranges[sensor_type]["p2"]
        p_high = self.global_ranges[sensor_type]["p98"]
        image = np.clip(image, p_low, p_high)

        return torch.from_numpy(image).unsqueeze(0)


In [7]:


class SatelliteDataModule(pl.LightningDataModule):
    def __init__(self, csv_file, batch_size=4, num_workers=0, percentile_range=(2, 98)):
        super().__init__()  # Correct super() call
        self.csv_file = csv_file
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.percentile_range = percentile_range

    def setup(self, stage=None):
        self.train_dataset = SatelliteTripletDataset(
            csv_file=self.csv_file,
            json_range_path="/content/radiance_visualization_ranges.json"
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True
        )

# callbacks

In [8]:



class VisualizePredictionCallback(Callback):
    def __init__(self, goes1_path, goes2_path, viirs_path, every_n_epochs=1):
        super().__init__()
        self.goes1_path = goes1_path
        self.goes2_path = goes2_path
        self.viirs_path = viirs_path
        self.every_n_epochs = every_n_epochs

        # Load visualization scaling values from JSON
        with open("/content/radiance_visualization_ranges.json", "r") as f:
            self.ranges = json.load(f)

        # Make output directory with timestamp
        timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        self.output_dir = os.path.join("/content/checkpoints", f"visual_{timestamp}")
        os.makedirs(self.output_dir, exist_ok=True)

    def scale_for_visualization(self, image, sensor_type=None):
        p1 = np.percentile(image, 2)
        p99 = np.percentile(image, 98)
        return np.clip((image - p1) / (p99 - p1), 0, 1.0)

    def load_image(self, path, band=1):
        with rasterio.open(path) as src:
            image = src.read(band).astype(np.float32)
            mask = ~(np.isnan(image) | np.isinf(image))
            if np.any(mask):
                mean_val = image[mask].mean()
                image = np.where(mask, image, mean_val)
            else:
                image = np.zeros_like(image)
            return image

    def on_train_epoch_end(self, trainer, pl_module):
        epoch = trainer.current_epoch
        if epoch % self.every_n_epochs != 0:
            return

        # Load GOES band 7 and VIIRS band 1
        goes1_img = self.load_image(self.goes1_path, band=7)
        goes2_img = self.load_image(self.goes2_path, band=7)
        viirs_img = self.load_image(self.viirs_path, band=1)

        # Convert to tensors
        goes1_tensor = torch.from_numpy(goes1_img).unsqueeze(0).unsqueeze(0).to(pl_module.device)
        goes2_tensor = torch.from_numpy(goes2_img).unsqueeze(0).unsqueeze(0).to(pl_module.device)

        # Run prediction
        with torch.no_grad():
            predicted = pl_module(goes1_tensor, goes2_tensor).squeeze().cpu().numpy()

        # Scale all images for visualization
        g1_viz = self.scale_for_visualization(goes1_img, "GOES")
        g2_viz = self.scale_for_visualization(goes2_img, "GOES")
        viirs_viz = self.scale_for_visualization(viirs_img, "VIIRS")
        pred_viz = self.scale_for_visualization(predicted, "VIIRS")

        # Plot all 4 images side by side
        fig, axs = plt.subplots(1, 4, figsize=(20, 5))
        titles = ["GOES-1", "GOES-2", "VIIRS (GT)", "Predicted"]
        images = [g1_viz, g2_viz, viirs_viz, pred_viz]

        for ax, img, title in zip(axs, images, titles):
            im = ax.imshow(img, cmap="gray", vmin=0, vmax=1.0)  # ⬅️ Changed to grayscale
            ax.set_title(title)
            ax.axis("off")
            plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

        plt.suptitle(f"Epoch {epoch}", fontsize=16)
        plt.tight_layout()
        save_path = os.path.join(self.output_dir, f"epoch_{epoch:03d}.png")
        plt.savefig(save_path)
        plt.close()
        print(f"✅ Saved visualization to {save_path}")



In [9]:
class PSNRValidationCallback(Callback):
    """
    Compute corrected PSNR (Kelvin metric) on a fixed month of triplets.
    * uses the same helpers as the baseline script              (already defined)
    * model gets raw GOES‑16 / GOES‑17(18) images (no bicubic up‑scale)
    * VIIRS GT and model prediction are both high‑res → same shape
    """

    # ------------------------------------------------------------------ #
    def __init__(self, vis_callback, val_month_dir, every_n_epochs=1):
        super().__init__()
        self.vis_callback   = vis_callback
        self.val_month_dir  = val_month_dir
        self.every_n_epochs = every_n_epochs

        # read radiometric ranges once
        with open("/content/radiance_visualization_ranges.json") as f:
            rng = json.load(f)
        self.vi_min = rng["VIIRS"]["p2"]
        self.vi_rng = rng["VIIRS"]["p98"] - self.vi_min

        # collect (goes1, goes2, viirs) triplets once
        self.triplets = self._collect_triplets(val_month_dir)

    # ------------------------------------------------------------------ #
    @staticmethod
    def _collect_triplets(root):
        out = []
        for cur, _, files in os.walk(root):
            files = [f for f in files if f.lower().endswith(".tif")]
            if not files:
                continue
            g1 = g2 = v = None
            for f in files:
                p  = os.path.join(cur, f)
                lf = f.lower()
                if "geo16" in lf:
                    g1 = p
                elif "geo17" in lf or ("geo18" in lf and g2 is None):
                    g2 = p
                elif "combined" in lf or "viirs" in lf:
                    v  = p
            if g1 and g2 and v:
                out.append((g1, g2, v))
        return out

    # ------------------------------------------------------------------ #
    @staticmethod
    def _load_band(path):
        is_viirs   = "viirs" in path.lower() or "combined" in path.lower()
        band_index = 1 if is_viirs else 7           # VIIRS‑I4 or GOES‑7
        with rasterio.open(path) as src:
            img = src.read(band_index).astype(np.float32)
        m = ~(np.isnan(img) | np.isinf(img))
        return np.where(m, img, img[m].mean() if m.any() else 0.0)

    @staticmethod
    def cpsnr(gt: np.ndarray, pred: np.ndarray, mask: np.ndarray) -> float:
        """Corrected PSNR (Kelvin): brightness‑bias + clear‑pixel mask."""
        diff = (gt - pred) * mask
        b    = diff.sum() / (mask.sum() + 1e-8)                 # brightness bias
        cmse = ((gt - pred + b) ** 2 * mask).sum() / (mask.sum() + 1e-8)
        return -10.0 * np.log10(cmse + 1e-8)

    # ------------------------------------------------------------------ #
    def on_train_epoch_end(self, trainer, pl_module):
        epoch = trainer.current_epoch
        if (epoch + 1) % self.every_n_epochs:
            return

        psnrs = []

        for g1_path, g2_path, v_path in self.triplets:
            g1   = self._load_band(g1_path)
            g2   = self._load_band(g2_path)
            vi   = self._load_band(v_path)              # H×W high‑res

            # forward pass – model expects raw GOES (low‑res) inputs
            pred = pl_module(torch.from_numpy(g1)[None, None].to(pl_module.device),
                             torch.from_numpy(g2)[None, None].to(pl_module.device)
                            ).squeeze().detach().cpu().numpy()   # H×W

            # ------------- normalise to [0,1] with VIIRS range -------------
            vi_n   = np.clip((vi   - self.vi_min) / self.vi_rng, 0, 1)
            pred_n = np.clip((pred - self.vi_min) / self.vi_rng, 0, 1)

            # ------------- clear‑pixel mask & Kelvin cPSNR -----------------
            psnrs.append(self.cpsnr(vi_n, pred_n, np.ones_like(vi_n)))

        mean_psnr = float(np.mean(psnrs))
        if not hasattr(pl_module, "psnr_scores"):
            pl_module.psnr_scores = []
        pl_module.psnr_scores.append(mean_psnr)

        # save a small curve in the same folder as visualisations
        plot_path = os.path.join(self.vis_callback.output_dir,
                                 f"psnr_curve_epoch_{epoch:03d}.png")
        plt.figure(); plt.plot(pl_module.psnr_scores, marker='o')
        plt.title("Validation cPSNR"); plt.xlabel("epoch"); plt.ylabel("dB")
        plt.grid(True); plt.tight_layout(); plt.savefig(plot_path); plt.close()

        print(f"📈  epoch {epoch:03d}  mean cPSNR: {mean_psnr:.2f} dB")

In [16]:
class MetricsAndCheckpointCallback(Callback):
    """tracks mean train‑loss & cPSNR, writes metrics.csv + loss plot,
       saves weights when cPSNR > threshold"""

    def __init__(self, vis_callback, cpsnr_threshold=10.5):
        super().__init__()
        self.dir = vis_callback.output_dir
        self.cpsnr_threshold = cpsnr_threshold
        self.loss_per_epoch = []
        self.cpsnr_per_epoch = []
        self._batch_losses = []

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        try:
            # Handle different output types
            if isinstance(outputs, torch.Tensor):
                loss_val = outputs.item()
            elif isinstance(outputs, dict) and 'loss' in outputs:
                loss_val = outputs['loss'].item()
            else:
                loss_val = float(outputs)  # Assume it's already a scalar

            self._batch_losses.append(loss_val)
        except Exception as e:
            print(f"❌ Error in on_train_batch_end: {e}")

    def on_train_epoch_end(self, trainer, pl_module):
        try:
            # Calculate mean loss
            if self._batch_losses:
                loss_mean = float(np.mean(self._batch_losses))
                self.loss_per_epoch.append(loss_mean)
                self._batch_losses.clear()
            else:
                print("⚠️ No batch losses recorded this epoch")
                loss_mean = 0.0
                self.loss_per_epoch.append(loss_mean)

            # Get cPSNR if available
            if hasattr(pl_module, 'psnr_scores') and pl_module.psnr_scores:
                cpsnr_mean = float(pl_module.psnr_scores[-1])
            else:
                print("⚠️ No psnr_scores available - initializing")
                if not hasattr(pl_module, 'psnr_scores'):
                    pl_module.psnr_scores = []
                cpsnr_mean = 0.0
                pl_module.psnr_scores.append(cpsnr_mean)

            self.cpsnr_per_epoch.append(cpsnr_mean)

            # Save checkpoint if cPSNR is above threshold
            if cpsnr_mean > self.cpsnr_threshold:
                fname = f"best_cpsnr_epoch_{trainer.current_epoch:03d}.pth"
                torch.save(pl_module.state_dict(), os.path.join(self.dir, fname))
                print(f"💾 Saved checkpoint – cPSNR {cpsnr_mean:.2f} dB")
        except Exception as e:
            print(f"❌ Error in on_train_epoch_end: {e}")
            import traceback
            traceback.print_exc()

    def on_fit_end(self, trainer, pl_module):
        try:
            # Save metrics to CSV
            csv_path = os.path.join(self.dir, "metrics.csv")
            with open(csv_path, "w", newline="") as f:
                w = csv.writer(f)
                w.writerow(["epoch", "train_loss", "cpsnr_db"])
                for e, l, p in zip(range(len(self.loss_per_epoch)),
                                  self.loss_per_epoch,
                                  self.cpsnr_per_epoch):
                    w.writerow([e, f"{l:.6f}", f"{p:.4f}"])
            print(f"✅ Metrics written → {csv_path}")

            # Plot loss curve
            plt.figure()
            plt.plot(self.loss_per_epoch, marker="o")
            plt.title("Mean Training Loss per Epoch")
            plt.xlabel("epoch"); plt.ylabel("loss"); plt.grid(True)
            plt.tight_layout()
            plt.savefig(os.path.join(self.dir, "loss_curve.png"))
            plt.close()
        except Exception as e:
            print(f"❌ Error in on_fit_end: {e}")
            import traceback
            traceback.print_exc()

# modeling SwinIR

In [17]:

sys.path.append('/content/SwinIR')  # Add the repository to path
from models.network_swinir import SwinIR

def create_modified_swinir(
    pretrained_path,
    in_chans=2,   # GOES-16 + GOES-17/18
    out_chans=1,  # VIIRS
    img_size=100  # Your image size
):
    # Find the largest window size that evenly divides the image size
    def find_divisible_window_size(img_size, max_window_size=16):
        for window_size in range(max_window_size, 0, -1):
            if img_size % window_size == 0:
                return window_size
        return 1  # Fallback to 1 if no divisible size found

    window_size = find_divisible_window_size(img_size)

    print(f"Adjusted window size to: {window_size}")

    model = SwinIR(
        upscale=4,            # 4x upscaling
        in_chans=in_chans,    # Input channels
        out_chans=out_chans,  # Output channels
        img_size=img_size,    # Set to your image size (100x100)
        window_size=window_size,  # Dynamically calculated window size
        img_range=1.,         # Normalized range [0,1]
        depths=[6, 6, 6, 6, 6, 6],  # Keep original depth
        embed_dim=180,        # Keep original embedding dimension
        num_heads=[6, 6, 6, 6, 6, 6],  # Keep original number of heads
        mlp_ratio=2,          # Keep original MLP ratio
        upsampler='nearest+conv',  # Use nearest+conv for real-world SR
        resi_connection='1conv'  # Keep original residual connection
    )

    # Rest of the weight loading code remains the same
    pretrained = torch.load(pretrained_path, map_location='cpu')

    if 'params_ema' in pretrained:
        pretrained = pretrained['params_ema']
    elif 'params' in pretrained:
        pretrained = pretrained['params']

    model_dict = model.state_dict()

    pretrained_dict = {k: v for k, v in pretrained.items()
                       if k in model_dict and v.shape == model_dict[k].shape}

    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict, strict=False)

    # Adapt input layer if needed
    if in_chans != 3:
        if 'conv_first.weight' in pretrained:
            with torch.no_grad():
                original_weight = pretrained['conv_first.weight']
                if in_chans == 2:
                    new_weight = original_weight[:, :2, :, :]
                    model.conv_first.weight.data = new_weight

    # ‑‑‑ adapt conv_last if you load RGB weights ‑‑‑
    if out_chans == 1 and 'conv_last.weight' in pretrained:
        with torch.no_grad():
            w_rgb = pretrained['conv_last.weight']      # shape [3, C, 3, 3]
            model.conv_last.weight.data = w_rgb.mean(dim=0, keepdim=True)

    print(f"Loaded {len(pretrained_dict)}/{len(model_dict)} parameters from pre-trained model")

    return model

# experiment with M

In [18]:
# Create the modified SwinIR model
# Usage
model = create_modified_swinir(
    pretrained_path='experiments/pretrained_models/003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.pth',
    in_chans=2,   # GOES-16 + GOES-17/18
    out_chans=1,  # VIIRS
    img_size=100  # Your image size
)

# Convert to Lightning module for training
class SwinIRLightningModule(pl.LightningModule):
    def __init__(self, model, lr=1e-4):
        super().__init__()
        self.model = model
        self.criterion = nn.L1Loss()
        self.lr = lr

    def forward(self, goes1, goes2):
        x = torch.cat((goes1, goes2), dim=1)  # Combine channels: [B, 2, H, W]
        return self.model(x)  # [B, 1, H*4, W*4]

    def training_step(self, batch, batch_idx):
        (goes1, goes2), viirs = batch
        output = self(goes1, goes2)

        # Ensure same spatial dimensions for loss calculation
        if output.shape[2:] != viirs.shape[2:]:
            output = F.interpolate(output, size=viirs.shape[2:],
                                  mode='bilinear', align_corners=False)

        loss = self.criterion(output, viirs)
        self.log('train_loss', loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

# Create Lightning module
pl_model = SwinIRLightningModule(model)

Adjusted window size to: 10
Loaded 459/552 parameters from pre-trained model


In [20]:


# Create visualization callback
vis_callback = VisualizePredictionCallback(
    goes1_path="/content/ONLY_TIF/2020-11/2020-11-01_20-12/clipped_geo16.tif",
    goes2_path="/content/ONLY_TIF/2020-11/2020-11-01_20-12/clipped_geo17.tif",
    viirs_path="/content/ONLY_TIF/2020-11/2020-11-01_20-12/combined_clip.tif",
    every_n_epochs=1  # Visualize every epoch
)

# Create PSNR validation callback
psnr_callback = PSNRValidationCallback(
    vis_callback=vis_callback,
    val_month_dir="/content/ONLY_TIF/2023-02",  # Directory containing validation month data
    every_n_epochs=1  # Compute PSNR every epoch
)

metrics_ckpt_cb = MetricsAndCheckpointCallback(vis_callback)

# Data module
datamodule = SatelliteDataModule(
    csv_file="/content/superres_triplets.csv",
    batch_size=3
)


# Create trainer with callbacks
trainer = Trainer(
    max_epochs=15,
    accelerator='gpu',
    devices=1,
    precision=32,  # Use mixed precision for efficiency
    log_every_n_steps=10,
    callbacks=[psnr_callback, vis_callback, metrics_ckpt_cb]  # Add callbacks here
)

# Train model
trainer.fit(pl_model, datamodule=datamodule)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name      | Type   | Params
-------------------------------------
0 | model     | SwinIR | 11.7 M
1 | criterion | L1Loss | 0     
-------------------------------------
11.7 M    Trainable params
0         Non-trainable params
11.7 M    Total params
46.969    Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

📈  epoch 000  mean cPSNR: 11.17 dB
✅ Saved visualization to /content/checkpoints/visual_2025-04-22_14-04-18/epoch_000.png
💾 Saved checkpoint – cPSNR 11.17 dB
📈  epoch 001  mean cPSNR: 10.50 dB
✅ Saved visualization to /content/checkpoints/visual_2025-04-22_14-04-18/epoch_001.png
💾 Saved checkpoint – cPSNR 10.50 dB
📈  epoch 002  mean cPSNR: 8.48 dB
✅ Saved visualization to /content/checkpoints/visual_2025-04-22_14-04-18/epoch_002.png
📈  epoch 003  mean cPSNR: 10.77 dB
✅ Saved visualization to /content/checkpoints/visual_2025-04-22_14-04-18/epoch_003.png
💾 Saved checkpoint – cPSNR 10.77 dB
📈  epoch 004  mean cPSNR: 9.27 dB
✅ Saved visualization to /content/checkpoints/visual_2025-04-22_14-04-18/epoch_004.png
📈  epoch 005  mean cPSNR: 9.40 dB
✅ Saved visualization to /content/checkpoints/visual_2025-04-22_14-04-18/epoch_005.png
📈  epoch 006  mean cPSNR: 10.97 dB
✅ Saved visualization to /content/checkpoints/visual_2025-04-22_14-04-18/epoch_006.png
💾 Saved checkpoint – cPSNR 10.97 dB
📈  ep

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=15` reached.


✅ Metrics written → /content/checkpoints/visual_2025-04-22_14-04-18/metrics.csv


In [21]:
from google.colab import files
import shutil

# Zip the folder
shutil.make_archive('/content/checkpoints', 'zip', '/content/checkpoints')

# Download the zip file
files.download('/content/checkpoints.zip')


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>