In [2]:
import os
import glob
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from segment_anything import sam_model_registry
from segment_anything.utils.transforms import ResizeLongestSide
import matplotlib.pyplot as plt
import random
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Subset

In [3]:
def seed_everything(seed: int = 42) -> None:
    os.environ["PYTHONHASHSEED"] = str(seed)           
    random.seed(seed)                                 
    np.random.seed(seed)                              
    torch.manual_seed(seed)                           
    torch.cuda.manual_seed(seed)                       
    torch.cuda.manual_seed_all(seed)                   
    torch.use_deterministic_algorithms(True, warn_only=True)
    torch.backends.cudnn.deterministic = True         
    torch.backends.cudnn.benchmark = False 

In [None]:

# Configuration


dataset_path = "./final_dataset"               # root containing train/ & test/
checkpoint_path = "../sam_vit_h_4b8939.pth"     # SAM ViT‑H checkpoint
model_type = "vit_h"                           # sam encoder type
output_dir = "./predictions"                   # where to save models & figs
os.makedirs(output_dir, exist_ok=True)

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [None]:

# Dataset

class RetinalHemorrhageDataset(Dataset):
    """Dataset loading RGB fundus images and binary masks."""

    def __init__(self, root_dir: str, split: str = "train", transform=None):
        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        self.images_dir = os.path.join(root_dir, split, "images")
        self.masks_dir = os.path.join(root_dir, split, "masks")
        self.image_files = sorted(glob.glob(os.path.join(self.images_dir, "*.*")))
        self.mask_files = sorted(glob.glob(os.path.join(self.masks_dir, "*.*")))

        assert len(self.image_files) == len(self.mask_files), (
            f"Number of images ({len(self.image_files)}) and masks "
            f"({len(self.mask_files)}) don't match!"
        )
        print(f"Found {len(self.image_files)} samples in '{split}' split")

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # Image
        img_path = self.image_files[idx]
        img = cv2.imread(img_path)
        if img is None:
            raise ValueError(f"Failed to load image: {img_path}")
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        # Mask
        msk_path = self.mask_files[idx]
        msk = cv2.imread(msk_path, cv2.IMREAD_GRAYSCALE)
        if msk is None:
            raise ValueError(f"Failed to load mask: {msk_path}")
        msk = (msk > 0).astype(np.float32)

        # To tensor
        img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0  # C,H,W
        msk = torch.from_numpy(msk).unsqueeze(0).float()              # 1,H,W
        return {"image": img, "mask": msk, "filename": os.path.basename(img_path)}


# Loss

class FocalTverskyLoss(nn.Module):
    def __init__(self, alpha=0.3, beta=0.7, gamma=0.75, eps=1e-6):
        super(FocalTverskyLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.eps = eps

    def forward(self, inputs, targets):
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        TP = (inputs * targets).sum()
        FP = ((1 - targets) * inputs).sum()
        FN = (targets * (1 - inputs)).sum()

        tversky = (TP + self.eps) / (TP + self.alpha * FP + self.beta * FN + self.eps)
        focal_tversky = (1 - tversky) ** self.gamma

        return focal_tversky


In [None]:

# U‑Net style decoder blocks


class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.net(x)

class UpBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)
        self.conv = DoubleConv(in_ch, out_ch)

    def forward(self, x):
        x = self.upsample(x)
        return self.conv(x)

class UNetDecoder(nn.Module):

    def __init__(self, in_ch: int = 256):
        super().__init__()
        self.block1 = DoubleConv(in_ch, 256)
        self.up1 = UpBlock(256, 128)
        self.up2 = UpBlock(128, 64)
        self.up3 = UpBlock(64, 32)
        self.up4 = UpBlock(32, 16)
        self.final = nn.Conv2d(16, 1, kernel_size=1)

    def forward(self, x):
        x = self.block1(x)
        x = self.up1(x)
        x = self.up2(x)
        x = self.up3(x)
        x = self.up4(x)
        return self.final(x)


# Fine‑tuner combining SAM encoder + U‑Net decoder


