In [None]:

!nvidia-smi

In [None]:
import os
import glob
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from segment_anything import sam_model_registry
from segment_anything.utils.transforms import ResizeLongestSide
import matplotlib.pyplot as plt
import random
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Subset

In [None]:
def seed_everything(seed: int = 42) -> None:
    os.environ["PYTHONHASHSEED"] = str(seed)           
    random.seed(seed)                                 
    np.random.seed(seed)                              
    torch.manual_seed(seed)                           
    torch.cuda.manual_seed(seed)                       
    torch.cuda.manual_seed_all(seed)                   
    torch.use_deterministic_algorithms(True, warn_only=True)
    torch.backends.cudnn.deterministic = True         
    torch.backends.cudnn.benchmark = False 

In [None]:

# Configuration


dataset_path = "./final_dataset"               # root containing train/ & test/
checkpoint_path = "./sam_vit_h_4b8939.pth"     # SAM ViT‑H checkpoint
model_type = "vit_h"                           # sam encoder type
output_dir = "./predictions"                   # where to save models & figs
os.makedirs(output_dir, exist_ok=True)

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:

# Dataset


class RetinalHemorrhageDataset(Dataset):
    """Dataset loading RGB fundus images and binary masks."""

    def __init__(self, root_dir: str, split: str = "train", transform=None):
        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        self.images_dir = os.path.join(root_dir, split, "images")
        self.masks_dir = os.path.join(root_dir, split, "masks")
        self.image_files = sorted(glob.glob(os.path.join(self.images_dir, "*.*")))
        self.mask_files = sorted(glob.glob(os.path.join(self.masks_dir, "*.*")))

        assert len(self.image_files) == len(self.mask_files), (
            f"Number of images ({len(self.image_files)}) and masks "
            f"({len(self.mask_files)}) don't match!"
        )
        print(f"Found {len(self.image_files)} samples in '{split}' split")

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # Image
        img_path = self.image_files[idx]
        img = cv2.imread(img_path)
        if img is None:
            raise ValueError(f"Failed to load image: {img_path}")
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        # Mask
        msk_path = self.mask_files[idx]
        msk = cv2.imread(msk_path, cv2.IMREAD_GRAYSCALE)
        if msk is None:
            raise ValueError(f"Failed to load mask: {msk_path}")
        msk = (msk > 0).astype(np.float32)

        # To tensor
        img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0  # C,H,W
        msk = torch.from_numpy(msk).unsqueeze(0).float()              # 1,H,W
        return {"image": img, "mask": msk, "filename": os.path.basename(img_path)}


# Loss Function



class FocalTverskyLoss(nn.Module):
    def __init__(self, alpha=0.3, beta=0.7, gamma=0.75, eps=1e-6):
        super(FocalTverskyLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.eps = eps

    def forward(self, inputs, targets):
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        TP = (inputs * targets).sum()
        FP = ((1 - targets) * inputs).sum()
        FN = (targets * (1 - inputs)).sum()

        tversky = (TP + self.eps) / (TP + self.alpha * FP + self.beta * FN + self.eps)
        focal_tversky = (1 - tversky) ** self.gamma

        return focal_tversky


In [None]:

# U‑Net style decoder blocks


class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.net(x)

class UpBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)
        self.conv = DoubleConv(in_ch, out_ch)

    def forward(self, x):
        x = self.upsample(x)
        return self.conv(x)

class UNetDecoder(nn.Module):

    def __init__(self, in_ch: int = 256):
        super().__init__()
        self.block1 = DoubleConv(in_ch, 256)
        self.up1 = UpBlock(256, 128)
        self.up2 = UpBlock(128, 64)
        self.up3 = UpBlock(64, 32)
        self.up4 = UpBlock(32, 16)
        self.final = nn.Conv2d(16, 1, kernel_size=1)

    def forward(self, x):
        x = self.block1(x)
        x = self.up1(x)
        x = self.up2(x)
        x = self.up3(x)
        x = self.up4(x)
        return self.final(x)


# Fine‑tuner combining SAM encoder + U‑Net decoder

