# init

In [1]:
!pip install pytorch-lightning
!pip install rasterio



In [2]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import numpy as np
import torch
import torch.nn.functional as F
import rasterio
import matplotlib.pyplot as plt
from pytorch_lightning.callbacks import Callback
from google.colab import drive
import pandas as pd
from torch.utils.data import Dataset
from skimage.metrics import peak_signal_noise_ratio as psnr
import json
from datetime import datetime
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import CSVLogger
from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl
import sys
import torch.nn as nn
from torch.optim import Adam
import torchvision.transforms.functional as TF

In [3]:
drive.mount('/content/drive', force_remount=True)
base_folder = '/content/drive/My Drive/ONLY_TIF'
csv_path = '/content/drive/My Drive/superres_triplets.csv'

Mounted at /content/drive


In [4]:
df = pd.read_csv(csv_path)
for col in ['goes1_path', 'goes2_path', 'viirs_path']:
    df[col] = df[col].apply(lambda x: os.path.join(base_folder, x))

df.to_csv(csv_path, index=False)

# dataset+loader

In [5]:
class SatelliteTripletDataset(Dataset):
    def __init__(self, csv_file, json_range_path="radiance_visualization_ranges.json"):
        self.data_info = pd.read_csv(csv_file)

        with open(json_range_path, "r") as f:
            self.global_ranges = json.load(f)

    def __len__(self):
        return len(self.data_info)

    def __getitem__(self, idx):
        record = self.data_info.iloc[idx]
        goes1 = self.load_raw_radiance(record['goes1_path'])
        goes2 = self.load_raw_radiance(record['goes2_path'])
        viirs = self.load_raw_radiance(record['viirs_path'])
        return (goes1, goes2), viirs

    def load_raw_radiance(self, path):
        filename = os.path.basename(path).lower()
        is_viirs = "viirs" in filename or "combined_clip" in filename
        sensor_type = "VIIRS" if is_viirs else "GOES"
        band_index = 1 if is_viirs else 7

        with rasterio.open(path) as src:
            if src.count < band_index:
                raise ValueError(f"{path} does not contain band {band_index}")
            image = src.read(band_index).astype(np.float32)

        # Handle NaN/Inf values
        mask = ~(np.isnan(image) | np.isinf(image))
        if not np.any(mask):
            image = np.zeros_like(image)
        else:
            mean_val = image[mask].mean()
            image = np.where(mask, image, mean_val)

        # 🔁 Use global clipping range from JSON
        p_low = self.global_ranges[sensor_type]["p2"]
        p_high = self.global_ranges[sensor_type]["p98"]
        image = np.clip(image, p_low, p_high)
        image = (image - p_low) / (p_high - p_low)

        return torch.from_numpy(image).unsqueeze(0)

In [6]:
class SatelliteDataModule(pl.LightningDataModule):
    def __init__(self, csv_file, batch_size=4, num_workers=0, percentile_range=(0.5, 99.5)):
        super().__init__()
        self.csv_file = csv_file
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.percentile_range = percentile_range

    def setup(self, stage=None):
        self.train_dataset = SatelliteTripletDataset(
            csv_file=self.csv_file,
            json_range_path="/content/radiance_visualization_ranges.json"
        )


    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True
        )

In [7]:
import torch
torch.cuda.empty_cache()

# external liberery

In [8]:
# sys.path.append("C:/Users/97254/OneDrive - post.bgu.ac.il/Desktop/code4finalproj/SwinIR-main")
# Clone the repo
!git clone https://github.com/JingyunLiang/SwinIR.git
sys.path.append('/content/SwinIR')

# Manually install necessary dependencies
!pip install timm einops

fatal: destination path 'SwinIR' already exists and is not an empty directory.


In [9]:
# Now you're safe to import
from models.network_swinir import SwinIR

