In [None]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import TensorDataset
import torch.optim as optim
import albumentations as A
from torch.amp import GradScaler, autocast
import gc
from sklearn.metrics import confusion_matrix
import logging
from tqdm import tqdm
import time
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.nn.functional as F
from google.colab import files


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not torch.cuda.is_available():
    logging.warning("CUDA is not available. Running on CPU, which will be slow.")
    print("CUDA is not available. Running on CPU, which will be slow.")

files.upload()

!pip install -q kaggle
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

!kaggle datasets download -d sikdermdsaiful/thermal-images-for-human-detection -p /content
!unzip /content/thermal-images-for-human-detection.zip -d /content/thermal-images-for-human-detection

In [None]:
!pip install git+https://github.com/lucasb-eyer/pydensecrf.git
!pip install opencv-python numpy matplotlib scikit-learn torch torchvision albumentations scipy pydensecrf tqdm segmentation-models-pytorch
import segmentation_models_pytorch as smp
import pydensecrf.densecrf as dcrf
from pydensecrf.utils import unary_from_softmax

In [4]:
class SegNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super(SegNet, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2)
        )
        # Decoder
        self.decoder = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(64, out_channels, kernel_size=3, padding=1)
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


In [5]:
class HRNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super(HRNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 64, 3, padding=1)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.out = nn.Conv2d(64, out_channels, 1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.out(x)
        return x

In [6]:
def preprocess_infrared(image, image_path="unknown"):
    image = image.astype(np.float32)
    min_val, max_val = np.min(image), np.max(image)
    print(f"Raw image min/max for {image_path}: {min_val}, {max_val}")
    if np.any(np.isnan(image)) or np.any(np.isinf(image)):
        print(f"Warning: NaN or Inf in raw image {image_path}")
        return None

    image = cv2.medianBlur(image, 3)

    min_val = np.percentile(image, 2)
    max_val = np.percentile(image, 98)
    if max_val <= min_val:
        print(f"Warning: Invalid min/max values in {image_path}: min={min_val}, max={max_val}")
        return None
    image = (image - min_val) / (max_val - min_val + 1e-7)
    image = np.clip(image, 0, 1)

    image = (image * 255).astype(np.uint8)

    image = cv2.createCLAHE(clipLimit=1.0, tileGridSize=(8,8)).apply(image)

    image = image.astype(np.float32) / 255.0

    image = cv2.GaussianBlur(image, (3, 3), 0)

    if np.any(np.isnan(image)) or np.any(np.isinf(image)):
        print(f"Warning: NaN or Inf in preprocessed image {image_path}")
        return None

    image = np.stack([image, image, image], axis=-1)
    return image.astype(np.float32)


In [7]:

def apply_imagenet_normalization(image):
    image = image.transpose(2, 0, 1)
    image = torch.from_numpy(image).float()
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    image = (image - mean) / std
    image = image.permute(1, 2, 0).numpy()
    return image.astype(np.float32)

def load_image_and_mask(image_path, label_path, num_classes=1, target_size=(224, 224)):
    try:
        image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
        if image is None:
            print(f"Warning: Failed to load {image_path}")
            return None, None, None

        image = cv2.resize(image, target_size, interpolation=cv2.INTER_LINEAR)
        image = preprocess_infrared(image, image_path)
        if image is None:
            print(f"Warning: Preprocessing failed for {image_path}")
            return None, None, None

        mask = np.zeros((target_size[0], target_size[1]), dtype=np.float32)
        class_counts = {'Human': 0}

        if not os.path.exists(label_path):
            print(f"Warning: Annotation {label_path} does not exist")
            return None, None, None

        with open(label_path, 'r') as f:
            annotations = f.readlines()

        print(f"Image {image_path}: {len(annotations)} annotations")
        for ann in annotations:
            parts = ann.strip().split()
            if len(parts) != 5:
                print(f"Invalid annotation format in {label_path}: {ann}")
                continue

            class_id, x_center, y_center, width, height = map(float, parts)
            class_id = int(class_id)

            x_center = x_center * target_size[1]
            y_center = y_center * target_size[0]
            width = max(width * target_size[1] * 1.5, 30)
            height = max(height * target_size[0] * 1.5, 30)

            x1 = int(x_center - width / 2)
            y1 = int(y_center - height / 2)
            x2 = int(x_center + width / 2)
            y2 = int(y_center + height / 2)

            x2 = max(x2, x1 + 30)
            y2 = max(y2, y1 + 30)
            x1, y1 = max(0, x1), max(0, y1)
            x2, y2 = min(target_size[1]-1, x2), min(target_size[0]-1, y2)

            if x1 >= x2 or y1 >= y2:
                print(f"Invalid bounding box in {image_path}: ({x1}, {y1}, {x2}, {y2})")
                continue

            cv2.rectangle(mask, (x1, y1), (x2, y2), 255, -1)
            class_counts['Human'] += 1

        mask = mask / 255.0
        mask_sum = mask.sum()
        print(f"Image {image_path}: Mask pixel sum before resize: {mask_sum}")
        if mask_sum < 100:
            print(f"Warning: Empty or nearly empty mask for {image_path} (sum={mask_sum})")
            return None, None, None

        return image, mask, class_counts

    except Exception as e:
        print(f"Error processing {image_path}: {str(e)}")
        return None, None, None

In [None]:
data_dir = "/content/thermal-images-for-human-detection/dataset"
train_image_dir = os.path.join(data_dir, "train", "images")
train_label_dir = os.path.join(data_dir, "train", "labels")
test_image_dir = os.path.join(data_dir, "test", "images")
test_label_dir = os.path.join(data_dir, "test", "labels")

for directory in [train_image_dir, train_label_dir, test_image_dir, test_label_dir]:
    if not os.path.exists(directory):
        print(f"Error: Directory {directory} does not exist")
        raise FileNotFoundError(f"Directory {directory} not found")

train_image_files = sorted([f for f in os.listdir(train_image_dir) if f.endswith('.jpg')])
train_label_files = sorted([f for f in os.listdir(train_label_dir) if f.endswith('.txt')])
test_image_files = sorted([f for f in os.listdir(test_image_dir) if f.endswith('.jpg')])
test_label_files = sorted([f for f in os.listdir(test_label_dir) if f.endswith('.txt')])

print(f"Found {len(train_image_files)} training images and {len(train_label_files)} training labels")
print(f"Found {len(test_image_files)} test images and {len(test_label_files)} test labels")

train_image_files = train_image_files[:600]
print(f"Using {len(train_image_files)} images for training")

images, masks, image_labels = [], [], []
all_class_counts = {'Human': 0}
skipped_images = 0

for img_file in train_image_files:
    img_path = os.path.join(train_image_dir, img_file)
    label_file = img_file.replace('.jpg', '.txt')
    label_path = os.path.join(train_label_dir, label_file)

    image, mask, class_counts = load_image_and_mask(img_path, label_path, num_classes=1)
    if image is None or mask is None or class_counts is None:
        print(f"Failed to load image: {img_path}")
        skipped_images += 1
        continue

    mask_sum_before = mask.sum()
    mask = cv2.resize(mask, (224, 224), interpolation=cv2.INTER_NEAREST)
    mask_sum_after = mask.sum()

    images.append(image)
    masks.append(mask)
    image_labels.append(class_counts)
    for class_name, count in class_counts.items():
        all_class_counts[class_name] = all_class_counts.get(class_name, 0) + count


avg_foreground_pixels = np.mean(np.array(masks).sum(axis=(1, 2)))
print(f"Average foreground pixels per mask: {avg_foreground_pixels:.2f}")
print(f"Overall class distribution: {all_class_counts}")

target_size = (224, 224)
images = np.array(images, dtype=np.float32)
masks = np.array(masks, dtype=np.float32)

plt.figure(figsize=(15, 10))
for i in range(min(3, len(images))):
    raw_image = cv2.imread(os.path.join(train_image_dir, train_image_files[i]), cv2.IMREAD_GRAYSCALE)
    raw_image = cv2.resize(raw_image, target_size, interpolation=cv2.INTER_LINEAR)
    plt.subplot(3, 3, i*3 + 1)
    plt.imshow(raw_image, cmap='gray')
    plt.title(f'Raw Image {i+1}')
    plt.axis('off')
    plt.subplot(3, 3, i*3 + 2)
    plt.imshow(images[i][:, :, 0], cmap='gray')
    plt.title(f'Processed Image {i+1}')
    plt.axis('off')
    plt.subplot(3, 3, i*3 + 3)
    plt.imshow(masks[i], cmap='gray')
    plt.title(f'Mask (Sample {i+1})')
    plt.axis('off')
plt.show()


In [None]:
def get_augmenter():
    return A.Compose([
        A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.7),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.Rotate(limit=60, p=0.8),
        A.RandomScale(scale_limit=0.3, p=0.6),
        A.Affine(translate_percent=0.2, scale=(0.8, 1.2), rotate=(-45, 45), p=0.6),
        A.RandomGamma(gamma_limit=(90, 110), p=0.5),
        A.Resize(height=224, width=224, p=1.0)
    ], additional_targets={'mask': 'mask'})