class SAMFineTuner(nn.Module):
    def __init__(self, checkpoint_path: str, model_type: str):
        super().__init__()
        print(f"Loading SAM encoder '{model_type}' …")
        self.sam = sam_model_registry[model_type](checkpoint=checkpoint_path)
        self.sam.to(device)
        print("SAM encoder loaded.")

        # Freeze encoder parameters
        for p in self.sam.image_encoder.parameters():
            p.requires_grad = False

        # Decoder
        self.decoder = UNetDecoder(in_ch=256)

        # Pre‑processing helper
        self.transform = ResizeLongestSide(self.sam.image_encoder.img_size)
        self.pixel_mean = torch.tensor([123.675, 116.28, 103.53]).view(3, 1, 1)
        self.pixel_std = torch.tensor([58.395, 57.12, 57.375]).view(3, 1, 1)

    @torch.no_grad()
    def preprocess_single(self, img_tensor: torch.Tensor) -> torch.Tensor:
       
        img_np = (img_tensor.cpu().numpy().transpose(1, 2, 0) * 255.0).astype(np.float32)
        img_np = self.transform.apply_image(img_np)
        t = torch.from_numpy(img_np).permute(2, 0, 1).float()
        t = (t - self.pixel_mean) / self.pixel_std
        return t

    def forward(self, imgs: torch.Tensor):  # imgs: B,C,H,W in [0,1]
        B, _, H, W = imgs.shape
        processed = torch.stack([self.preprocess_single(im) for im in imgs]).to(device)

        with torch.no_grad():
            emb = self.sam.image_encoder(processed)  # B,256,h,w

        dec_out = self.decoder(emb)                 # B,1,h',w'
        dec_out = F.interpolate(dec_out, size=(H, W), mode="bilinear", align_corners=False)
        return torch.sigmoid(dec_out)

In [None]:

# Metrics


def compute_metrics(y_true: np.ndarray, y_pred: np.ndarray, thr: float = 0.5):
    y_pred_bin = (y_pred > thr).astype(np.uint8)
    y_true = y_true.astype(np.uint8)
    tp = np.sum((y_pred_bin == 1) & (y_true == 1))
    fp = np.sum((y_pred_bin == 1) & (y_true == 0))
    fn = np.sum((y_pred_bin == 0) & (y_true == 1))
    tn = np.sum((y_pred_bin == 0) & (y_true == 0))

    acc = (tp + tn) / (tp + fp + fn + tn + 1e-6)
    precision = tp / (tp + fp + 1e-6)
    recall = tp / (tp + fn + 1e-6)
    f1 = 2 * precision * recall / (precision + recall + 1e-6)
    iou = tp / (tp + fp + fn + 1e-6)
    return {"accuracy": acc, "precision": precision, "recall": recall, "f1_score": f1, "jaccard": iou}


# Train / Eval loops


def train_epoch(model, loader, criterion, optimizer, epoch):
    model.train()
    total_loss = 0.0
    pbar = tqdm(loader, desc=f"Epoch {epoch}")
    for batch in pbar:
        imgs = batch["image"].to(device)
        msks = batch["mask"].to(device)
        preds = model(imgs)
        loss = criterion(preds, msks)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        pbar.set_postfix(loss=f"{loss.item():.4f}")
    return total_loss / len(loader)

def evaluate(model, loader, criterion, save_preds: bool = False):
    model.eval()
    val_loss = 0.0
    stats = {k: [] for k in ["accuracy", "precision", "recall", "f1_score", "jaccard"]}
    pbar = tqdm(loader, desc="Eval")
    with torch.no_grad():
        for batch in pbar:
            imgs = batch["image"].to(device)
            msks = batch["mask"].to(device)
            fnames = batch["filename"]
            preds = model(imgs)
            loss = criterion(preds, msks)
            val_loss += loss.item()
            preds_np = preds.cpu().numpy()
            msks_np = msks.cpu().numpy()
            for i in range(len(imgs)):
                metrics = compute_metrics(msks_np[i, 0], preds_np[i, 0])
                for k in stats:
                    stats[k].append(metrics[k])
                if save_preds:
                    save_visualization(imgs[i], msks_np[i, 0], preds_np[i, 0], metrics, fnames[i])
    avg_stats = {k: float(np.mean(v)) for k, v in stats.items()}
    return val_loss / len(loader), avg_stats


# Helper to save side‑by‑side visualisations


