In [None]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
import albumentations as A
from torch.amp import GradScaler, autocast
import gc
from sklearn.metrics import confusion_matrix
import logging
from tqdm import tqdm
import time
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.nn.functional as F
from google.colab import files


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not torch.cuda.is_available():
    logging.warning("CUDA is not available. Running on CPU, which will be slow.")
    print("CUDA is not available. Running on CPU, which will be slow.")

files.upload()

!pip install -q kaggle
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

!kaggle datasets download -d pandrii000/hituav-a-highaltitude-infrared-thermal-dataset -p /content
!unzip /content/hituav-a-highaltitude-infrared-thermal-dataset.zip -d /content/hituav-a-highaltitude-infrared-thermal-dataset

In [None]:
!pip install git+https://github.com/lucasb-eyer/pydensecrf.git
!pip install opencv-python numpy matplotlib scikit-learn torch torchvision albumentations scipy pydensecrf tqdm segmentation-models-pytorch
import segmentation_models_pytorch as smp
import pydensecrf.densecrf as dcrf
from pydensecrf.utils import unary_from_softmax

In [4]:
class SegNet(nn.Module):
    def __init__(self, in_channels=1, num_classes=5):
        super(SegNet, self).__init__()
        # Енкодер
        self.enc1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.pool1 = nn.MaxPool2d(2, 2, return_indices=True)

        self.enc2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        self.pool2 = nn.MaxPool2d(2, 2, return_indices=True)

        self.enc3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        self.pool3 = nn.MaxPool2d(2, 2, return_indices=True)

        # Декодер
        self.unpool3 = nn.MaxUnpool2d(2, 2)
        self.dec3 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )

        self.unpool2 = nn.MaxUnpool2d(2, 2)
        self.dec2 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )

        self.unpool1 = nn.MaxUnpool2d(2, 2)
        self.dec1 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, num_classes, kernel_size=3, padding=1)
        )

    def forward(self, x):
        x, idx1 = self.pool1(self.enc1(x))
        x, idx2 = self.pool2(self.enc2(x))
        x, idx3 = self.pool3(self.enc3(x))
        x = self.unpool3(x, idx3)
        x = self.dec3(x)
        x = self.unpool2(x, idx2)
        x = self.dec2(x)
        x = self.unpool1(x, idx1)
        x = self.dec1(x)
        return x

In [5]:

