In [None]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
import albumentations as A
from scipy.ndimage import binary_dilation
from torch.amp import GradScaler, autocast
import gc
from sklearn.metrics import confusion_matrix
from tqdm import tqdm
import time
import logging
from torch.optim.lr_scheduler import CosineAnnealingLR
from google.colab import files

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True

logging.basicConfig(filename='training_segnet.log', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

files.upload()
!pip install -q kaggle
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

!kaggle datasets download -d meteahishali/aerial-imagery-for-standing-dead-tree-segmentation -p /content
!unzip /content/aerial-imagery-for-standing-dead-tree-segmentation.zip -d /content/aerial-imagery-for-standing-dead-tree-segmentation


In [None]:
!pip install git+https://github.com/lucasb-eyer/pydensecrf.git
!pip install opencv-python numpy matplotlib scikit-learn torch torchvision albumentations scipy pydensecrf tqdm segmentation-models-pytorch
import pydensecrf.densecrf as dcrf
from pydensecrf.utils import unary_from_softmax
import segmentation_models_pytorch as smp

In [4]:
class SegNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super(SegNet, self).__init__()
        self.enc1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.pool1 = nn.MaxPool2d(2, 2, return_indices=True)

        self.enc2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        self.pool2 = nn.MaxPool2d(2, 2, return_indices=True)

        self.enc3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        self.pool3 = nn.MaxPool2d(2, 2, return_indices=True)

        self.unpool3 = nn.MaxUnpool2d(2, 2)
        self.dec3 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )

        self.unpool2 = nn.MaxUnpool2d(2, 2)
        self.dec2 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )

        self.unpool1 = nn.MaxUnpool2d(2, 2)
        self.dec1 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, out_channels, kernel_size=3, padding=1)
        )

    def forward(self, x):
        x, idx1 = self.pool1(self.enc1(x))
        x, idx2 = self.pool2(self.enc2(x))
        x, idx3 = self.pool3(self.enc3(x))
        x = self.unpool3(x, idx3)
        x = self.dec3(x)
        x = self.unpool2(x, idx2)
        x = self.dec2(x)
        x = self.unpool1(x, idx1)
        x = self.dec1(x)
        return x