def save_visualization(img_t, true_msk, pred_msk, metrics, fname):
    img = (img_t.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
    pred = (pred_msk > 0.5).astype(np.uint8) * 255
    true = (true_msk > 0.5).astype(np.uint8) * 255

    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    axes[0].imshow(img)
    axes[0].set_title("Image")
    axes[0].axis("off")
    axes[1].imshow(true, cmap="gray")
    axes[1].set_title("Ground Truth")
    axes[1].axis("off")
    axes[2].imshow(pred, cmap="gray")
    axes[2].set_title("Prediction")
    axes[2].axis("off")

    txt = (
        f"IoU={metrics['jaccard']:.3f} | F1={metrics['f1_score']:.3f} | "
        f"Prec={metrics['precision']:.3f} | Rec={metrics['recall']:.3f}"
    )
    
    # Add more space at the bottom for the text
    plt.subplots_adjust(bottom=0.2)
    fig.text(0.5, 0.05, txt, ha="center", fontsize=10)

    out_path = os.path.join(output_dir, f"{os.path.splitext(fname)[0]}_viz.png")
    plt.savefig(out_path, dpi=150, bbox_inches='tight')
    plt.close()


In [None]:
# Seeding
seed_everything(42)

# Data
train_ds = RetinalHemorrhageDataset(dataset_path, split="train")
full_test_ds = RetinalHemorrhageDataset(dataset_path, split="test")

# Split test dataset into fixed 50% validation and 50% test
val_size = len(full_test_ds) // 2
test_size = len(full_test_ds) - val_size
val_ds, test_ds = torch.utils.data.random_split(full_test_ds, [val_size, test_size], generator=torch.Generator().manual_seed(42))

# Dataloaders
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=0, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=0, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=0, pin_memory=True)

# Model / loss / optim
model = SAMFineTuner(checkpoint_path, model_type).to(device)
criterion = FocalTverskyLoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
scheduler = ReduceLROnPlateau(
    optimizer,
    mode="min",
    factor=0.1,
    patience=4,
    threshold=1e-4,
    verbose=True,
)


In [9]:
# Training loop with early stopping
num_epochs = 50
best_f1 = -1.0
patience = 10
no_improve_epochs = 0

for epoch in range(1, num_epochs + 1):
    tr_loss = train_epoch(model, train_loader, criterion, optimizer, epoch)
    val_loss, metrics = evaluate(model, val_loader, criterion)
    scheduler.step(val_loss)
    lr_now = optimizer.param_groups[0]["lr"]

    print(
        f"Epoch {epoch}/{num_epochs} | "
        f"Train Loss: {tr_loss:.4f} | Val Loss: {val_loss:.4f} | "
        f"IoU: {metrics['jaccard']:.4f} | "
        f"F1: {metrics['f1_score']:.4f} | "
        f"Precision: {metrics['precision']:.4f} | "
        f"Recall: {metrics['recall']:.4f} | "
        f"Accuracy: {metrics['accuracy']:.4f}"
    )

    if metrics["f1_score"] > best_f1:
        best_f1 = metrics["f1_score"]
        no_improve_epochs = 0  # reset counter
        torch.save(model.state_dict(), os.path.join(output_dir, "best_model.pth"))
        print(f"New best model saved (F1: {best_f1:.4f})")
    else:
        no_improve_epochs += 1
        print(f"No improvement in F1 for {no_improve_epochs} epoch(s)")

    print(f"Current lr: {lr_now}")

    if no_improve_epochs >= patience:
        print("Early stopping triggered.")
        break


Epoch 1: 100%|██████████| 477/477 [10:24<00:00,  1.31s/it, loss=0.8608]
Eval: 100%|██████████| 13/13 [00:12<00:00,  1.02it/s]


Epoch 1/50 | Train Loss: 0.9607 | Val Loss: 0.9527 | IoU: 0.0235 | F1: 0.0426 | Precision: 0.0237 | Recall: 0.8766 | Accuracy: 0.4044
New best model saved (F1: 0.0426)
Current lr: 0.0001


Epoch 2: 100%|██████████| 477/477 [10:00<00:00,  1.26s/it, loss=0.8793]
Eval: 100%|██████████| 13/13 [00:12<00:00,  1.02it/s]


Epoch 2/50 | Train Loss: 0.9511 | Val Loss: 0.9456 | IoU: 0.0317 | F1: 0.0565 | Precision: 0.0337 | Recall: 0.8127 | Accuracy: 0.6308
New best model saved (F1: 0.0565)
Current lr: 0.0001


Epoch 3: 100%|██████████| 477/477 [10:18<00:00,  1.30s/it, loss=0.9661]
Eval: 100%|██████████| 13/13 [00:12<00:00,  1.02it/s]


Epoch 3/50 | Train Loss: 0.9426 | Val Loss: 0.9436 | IoU: 0.0451 | F1: 0.0815 | Precision: 0.0568 | Recall: 0.6371 | Accuracy: 0.8542
New best model saved (F1: 0.0815)
Current lr: 0.0001