model = SwinIR(
    upscale=4,            # 100x4 = 400 → desired upscale
    in_chans=1,           # GOES1 + GOES2 = 2 input channels
    img_size=100,         # patch input size
    window_size=25,        # usually 8 or 7 works fine
    img_range=1.,
    depths=[3,3,3,3,3,3],
    # depths=[2,2,2,2,2,2],
    embed_dim=180,
    # embed_dim=64,
    # num_heads=[8,8,8,8,8,8],
    num_heads=[6,6,6,6,6,6],
    mlp_ratio=2,
    # upsampler='nearest+conv',
    upsampler='convtranspose',
    resi_connection='3conv'
    # out_chans=1
    # <<< set to 1 for single-channel output
)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


# model

In [10]:
class SuperResolutionModule(pl.LightningModule):
    def __init__(self, model, lr=1e-4):
        super().__init__()
        self.model = model
        self.criterion = nn.L1Loss()
        self.lr = lr

        with open("/content/radiance_visualization_ranges.json", "r") as f:
            self.ranges = json.load(f)

        self.goes1_path = "/content/drive/My Drive/ONLY_TIF/ONLY_TIF/2020-11/2020-11-01_20-12/clipped_geo16.tif"
        self.goes2_path = "/content/drive/My Drive/ONLY_TIF/ONLY_TIF/2020-11/2020-11-01_20-12/clipped_geo17.tif"
        self.viirs_path = "/content/drive/My Drive/ONLY_TIF/ONLY_TIF/2020-11/2020-11-01_20-12/combined_clip.tif"

    def forward(self, goes1, goes2):
        out1 = self.model(goes1)  # Each has shape [B, 1, H, W]
        out2 = self.model(goes2)
        return out1, out2

    def training_step(self, batch, batch_idx):
        (goes1, goes2), viirs = batch
        # print("in", goes1.min(), goes1.max())
        viirs = viirs.float()
        # print("VIIRS min/max:", viirs.min().item(), viirs.max().item())

        out1, out2 = self(goes1, goes2)
        # print("out" , out1.min(), out1.max())

        out1 = out1.clamp(0., 1.)
        out2 = out2.clamp(0., 1.)

        # print("out after norm" , out1.min(), out1.max())


        loss1 = self.criterion(out1, viirs)
        loss2 = self.criterion(out2, viirs)
        loss = (loss1 + loss2) / 2

        psnr1 = self.compute_psnr(out1, viirs)
        psnr2 = self.compute_psnr(out2, viirs)
        avg_psnr = (psnr1 + psnr2) / 2

        self.log('train_loss', loss, prog_bar=True)
        self.log('train_psnr', avg_psnr, prog_bar=True)
        return loss



    def validation_step(self, batch, batch_idx):
        (goes1, goes2), viirs = batch
        viirs = viirs.float()

        out1, out2 = self(goes1, goes2)
        loss1 = self.criterion(out1, viirs)
        loss2 = self.criterion(out2, viirs)
        val_loss = (loss1 + loss2) / 2

        psnr1 = self.compute_psnr(out1, viirs)
        psnr2 = self.compute_psnr(out2, viirs)
        avg_psnr = (psnr1 + psnr2) / 2

        self.log('val_loss', val_loss, prog_bar=True)
        self.log('val_psnr', avg_psnr, prog_bar=True)
        return val_loss


    def configure_optimizers(self):
        return Adam(self.parameters(), lr=self.lr)


    # def on_train_epoch_end(self):
    #   (goes1, goes2), viirs = self.example_batch

    #   self.eval()
    #   with torch.no_grad():
    #       # inputs = torch.cat((goes1, goes2), dim=1).float().to(self.device)
    #       # output = self(inputs)
    #       output = self(goes1.to(self.device), goes2.to(self.device))
    #       output = F.interpolate(output, size=viirs.shape[2:], mode='bilinear', align_corners=False)

    #   inp1 = goes1[0][0].cpu().numpy()
    #   inp2 = goes2[0][0].cpu().numpy()
    #   tgt = viirs[0][0].cpu().numpy()
    #   out = output[0][0].cpu().numpy()

    #   fig, axs = plt.subplots(1, 4, figsize=(14, 4))
    #   axs[0].imshow(inp1, cmap='gray'); axs[0].set_title('GOES 1')
    #   axs[1].imshow(inp2, cmap='gray'); axs[1].set_title('GOES 2')
    #   axs[2].imshow(tgt, cmap='gray'); axs[2].set_title('VIIRS Target')
    #   axs[3].imshow(out, cmap='gray'); axs[3].set_title('Model Output')
    #   fig.suptitle(f'Epoch {self.current_epoch+1} - Triplet Visualization', fontsize=14)

    #   for ax in axs: ax.axis('off')
    #   plt.tight_layout()
    #   plt.show()


    def on_train_start(self):
      goes1 = self.load_radiance(self.goes1_path)
      goes2 = self.load_radiance(self.goes2_path)
      viirs  = self.load_radiance(self.viirs_path)
      # Add batch dimension [1, 1, H, W]
      self.example_batch = ((goes1.unsqueeze(0), goes2.unsqueeze(0)), viirs.unsqueeze(0))


    def load_radiance(self, path, band=1):
        with rasterio.open(path) as src:
            img = src.read(band).astype(np.float32)
            mask = ~(np.isnan(img) | np.isinf(img))
            if mask.any():
                img = np.where(mask, img, img[mask].mean())
            else:
                img = np.zeros_like(img)
            if "viirs" in path.lower():
                p2, p98 = self.ranges["VIIRS"]["p2"], self.ranges["VIIRS"]["p98"]
            else:
                p2, p98 = self.ranges["GOES"]["p2"], self.ranges["GOES"]["p98"]

            img = np.clip(img, p2, p98)
            img = (img - p2) / (p98 - p2)
            return torch.from_numpy(img).unsqueeze(0)



    def compute_psnr(self, output, target, max_val=1.0):
      mse = F.mse_loss(output, target)
      if mse == 0:
          return torch.tensor(100.0)  # Perfect match
      psnr = 20 * torch.log10(max_val / torch.sqrt(mse))
      return psnr