class SAMFineTuner(nn.Module):
    def __init__(self, checkpoint_path: str, model_type: str):
        super().__init__()
        print(f"Loading SAM encoder '{model_type}' …")
        self.sam = sam_model_registry[model_type](checkpoint=checkpoint_path)
        self.sam.to(device)
        print("SAM encoder loaded.")

        # Freeze encoder parameters
        for p in self.sam.image_encoder.parameters():
            p.requires_grad = False

        # Decoder
        self.decoder = UNetDecoder(in_ch=256)

        # Pre‑processing helper
        self.transform = ResizeLongestSide(self.sam.image_encoder.img_size)
        self.pixel_mean = torch.tensor([123.675, 116.28, 103.53]).view(3, 1, 1)
        self.pixel_std = torch.tensor([58.395, 57.12, 57.375]).view(3, 1, 1)

    @torch.no_grad()
    def preprocess_single(self, img_tensor: torch.Tensor) -> torch.Tensor:
       
        img_np = (img_tensor.cpu().numpy().transpose(1, 2, 0) * 255.0).astype(np.float32)
        img_np = self.transform.apply_image(img_np)
        t = torch.from_numpy(img_np).permute(2, 0, 1).float()
        t = (t - self.pixel_mean) / self.pixel_std
        return t

    def forward(self, imgs: torch.Tensor):  # imgs: B,C,H,W in [0,1]
        B, _, H, W = imgs.shape
        processed = torch.stack([self.preprocess_single(im) for im in imgs]).to(device)

        with torch.no_grad():
            emb = self.sam.image_encoder(processed)  # B,256,h,w

        dec_out = self.decoder(emb)                 # B,1,h',w'
        dec_out = F.interpolate(dec_out, size=(H, W), mode="bilinear", align_corners=False)
        return torch.sigmoid(dec_out)

In [None]:

# Metrics


def compute_metrics(y_true: np.ndarray, y_pred: np.ndarray, thr: float = 0.5):
    y_pred_bin = (y_pred > thr).astype(np.uint8)
    y_true = y_true.astype(np.uint8)
    tp = np.sum((y_pred_bin == 1) & (y_true == 1))
    fp = np.sum((y_pred_bin == 1) & (y_true == 0))
    fn = np.sum((y_pred_bin == 0) & (y_true == 1))
    tn = np.sum((y_pred_bin == 0) & (y_true == 0))

    acc = (tp + tn) / (tp + fp + fn + tn + 1e-6)
    precision = tp / (tp + fp + 1e-6)
    recall = tp / (tp + fn + 1e-6)
    f1 = 2 * precision * recall / (precision + recall + 1e-6)
    iou = tp / (tp + fp + fn + 1e-6)
    return {"accuracy": acc, "precision": precision, "recall": recall, "f1_score": f1, "jaccard": iou}

# Train / Eval loops


def train_epoch(model, loader, criterion, optimizer, epoch):
    model.train()
    total_loss = 0.0
    pbar = tqdm(loader, desc=f"Epoch {epoch}")
    for batch in pbar:
        imgs = batch["image"].to(device)
        msks = batch["mask"].to(device)
        preds = model(imgs)
        loss = criterion(preds, msks)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        pbar.set_postfix(loss=f"{loss.item():.4f}")
    return total_loss / len(loader)

def evaluate(model, loader, criterion, save_preds: bool = False):
    model.eval()
    val_loss = 0.0
    stats = {k: [] for k in ["accuracy", "precision", "recall", "f1_score", "jaccard"]}
    pbar = tqdm(loader, desc="Eval")
    with torch.no_grad():
        for batch in pbar:
            imgs = batch["image"].to(device)
            msks = batch["mask"].to(device)
            fnames = batch["filename"]
            preds = model(imgs)
            loss = criterion(preds, msks)
            val_loss += loss.item()
            preds_np = preds.cpu().numpy()
            msks_np = msks.cpu().numpy()
            for i in range(len(imgs)):
                metrics = compute_metrics(msks_np[i, 0], preds_np[i, 0])
                for k in stats:
                    stats[k].append(metrics[k])
                if save_preds:
                    save_visualization(imgs[i], msks_np[i, 0], preds_np[i, 0], metrics, fnames[i])
    avg_stats = {k: float(np.mean(v)) for k, v in stats.items()}
    return val_loss / len(loader), avg_stats


# Helper to save side‑by‑side visualizations