class HRNet(nn.Module):
    def __init__(self, in_channels=1,num_classes=5):
        super(HRNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 64, 3, padding=1)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.out = nn.Conv2d(64, num_classes, 1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.out(x)
        return x

In [6]:
def denormalize_image(image, mean, std):
    image = image.copy()
    image = (image * std) + mean
    return np.clip(image, 0, 1)

def preprocess_infrared(image, image_path="unknown"):
    image = image.astype(np.float32) / 255.0
    if image.ndim == 2:
        image = image[..., np.newaxis]
    if image.shape[-1] != 1:
        logging.warning(f"Expected 1 channel at {image_path}, got {image.shape[-1]}")
        return None

    image_eq = cv2.equalizeHist((image.squeeze() * 255).astype(np.uint8))
    image_eq = image_eq.astype(np.float32) / 255.0
    image = np.clip(image_eq * 1.5, 0, 1)

    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    image_clahe = clahe.apply((image.squeeze() * 255).astype(np.uint8))
    image = np.clip(image_clahe.astype(np.float32) / 255.0 * 1.2, 0, 1)

    if np.any(np.isnan(image)) or np.any(np.isinf(image)):
        logging.error(f"NaN/Inf in {image_path}")
        return None
    return image[..., np.newaxis]

def apply_infrared_normalization(image, mean, std):
    image = image.transpose(2, 0, 1)
    image = torch.from_numpy(image).float()
    mean = torch.tensor(mean, dtype=torch.float32).view(-1, 1, 1)
    std = torch.tensor(std, dtype=torch.float32).view(-1, 1, 1)
    image = (image - mean) / (std + 1e-7)
    image = torch.clamp(image, -5, 5)
    if torch.any(torch.isnan(image)) or torch.any(torch.isinf(image)):
        logging.error("NaN/Inf in normalized image")
        return None
    return image

def load_image_and_mask(image_path, label_path, classes, target_size=(256, 256)):
    try:
        image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
        if image is None or image.shape[0] < 128 or image.shape[1] < 128:
            return None, None
        image = cv2.resize(image, target_size, interpolation=cv2.INTER_LINEAR)
        image = preprocess_infrared(image, image_path)
        if image is None:
            return None, None

        with open(label_path, 'r') as f:
            lines = f.readlines()
        mask = np.zeros(target_size, dtype=np.int64)
        img_height, img_width = target_size
        for line in lines:
            data = list(map(float, line.strip().split()))
            class_id = int(data[0])
            if class_id >= len(classes):
                logging.warning(f"Invalid class_id {class_id} in {label_path}")
                continue
            x_center, y_center, width, height = data[1:5]
            x_min = int((x_center - width / 2) * img_width)
            x_max = int((x_center + width / 2) * img_width)
            y_min = int((y_center - height / 2) * img_height)
            y_max = int((y_center + height / 2) * img_height)
            x_min, x_max = max(0, x_min), min(img_width, x_max)
            y_min, y_max = max(0, y_min), min(img_height, y_max)
            mask[y_min:y_max, x_min:x_max] = class_id + 1

        if mask.sum() < 5:
            with open("empty_masks.log", "a") as f:
                f.write(f"Empty mask for {image_path} (sum={mask.sum()})\n")
            return None, None
        return image, mask
    except Exception as e:
        logging.error(f"Error in {image_path}: {str(e)}")
        return None, None

In [7]:
def preprocess_and_save(image_paths, label_paths, mean, std, save_dir, classes):
    os.makedirs(save_dir, exist_ok=True)
    valid_indices = []
    images_to_display = []
    masks_to_display = []
    display_limit = 3

    for i, (img_path, lbl_path) in enumerate(tqdm(zip(image_paths, label_paths), total=len(image_paths))):
        image, mask = load_image_and_mask(img_path, lbl_path, classes)
        if image is None or mask is None:
            continue

        if len(images_to_display) < display_limit:
            images_to_display.append(image.squeeze())
            masks_to_display.append(mask)

        image = apply_infrared_normalization(image, mean, std)
        if image is None:
            continue
        mask = torch.from_numpy(mask).long()
        np.save(os.path.join(save_dir, f"image_{i}.npy"), image.numpy(), allow_pickle=False)
        np.save(os.path.join(save_dir, f"mask_{i}.npy"), mask.numpy(), allow_pickle=False)
        valid_indices.append(i)

    plt.figure(figsize=(15, 5))
    classes_with_bg = ['Background', 'Person', 'Bicycle', 'Car', 'OtherVehicle']
    colors = ['#000000', '#FF0000', '#00FF00', '#0000FF', '#FFFF00']
    cmap = plt.cm.colors.ListedColormap(colors)
    bounds = range(6)
    norm = plt.cm.colors.BoundaryNorm(bounds, cmap.N)

    for j in range(len(images_to_display)):
        plt.subplot(2, 3, j + 1)
        plt.imshow(images_to_display[j], cmap='gray')
        plt.title('Raw Infrared Image')
        plt.axis('off')

        plt.subplot(2, 3, j + 4)
        mask_display = masks_to_display[j]
        im = plt.imshow(mask_display, cmap=cmap, norm=norm, alpha=0.8)
        plt.title('Ground Truth Mask')
        plt.axis('off')
        cbar = plt.colorbar(im, ticks=range(5), shrink=0.5)
        cbar.ax.set_yticklabels(classes_with_bg)
        cbar.set_label('Classes')

    plt.tight_layout()
    plt.show()

    return valid_indices


In [8]:
class HITUAVDataset(Dataset):
    def __init__(self, preprocessed_dir, mean, std, augmenter=None):
        self.image_files = sorted([f for f in os.listdir(preprocessed_dir) if f.startswith("image_")])
        self.mask_files = sorted([f for f in os.listdir(preprocessed_dir) if f.startswith("mask_")])
        self.preprocessed_dir = preprocessed_dir
        self.mean = mean
        self.std = std
        self.augmenter = augmenter

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image = np.load(os.path.join(self.preprocessed_dir, self.image_files[idx]), mmap_mode='r')
        mask = np.load(os.path.join(self.preprocessed_dir, self.mask_files[idx]), mmap_mode='r')
        image = torch.from_numpy(image.copy()).float()
        mask = torch.from_numpy(mask.copy()).long()
        if self.augmenter:
            img_1ch = denormalize_image(image.numpy(), self.mean, self.std)[0]
            mask_np = mask.numpy()
            augmented = self.augmenter(image=img_1ch, mask=mask_np)
            aug_img, aug_mask = augmented['image'], augmented['mask']
            if aug_mask.sum() >= 5:
                aug_img = aug_img[..., np.newaxis]
                image = apply_infrared_normalization(aug_img, self.mean, self.std)
                mask = torch.from_numpy(aug_mask).long()
        return image, mask

In [None]:
data_dir = "/content/hituav-a-highaltitude-infrared-thermal-dataset/hit-uav"
image_dir = os.path.join(data_dir, "images")
label_dir = os.path.join(data_dir, "labels")

for directory in [image_dir, label_dir]:
    if not os.path.exists(directory):
        raise FileNotFoundError(f"Directory {directory} not found")

splits = ['train', 'val']
image_paths = []
label_paths = []
for split in splits:
    split_image_dir = os.path.join(image_dir, split)
    split_label_dir = os.path.join(label_dir, split)
    label_files = sorted([f for f in os.listdir(split_label_dir) if f.endswith('.txt')])
    for lbl_file in label_files:
        img_file = lbl_file.replace('.txt', '.jpg')
        img_path = os.path.join(split_image_dir, img_file)
        lbl_path = os.path.join(split_label_dir, lbl_file)
        if os.path.exists(img_path):
            image_paths.append(img_path)
            label_paths.append(lbl_path)

image_paths = image_paths[:200]
label_paths = label_paths[:200]

mean_ir = np.array([0.5])
std_ir = np.array([0.2])
logging.info(f"Infrared mean: {mean_ir}, std: {std_ir}")
print(f"Infrared mean: {mean_ir}, std: {std_ir}")

classes = ['Person', 'Bicycle', 'Car', 'OtherVehicle']

valid_indices = preprocess_and_save(image_paths, label_paths, mean_ir, std_ir, "preprocessed_data_hit_uav", classes)

In [10]:
def get_augmenter():
    return A.Compose([
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.Resize(height=256, width=256, p=1.0)
    ], additional_targets={'mask': 'mask'})

train_idx, val_idx = train_test_split(range(len(valid_indices)), test_size=0.25, random_state=42)
train_dataset = HITUAVDataset("preprocessed_data_hit_uav", mean_ir, std_ir, get_augmenter())
val_dataset = HITUAVDataset("preprocessed_data_hit_uav", mean_ir, std_ir, None)

def custom_collate_fn(batch):
    batch = [b for b in batch if b[0] is not None]
    if len(batch) <= 1:
        return torch.tensor([]), torch.tensor([])
    images, masks = zip(*batch)
    return torch.stack(images), torch.stack(masks)

batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, collate_fn=custom_collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True, collate_fn=custom_collate_fn)