#callbacks

In [11]:
class PSNRValidationCallback(Callback):
    """
    Compute corrected PSNR (Kelvin metric) on a fixed month of triplets.
    * uses the same helpers as the baseline script              (already defined)
    * model gets raw GOES‑16 / GOES‑17(18) images (no bicubic up‑scale)
    * VIIRS GT and model prediction are both high‑res → same shape
    """

    # ------------------------------------------------------------------ #
    def __init__(self, vis_callback, val_month_dir, every_n_epochs=1):
        super().__init__()
        self.vis_callback   = vis_callback
        self.val_month_dir  = val_month_dir
        self.every_n_epochs = every_n_epochs

        # read radiometric ranges once
        with open("/content/radiance_visualization_ranges.json") as f:
            rng = json.load(f)
        self.vi_min = rng["VIIRS"]["p2"]
        self.vi_rng = rng["VIIRS"]["p98"] - self.vi_min

        # collect (goes1, goes2, viirs) triplets once
        self.triplets = self._collect_triplets(val_month_dir)

    # ------------------------------------------------------------------ #
    @staticmethod
    def _collect_triplets(root):
        out = []
        for cur, _, files in os.walk(root):
            files = [f for f in files if f.lower().endswith(".tif")]
            if not files:
                continue
            g1 = g2 = v = None
            for f in files:
                p  = os.path.join(cur, f)
                lf = f.lower()
                if "geo16" in lf:
                    g1 = p
                elif "geo17" in lf or ("geo18" in lf and g2 is None):
                    g2 = p
                elif "combined" in lf or "viirs" in lf:
                    v  = p
            if g1 and g2 and v:
                out.append((g1, g2, v))
        return out

    # ------------------------------------------------------------------ #
    @staticmethod
    def _load_band(path):
        with rasterio.open(path) as src:
            img = src.read(1).astype(np.float32)
        m = ~(np.isnan(img) | np.isinf(img))
        img = np.where(m, img, img[m].mean() if m.any() else 0.0)

        # נירמול ל־[0,1]
        is_viirs = "viirs" in path.lower() or "combined" in path.lower()
        with open("/content/radiance_visualization_ranges.json","r") as f:
            ranges = json.load(f)
        rng = ranges["VIIRS"] if is_viirs else ranges["GOES"]

        img = np.clip(img, rng["p2"], rng["p98"])
        img = (img - rng["p2"]) / (rng["p98"] - rng["p2"])
        return img


    @staticmethod
    def cpsnr(gt: np.ndarray, pred: np.ndarray, mask: np.ndarray) -> float:
        """Corrected PSNR (Kelvin): brightness‑bias + clear‑pixel mask."""
        diff = (gt - pred) * mask
        b    = diff.sum() / (mask.sum() + 1e-8)                 # brightness bias
        cmse = ((gt - pred + b) ** 2 * mask).sum() / (mask.sum() + 1e-8)
        return -10.0 * np.log10(cmse + 1e-8)


    # ------------------------------------------------------------------ #
    def on_train_epoch_end(self, trainer, pl_module):
        epoch = trainer.current_epoch
        if (epoch + 1) % self.every_n_epochs:
            return
        psnrs1 = []
        psnrs2 = []

        for i, (g1_path, g2_path, v_path) in enumerate(self.triplets):
            print(f"→ [{i}] {os.path.basename(g1_path)}, {os.path.basename(g2_path)}, {os.path.basename(v_path)}")

            g1 = self._load_band(g1_path)
            g2 = self._load_band(g2_path)
            vi = self._load_band(v_path)


            # forward pass – model expects raw GOES (low-res) inputs
            with torch.no_grad():
                pred1, pred2 = pl_module(
                    torch.from_numpy(g1)[None, None].to(pl_module.device),
                    torch.from_numpy(g2)[None, None].to(pl_module.device)
                )
            pred1 = pred1.squeeze().detach().cpu().numpy()
            pred2 = pred2.squeeze().detach().cpu().numpy()

            # Normalize ground truth and predictions to [0,1]
            vi_n     = np.clip((vi    - self.vi_min) / self.vi_rng, 0, 1)
            pred1_n  = np.clip((pred1 - self.vi_min) / self.vi_rng, 0, 1)
            pred2_n  = np.clip((pred2 - self.vi_min) / self.vi_rng, 0, 1)

            # Compute PSNR for each prediction
            psnr1 = self.cpsnr(vi_n, pred1_n, np.ones_like(vi_n))
            psnr2 = self.cpsnr(vi_n, pred2_n, np.ones_like(vi_n))

            # Store results
            psnrs1.append(psnr1)
            psnrs2.append(psnr2)


        mean_psnr1 = float(np.mean(psnrs1))
        mean_psnr2 = float(np.mean(psnrs2))
        mean_psnr_avg = (mean_psnr1 + mean_psnr2) / 2

        if not hasattr(pl_module, "psnr_scores_goes1"):
            pl_module.psnr_scores_goes1 = []
            pl_module.psnr_scores_goes2 = []
            pl_module.psnr_scores_avg   = []

        pl_module.psnr_scores_goes1.append(mean_psnr1)
        pl_module.psnr_scores_goes2.append(mean_psnr2)
        pl_module.psnr_scores_avg.append(mean_psnr_avg)

        print(f"\n📈 Epoch {epoch:03d} — GOES-1 PSNR: {mean_psnr1:.2f} | GOES-2 PSNR: {mean_psnr2:.2f} | Avg: {mean_psnr_avg:.2f} dB\n")

        # save curve
        plot_path = os.path.join(self.vis_callback.output_dir,
                                f"psnr_curve_epoch_{epoch:03d}.png")
        plt.figure(figsize=(10, 5))
        plt.plot(pl_module.psnr_scores_goes1, label="GOES-1", marker='o')
        plt.plot(pl_module.psnr_scores_goes2, label="GOES-2", marker='o')
        plt.plot(pl_module.psnr_scores_avg,   label="Avg", marker='o')
        plt.title("Validation cPSNR over Epochs")
        plt.xlabel("Epoch")
        plt.ylabel("cPSNR (dB)")
        plt.legend()
        plt.grid(True); plt.tight_layout(); plt.savefig(plot_path); plt.close()