In [None]:
class HRNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super(HRNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.upconv = nn.ConvTranspose2d(128, out_channels, 4, stride=2, padding=1)
        self.pool = nn.AdaptiveAvgPool2d((256, 256))

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.upconv(x)
        x = self.pool(x)
        return x

In [5]:
def denormalize_image(image, mean, std):
    image = image.copy()
    for c in range(image.shape[0]):
        image[c] = (image[c] * std[c]) + mean[c]
    return np.clip(image, 0, 1)

def preprocess_aerial(image, is_nrg=False, image_path="unknown"):
    image = image.astype(np.float32) / 255.0
    if is_nrg:
        if image.shape[-1] != 3:
            logging.warning(f"Expected 3 channels at {image_path}, got {image.shape[-1]}")
            return None
        nir = cv2.equalizeHist((image[:, :, 0] * 255).astype(np.uint8)).astype(np.float32) / 255.0
        image = np.stack([nir, image[:, :, 1], image[:, :, 2]], axis=-1)
    image = np.clip(image * 1.2, 0, 1)
    if np.any(np.isnan(image)) or np.any(np.isinf(image)):
        logging.error(f"NaN/Inf in {image_path}")
        return None
    return image

def load_image_and_mask(image_path, mask_path, target_size=(256, 256)):
    try:
        image = cv2.imread(image_path, cv2.IMREAD_COLOR)
        if image is None or image.shape[0] < 128 or image.shape[1] < 128:
            return None, None
        is_nrg = "NRG" in os.path.basename(image_path)
        image = cv2.resize(image, target_size, interpolation=cv2.INTER_LINEAR)
        image = preprocess_aerial(image, is_nrg, image_path)
        if image is None:
            return None, None
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        if mask is None:
            return None, None
        mask = cv2.resize(mask, target_size, interpolation=cv2.INTER_NEAREST)
        mask = (mask > 128).astype(np.float32)
        mask = binary_dilation(mask, iterations=1).astype(np.float32)
        if mask.sum() < 5:
            with open("empty_masks.log", "a") as f:
                f.write(f"Empty mask for {image_path} (sum={mask.sum()})\n")
            return None, None
        return image, mask
    except Exception as e:
        logging.error(f"Error in {image_path}: {str(e)}")
        return None, None

def apply_nrg_normalization(image, mean, std):
    image = image.transpose(2, 0, 1)
    image = torch.from_numpy(image).float()
    mean = torch.tensor(mean, dtype=torch.float32).view(-1, 1, 1)
    std = torch.tensor(std, dtype=torch.float32).view(-1, 1, 1)
    image = (image - mean) / (std + 1e-7)
    image = torch.clamp(image, -5, 5)
    if torch.any(torch.isnan(image)) or torch.any(torch.isinf(image)):
        logging.error("NaN/Inf in normalized image")
        return None
    return image

def preprocess_and_save(image_paths, mask_paths, mean, std, save_dir):
    os.makedirs(save_dir, exist_ok=True)
    valid_indices = []
    for i, (img_path, mask_path) in enumerate(tqdm(zip(image_paths, mask_paths), total=len(image_paths))):
        image, mask = load_image_and_mask(img_path, mask_path)
        if image is None or mask is None:
            continue
        image = apply_nrg_normalization(image, mean, std)
        if image is None:
            continue
        mask = torch.from_numpy(mask).float().unsqueeze(0)
        np.save(os.path.join(save_dir, f"image_{i}.npy"), image.numpy(), allow_pickle=False)
        np.save(os.path.join(save_dir, f"mask_{i}.npy"), mask.numpy(), allow_pickle=False)
        valid_indices.append(i)
    return valid_indices

In [6]:
class AerialDataset(Dataset):
    def __init__(self, preprocessed_dir, mean, std, augmenter=None):
        self.image_files = sorted([f for f in os.listdir(preprocessed_dir) if f.startswith("image_")])
        self.mask_files = sorted([f for f in os.listdir(preprocessed_dir) if f.startswith("mask_")])
        self.preprocessed_dir = preprocessed_dir
        self.mean = mean
        self.std = std
        self.augmenter = augmenter

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image = np.load(os.path.join(self.preprocessed_dir, self.image_files[idx]), mmap_mode='r')
        mask = np.load(os.path.join(self.preprocessed_dir, self.mask_files[idx]), mmap_mode='r')
        image = torch.from_numpy(image.copy()).float()
        mask = torch.from_numpy(mask.copy()).float()
        if self.augmenter:
            img_3ch = denormalize_image(image.numpy(), self.mean, self.std).transpose(1, 2, 0)
            augmented = self.augmenter(image=img_3ch, mask=mask.squeeze().numpy())
            aug_img, aug_mask = augmented['image'], augmented['mask']
            if aug_mask.sum() >= 5:
                image = apply_nrg_normalization(aug_img, self.mean, self.std)
                mask = torch.from_numpy(aug_mask).float().unsqueeze(0)
        return image, mask


In [7]:
data_dir = "aerial-imagery-for-standing-dead-tree-segmentation/USA_segmentation"
nrg_image_dir = os.path.join(data_dir, "NRG_images")
mask_dir = os.path.join(data_dir, "masks")

for directory in [nrg_image_dir, mask_dir]:
    if not os.path.exists(directory):
        raise FileNotFoundError(f"Directory {directory} not found")

nrg_image_files = sorted([f for f in os.listdir(nrg_image_dir) if f.endswith('.png')])
mask_files = sorted([f for f in os.listdir(mask_dir) if f.startswith('mask_') and f.endswith('.png')])

image_paths = []
mask_paths = []
for img_file in nrg_image_files:
    img_path = os.path.join(nrg_image_dir, img_file)
    base_name = img_file.replace("NRG_", "")
    mask_file = f"mask_{base_name}"
    mask_path = os.path.join(mask_dir, mask_file)
    if os.path.exists(mask_path):
        image_paths.append(img_path)
        mask_paths.append(mask_path)

image_paths = image_paths[:200]
mask_paths = mask_paths[:200]

mean_nrg = np.array([0.5827, 0.4934, 0.6260])
std_nrg = np.array([0.3245, 0.2736, 0.2906])
logging.info(f"NRG mean: {mean_nrg}, std: {std_nrg}")
print(f"NRG mean: {mean_nrg}, std: {std_nrg}")


valid_indices = preprocess_and_save(image_paths, mask_paths, mean_nrg, std_nrg, "preprocessed_data_segnet")


NRG mean: [0.5827 0.4934 0.626 ], std: [0.3245 0.2736 0.2906]


100%|██████████| 200/200 [00:05<00:00, 33.58it/s]


In [8]:
def get_augmenter():
    return A.Compose([
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.Resize(height=256, width=256, p=1.0)
    ], additional_targets={'mask': 'mask'})

train_idx, val_idx = train_test_split(range(len(valid_indices)), test_size=0.25, random_state=42)
train_dataset = AerialDataset("preprocessed_data_segnet", mean_nrg, std_nrg, get_augmenter())
val_dataset = AerialDataset("preprocessed_data_segnet", mean_nrg, std_nrg, None)

def custom_collate_fn(batch):
    batch = [b for b in batch if b[0] is not None]
    if len(batch) <= 1:
        return torch.tensor([]), torch.tensor([])
    images, masks = zip(*batch)
    return torch.stack(images), torch.stack(masks)

batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, collate_fn=custom_collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True, collate_fn=custom_collate_fn)