def save_visualization(img_t, true_msk, pred_msk, metrics, fname):
    img = (img_t.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
    pred = (pred_msk > 0.5).astype(np.uint8) * 255
    true = (true_msk > 0.5).astype(np.uint8) * 255

    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    axes[0].imshow(img)
    axes[0].set_title("Image")
    axes[0].axis("off")
    axes[1].imshow(true, cmap="gray")
    axes[1].set_title("Ground Truth")
    axes[1].axis("off")
    axes[2].imshow(pred, cmap="gray")
    axes[2].set_title("Prediction")
    axes[2].axis("off")

    txt = (
        f"IoU={metrics['jaccard']:.3f} | F1={metrics['f1_score']:.3f} | "
        f"Prec={metrics['precision']:.3f} | Rec={metrics['recall']:.3f}"
    )
    
    # Add more space at the bottom for the text
    plt.subplots_adjust(bottom=0.2)
    fig.text(0.5, 0.05, txt, ha="center", fontsize=10)

    out_path = os.path.join(output_dir, f"{os.path.splitext(fname)[0]}_viz.png")
    plt.savefig(out_path, dpi=150, bbox_inches='tight')
    plt.close()


In [8]:

seed_everything(42)

# Data
train_ds = RetinalHemorrhageDataset(dataset_path, split="train")
full_test_ds = RetinalHemorrhageDataset(dataset_path, split="test")

# Split test dataset into fixed 50% validation and 50% test
val_size = len(full_test_ds) // 2
test_size = len(full_test_ds) - val_size
val_ds, test_ds = torch.utils.data.random_split(full_test_ds, [val_size, test_size], generator=torch.Generator().manual_seed(42))

# Dataloaders
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=0, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=0, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=0, pin_memory=True)

# Model / loss / optim
model = SAMFineTuner(checkpoint_path, model_type).to(device)
criterion = FocalTverskyLoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
scheduler = ReduceLROnPlateau(
    optimizer,
    mode="min",
    factor=0.1,
    patience=4,
    threshold=1e-4,
    verbose=True,
)


Found 3192 samples in 'train' split
Found 225 samples in 'test' split
Loading SAM encoder 'vit_h' …


  state_dict = torch.load(f)


SAM encoder loaded.




In [9]:
# Training loop with early stopping
num_epochs = 30
best_f1 = -1.0
patience = 8
no_improve_epochs = 0

for epoch in range(1, num_epochs + 1):
    tr_loss = train_epoch(model, train_loader, criterion, optimizer, epoch)
    val_loss, metrics = evaluate(model, val_loader, criterion)
    scheduler.step(val_loss)
    lr_now = optimizer.param_groups[0]["lr"]

    print(
        f"Epoch {epoch}/{num_epochs} | "
        f"Train Loss: {tr_loss:.4f} | Val Loss: {val_loss:.4f} | "
        f"IoU: {metrics['jaccard']:.4f} | "
        f"F1: {metrics['f1_score']:.4f} | "
        f"Precision: {metrics['precision']:.4f} | "
        f"Recall: {metrics['recall']:.4f} | "
        f"Accuracy: {metrics['accuracy']:.4f}"
    )

    if metrics["f1_score"] > best_f1:
        best_f1 = metrics["f1_score"]
        no_improve_epochs = 0  # reset counter
        torch.save(model.state_dict(), os.path.join(output_dir, "best_model.pth"))
        print(f"New best model saved (F1: {best_f1:.4f})")
    else:
        no_improve_epochs += 1
        print(f"No improvement in F1 for {no_improve_epochs} epoch(s)")

    print(f"Current lr: {lr_now}")

    if no_improve_epochs >= patience:
        print("Early stopping triggered.")
        break


  attn = (q * self.scale) @ k.transpose(-2, -1)
  return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
  x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
Epoch 1: 100%|██████████| 3192/3192 [1:09:03<00:00,  1.30s/it, loss=1.0000]
Eval: 100%|██████████| 112/112 [01:56<00:00,  1.04s/it]


Epoch 1/30 | Train Loss: 0.9833 | Val Loss: 0.9591 | IoU: 0.0363 | F1: 0.0655 | Precision: 0.0410 | Recall: 0.4877 | Accuracy: 0.9112
New best model saved (F1: 0.0655)
Current lr: 0.0001


Epoch 2: 100%|██████████| 3192/3192 [1:08:05<00:00,  1.28s/it, loss=1.0000]
Eval: 100%|██████████| 112/112 [01:55<00:00,  1.03s/it]


Epoch 2/30 | Train Loss: 0.9644 | Val Loss: 0.9077 | IoU: 0.1110 | F1: 0.1842 | Precision: 0.1709 | Recall: 0.3340 | Accuracy: 0.9850
New best model saved (F1: 0.1842)
Current lr: 0.0001