augmenter = get_augmenter()
aug_images, aug_masks, aug_labels = [], [], []
augmentation_factor = 3

for img_idx, (img, mask, label_counts) in enumerate(zip(images, masks, image_labels)):
    aug_images.append(apply_imagenet_normalization(img))
    aug_masks.append(mask)
    aug_labels.append(label_counts)
    for aug_iter in range(augmentation_factor):
        try:
            img_clipped = np.clip(img, 0, 1)
            mask_clipped = np.clip(mask, 0, 1)
            augmented = augmenter(image=img_clipped, mask=mask_clipped)
            aug_img = augmented['image']
            aug_mask = augmented['mask']

            if aug_img.shape != (224, 224, 3) or aug_mask.shape != (224, 224):
                print(f"Invalid shape after augmentation for image {img_idx}: image={aug_img.shape}, mask={aug_mask.shape}")
                continue
            if np.any(np.isnan(aug_img)) or np.any(np.isinf(aug_img)) or np.any(np.isnan(aug_mask)) or np.any(np.isinf(aug_mask)):
                print(f"Warning: NaN or Inf in augmented data for image {img_idx}, iteration {aug_iter}")
                continue

            aug_img = np.clip(aug_img, 0, 1)
            aug_img = apply_imagenet_normalization(aug_img)
            aug_images.append(aug_img)
            aug_masks.append(aug_mask)
            aug_labels.append(label_counts)
        except Exception as e:
            print(f"Augmentation failed for image {img_idx}, iteration {aug_iter}: {e}")
            continue