Epoch 4: 100%|██████████| 477/477 [10:07<00:00,  1.27s/it, loss=0.7911]
Eval: 100%|██████████| 13/13 [00:12<00:00,  1.01it/s]


Epoch 4/50 | Train Loss: 0.9299 | Val Loss: 0.9295 | IoU: 0.0807 | F1: 0.1406 | Precision: 0.1101 | Recall: 0.4863 | Accuracy: 0.9492
New best model saved (F1: 0.1406)
Current lr: 0.0001


Epoch 5: 100%|██████████| 477/477 [10:11<00:00,  1.28s/it, loss=0.9121]
Eval: 100%|██████████| 13/13 [00:12<00:00,  1.02it/s]


Epoch 5/50 | Train Loss: 0.9138 | Val Loss: 0.9192 | IoU: 0.0861 | F1: 0.1469 | Precision: 0.1075 | Recall: 0.5549 | Accuracy: 0.9422
New best model saved (F1: 0.1469)
Current lr: 0.0001


Epoch 6: 100%|██████████| 477/477 [10:14<00:00,  1.29s/it, loss=0.9506]
Eval: 100%|██████████| 13/13 [00:13<00:00,  1.01s/it]


Epoch 6/50 | Train Loss: 0.8840 | Val Loss: 0.9100 | IoU: 0.0801 | F1: 0.1397 | Precision: 0.1261 | Recall: 0.4396 | Accuracy: 0.9485
No improvement in F1 for 1 epoch(s)
Current lr: 0.0001


Epoch 7: 100%|██████████| 477/477 [10:34<00:00,  1.33s/it, loss=0.9868]
Eval: 100%|██████████| 13/13 [00:14<00:00,  1.08s/it]


Epoch 7/50 | Train Loss: 0.8421 | Val Loss: 0.8679 | IoU: 0.1178 | F1: 0.1975 | Precision: 0.1657 | Recall: 0.4191 | Accuracy: 0.9683
New best model saved (F1: 0.1975)
Current lr: 0.0001


Epoch 8: 100%|██████████| 477/477 [10:43<00:00,  1.35s/it, loss=0.5812]
Eval: 100%|██████████| 13/13 [00:12<00:00,  1.04it/s]


Epoch 8/50 | Train Loss: 0.7893 | Val Loss: 0.8607 | IoU: 0.1219 | F1: 0.2011 | Precision: 0.1999 | Recall: 0.3077 | Accuracy: 0.9765
New best model saved (F1: 0.2011)
Current lr: 0.0001


Epoch 9: 100%|██████████| 477/477 [09:56<00:00,  1.25s/it, loss=0.6088]
Eval: 100%|██████████| 13/13 [00:12<00:00,  1.04it/s]


Epoch 9/50 | Train Loss: 0.7370 | Val Loss: 0.8259 | IoU: 0.1676 | F1: 0.2701 | Precision: 0.3174 | Recall: 0.2923 | Accuracy: 0.9827
New best model saved (F1: 0.2701)
Current lr: 0.0001


Epoch 10: 100%|██████████| 477/477 [10:41<00:00,  1.35s/it, loss=0.6479]
Eval: 100%|██████████| 13/13 [00:12<00:00,  1.03it/s]


Epoch 10/50 | Train Loss: 0.6916 | Val Loss: 0.8257 | IoU: 0.1535 | F1: 0.2523 | Precision: 0.3155 | Recall: 0.2704 | Accuracy: 0.9828
No improvement in F1 for 1 epoch(s)
Current lr: 0.0001


Epoch 11: 100%|██████████| 477/477 [10:44<00:00,  1.35s/it, loss=0.8297]
Eval: 100%|██████████| 13/13 [00:13<00:00,  1.01s/it]


Epoch 11/50 | Train Loss: 0.6547 | Val Loss: 0.8188 | IoU: 0.1460 | F1: 0.2392 | Precision: 0.2647 | Recall: 0.2839 | Accuracy: 0.9806
No improvement in F1 for 2 epoch(s)
Current lr: 0.0001


Epoch 12: 100%|██████████| 477/477 [10:26<00:00,  1.31s/it, loss=0.7531]
Eval: 100%|██████████| 13/13 [00:12<00:00,  1.03it/s]


Epoch 12/50 | Train Loss: 0.6315 | Val Loss: 0.8248 | IoU: 0.1623 | F1: 0.2646 | Precision: 0.4222 | Recall: 0.2122 | Accuracy: 0.9846
No improvement in F1 for 3 epoch(s)
Current lr: 0.0001