Epoch 3: 100%|██████████| 3192/3192 [1:08:09<00:00,  1.28s/it, loss=1.0000]
Eval: 100%|██████████| 112/112 [01:56<00:00,  1.04s/it]


Epoch 3/30 | Train Loss: 0.8952 | Val Loss: 0.8307 | IoU: 0.1334 | F1: 0.2136 | Precision: 0.2274 | Recall: 0.2851 | Accuracy: 0.9885
New best model saved (F1: 0.2136)
Current lr: 0.0001


Epoch 4: 100%|██████████| 3192/3192 [1:08:20<00:00,  1.28s/it, loss=0.9959]
Eval: 100%|██████████| 112/112 [01:56<00:00,  1.04s/it]


Epoch 4/30 | Train Loss: 0.8260 | Val Loss: 0.8230 | IoU: 0.1533 | F1: 0.2374 | Precision: 0.3891 | Recall: 0.2169 | Accuracy: 0.9914
New best model saved (F1: 0.2374)
Current lr: 0.0001


Epoch 5: 100%|██████████| 3192/3192 [1:08:24<00:00,  1.29s/it, loss=0.7270]
Eval: 100%|██████████| 112/112 [01:56<00:00,  1.04s/it]


Epoch 5/30 | Train Loss: 0.7989 | Val Loss: 0.8231 | IoU: 0.1459 | F1: 0.2260 | Precision: 0.3303 | Recall: 0.2283 | Accuracy: 0.9910
No improvement in F1 for 1 epoch(s)
Current lr: 0.0001


Epoch 6: 100%|██████████| 3192/3192 [1:08:28<00:00,  1.29s/it, loss=0.9626]
Eval: 100%|██████████| 112/112 [01:56<00:00,  1.04s/it]


Epoch 6/30 | Train Loss: 0.7857 | Val Loss: 0.8167 | IoU: 0.1584 | F1: 0.2445 | Precision: 0.4052 | Recall: 0.2177 | Accuracy: 0.9916
New best model saved (F1: 0.2445)
Current lr: 0.0001


Epoch 7: 100%|██████████| 3192/3192 [1:08:28<00:00,  1.29s/it, loss=0.5423]
Eval: 100%|██████████| 112/112 [01:56<00:00,  1.04s/it]


Epoch 7/30 | Train Loss: 0.7787 | Val Loss: 0.8099 | IoU: 0.1609 | F1: 0.2484 | Precision: 0.3893 | Recall: 0.2326 | Accuracy: 0.9915
New best model saved (F1: 0.2484)
Current lr: 0.0001


Epoch 8: 100%|██████████| 3192/3192 [1:08:30<00:00,  1.29s/it, loss=1.0000]
Eval: 100%|██████████| 112/112 [01:56<00:00,  1.04s/it]


Epoch 8/30 | Train Loss: 0.7703 | Val Loss: 0.8032 | IoU: 0.1618 | F1: 0.2503 | Precision: 0.3604 | Recall: 0.2499 | Accuracy: 0.9912
New best model saved (F1: 0.2503)
Current lr: 0.0001


Epoch 9: 100%|██████████| 3192/3192 [1:08:26<00:00,  1.29s/it, loss=1.0000]
Eval: 100%|██████████| 112/112 [01:56<00:00,  1.04s/it]


Epoch 9/30 | Train Loss: 0.7614 | Val Loss: 0.8073 | IoU: 0.1653 | F1: 0.2537 | Precision: 0.4099 | Recall: 0.2318 | Accuracy: 0.9916
New best model saved (F1: 0.2537)
Current lr: 0.0001


Epoch 10: 100%|██████████| 3192/3192 [1:08:09<00:00,  1.28s/it, loss=1.0000]
Eval: 100%|██████████| 112/112 [01:55<00:00,  1.03s/it]


Epoch 10/30 | Train Loss: 0.7540 | Val Loss: 0.8020 | IoU: 0.1578 | F1: 0.2466 | Precision: 0.3311 | Recall: 0.2702 | Accuracy: 0.9909
No improvement in F1 for 1 epoch(s)
Current lr: 0.0001


Epoch 11: 100%|██████████| 3192/3192 [1:08:14<00:00,  1.28s/it, loss=1.0000]
Eval: 100%|██████████| 112/112 [01:56<00:00,  1.04s/it]