aug_images = np.array(aug_images, dtype=np.float32)
aug_masks = np.array(aug_masks, dtype=np.float32)

print("Augmented images shape:", aug_images.shape)
print("Augmented masks shape:", aug_masks.shape)
print(f"Augmented mask pixel counts: {aug_masks.sum(axis=(1, 2))}")
print(f"Number of labels: {len(aug_labels)}")

all_images = aug_images
all_masks = aug_masks
all_labels = aug_labels

gc.collect()
torch.cuda.empty_cache()

def get_multi_class_labels(labels, categories):
    multi_labels = []
    for label_counts in labels:
        multi_label = [0] * len(categories)
        for idx, cat in enumerate(categories):
            if label_counts.get(cat, 0) > 0:
                multi_label[idx] = 1
        multi_labels.append(multi_label)
    return np.array(multi_labels)

categories = ['Human']
multi_class_labels = get_multi_class_labels(all_labels, categories)
X_train, X_val, y_train, y_val, train_labels, val_labels = train_test_split(
    all_images,
    np.expand_dims(all_masks, axis=-1),
    multi_class_labels,
    test_size=0.2,
    random_state=42,
    stratify=multi_class_labels
)

print(f"Train images min/max: {np.min(X_train)}, {np.max(X_train)}")
print(f"Train masks min/max: {np.min(y_train)}, {np.max(y_train)}")
print(f"Val images min/max: {np.min(X_val)}, {np.max(X_val)}")
print(f"Val masks min/max: {np.min(y_val)}, {np.max(y_val)}")
print(f"Train mask pixel sums: {y_train.sum(axis=(1, 2, 3))}")
print(f"Val mask pixel sums: {y_val.sum(axis=(1, 2, 3))}")

val_class_presence = {cat: 0 for cat in categories}
for labels in val_labels:
    for idx, present in enumerate(labels):
        if present == 1:
            val_class_presence[categories[idx]] += 1