In [9]:

def dice_coefficient(y_true, y_pred, smooth=1e-7):
    y_true_f = y_true.flatten()
    y_pred_f = y_pred.flatten()
    intersection = (y_true_f * y_pred_f).sum()
    return (2. * intersection + smooth) / (y_true_f.sum() + y_pred_f.sum() + smooth)

def jaccard_index(y_true, y_pred, smooth=1e-7):
    y_true_f = y_true.flatten()
    y_pred_f = y_pred.flatten()
    intersection = (y_true_f * y_pred_f).sum()
    union = y_true_f.sum() + y_pred_f.sum() - intersection
    return (intersection + smooth) / (union + smooth)

def pixel_accuracy(y_true, y_pred):
    y_true_bin = (y_true > 0.5).float()
    y_pred_bin = (y_pred > 0.5).float()
    return (y_true_bin == y_pred_bin).float().mean()

class CombinedLoss(nn.Module):
    def __init__(self, bce_weight=0.3, focal_weight=0.7, gamma=2.0):
        super().__init__()
        self.bce_weight = bce_weight
        self.focal_weight = focal_weight
        self.gamma = gamma
        self.bce = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(10.0).to(device))

    def focal_loss(self, y_pred, y_true):
        y_pred = torch.sigmoid(y_pred)
        pt = y_true * y_pred + (1 - y_true) * (1 - y_pred)
        focal_weight = (1 - pt).pow(self.gamma)
        bce = F.binary_cross_entropy_with_logits(y_pred, y_true, reduction='none')
        return (focal_weight * bce).mean()

    def forward(self, y_pred, y_true):
        bce = self.bce(y_pred, y_true)
        focal = self.focal_loss(y_pred, y_true)
        loss = self.bce_weight * bce + self.focal_weight * focal
        if y_true.sum() == 0:
            return 0.1 * loss
        return loss

In [14]:
model_name = "SegNet"
model_configs = {
    "DeepLabV3Plus": {
        "class": smp.DeepLabV3Plus,
        "save_prefix": "deeplabv3plus_hit_uav",
        "params": {"encoder_name": "resnet50", "in_channels": 3, "classes": 1}
    },
    "Unet": {
        "class": smp.Unet,
        "save_prefix": "unet_hit_uav",
        "params": {"encoder_name": "resnet50", "in_channels": 3, "classes": 1}
    },
    "PSPNet": {
        "class": smp.PSPNet,
        "save_prefix": "pspnet_hit_uav",
        "params": {"encoder_name": "resnet50", "in_channels": 3, "classes": 1}
    },
    "SegNet": {
        "class": SegNet,
        "save_prefix": "segnet_hit_uav",
        "params": {"in_channels": 3, "out_channels": 1}
    },
    "HRNet": {
        "class": HRNet,
        "save_prefix": "hrnet_hit_uav",
        "params": {"in_channels": 3, "out_channels": 1}
    }
}

if model_name not in model_configs:
    raise ValueError(f"Model {model_name} is not supported. Choose from {list(model_configs.keys())}")

In [None]:
def get_segmentation_model(model_name):
    model_class = model_configs[model_name]["class"]
    model_params = model_configs[model_name]["params"]
    if model_name in ["DeepLabV3Plus", "Unet", "PSPNet"]:
        return model_class(**model_params, encoder_weights="imagenet", activation=None).to(device)
    elif model_name in ["SegNet", "HRNet"]:
        return model_class(**model_params).to(device)
    else:
        raise ValueError(f"Unsupported model type: {model_name}")

model = get_segmentation_model(model_name)