In [11]:
def dice_coefficient(y_true, y_pred, num_classes, smooth=1e-7):
    dice_scores = []
    for cls in range(num_classes):
        y_true_cls = (y_true == cls).float()
        y_pred_cls = (y_pred == cls).float()
        intersection = (y_true_cls * y_pred_cls).sum()
        score = (2. * intersection + smooth) / (y_true_cls.sum() + y_pred_cls.sum() + smooth)
        dice_scores.append(score.item())
    return np.mean(dice_scores)

def jaccard_index(y_true, y_pred, num_classes, smooth=1e-7):
    iou_scores = []
    for cls in range(num_classes):
        y_true_cls = (y_true == cls).float()
        y_pred_cls = (y_pred == cls).float()
        intersection = (y_true_cls * y_pred_cls).sum()
        union = y_true_cls.sum() + y_pred_cls.sum() - intersection
        score = (intersection + smooth) / (union + smooth)
        iou_scores.append(score.item())
    return np.mean(iou_scores)

def pixel_accuracy(y_true, y_pred):
    return (y_true == y_pred).float().mean()

In [14]:
model_name = "PSPNet"
model_configs = {
    "DeepLabV3Plus": {
        "class": smp.DeepLabV3Plus,
        "save_prefix": "deeplabv3plus_hit_uav",
        "params": {"encoder_name": "resnet50", "in_channels": 1, "classes": 5}
    },
    "Unet": {
        "class": smp.Unet,
        "save_prefix": "unet_hit_uav",
        "params": {"encoder_name": "resnet50", "in_channels": 1, "classes": 5}
    },
    "PSPNet": {
        "class": smp.PSPNet,
        "save_prefix": "pspnet_hit_uav",
        "params": {"encoder_name": "resnet50", "in_channels": 1, "classes": 5}
    },
    "SegNet": {
        "class": SegNet,
        "save_prefix": "segnet_hit_uav",
        "params": {"in_channels": 1, "out_channels": 5}
    },
    "HRNet": {
        "class": HRNet,
        "save_prefix": "hrnet_hit_uav",
        "params": {"in_channels": 1, "out_channels": 5}
    }
}