print(f"Validation set class presence: {val_class_presence}")

print(f"Train dataset size: {len(X_train)}, Val dataset size: {len(X_val)}")
print(f"X_train shape: {X_train.shape}")
print(f"y_train shape: {y_train.shape}")
print(f"X_val shape: {X_val.shape}")
print(f"y_val shape: {y_val.shape}")
print(f"Class pixel counts: {y_train.sum()}")

assert all(img.shape == (224, 224, 3) for img in X_train), "Invalid image shape in X_train"
assert all(mask.shape == (224, 224, 1) for mask in y_train), "Invalid mask shape in y_train"

plt.figure(figsize=(15, 10))
for i in range(min(2, len(X_train))):
    plt.subplot(2, 2, i * 2 + 1)
    plt.imshow(X_train[i][:, :, 0], cmap='gray')
    plt.title(f'Train Image {i+1}')
    plt.axis('off')
    plt.subplot(2, 2, i * 2 + 2)
    plt.imshow(y_train[i].squeeze(), cmap='gray')
    plt.title(f'Train Mask')
    plt.axis('off')
plt.show()

plt.figure(figsize=(15, 10))
for i in range(min(2, len(X_val))):
    plt.subplot(2, 2, i * 2 + 1)
    plt.imshow(X_val[i][:, :, 0], cmap='gray')
    plt.title(f'Val Image {i+1}')
    plt.axis('off')
    plt.subplot(2, 2, i * 2 + 2)
    plt.imshow(y_val[i].squeeze(), cmap='gray')
    plt.title(f'Val Mask')
    plt.axis('off')
plt.show()

In [13]:
def dice_coefficient(y_true, y_pred, smooth=1e-7):
    y_true = y_true.view(y_true.size(0), y_true.size(1), -1)
    y_pred = y_pred.view(y_pred.size(0), y_pred.size(1), -1)
    intersection = (y_true * y_pred).sum(dim=2)
    dice = (2. * intersection + smooth) / (y_true.sum(dim=2) + y_pred.sum(dim=2) + smooth)
    return dice.mean()

def jaccard_index(y_true, y_pred, smooth=1e-7):
    y_true = y_true.view(y_true.size(0), y_pred.size(1), -1)
    y_pred = y_pred.view(y_true.size(0), y_pred.size(1), -1)
    intersection = (y_true * y_pred).sum(dim=2)
    union = y_true.sum(dim=2) + y_pred.sum(dim=2) - intersection
    iou = (intersection + smooth) / (union + smooth)
    return iou.mean()

def pixel_accuracy(y_true, y_pred):
    y_true_classes = (y_true > 0.5).float()
    y_pred_classes = (y_pred > 0.5).float()
    correct = (y_true_classes == y_pred_classes).float()
    return correct.mean()

class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-7, label_smoothing=0.05):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
        self.label_smoothing = label_smoothing

    def forward(self, y_pred, y_true):
        if y_true.sum() == 0:
            return torch.tensor(0.0, device=y_pred.device, requires_grad=True)
        y_pred = torch.clamp(y_pred, -1e5, 1e5)
        y_pred = torch.sigmoid(y_pred)
        y_true = y_true * (1 - 2 * self.label_smoothing) + self.label_smoothing
        y_true = y_true.view(y_true.size(0), y_true.size(1), -1)
        y_pred = y_pred.view(y_pred.size(0), y_pred.size(1), -1)
        intersection = (y_true * y_pred).sum(dim=2)
        dice = (2. * intersection + self.smooth) / (y_true.sum(dim=2) + y_pred.sum(dim=2) + self.smooth)
        return 1 - dice.mean()

class FocalLoss(nn.Module):
    def __init__(self, alpha=1.0, gamma=2.0, label_smoothing=0.05):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.label_smoothing = label_smoothing

    def forward(self, y_pred, y_true):
        if y_true.sum() == 0:
            return torch.tensor(0.0, device=y_pred.device, requires_grad=True)
        y_pred = torch.clamp(y_pred, -1e5, 1e5)
        y_true = y_true * (1 - 2 * self.label_smoothing) + self.label_smoothing
        bce = F.binary_cross_entropy_with_logits(y_pred, y_true, reduction='none')
        pt = torch.exp(-bce)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * bce
        return focal_loss.mean()