loss_fn = CombinedLoss(bce_weight=0.3, focal_weight=0.7)
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6)
scaler = GradScaler('cuda')
accumulation_steps = 4
warmup_epochs = 2
num_epochs = 200
warmup_epochs = 5
base_lr = 3e-4
warmup_lr = 3e-5
best_val_iou = 0.0
patience = 15
history = {'train_loss': [], 'val_loss': [], 'train_iou': [], 'val_iou': []}
epochs_no_improve = 0
for epoch in range(num_epochs):
    start_time = time.time()
    model.train()
    train_loss, train_iou, train_count = 0.0, 0.0, 0
    optimizer.zero_grad(set_to_none=True)
    data_time = 0.0
    forward_time = 0.0
    for i, (images, masks) in enumerate(train_loader):
        data_start = time.time()
        if images.numel() == 0 or images.size(0) <= 1:
            continue
        images, masks = images.to(device), masks.to(device)
        data_time += time.time() - data_start

        forward_start = time.time()
        with autocast('cuda'):
            outputs = model(images)
            loss = loss_fn(outputs, masks) / accumulation_steps
        if torch.isnan(loss) or torch.isinf(loss):
            continue
        scaler.scale(loss).backward()
        if (i + 1) % accumulation_steps == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)
        forward_time += time.time() - forward_start

        train_loss += loss.item() * accumulation_steps * images.size(0)
        preds = torch.sigmoid(outputs).detach()
        train_iou += jaccard_index(masks, preds).item() * images.size(0)
        train_count += images.size(0)
        del images, masks, outputs, preds
        torch.cuda.empty_cache()

    if train_count > 0:
        train_loss /= train_count
        train_iou /= train_count

    if epoch < warmup_epochs:
        lr = 1e-3 * (epoch + 1) / warmup_epochs
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    else:
        scheduler.step()

    val_loss, val_iou = train_loss, train_iou
    if (epoch + 1) % 5 == 0:
        model.eval()
        val_loss, val_iou, val_count = 0.0, 0.0, 0
        with torch.no_grad():
            for images, masks in val_loader:
                if images.numel() == 0 or images.size(0) <= 1:
                    continue
                images, masks = images.to(device), masks.to(device)
                with autocast('cuda'):
                    outputs = model(images)
                    loss = loss_fn(outputs, masks)
                val_loss += loss.item() * images.size(0)
                preds = torch.sigmoid(outputs).detach()
                val_iou += jaccard_index(masks, preds).item() * images.size(0)
                val_count += images.size(0)
                del images, masks, outputs, preds
                torch.cuda.empty_cache()
        if val_count > 0:
            val_loss /= val_count
            val_iou /= val_count

    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_iou'].append(train_iou)
    history['val_iou'].append(val_iou)

    current_lr = optimizer.param_groups[0]['lr']
    logging.info(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
                 f"Train IoU: {train_iou:.4f}, Val IoU: {val_iou:.4f}, LR: {current_lr:.6f}, "
                 f"Data: {data_time:.2f}s, Forward: {forward_time:.2f}s, Time: {time.time() - start_time:.2f}s")
    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
          f"Train IoU: {train_iou:.4f}, Val IoU: {val_iou:.4f}, LR: {current_lr:.6f}, "
          f"Data: {data_time:.2f}s, Forward: {forward_time:.2f}s, Time: {time.time() - start_time:.2f}s")

    if val_iou > best_val_iou:
        best_val_iou = val_iou
        model_path = f"best_{model_configs[model_name]['save_prefix']}.pth"
        torch.save(model.state_dict(), model_path)
        logging.info(f"New best model saved at epoch {epoch+1} with Val IoU: {val_iou:.4f}")
        print(f"New best model saved at epoch {epoch+1} with Val IoU: {val_iou:.4f}")
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        model_path = f"last_{model_configs[model_name]['save_prefix']}.pth"
        torch.save(model.state_dict(), model_path)

    gc.collect()
    torch.cuda.empty_cache()