Epoch 11/30 | Train Loss: 0.7514 | Val Loss: 0.7888 | IoU: 0.1669 | F1: 0.2588 | Precision: 0.3254 | Recall: 0.2911 | Accuracy: 0.9905
New best model saved (F1: 0.2588)
Current lr: 0.0001


Epoch 12: 100%|██████████| 3192/3192 [1:08:22<00:00,  1.29s/it, loss=0.4927]
Eval: 100%|██████████| 112/112 [01:56<00:00,  1.04s/it]


Epoch 12/30 | Train Loss: 0.7534 | Val Loss: 0.8022 | IoU: 0.1693 | F1: 0.2581 | Precision: 0.4298 | Recall: 0.2371 | Accuracy: 0.9916
No improvement in F1 for 1 epoch(s)
Current lr: 0.0001


Epoch 13: 100%|██████████| 3192/3192 [1:08:14<00:00,  1.28s/it, loss=0.6677]
Eval: 100%|██████████| 112/112 [01:56<00:00,  1.04s/it]


Epoch 13/30 | Train Loss: 0.7446 | Val Loss: 0.8073 | IoU: 0.1707 | F1: 0.2573 | Precision: 0.4603 | Recall: 0.2210 | Accuracy: 0.9917
No improvement in F1 for 2 epoch(s)
Current lr: 0.0001


Epoch 14: 100%|██████████| 3192/3192 [1:08:07<00:00,  1.28s/it, loss=0.4698]
Eval: 100%|██████████| 112/112 [01:55<00:00,  1.04s/it]


Epoch 14/30 | Train Loss: 0.7396 | Val Loss: 0.7994 | IoU: 0.1726 | F1: 0.2621 | Precision: 0.4510 | Recall: 0.2411 | Accuracy: 0.9917
New best model saved (F1: 0.2621)
Current lr: 0.0001


Epoch 15: 100%|██████████| 3192/3192 [1:08:09<00:00,  1.28s/it, loss=1.0000]
Eval: 100%|██████████| 112/112 [01:56<00:00,  1.04s/it]


Epoch 15/30 | Train Loss: 0.7360 | Val Loss: 0.7945 | IoU: 0.1763 | F1: 0.2673 | Precision: 0.4384 | Recall: 0.2485 | Accuracy: 0.9917
New best model saved (F1: 0.2673)
Current lr: 0.0001


Epoch 16: 100%|██████████| 3192/3192 [1:08:12<00:00,  1.28s/it, loss=1.0000]
Eval: 100%|██████████| 112/112 [01:56<00:00,  1.04s/it]


Epoch 16/30 | Train Loss: 0.7331 | Val Loss: 0.7947 | IoU: 0.1647 | F1: 0.2520 | Precision: 0.3072 | Recall: 0.2804 | Accuracy: 0.9902
No improvement in F1 for 1 epoch(s)
Current lr: 1e-05


Epoch 17: 100%|██████████| 3192/3192 [1:08:13<00:00,  1.28s/it, loss=1.0000]
Eval: 100%|██████████| 112/112 [01:56<00:00,  1.04s/it]


Epoch 17/30 | Train Loss: 0.7110 | Val Loss: 0.7975 | IoU: 0.1753 | F1: 0.2657 | Precision: 0.4356 | Recall: 0.2428 | Accuracy: 0.9918
No improvement in F1 for 2 epoch(s)
Current lr: 1e-05


Epoch 18: 100%|██████████| 3192/3192 [1:08:15<00:00,  1.28s/it, loss=0.7691]
Eval: 100%|██████████| 112/112 [01:56<00:00,  1.04s/it]


Epoch 18/30 | Train Loss: 0.7008 | Val Loss: 0.7991 | IoU: 0.1754 | F1: 0.2645 | Precision: 0.4453 | Recall: 0.2371 | Accuracy: 0.9918
No improvement in F1 for 3 epoch(s)
Current lr: 1e-05


Epoch 19: 100%|██████████| 3192/3192 [1:08:18<00:00,  1.28s/it, loss=1.0000]
Eval: 100%|██████████| 112/112 [01:56<00:00,  1.04s/it]


Epoch 19/30 | Train Loss: 0.6964 | Val Loss: 0.7945 | IoU: 0.1773 | F1: 0.2682 | Precision: 0.4447 | Recall: 0.2461 | Accuracy: 0.9918
New best model saved (F1: 0.2682)
Current lr: 1e-05


Epoch 20: 100%|██████████| 3192/3192 [1:08:10<00:00,  1.28s/it, loss=0.4012]
Eval: 100%|██████████| 112/112 [01:56<00:00,  1.04s/it]