Epoch 13: 100%|██████████| 477/477 [09:51<00:00,  1.24s/it, loss=0.7219]
Eval: 100%|██████████| 13/13 [00:14<00:00,  1.12s/it]


Epoch 13/50 | Train Loss: 0.6130 | Val Loss: 0.8124 | IoU: 0.1506 | F1: 0.2478 | Precision: 0.3104 | Recall: 0.2575 | Accuracy: 0.9826
No improvement in F1 for 4 epoch(s)
Current lr: 0.0001


Epoch 14: 100%|██████████| 477/477 [10:44<00:00,  1.35s/it, loss=0.5274]
Eval: 100%|██████████| 13/13 [00:13<00:00,  1.02s/it]


Epoch 14/50 | Train Loss: 0.5886 | Val Loss: 0.8178 | IoU: 0.1546 | F1: 0.2539 | Precision: 0.3761 | Recall: 0.2253 | Accuracy: 0.9838
No improvement in F1 for 5 epoch(s)
Current lr: 0.0001


Epoch 15: 100%|██████████| 477/477 [10:39<00:00,  1.34s/it, loss=0.4254]
Eval: 100%|██████████| 13/13 [00:12<00:00,  1.01it/s]


Epoch 15/50 | Train Loss: 0.5816 | Val Loss: 0.7867 | IoU: 0.1878 | F1: 0.2945 | Precision: 0.4456 | Recall: 0.2532 | Accuracy: 0.9846
New best model saved (F1: 0.2945)
Current lr: 0.0001


Epoch 16: 100%|██████████| 477/477 [10:02<00:00,  1.26s/it, loss=0.3820]
Eval: 100%|██████████| 13/13 [00:12<00:00,  1.04it/s]


Epoch 16/50 | Train Loss: 0.5659 | Val Loss: 0.7779 | IoU: 0.1701 | F1: 0.2735 | Precision: 0.2902 | Recall: 0.3264 | Accuracy: 0.9816
No improvement in F1 for 1 epoch(s)
Current lr: 0.0001


Epoch 17: 100%|██████████| 477/477 [10:35<00:00,  1.33s/it, loss=0.4221]
Eval: 100%|██████████| 13/13 [00:14<00:00,  1.09s/it]


Epoch 17/50 | Train Loss: 0.5590 | Val Loss: 0.7772 | IoU: 0.1860 | F1: 0.2922 | Precision: 0.3788 | Recall: 0.2832 | Accuracy: 0.9838
No improvement in F1 for 2 epoch(s)
Current lr: 0.0001


Epoch 18: 100%|██████████| 477/477 [10:44<00:00,  1.35s/it, loss=0.5690]
Eval: 100%|██████████| 13/13 [00:13<00:00,  1.06s/it]


Epoch 18/50 | Train Loss: 0.5478 | Val Loss: 0.7630 | IoU: 0.1997 | F1: 0.3042 | Precision: 0.3876 | Recall: 0.2927 | Accuracy: 0.9838
New best model saved (F1: 0.3042)
Current lr: 0.0001


Epoch 19: 100%|██████████| 477/477 [11:17<00:00,  1.42s/it, loss=0.7941]
Eval: 100%|██████████| 13/13 [00:14<00:00,  1.10s/it]


Epoch 19/50 | Train Loss: 0.5517 | Val Loss: 0.7898 | IoU: 0.1597 | F1: 0.2597 | Precision: 0.3026 | Recall: 0.3154 | Accuracy: 0.9820
No improvement in F1 for 1 epoch(s)
Current lr: 0.0001


Epoch 20: 100%|██████████| 477/477 [10:40<00:00,  1.34s/it, loss=0.2862]
Eval: 100%|██████████| 13/13 [00:13<00:00,  1.03s/it]


Epoch 20/50 | Train Loss: 0.5387 | Val Loss: 0.7946 | IoU: 0.1689 | F1: 0.2744 | Precision: 0.3803 | Recall: 0.2551 | Accuracy: 0.9840
No improvement in F1 for 2 epoch(s)
Current lr: 0.0001


Epoch 21: 100%|██████████| 477/477 [10:39<00:00,  1.34s/it, loss=0.4624]
Eval: 100%|██████████| 13/13 [00:13<00:00,  1.04s/it]