In [12]:
class VisualizePredictionCallback(Callback):
    def __init__(self, goes1_path, goes2_path, viirs_path, every_n_epochs=1):
        super().__init__()
        self.goes1_path = goes1_path
        self.goes2_path = goes2_path
        self.viirs_path = viirs_path
        self.every_n_epochs = every_n_epochs

        # Load visualization scaling values from JSON
        with open("/content/radiance_visualization_ranges.json", "r") as f:
            self.ranges = json.load(f)

        # Make output directory with timestamp
        timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        self.output_dir = os.path.join("/content/checkpoints", f"visual_{timestamp}")
        os.makedirs(self.output_dir, exist_ok=True)

    def scale_for_visualization(self, image, sensor_type=None):
        p1 = np.percentile(image, 2)
        p99 = np.percentile(image, 98)
        return np.clip((image - p1) / (p99 - p1), 0, 1.0)


    def load_image(self, path, band=1):
        with rasterio.open(path) as src:
            image = src.read(band).astype(np.float32)
            mask = ~(np.isnan(image) | np.isinf(image))
            if np.any(mask):
                mean_val = image[mask].mean()
                image = np.where(mask, image, mean_val)
            else:
                image = np.zeros_like(image)
            return image

    def on_train_epoch_end(self, trainer, pl_module):
        epoch = trainer.current_epoch
        if epoch % self.every_n_epochs != 0:
            return

        # Load GOES band 7 and VIIRS band 1
        goes1_img = self.load_image(self.goes1_path, band=7)
        goes2_img = self.load_image(self.goes2_path, band=7)
        viirs_img = self.load_image(self.viirs_path, band=1)

        # Convert to tensors
        goes1_tensor = torch.from_numpy(goes1_img).unsqueeze(0).unsqueeze(0).to(pl_module.device)
        goes2_tensor = torch.from_numpy(goes2_img).unsqueeze(0).unsqueeze(0).to(pl_module.device)

        # Run prediction
        with torch.no_grad():
            pred1, pred2 = pl_module(goes1_tensor, goes2_tensor)
            pred1 = pred1.squeeze().cpu().numpy()
            pred2 = pred2.squeeze().cpu().numpy()

        # Scale all images for visualization
        g1_viz = self.scale_for_visualization(goes1_img, "GOES")
        g2_viz = self.scale_for_visualization(goes2_img, "GOES")
        viirs_viz = self.scale_for_visualization(viirs_img, "VIIRS")
        pred1_viz = self.scale_for_visualization(pred1)
        pred2_viz = self.scale_for_visualization(pred2)

        # Plot all 4 images side by side
        titles = ["GOES-1", "GOES-2", "VIIRS (GT)", "Pred from G1", "Pred from G2"]
        images = [g1_viz, g2_viz, viirs_viz, pred1_viz, pred2_viz]
        fig, axs = plt.subplots(1, 5, figsize=(24, 5))

        for ax, img, title in zip(axs, images, titles):
            im = ax.imshow(img, cmap="gray", vmin=0, vmax=1.0)  # ⬅️ Changed to grayscale
            ax.set_title(title)
            ax.axis("off")
            plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

        plt.suptitle(f"Epoch {epoch}", fontsize=16)
        plt.tight_layout()
        save_path = os.path.join(self.output_dir, f"epoch_{epoch:03d}.png")
        plt.savefig(save_path)
        plt.close()
        print(f"✅ Saved visualization to {save_path}")

