In [None]:
# Mount Drive and install dependencies
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
!pip install -q "pytorch_lightning>=2.0.0" segmentation-models-pytorch albumentations torchmetrics tifffile

import numpy as np, tifffile, torch
from pathlib import Path
from tqdm.auto import tqdm
from sklearn.model_selection import GroupShuffleSplit
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader

# Restore deprecated alias for timm < 0.6
np.float = np.float64

# Paths
S2_DIR = Path('/content/drive/MyDrive/STURM/data/raw/Dataset/Sentinel2/S2')
FLOOD_DIR = Path('/content/drive/MyDrive/STURM/data/raw/Dataset/Sentinel2/Floodmaps')
all_image_paths = sorted(S2_DIR.rglob('*.tif'))

# Compute mean and std over all nine bands (NaN-safe)
n_bands = 9
sum_pixels    = np.zeros(n_bands)
sum_sq_pixels = np.zeros(n_bands)
count_pixels  = np.zeros(n_bands)
for p in tqdm(all_image_paths):
    img = tifffile.imread(p).astype(np.float64)
    if img.ndim == 3 and img.shape[-1] < img.shape[0]:
        img = img.transpose(2,0,1)
    sum_pixels    += np.nansum(img, axis=(1,2))
    sum_sq_pixels += np.nansum(img**2, axis=(1,2))
    count_pixels  += np.count_nonzero(~np.isnan(img), axis=(1,2))
dataset_mean = sum_pixels / count_pixels
dataset_std  = np.sqrt(sum_sq_pixels / count_pixels - dataset_mean**2)

# Helper: extract event ID (prefix before first underscore)
def get_event_id(p: Path) -> str:
    return p.name.split('_')[0]

# Group paths by event and select smallest events until ~10% of tiles
event_to_paths = {}
for p in all_image_paths:
    ev = get_event_id(p)
    event_to_paths.setdefault(ev, []).append(p)
total_images = len(all_image_paths)
target_count = int(0.10 * total_images)
selected_events = []
count = 0
for ev in sorted(event_to_paths, key=lambda e: len(event_to_paths[e])):
    if count + len(event_to_paths[ev]) <= target_count:
        selected_events.append(ev)
        count += len(event_to_paths[ev])
    else:
        break

train_paths = []
val_paths   = []
for ev, paths in event_to_paths.items():
    (train_paths if ev in selected_events else val_paths).extend(paths)

print(f"Training tiles: {len(train_paths)} ({100*len(train_paths)/total_images:.1f}%); Validation tiles: {len(val_paths)}")

# Dataset definition (fixed transform handling)
class FloodDataset(Dataset):
    def __init__(self, paths, mean, std, transform=None):
        self.paths = paths
        self.mean = torch.from_numpy(mean).float().view(-1,1,1)
        self.std  = torch.from_numpy(std + 1e-6).float().view(-1,1,1)
        self.transform = transform
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, i):
        img_path = self.paths[i]
        img = tifffile.imread(img_path).astype(np.float32)
        if img.ndim == 3 and img.shape[-1] < img.shape[0]:
            img = img.transpose(2,0,1)
        mask = (tifffile.imread(FLOOD_DIR / img_path.name) > 0).astype(np.uint8)

        tensor = torch.from_numpy(img)
        tensor = (tensor - self.mean) / self.std
        tensor = torch.nan_to_num(tensor, nan=0.0)

        if self.transform:
            # Albumentations expects HWC numpy; returns CHW torch tensors
            aug = self.transform(image=tensor.permute(1,2,0).numpy(), mask=mask)
            tensor = aug['image']  # already CHW torch Tensor
            mask   = aug['mask']   # torch Tensor
        else:
            mask = torch.from_numpy(mask)

        return tensor.float(), mask.long()

# Transforms (resize to 128×128 to match SatMAE)
train_tf = A.Compose([A.Resize(128,128), A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.5), ToTensorV2(transpose_mask=True)], is_check_shapes=False)
val_tf   = A.Compose([A.Resize(128,128), ToTensorV2(transpose_mask=True)], is_check_shapes=False)