In [None]:
if model_name not in model_configs:
    raise ValueError(f"Model {model_name} is not supported. Choose from {list(model_configs.keys())}")


def get_segmentation_model(model_name):

    if model_name not in model_configs:
        raise ValueError(f"Model {model_name} not found in model_configs")

    model_class = model_configs[model_name]["class"]
    model_params = model_configs[model_name]["params"].copy()

    if model_name in ["DeepLabV3Plus", "Unet", "PSPNet"]:
        model_params["encoder_weights"] = "imagenet"
        model_params["activation"] = None
        return model_class(**model_params)

    elif model_name in ["SegNet", "HRNet"]:
        return model_class(in_channels=model_params["in_channels"], num_classes=model_params["out_channels"])

    else:
        raise ValueError(f"Unsupported model type: {model_name}")

num_classes = 5
model = get_segmentation_model(model_name).to(device)

In [None]:
class_weights = torch.tensor([0.1, 2.0, 2.0, 2.0, 2.0], device=device)
loss_fn = nn.CrossEntropyLoss(weight=class_weights)
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6)
scaler = GradScaler('cuda' if torch.cuda.is_available() else 'cpu')
accumulation_steps = 4
warmup_epochs = 2
num_epochs = 500
best_val_iou = 0.0
patience = 3
history = {'train_loss': [], 'val_loss': [], 'train_iou': [], 'val_iou': []}