# training

In [13]:
import torch
import gc
gc.collect()
torch.cuda.empty_cache()

pl_model = SuperResolutionModule(model)

In [14]:
# === Create dummy inputs ===
goes1 = torch.randn(1, 1, 100, 100)  # GOES-1 image
goes2 = torch.randn(1, 1, 100, 100)  # GOES-2 image

# === Forward pass through the updated SwinIR wrapper ===
pl_model.eval()
with torch.no_grad():
    out1, out2 = pl_model(goes1, goes2)

# === Print shapes ===
print("GOES-1 shape:", goes1.shape)
print("GOES-2 shape:", goes2.shape)
print("Output from GOES-1 shape:", out1.shape)
print("Output from GOES-2 shape:", out2.shape)


GOES-1 shape: torch.Size([1, 1, 100, 100])
GOES-2 shape: torch.Size([1, 1, 100, 100])
Output from GOES-1 shape: torch.Size([1, 1, 400, 400])
Output from GOES-2 shape: torch.Size([1, 1, 400, 400])


In [15]:
# === DataModule ===
datamodule = SatelliteDataModule(
    csv_file="/content/drive/My Drive/superres_triplets.csv",
    batch_size=1,
    num_workers=2
)

# === Visualization Callback ===
vis_callback = VisualizePredictionCallback(
    goes1_path="/content/drive/My Drive/ONLY_TIF/ONLY_TIF/2020-11/2020-11-01_20-12/clipped_geo16.tif",
    goes2_path="/content/drive/My Drive/ONLY_TIF/ONLY_TIF/2020-11/2020-11-01_20-12/clipped_geo17.tif",
    viirs_path="/content/drive/My Drive/ONLY_TIF/ONLY_TIF/2020-11/2020-11-01_20-12/combined_clip.tif",
    every_n_epochs=1
)