Epoch 21/50 | Train Loss: 0.5389 | Val Loss: 0.7810 | IoU: 0.1786 | F1: 0.2737 | Precision: 0.3339 | Recall: 0.2884 | Accuracy: 0.9832
No improvement in F1 for 3 epoch(s)
Current lr: 0.0001


Epoch 22: 100%|██████████| 477/477 [10:43<00:00,  1.35s/it, loss=0.4823]
Eval: 100%|██████████| 13/13 [00:13<00:00,  1.04s/it]


Epoch 22/50 | Train Loss: 0.5332 | Val Loss: 0.7788 | IoU: 0.1854 | F1: 0.2881 | Precision: 0.3713 | Recall: 0.2718 | Accuracy: 0.9838
No improvement in F1 for 4 epoch(s)
Current lr: 0.0001


Epoch 23: 100%|██████████| 477/477 [10:58<00:00,  1.38s/it, loss=0.6493]
Eval: 100%|██████████| 13/13 [00:13<00:00,  1.06s/it]


Epoch 23/50 | Train Loss: 0.5206 | Val Loss: 0.7677 | IoU: 0.1926 | F1: 0.3016 | Precision: 0.3920 | Recall: 0.2890 | Accuracy: 0.9849
No improvement in F1 for 5 epoch(s)
Current lr: 1e-05


Epoch 24: 100%|██████████| 477/477 [10:46<00:00,  1.35s/it, loss=0.6201]
Eval: 100%|██████████| 13/13 [00:13<00:00,  1.04s/it]


Epoch 24/50 | Train Loss: 0.4909 | Val Loss: 0.7659 | IoU: 0.1985 | F1: 0.3079 | Precision: 0.4003 | Recall: 0.2837 | Accuracy: 0.9846
New best model saved (F1: 0.3079)
Current lr: 1e-05


Epoch 25: 100%|██████████| 477/477 [10:39<00:00,  1.34s/it, loss=0.4517]
Eval: 100%|██████████| 13/13 [00:13<00:00,  1.03s/it]


Epoch 25/50 | Train Loss: 0.4782 | Val Loss: 0.7726 | IoU: 0.1906 | F1: 0.3002 | Precision: 0.3993 | Recall: 0.2783 | Accuracy: 0.9848
No improvement in F1 for 1 epoch(s)
Current lr: 1e-05


Epoch 26: 100%|██████████| 477/477 [10:40<00:00,  1.34s/it, loss=0.7832]
Eval: 100%|██████████| 13/13 [00:13<00:00,  1.04s/it]


Epoch 26/50 | Train Loss: 0.4708 | Val Loss: 0.7709 | IoU: 0.1940 | F1: 0.3017 | Precision: 0.4009 | Recall: 0.2772 | Accuracy: 0.9846
No improvement in F1 for 2 epoch(s)
Current lr: 1e-05


Epoch 27: 100%|██████████| 477/477 [10:36<00:00,  1.33s/it, loss=0.5727]
Eval: 100%|██████████| 13/13 [00:13<00:00,  1.04s/it]


Epoch 27/50 | Train Loss: 0.4656 | Val Loss: 0.7676 | IoU: 0.1994 | F1: 0.3086 | Precision: 0.4146 | Recall: 0.2755 | Accuracy: 0.9845
New best model saved (F1: 0.3086)
Current lr: 1e-05


Epoch 28: 100%|██████████| 477/477 [10:40<00:00,  1.34s/it, loss=0.4837]
Eval: 100%|██████████| 13/13 [00:13<00:00,  1.03s/it]


Epoch 28/50 | Train Loss: 0.4621 | Val Loss: 0.7720 | IoU: 0.1981 | F1: 0.3071 | Precision: 0.4262 | Recall: 0.2643 | Accuracy: 0.9847
No improvement in F1 for 1 epoch(s)
Current lr: 1.0000000000000002e-06


Epoch 29: 100%|██████████| 477/477 [10:39<00:00,  1.34s/it, loss=0.3347]
Eval: 100%|██████████| 13/13 [00:13<00:00,  1.04s/it]


Epoch 29/50 | Train Loss: 0.4574 | Val Loss: 0.7644 | IoU: 0.2033 | F1: 0.3139 | Precision: 0.4297 | Recall: 0.2772 | Accuracy: 0.9847
New best model saved (F1: 0.3139)
Current lr: 1.0000000000000002e-06


Epoch 30: 100%|██████████| 477/477 [10:40<00:00,  1.34s/it, loss=0.5149]
Eval: 100%|██████████| 13/13 [00:13<00:00,  1.04s/it]