for epoch in range(num_epochs):
    start_time = time.time()
    model.train()
    train_loss, train_iou, train_count = 0.0, 0.0, 0
    optimizer.zero_grad(set_to_none=True)
    data_time = 0.0
    forward_time = 0.0
    for i, (images, masks) in enumerate(train_loader):
        data_start = time.time()
        if images.numel() == 0 or images.size(0) <= 1:
            continue
        images, masks = images.to(device), masks.to(device)
        data_time += time.time() - data_start

        forward_start = time.time()
        with autocast('cuda' if torch.cuda.is_available() else 'cpu'):
            outputs = model(images)
            loss = loss_fn(outputs, masks) / accumulation_steps
        if torch.isnan(loss) or torch.isinf(loss):
            continue
        scaler.scale(loss).backward()
        if (i + 1) % accumulation_steps == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)
        forward_time += time.time() - forward_start

        train_loss += loss.item() * accumulation_steps * images.size(0)
        preds = torch.argmax(outputs, dim=1).detach()
        train_iou += jaccard_index(masks, preds, num_classes=5).item() * images.size(0)
        del images, masks, outputs, preds
        torch.cuda.empty_cache() if torch.cuda.is_available() else None

    if train_count > 0:
        train_loss /= train_count
        train_iou /= train_count

    if epoch < warmup_epochs:
        lr = 1e-3 * (epoch + 1) / warmup_epochs
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    else:
        scheduler.step()

    val_loss, val_iou = train_loss, train_iou
    if (epoch + 1) % 5 == 0:
        model.eval()
        val_loss, val_iou, val_count = 0.0, 0.0, 0
        with torch.no_grad():
            for images, masks in val_loader:
                if images.numel() == 0 or images.size(0) <= 1:
                    continue
                images, masks = images.to(device), masks.to(device)
                with autocast('cuda' if torch.cuda.is_available() else 'cpu'):
                    outputs = model(images)
                    loss = loss_fn(outputs, masks)
                val_loss += loss.item() * images.size(0)
                preds = torch.argmax(outputs, dim=1).detach()
                val_iou += jaccard_index(masks, preds, num_classes=5).item() * images.size(0)  # Оновлено до 5 класів
                val_count += images.size(0)
                del images, masks, outputs, preds
                torch.cuda.empty_cache() if torch.cuda.is_available() else None
        if val_count > 0:
            val_loss /= val_count
            val_iou /= val_count

    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_iou'].append(train_iou)
    history['val_iou'].append(val_iou)

    current_lr = optimizer.param_groups[0]['lr']
    logging.info(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
                 f"Train IoU: {train_iou:.4f}, Val IoU: {val_iou:.4f}, LR: {current_lr:.6f}, "
                 f"Data: {data_time:.2f}s, Forward: {forward_time:.2f}s, Time: {time.time() - start_time:.2f}s")
    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
          f"Train IoU: {train_iou:.4f}, Val IoU: {val_iou:.4f}, LR: {current_lr:.6f}, "
          f"Data: {data_time:.2f}s, Forward: {forward_time:.2f}s, Time: {time.time() - start_time:.2f}s")

    if val_iou > best_val_iou:
        best_val_iou = val_iou
        save_path = f"best_{model_configs[model_name]['save_prefix']}.pth"
        torch.save(model.state_dict(), save_path)
        logging.info(f"New best {model_name} model saved at epoch {epoch+1} with Val IoU: {val_iou:.4f}")
        print(f"New best {model_name} model saved at epoch {epoch+1} with Val IoU: {val_iou:.4f}")
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        save_path = f"last_{model_configs[model_name]['save_prefix']}.pth"
        torch.save(model.state_dict(), save_path)

    gc.collect()
    torch.cuda.empty_cache() if torch.cuda.is_available() else None