In [None]:
def evaluate_model(model, val_loader, mean_nrg, std_nrg, use_crf=False):
    start_time = time.time()
    model.eval()
    model_path = f"best_{model_configs[model_name]['save_prefix']}.pth"
    if os.path.exists(model_path):
        model.load_state_dict(torch.load(model_path))
        logging.info(f"Loaded {model_path}")
        print(f"Loaded {model_path}")
    else:
        logging.warning("No saved model")
        print("No saved model")

    y_pred, y_true, per_image_iou = [], [], []
    with torch.no_grad():
        for images, masks in val_loader:
            if images.numel() == 0 or images.size(0) <= 1:
                continue
            images, masks = images.to(device), masks.to(device)
            with autocast('cuda'):
                outputs = model(images)
            probs = torch.sigmoid(outputs).cpu().numpy()
            images_np = images.cpu().numpy()

            refined_preds = probs > 0.5 if not use_crf else np.stack([
                apply_crf_wrapper((denormalize_image(images_np[i], mean_nrg, std_nrg).transpose(1, 2, 0), probs[i].squeeze()))
                for i in range(probs.shape[0])
            ])

            best_preds = refined_preds > 0.5
            best_preds = torch.from_numpy(best_preds).to(device)

            y_pred.append(best_preds.cpu().numpy())
            y_true.append(masks.cpu().numpy())
            for i in range(images.size(0)):
                iou = jaccard_index(masks[i:i+1], best_preds[i:i+1])
                per_image_iou.append(iou.item())

            del images, masks, outputs, probs, images_np, refined_preds, best_preds
            torch.cuda.empty_cache()

    y_pred = np.concatenate(y_pred, axis=0)
    y_true = np.concatenate(y_true, axis=0)

    val_iou = jaccard_index(torch.tensor(y_true), torch.tensor(y_pred))
    val_dice = dice_coefficient(torch.tensor(y_true), torch.tensor(y_pred))
    val_accuracy = pixel_accuracy(torch.tensor(y_true), torch.tensor(y_pred))

    logging.info(f"Validation: IoU: {val_iou:.4f}, Dice: {val_dice:.4f}, Accuracy: {val_accuracy:.4f}, "
                 f"IoU (mean ± std): {np.mean(per_image_iou):.4f} ± {np.std(per_image_iou):.4f}, Time: {time.time() - start_time:.2f}s")
    print(f"Validation: IoU: {val_iou:.4f}, Dice: {val_dice:.4f}, Accuracy: {val_accuracy:.4f}, "
          f"IoU (mean ± std): {np.mean(per_image_iou):.4f} ± {np.std(per_image_iou):.4f}, Time: {time.time() - start_time:.2f}s")

    tn, fp, fn, tp = confusion_matrix(y_true.flatten(), y_pred.flatten(), labels=[0, 1]).ravel()
    logging.info(f"Confusion Matrix: TN={tn}, FP={fp}, FN={fn}, TP={tp}")
    print(f"Confusion Matrix: TN={tn}, FP={fp}, FN={fn}, TP={tp}")

    return y_pred, y_true, per_image_iou

y_pred, y_true, per_image_iou = evaluate_model(model, val_loader, mean_nrg, std_nrg, use_crf=False)


In [None]:
plt.figure(figsize=(12, 4))
for i, (images, masks) in enumerate(val_loader):
    if images.numel() == 0 or images.size(0) <= 1:
        continue
    images = images[:3].to(device)
    with torch.no_grad():
        with autocast('cuda'):
            outputs = model(images)
            probs = torch.sigmoid(outputs).cpu().numpy()
            images_np = images.cpu().numpy()

    images = images.cpu().numpy()
    masks = masks[:3].cpu().numpy()
    for j in range(min(3, len(images))):
        plt.subplot(3, 3, j*3 + 1)
        plt.imshow(denormalize_image(images[j], mean_nrg, std_nrg).transpose(1, 2, 0))
        plt.title('NRG Image')
        plt.axis('off')
        plt.subplot(3, 3, j*3 + 2)
        plt.imshow(masks[j].squeeze(), cmap='gray')
        plt.title('True Mask')
        plt.axis('off')
        plt.subplot(3, 3, j*3 + 3)
        plt.imshow(probs[j].squeeze() > 0.5, cmap='gray')
        plt.title('Predicted Mask')
        plt.axis('off')
    break
plt.tight_layout()
plt.show()

plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='Train Loss')
plt.plot(history['val_loss'], label='Val Loss')
plt.title('Loss')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(history['train_iou'], label='Train IoU')
plt.plot(history['val_iou'], label='Val IoU')
plt.title('Jaccard Index')
plt.legend()
plt.show()

np.save(f'y_pred_{model_configs[model_name]["save_prefix"]}.npy', y_pred)
np.save(f'y_true_{model_configs[model_name]["save_prefix"]}.npy', y_true)
np.save(f'per_image_iou_{model_configs[model_name]["save_prefix"]}.npy', per_image_iou)