Epoch 30/50 | Train Loss: 0.4558 | Val Loss: 0.7669 | IoU: 0.1983 | F1: 0.3061 | Precision: 0.3969 | Recall: 0.2802 | Accuracy: 0.9845
No improvement in F1 for 1 epoch(s)
Current lr: 1.0000000000000002e-06


Epoch 31: 100%|██████████| 477/477 [10:44<00:00,  1.35s/it, loss=0.7662]
Eval: 100%|██████████| 13/13 [00:12<00:00,  1.01it/s]


Epoch 31/50 | Train Loss: 0.4549 | Val Loss: 0.7693 | IoU: 0.2006 | F1: 0.3091 | Precision: 0.4277 | Recall: 0.2702 | Accuracy: 0.9847
No improvement in F1 for 2 epoch(s)
Current lr: 1.0000000000000002e-06


Epoch 32: 100%|██████████| 477/477 [09:58<00:00,  1.25s/it, loss=0.2576]
Eval: 100%|██████████| 13/13 [00:12<00:00,  1.05it/s]


Epoch 32/50 | Train Loss: 0.4541 | Val Loss: 0.7743 | IoU: 0.1970 | F1: 0.3047 | Precision: 0.4243 | Recall: 0.2627 | Accuracy: 0.9846
No improvement in F1 for 3 epoch(s)
Current lr: 1.0000000000000002e-06


Epoch 33: 100%|██████████| 477/477 [10:42<00:00,  1.35s/it, loss=0.5279]
Eval: 100%|██████████| 13/13 [00:15<00:00,  1.16s/it]


Epoch 33/50 | Train Loss: 0.4534 | Val Loss: 0.7622 | IoU: 0.2015 | F1: 0.3109 | Precision: 0.4047 | Recall: 0.2882 | Accuracy: 0.9843
No improvement in F1 for 4 epoch(s)
Current lr: 1.0000000000000002e-06


Epoch 34: 100%|██████████| 477/477 [10:32<00:00,  1.33s/it, loss=0.7174]
Eval: 100%|██████████| 13/13 [00:12<00:00,  1.01it/s]


Epoch 34/50 | Train Loss: 0.4527 | Val Loss: 0.7771 | IoU: 0.1956 | F1: 0.3044 | Precision: 0.4485 | Recall: 0.2552 | Accuracy: 0.9848
No improvement in F1 for 5 epoch(s)
Current lr: 1.0000000000000002e-06


Epoch 35: 100%|██████████| 477/477 [09:52<00:00,  1.24s/it, loss=0.3691]
Eval: 100%|██████████| 13/13 [00:12<00:00,  1.04it/s]


Epoch 35/50 | Train Loss: 0.4521 | Val Loss: 0.7557 | IoU: 0.2066 | F1: 0.3164 | Precision: 0.3990 | Recall: 0.2993 | Accuracy: 0.9843
New best model saved (F1: 0.3164)
Current lr: 1.0000000000000002e-06


Epoch 36: 100%|██████████| 477/477 [09:49<00:00,  1.24s/it, loss=0.5754]
Eval: 100%|██████████| 13/13 [00:12<00:00,  1.04it/s]


Epoch 36/50 | Train Loss: 0.4514 | Val Loss: 0.7688 | IoU: 0.1997 | F1: 0.3085 | Precision: 0.4189 | Recall: 0.2735 | Accuracy: 0.9845
No improvement in F1 for 1 epoch(s)
Current lr: 1.0000000000000002e-06


Epoch 37: 100%|██████████| 477/477 [09:49<00:00,  1.24s/it, loss=0.5774]
Eval: 100%|██████████| 13/13 [00:12<00:00,  1.05it/s]


Epoch 37/50 | Train Loss: 0.4508 | Val Loss: 0.7721 | IoU: 0.1948 | F1: 0.3028 | Precision: 0.4135 | Recall: 0.2689 | Accuracy: 0.9845
No improvement in F1 for 2 epoch(s)
Current lr: 1.0000000000000002e-06


Epoch 38: 100%|██████████| 477/477 [10:03<00:00,  1.27s/it, loss=0.5086]
Eval: 100%|██████████| 13/13 [00:12<00:00,  1.04it/s]


Epoch 38/50 | Train Loss: 0.4503 | Val Loss: 0.7726 | IoU: 0.1965 | F1: 0.3062 | Precision: 0.4378 | Recall: 0.2673 | Accuracy: 0.9846
No improvement in F1 for 3 epoch(s)
Current lr: 1.0000000000000002e-06