Epoch 20/30 | Train Loss: 0.6924 | Val Loss: 0.7971 | IoU: 0.1757 | F1: 0.2666 | Precision: 0.4410 | Recall: 0.2419 | Accuracy: 0.9918
No improvement in F1 for 1 epoch(s)
Current lr: 1e-05


Epoch 21: 100%|██████████| 3192/3192 [1:08:09<00:00,  1.28s/it, loss=0.6623]
Eval: 100%|██████████| 112/112 [01:56<00:00,  1.04s/it]


Epoch 21/30 | Train Loss: 0.6894 | Val Loss: 0.8050 | IoU: 0.1708 | F1: 0.2601 | Precision: 0.4598 | Recall: 0.2273 | Accuracy: 0.9919
No improvement in F1 for 2 epoch(s)
Current lr: 1.0000000000000002e-06


Epoch 22: 100%|██████████| 3192/3192 [1:08:10<00:00,  1.28s/it, loss=0.3795]
Eval: 100%|██████████| 112/112 [01:56<00:00,  1.04s/it]


Epoch 22/30 | Train Loss: 0.6861 | Val Loss: 0.7977 | IoU: 0.1754 | F1: 0.2664 | Precision: 0.4493 | Recall: 0.2391 | Accuracy: 0.9918
No improvement in F1 for 3 epoch(s)
Current lr: 1.0000000000000002e-06


Epoch 23: 100%|██████████| 3192/3192 [1:08:09<00:00,  1.28s/it, loss=0.6156]
Eval: 100%|██████████| 112/112 [01:56<00:00,  1.04s/it]


Epoch 23/30 | Train Loss: 0.6846 | Val Loss: 0.8000 | IoU: 0.1749 | F1: 0.2651 | Precision: 0.4592 | Recall: 0.2349 | Accuracy: 0.9919
No improvement in F1 for 4 epoch(s)
Current lr: 1.0000000000000002e-06


Epoch 24: 100%|██████████| 3192/3192 [1:08:28<00:00,  1.29s/it, loss=0.1380]
Eval: 100%|██████████| 112/112 [01:56<00:00,  1.04s/it]


Epoch 24/30 | Train Loss: 0.6838 | Val Loss: 0.7978 | IoU: 0.1753 | F1: 0.2661 | Precision: 0.4507 | Recall: 0.2388 | Accuracy: 0.9918
No improvement in F1 for 5 epoch(s)
Current lr: 1.0000000000000002e-06


Epoch 25: 100%|██████████| 3192/3192 [1:08:32<00:00,  1.29s/it, loss=0.9849]
Eval: 100%|██████████| 112/112 [01:56<00:00,  1.04s/it]


Epoch 25/30 | Train Loss: 0.6831 | Val Loss: 0.7940 | IoU: 0.1777 | F1: 0.2692 | Precision: 0.4411 | Recall: 0.2465 | Accuracy: 0.9918
New best model saved (F1: 0.2692)
Current lr: 1.0000000000000002e-06


Epoch 26: 100%|██████████| 3192/3192 [1:08:32<00:00,  1.29s/it, loss=1.0000]
Eval: 100%|██████████| 112/112 [01:57<00:00,  1.04s/it]


Epoch 26/30 | Train Loss: 0.6824 | Val Loss: 0.7994 | IoU: 0.1754 | F1: 0.2654 | Precision: 0.4672 | Recall: 0.2357 | Accuracy: 0.9919
No improvement in F1 for 1 epoch(s)
Current lr: 1.0000000000000002e-07


Epoch 27:  46%|████▌     | 1474/3192 [31:42<36:57,  1.29s/it, loss=0.8312] 


KeyboardInterrupt: 

In [None]:
# Final evaluation with visualisations
print("Generating visualizations for best model")
model.load_state_dict(torch.load(os.path.join(output_dir, "best_model.pth")))
_, final_metrics = evaluate(model, test_loader, criterion, save_preds=True)
print("Final Test Metrics:")
for k, v in final_metrics.items():
    print(f"  {k.capitalize()}: {v:.4f}")
print(f"Predictions & visualizations saved to: {output_dir}")

Generating visualizations for best model


  model.load_state_dict(torch.load(os.path.join(output_dir, "best_model.pth")))
Eval:  13%|█▎        | 15/113 [00:18<01:55,  1.18s/it]