train_ds = FloodDataset(train_paths, dataset_mean, dataset_std, transform=train_tf)
val_ds   = FloodDataset(val_paths,   dataset_mean, dataset_std, transform=val_tf)

train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=2)
val_loader   = DataLoader(val_ds,   batch_size=16, shuffle=False, num_workers=2)


Mounted at /content/drive
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m825.4/825.4 kB[0m [31m14.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.8/154.8 kB[0m [31m13.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m981.9/981.9 kB[0m [31m42.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m123.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m97.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m53.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90

  0%|          | 0/2675 [00:00<?, ?it/s]

Training tiles: 247 (9.2%); Validation tiles: 2428


In [None]:
import sys, os, subprocess
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchmetrics import F1Score
import segmentation_models_pytorch as smp

# Clone SatMAE once and patch qk_scale
REPO_DIR = '/content/SatMAE'
if not os.path.isdir(REPO_DIR):
    subprocess.run(['git','clone','--depth','1','https://github.com/sustainlab-group/SatMAE.git', REPO_DIR], check=True)
    subprocess.run(['sed','-i',"s/qk_scale=None, //g", f'{REPO_DIR}/models_mae_group_channels.py'], check=True)
if REPO_DIR not in sys.path:
    sys.path.insert(0, REPO_DIR)

from SatMAE.models_mae_group_channels import mae_vit_base_patch16_dec512d8b as mae_factory

# ASPP decoder module
class ASPP(nn.Module):
    def __init__(self, in_ch, out_ch, rates):
        super().__init__()
        blocks = [nn.Sequential(nn.Conv2d(in_ch,out_ch,1,bias=False), nn.BatchNorm2d(out_ch), nn.ReLU(True))]
        for r in rates:
            blocks.append(nn.Sequential(nn.Conv2d(in_ch,out_ch,3,padding=r,dilation=r,bias=False), nn.BatchNorm2d(out_ch), nn.ReLU(True)))
        blocks.append(nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_ch,out_ch,1,bias=False), nn.BatchNorm2d(out_ch), nn.ReLU(True)))
        self.convs = nn.ModuleList(blocks)
        self.project = nn.Sequential(
            nn.Conv2d(len(blocks)*out_ch, out_ch, 1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(True),
            nn.Dropout(0.5)
        )
    def forward(self,x):
        outs = [F.interpolate(conv(x), size=x.shape[2:], mode='bilinear', align_corners=False) for conv in self.convs]
        return self.project(torch.cat(outs, dim=1))

# Wrapper to load SatMAE core model
class SatMAE9_for_loading(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = mae_factory(img_size=128, in_chans=9, channel_groups=[[0,1,2],[3,4,5],[6,7,8]])

# Fine‑tuning module (encoder frozen) with weighted focal loss
class SatMAE_ASPP_Finetuner(pl.LightningModule):
    def __init__(self, mae_core, lr=3e-4, alpha=0.5, gamma=2.0):
        super().__init__()
        self.save_hyperparameters(ignore=['mae_core'])
        self.mae_model = mae_core.eval()
        for p in self.mae_model.parameters():
            p.requires_grad = False
        embed_dim = getattr(self.mae_model, "embed_dim", 768)
        self.aspp = ASPP(embed_dim, 256, [12,24,36])
        self.head = nn.Sequential(nn.Conv2d(256,1,1), nn.Upsample(scale_factor=16, mode='bilinear', align_corners=False))
        self.loss_fn = smp.losses.FocalLoss(mode='binary', alpha=alpha, gamma=gamma)
        self.val_f1 = F1Score(task='binary', threshold=0.5)

    def forward(self,x):
        with torch.no_grad():
            tokens, _, _ = self.mae_model.forward_encoder(x,0.0)
        tokens = tokens[:,1:,:]
        G = len(self.mae_model.channel_groups); B,N,C = tokens.shape; H = W = int((N//G)**0.5)
        feats = tokens.view(B,G,H,W,C).mean(dim=1).permute(0,3,1,2).contiguous()
        return self.head(self.aspp(feats))

    def training_step(self,b, _):
        x,y=b; y=y.unsqueeze(1).float()
        loss=self.loss_fn(self(x),y)
        self.log('train_loss', loss)
        return loss
    def validation_step(self,b, _):
        x,y=b; y=y.unsqueeze(1).float()
        logits=self(x)
        self.val_f1.update(logits, y.int())
        self.log('val_loss', self.loss_fn(logits,y))
    def on_validation_epoch_end(self):
        f1=self.val_f1.compute()
        self.log('val_f1_water', f1, prog_bar=True)
        self.val_f1.reset()
    def configure_optimizers(self):
        return torch.optim.Adam(list(self.aspp.parameters())+list(self.head.parameters()), lr=self.hparams.lr)


In [None]:
# Compute alpha again (already computed, but recompute for completeness)
total_water = total_nonwater = 0
for p in train_paths:
    m = tifffile.imread(FLOOD_DIR / p.name)
    total_water += (m > 0).sum()
    total_nonwater += (m == 0).sum()
alpha = total_nonwater / (total_water + total_nonwater)

# Load your pretrained encoder (big SatMAE)
PRETRAINED_CKPT = '/content/drive/MyDrive/satmae_pretrain_ckpts/mae9_epoch20.ckpt'
mae_wrapper = SatMAE9_for_loading.load_from_checkpoint(PRETRAINED_CKPT, strict=False)
mae_core = mae_wrapper.model

# Instantiate the fine-tuner with weighted focal loss
finetuner = SatMAE_ASPP_Finetuner(mae_core, lr=3e-4, alpha=alpha, gamma=2.0)

# Set up callbacks and trainer
save_dir = Path('/content/drive/MyDrive/satmae_pretrain_ckpts/finetuned_eventwise_models')
save_dir.mkdir(parents=True, exist_ok=True)
early_stop = pl.callbacks.EarlyStopping('val_f1_water', patience=50, mode='max', verbose=True)
ckpt_cb = pl.callbacks.ModelCheckpoint(
    monitor='val_f1_water', mode='max', save_top_k=1,
    dirpath=save_dir, filename='satmae-ASPP-eventwise10p-best-{epoch}-{val_f1_water:.4f}'
)
trainer = pl.Trainer(
    max_epochs=150,
    accelerator='auto',
    devices=1,
    accumulate_grad_batches=2,
    log_every_n_steps=1,
    callbacks=[early_stop, ckpt_cb]
)

# Train
print("🚀 Starting event-wise fine‑tuning on 10% labelled data...")
trainer.fit(finetuner, train_loader, val_loader)
print("🎉 Event-wise fine‑tuning complete!")


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.utilities.rank_zero:You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


🚀 Starting event-wise fine‑tuning on 10% labelled data...


/usr/local/lib/python3.11/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py:658: Checkpoint directory /content/drive/MyDrive/satmae_pretrain_ckpts/finetuned_eventwise_models exists and is not empty.
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name      | Type                             | Params | Mode 
-----------------------------------------------------------------------
0 | mae_model | MaskedAutoencoderGroupChannelViT | 113 M  | eval 
1 | aspp      | ASPP                             | 6.0 M  | train
2 | head      | Sequential                       | 257    | train
3 | loss_fn   | FocalLoss                        | 0      | train
4 | val_f1    | BinaryF1Score                    | 0      | train
-----------------------------------------------------------------------
6.0 M     Trainable params
113 M     Non-trainable params
119 M     Total params
478.868   Total estimated model param

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_f1_water improved. New best score: 0.451


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_f1_water improved by 0.132 >= min_delta = 0.0. New best score: 0.583


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_f1_water improved by 0.042 >= min_delta = 0.0. New best score: 0.625


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_f1_water improved by 0.001 >= min_delta = 0.0. New best score: 0.626


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Monitored metric val_f1_water did not improve in the last 50 records. Best score: 0.626. Signaling Trainer to stop.


🎉 Event-wise fine‑tuning complete!


In [None]:
# ==============================================================================
# Cell: Event-wise fine-tuning with U-Net (ResNet backbone) on 10% data
# ==============================================================================

import segmentation_models_pytorch as smp
import pytorch_lightning as pl
import torch
from torchmetrics import F1Score
from pathlib import Path
import tifffile
import numpy as np

# --- 1. Compute class weight (alpha) for Focal Loss from the training masks ---
total_water = total_nonwater = 0
for p in train_paths:
    mask = tifffile.imread(FLOOD_DIR / p.name)
    total_water += (mask > 0).sum()
    total_nonwater += (mask == 0).sum()
alpha = total_nonwater / (total_water + total_nonwater)
print(f"Alpha for FocalLoss (U-Net): {alpha:.3f}")

# --- 2. Define a LightningModule for U-Net with a ResNet backbone ---
class UNetResNetFineTuner(pl.LightningModule):
    def __init__(self, in_channels=9, lr=3e-4, alpha=alpha, gamma=2.0):
        super().__init__()
        self.save_hyperparameters()
        # U-Net with ResNet-34 encoder; no pretrained weights due to 9 channels
        self.model = smp.Unet(
            encoder_name='resnet34',
            encoder_weights=None,
            in_channels=in_channels,
            classes=1
        )
        # Use weighted Focal Loss to handle class imbalance
        self.loss_fn = smp.losses.FocalLoss(mode='binary', alpha=alpha, gamma=gamma)
        self.val_f1 = F1Score(task='binary', threshold=0.5)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y = y.unsqueeze(1).float()
        logits = self(x)
        loss = self.loss_fn(logits, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y = y.unsqueeze(1).float()
        logits = self(x)
        self.val_f1.update(logits, y.int())
        self.log('val_loss', self.loss_fn(logits, y))

    def on_validation_epoch_end(self):
        f1 = self.val_f1.compute()
        self.log('val_f1_water', f1, prog_bar=True)
        self.val_f1.reset()

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

# --- 3. Instantiate the model and set up trainer ---
unet_finetuner = UNetResNetFineTuner(in_channels=9, lr=3e-4, alpha=alpha, gamma=2.0)

# Callbacks: early stopping and checkpointing
save_dir = Path('/content/drive/MyDrive/satmae_pretrain_ckpts/finetuned_eventwise_unet')
save_dir.mkdir(parents=True, exist_ok=True)
early_stop = pl.callbacks.EarlyStopping('val_f1_water', patience=30, mode='max', verbose=True)
ckpt_cb = pl.callbacks.ModelCheckpoint(
    monitor='val_f1_water', mode='max', save_top_k=1,
    dirpath=save_dir,
    filename='unet-resnet-eventwise10p-best-{epoch}-{val_f1_water:.4f}'
)

# Trainer
trainer = pl.Trainer(
    max_epochs=30,              # U-Net often converges within ~100 epochs on small datasets
    accelerator='auto',
    devices=1,
    accumulate_grad_batches=2,   # effective batch size = 32
    log_every_n_steps=1,
    callbacks=[early_stop, ckpt_cb],
)

print("\n🚀 Starting U-Net (ResNet) fine‑tuning on 10% event‑wise data...")
trainer.fit(unet_finetuner, train_loader, val_loader)
print("\n🎉 U-Net fine‑tuning complete!")


Alpha for FocalLoss (U-Net): 0.640


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.11/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py:658: Checkpoint directory /content/drive/MyDrive/satmae_pretrain_ckpts/finetuned_eventwise_unet exists and is not empty.
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type          | Params | Mode 
--------------------------------------------------
0 | model   | Unet          | 24.5 M | train
1 | loss_fn | FocalLoss     | 0      | train
2 | val_f1  | BinaryF1Score | 0      | train
--------------------------------------------------
24.5 M    Trainable params
0         Non-trainable params
24.5 M    Total params
97.821    Total estimated model params size (MB)
190


🚀 Starting U-Net (ResNet) fine‑tuning on 10% event‑wise data...


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_f1_water improved. New best score: 0.652


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_f1_water improved by 0.050 >= min_delta = 0.0. New best score: 0.702


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_f1_water improved by 0.044 >= min_delta = 0.0. New best score: 0.746


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_f1_water improved by 0.022 >= min_delta = 0.0. New best score: 0.767


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_f1_water improved by 0.006 >= min_delta = 0.0. New best score: 0.774


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_f1_water improved by 0.001 >= min_delta = 0.0. New best score: 0.774


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=30` reached.



🎉 U-Net fine‑tuning complete!