# === PSNR Callback ===
psnr_callback = PSNRValidationCallback(
    vis_callback=vis_callback,
    val_month_dir="/content/drive/My Drive/ONLY_TIF/ONLY_TIF/2023-02",  # Folder with validation triplets
    every_n_epochs=1
)

# === Logger ===
logger = CSVLogger("logs", name="swinir_superres")

# === Trainer ===
trainer = Trainer(
    max_epochs=10,
    accelerator="auto",
    devices=1,
    precision=16,                # mixed precision
    logger=logger,
    gradient_clip_val=1.0,
    callbacks=[psnr_callback, vis_callback],
    log_every_n_steps=10
)

# Wrap your SwinIR model
pl_model = SuperResolutionModule(model)

import gc
gc.collect()
torch.cuda.empty_cache()

# === Fit the model ===
trainer.fit(pl_model, datamodule=datamodule)

/usr/local/lib/python3.11/dist-packages/lightning_fabric/connector.py:571: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
INFO:pytorch_lightning.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INF

Training: |          | 0/? [00:00<?, ?it/s]

→ [0] clipped_geo16.tif, clipped_geo18.tif, combined_clip.tif
→ [1] clipped_geo16.tif, clipped_geo18.tif, combined_clip.tif
→ [2] clipped_geo16.tif, clipped_geo18.tif, combined_clip.tif
→ [3] clipped_geo16.tif, clipped_geo18.tif, combined_clip.tif
→ [4] clipped_geo16.tif, clipped_geo18.tif, combined_clip.tif
→ [5] clipped_geo16.tif, clipped_geo18.tif, combined_clip.tif
→ [6] clipped_geo16.tif, clipped_geo18.tif, combined_clip.tif
→ [7] clipped_geo16.tif, clipped_geo18.tif, combined_clip.tif
→ [8] clipped_geo16.tif, clipped_geo18.tif, combined_clip.tif
→ [9] clipped_geo16.tif, clipped_geo18.tif, combined_clip.tif
→ [10] clipped_geo16.tif, clipped_geo18.tif, combined_clip.tif
→ [11] clipped_geo16.tif, clipped_geo18.tif, combined_clip.tif
→ [12] clipped_geo16.tif, clipped_geo18.tif, combined_clip.tif
→ [13] clipped_geo16.tif, clipped_geo18.tif, combined_clip.tif
→ [14] clipped_geo16.tif, clipped_geo18.tif, combined_clip.tif
→ [15] clipped_geo16.tif, clipped_geo18.tif, combined_clip.tif
→ 

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=10` reached.