In [None]:
model_name = "Unet"
model_configs = {
    "DeepLabV3Plus": {
        "class": smp.DeepLabV3Plus,
        "save_prefix": "deeplabv3plus_hit_uav",
        "params": {"encoder_name": "resnet50", "in_channels": 3, "classes": 1}
    },
    "Unet": {
        "class": smp.Unet,
        "save_prefix": "unet_hit_uav",
        "params": {"encoder_name": "resnet50", "in_channels": 3, "classes": 1}
    },
    "PSPNet": {
        "class": smp.PSPNet,
        "save_prefix": "pspnet_hit_uav",
        "params": {"encoder_name": "resnet50", "in_channels": 3, "classes": 1}
    },
    "SegNet": {
        "class": SegNet,
        "save_prefix": "segnet_hit_uav",
        "params": {"in_channels": 3, "out_channels": 1}
    },
    "HRNet": {
        "class": HRNet,
        "save_prefix": "hrnet_hit_uav",
        "params": {"in_channels": 3, "out_channels": 1}
    }
}

if model_name not in model_configs:
    raise ValueError(f"Model {model_name} is not supported. Choose from {list(model_configs.keys())}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


In [None]:

model_config = model_configs[model_name]
model_class = model_config["class"]
model_params = model_config["params"]
model = model_class(**model_params).to(device)

def init_weights(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)
model.apply(init_weights)

X_train_torch = torch.from_numpy(X_train).permute(0, 3, 1, 2).float().to(device)
y_train_torch = torch.from_numpy(y_train).permute(0, 3, 1, 2).float().to(device)
X_val_torch = torch.from_numpy(X_val).permute(0, 3, 1, 2).float().to(device)
y_val_torch = torch.from_numpy(y_val).permute(0, 3, 1, 2).float().to(device)

train_dataset = TensorDataset(X_train_torch, y_train_torch)
val_dataset = TensorDataset(X_val_torch, y_val_torch)

batch_size = 2
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
print(f"Dataset size: {len(train_dataset)}, Batches: {len(train_loader)}")

all_masks = np.concatenate([y_train, y_val], axis=0)
foreground_pixels = all_masks.sum() / (all_masks.size * all_masks.shape[-2] * all_masks.shape[-1])
background_pixels = 1 - foreground_pixels
pos_weight = torch.tensor([min(background_pixels / (foreground_pixels + 1e-7), 50.0)]).to(device)
print(f"Pos weight for BCE: {pos_weight.item()}")

bce_loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
dice_loss = DiceLoss(smooth=1e-7, label_smoothing=0.05)
focal_loss = FocalLoss(alpha=1.0, gamma=2.0, label_smoothing=0.05)

def combined_loss(y_pred, y_true, bce_weight=0.1, dice_weight=0.5, phase="train"):
    if y_true.sum() == 0:
        print(f"Warning: Empty mask in {phase} batch, skipping loss")
        return torch.tensor(0.0, device=y_pred.device, requires_grad=True)
    bce = bce_loss(y_pred, y_true)
    dice = dice_loss(y_pred, y_true)
    focal = focal_loss(y_pred, y_true)
    if torch.isnan(bce) or torch.isinf(bce) or torch.isnan(dice) or torch.isinf(dice) or torch.isnan(focal) or torch.isinf(focal):
        print(f"Warning: Invalid loss components in {phase} - BCE: {bce.item()}, Dice: {dice.item()}, Focal: {focal.item()}")
        return torch.tensor(0.0, device=y_pred.device, requires_grad=True)
    total_loss = bce_weight * bce + dice_weight * dice + (1 - bce_weight - dice_weight) * focal
    return total_loss


In [None]:

optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, min_lr=1e-7)

num_epochs = 50
warmup_epochs = 5
base_lr = 3e-4
warmup_lr = 3e-5
best_val_iou = 0.0
patience = 15
epochs_no_improve = 0
history = {'train_loss': [], 'val_loss': [], 'train_iou': [], 'val_iou': []}