Epoch 39: 100%|██████████| 477/477 [09:48<00:00,  1.23s/it, loss=0.5295]
Eval: 100%|██████████| 13/13 [00:12<00:00,  1.05it/s]


Epoch 39/50 | Train Loss: 0.4498 | Val Loss: 0.7598 | IoU: 0.2037 | F1: 0.3130 | Precision: 0.4037 | Recall: 0.2926 | Accuracy: 0.9844
No improvement in F1 for 4 epoch(s)
Current lr: 1.0000000000000002e-06


Epoch 40: 100%|██████████| 477/477 [09:48<00:00,  1.23s/it, loss=0.5494]
Eval: 100%|██████████| 13/13 [00:12<00:00,  1.05it/s]


Epoch 40/50 | Train Loss: 0.4492 | Val Loss: 0.7689 | IoU: 0.1986 | F1: 0.3075 | Precision: 0.4205 | Recall: 0.2710 | Accuracy: 0.9846
No improvement in F1 for 5 epoch(s)
Current lr: 1.0000000000000002e-07


Epoch 41: 100%|██████████| 477/477 [09:48<00:00,  1.23s/it, loss=0.4504]
Eval: 100%|██████████| 13/13 [00:12<00:00,  1.05it/s]


Epoch 41/50 | Train Loss: 0.4485 | Val Loss: 0.7698 | IoU: 0.1985 | F1: 0.3059 | Precision: 0.4171 | Recall: 0.2722 | Accuracy: 0.9845
No improvement in F1 for 6 epoch(s)
Current lr: 1.0000000000000002e-07


Epoch 42: 100%|██████████| 477/477 [09:48<00:00,  1.23s/it, loss=0.4136]
Eval: 100%|██████████| 13/13 [00:12<00:00,  1.05it/s]


Epoch 42/50 | Train Loss: 0.4484 | Val Loss: 0.7829 | IoU: 0.1901 | F1: 0.2984 | Precision: 0.4382 | Recall: 0.2478 | Accuracy: 0.9847
No improvement in F1 for 7 epoch(s)
Current lr: 1.0000000000000002e-07


Epoch 43: 100%|██████████| 477/477 [09:48<00:00,  1.23s/it, loss=0.3470]
Eval: 100%|██████████| 13/13 [00:12<00:00,  1.04it/s]


Epoch 43/50 | Train Loss: 0.4484 | Val Loss: 0.7725 | IoU: 0.1956 | F1: 0.3045 | Precision: 0.4267 | Recall: 0.2639 | Accuracy: 0.9846
No improvement in F1 for 8 epoch(s)
Current lr: 1.0000000000000002e-07


Epoch 44: 100%|██████████| 477/477 [09:48<00:00,  1.23s/it, loss=0.2444]
Eval: 100%|██████████| 13/13 [00:12<00:00,  1.05it/s]


Epoch 44/50 | Train Loss: 0.4483 | Val Loss: 0.7805 | IoU: 0.1903 | F1: 0.2973 | Precision: 0.4326 | Recall: 0.2520 | Accuracy: 0.9847
No improvement in F1 for 9 epoch(s)
Current lr: 1.0000000000000002e-07


Epoch 45: 100%|██████████| 477/477 [09:48<00:00,  1.23s/it, loss=0.3082]
Eval: 100%|██████████| 13/13 [00:12<00:00,  1.05it/s]

Epoch 45/50 | Train Loss: 0.4482 | Val Loss: 0.7668 | IoU: 0.2011 | F1: 0.3097 | Precision: 0.4185 | Recall: 0.2741 | Accuracy: 0.9844
No improvement in F1 for 10 epoch(s)
Current lr: 1.0000000000000004e-08
Early stopping triggered.





In [None]:
# Final evaluation with visualisations
print("Generating visualizations for best model")
model.load_state_dict(torch.load(os.path.join(output_dir, "best_model.pth")))
_, final_metrics = evaluate(model, test_loader, criterion, save_preds=True)
print("Final Test Metrics:")
for k, v in final_metrics.items():
    print(f"  {k.capitalize()}: {v:.4f}")
print(f"Predictions & visualizations saved to: {output_dir}")

Generating visualizations for best model


Eval: 100%|██████████| 14/14 [00:18<00:00,  1.30s/it]

Final Test Metrics:
  Accuracy: 0.9891
  Precision: 0.4709
  Recall: 0.2234
  F1_score: 0.2789
  Jaccard: 0.1720
Predictions & visualizations saved to: ./predictions