In [None]:
def evaluate_model(model, val_loader, mean_ir, std_ir, use_crf=True):
    start_time = time.time()
    model.eval()
    model_path = f"best_{model_configs[model_name]['save_prefix']}.pth"
    if os.path.exists(model_path):
        model.load_state_dict(torch.load(model_path))
        logging.info(f"Loaded {model_path}")
        print(f"Loaded {model_path}")
    else:
        logging.warning("No saved model")
        print("No saved model")

    y_pred, y_true, per_image_iou = [], [], []
    with torch.no_grad():
        for images, masks in val_loader:
            if images.numel() == 0 or images.size(0) <= 1:
                continue
            images, masks = images.to(device), masks.to(device)
            with autocast('cuda' if torch.cuda.is_available() else 'cpu'):
                outputs = model(images)
            probs = torch.softmax(outputs, dim=1).cpu().numpy()
            images_np = images.cpu().numpy()

            refined_preds = np.argmax(probs, axis=1)

            refined_preds = torch.from_numpy(refined_preds).to(device)

            y_pred.append(refined_preds.cpu().numpy())
            y_true.append(masks.cpu().numpy())
            for i in range(images.size(0)):
                iou = jaccard_index(masks[i:i+1], refined_preds[i:i+1], num_classes=5)
                per_image_iou.append(iou.item())

            del images, masks, outputs, probs, images_np, refined_preds
            torch.cuda.empty_cache() if torch.cuda.is_available() else None

    y_pred = np.concatenate(y_pred, axis=0)
    y_true = np.concatenate(y_true, axis=0)

    val_iou = jaccard_index(torch.tensor(y_true), torch.tensor(y_pred), num_classes=5)
    val_dice = dice_coefficient(torch.tensor(y_true), torch.tensor(y_pred), num_classes=5)
    val_accuracy = pixel_accuracy(torch.tensor(y_true), torch.tensor(y_pred))

    logging.info(f"Validation: IoU: {val_iou:.4f}, Dice: {val_dice:.4f}, Accuracy: {val_accuracy:.4f}, "
                 f"IoU (mean ± std): {np.mean(per_image_iou):.4f} ± {np.std(per_image_iou):.4f}, Time: {time.time() - start_time:.2f}s")
    print(f"Validation: IoU: {val_iou:.4f}, Dice: {val_dice:.4f}, Accuracy: {val_accuracy:.4f}, "
          f"IoU (mean ± std): {np.mean(per_image_iou):.4f} ± {np.std(per_image_iou):.4f}, Time: {time.time() - start_time:.2f}s")

    cm = confusion_matrix(y_true.flatten(), y_pred.flatten(), labels=list(range(5)))
    logging.info(f"Confusion Matrix:\n{cm}")
    print(f"Confusion Matrix:\n{cm}")

    return y_pred, y_true, per_image_iou
y_pred, y_true, per_image_iou = evaluate_model(model, val_loader, mean_ir, std_ir, use_crf=True)


In [None]:
batch_idx = 1
image_indices = [8, 1, 15]

plt.figure(figsize=(15, 5))
for i, (images, masks) in enumerate(val_loader):
    if i != batch_idx:
        continue

    images = images[image_indices].to(device)
    with torch.no_grad():
        with autocast('cuda' if torch.cuda.is_available() else 'cpu'):
            outputs = model(images)
            probs = torch.softmax(outputs, dim=1).cpu().numpy()
            refined_preds = np.argmax(probs, axis=1)
            images_np = images.cpu().numpy()

    images = images.cpu().numpy()
    masks = masks[image_indices].cpu().numpy()
    classes_with_bg = ['Background', 'Person', 'Car','Bicycle',  'OtherVehicle']
    colors = ['#000000', '#FF0000', '#00FF00', '#0000FF', '#FFFF00']
    cmap = plt.cm.colors.ListedColormap(colors)
    bounds = range(6)
    norm = plt.cm.colors.BoundaryNorm(bounds, cmap.N)

    for j, idx in enumerate(image_indices):
        plt.subplot(len(image_indices), 3, j*3 + 1)
        img_display = denormalize_image(images[j], mean_ir, std_ir)[0]
        plt.imshow(img_display, cmap='gray')
        plt.title(f'Image {idx} in Batch {batch_idx}')
        plt.axis('off')

        plt.subplot(len(image_indices), 3, j*3 + 2)
        mask_display = masks[j]
        im = plt.imshow(mask_display, cmap=cmap, norm=norm, alpha=0.8)
        plt.title('Ground Truth')
        plt.axis('off')
        cbar = plt.colorbar(im, ticks=range(5), shrink=0.5)
        cbar.ax.set_yticklabels(classes_with_bg)

        plt.subplot(len(image_indices), 3, j*3 + 3)
        pred_display = refined_preds[j]
        im = plt.imshow(pred_display, cmap=cmap, norm=norm, alpha=0.8)
        plt.title('Prediction')
        plt.axis('off')
        cbar = plt.colorbar(im, ticks=range(5), shrink=0.5)
        cbar.ax.set_yticklabels(classes_with_bg)

    break

plt.tight_layout()
plt.show()