for epoch in range(num_epochs):
    model.train()
    train_loss, train_iou = 0.0, 0.0

    if epoch < warmup_epochs:
        lr = warmup_lr + (base_lr - warmup_lr) * (epoch / warmup_epochs)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    try:
        for batch_idx, (images, masks) in enumerate(train_loader):
            images, masks = images.to(device), masks.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            if isinstance(outputs, dict):
                outputs = outputs['out']
            if epoch == 0 and batch_idx == 0:
                print(f"Outputs shape: {outputs.shape}, Masks shape: {masks.shape}")

            loss = combined_loss(outputs, masks, phase="train")
            if torch.isnan(loss) or torch.isinf(loss):
                print(f"Warning: Invalid loss at epoch {epoch+1}, batch {batch_idx}: {loss.item()}")
                continue

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            train_loss += loss.item() * images.size(0)
            preds = torch.sigmoid(outputs)
            iou = jaccard_index(masks, preds)
            train_iou += iou.item() * images.size(0)
    except Exception as e:
        print(f"DataLoader error: {e}")
        raise

    train_loss /= len(train_loader.dataset)
    train_iou /= len(train_loader.dataset)

    model.eval()
    val_loss, val_iou = 0.0, 0.0
    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            if isinstance(outputs, dict):
                outputs = outputs['out']
            loss = combined_loss(outputs, masks, phase="val")

            val_loss += loss.item() * images.size(0)
            preds = torch.sigmoid(outputs)
            iou = jaccard_index(masks, preds)
            val_iou += iou.item() * images.size(0)

    val_loss /= len(val_loader.dataset)
    val_iou /= len(val_loader.dataset)

    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_iou'].append(train_iou)
    history['val_iou'].append(val_iou)

    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
          f"Train IoU: {train_iou:.4f}, Val IoU: {val_iou:.4f}, LR: {optimizer.param_groups[0]['lr']:.6f}")

    scheduler.step(val_loss)

    if val_iou > best_val_iou:
        best_val_iou = val_iou
        torch.save(model.state_dict(), f"{model_config['save_prefix']}_thermal_images_human_detection.pth")
        print(f"New best model saved at epoch {epoch+1} with Val IoU: {val_iou:.4f}")
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1

    if epochs_no_improve >= patience:
        print(f"Early stopping at epoch {epoch+1}")
        break

In [None]:
def postprocess(preds):
    preds = torch.sigmoid(preds)
    return preds.cpu().numpy()

model.eval()
y_pred = []
with torch.no_grad():
    for images, _ in val_loader:
        images = images.to(device)
        outputs = model(images)
        if isinstance(outputs, dict):
            outputs = outputs['out']
        y_pred.append(outputs.cpu())
y_pred = torch.cat(y_pred, dim=0)
y_pred = postprocess(y_pred)

y_pred_torch = torch.from_numpy(y_pred).float().to(device)
y_val_torch = y_val_torch.to(device)
val_iou = jaccard_index(y_val_torch, y_pred_torch)
val_dice = dice_coefficient(y_val_torch, y_pred_torch)
val_accuracy = pixel_accuracy(y_val_torch, y_pred_torch)

print("\nValidation Metrics:")
print(f"Jaccard Index: {val_iou:.4f}")
print(f"Dice Coefficient: {val_dice:.4f}")
print(f"Pixel Accuracy: {val_accuracy:.4f}")

plt.figure(figsize=(15, 10))
for i in range(min(3, len(X_val))):
    plt.subplot(3, 2, i * 2 + 1)
    plt.imshow(X_val[i][:, :, 0], cmap='gray')
    plt.title('Original')
    plt.axis('off')
    plt.subplot(3, 2, i * 2 + 2)
    plt.imshow(y_pred[i].squeeze(), cmap='gray')
    plt.title(f'Predicted Mask')
    plt.axis('off')
plt.show()


plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='Train Loss')
plt.plot(history['val_loss'], label='Val Loss')
plt.title('Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss Value')
plt.legend()
plt.grid(True)
plt.ylim(0, 2)
plt.subplot(1, 2, 2)
plt.plot(history['train_iou'], label='Train IoU')
plt.plot(history['val_iou'], label='Val IoU')
plt.title('IoU')
plt.xlabel('Epochs')
plt.ylabel('IoU Value')
plt.legend()
plt.grid(True)
plt.ylim(0, 1)
plt.show()