In [None]:
import torch
import torchvision
import torchvision.transforms.v2 as transforms
import PIL
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
import os
import cv2
import numpy as np
from torch.utils.data import Dataset
import matplotlib.colors as mcolors
from clearml import Task
import random

In [None]:
task = Task.init(project_name='Experimentation', task_name='exp5_ABX_noFT_aug_750')

In [None]:
dir_truth = "/home/ubuntu/Datasets/abx_data_pfai/Mask_512_256"
dir_input = "/home/ubuntu/Datasets/abx_data_pfai/Image_512_256"
dir_truth_val = "/home/ubuntu/Datasets/abx_data_pfai/Val_mask_512_256"
dir_input_val = "/home/ubuntu/Datasets/abx_data_pfai/Val_512_256"
dir_truth_test = "/home/ubuntu/Datasets/abx_data_pfai/Test_mask_512_256"
dir_input_test = "/home/ubuntu/Datasets/abx_data_pfai/Test_512_256"

In [None]:
# code to copy specific number of images to the subset directory
import shutil


dir_truth = "/home/ubuntu/Datasets/abx_data_pfai/Mask_512_256"
dir_input = "/home/ubuntu/Datasets/abx_data_pfai/Image_512_256"
output_dir_truth_subset = "/home/ubuntu/Mask_train_subset"
output_dir_input_subset = "/home/ubuntu/Image_train_subset"

# Create output directories if they don't exist
os.makedirs(output_dir_truth_subset, exist_ok=True)
os.makedirs(output_dir_input_subset, exist_ok=True)

mask_list = os.listdir(dir_truth)


subset_indices = random.sample(range(len(mask_list)), 500)
subset_masks = [mask_list[i] for i in subset_indices]

# Copy the selected masks and corresponding images to the output directories
for mask_name in subset_masks:
    mask_path = os.path.join(dir_truth, mask_name)
    image_path = os.path.join(dir_input, mask_name)

    output_mask_path = os.path.join(output_dir_truth_subset, mask_name)
    output_image_path = os.path.join(output_dir_input_subset, mask_name)

    shutil.copy(mask_path, output_mask_path)
    shutil.copy(image_path, output_image_path)

Image preprocessing

In [None]:
from PIL import Image

# Target size of each sample in the dataset
sample_size = (512, 256)

# Directories for preprocessed datasets
dir_truth_pp, dir_input_pp = (f'{d}_{sample_size[0]}_{sample_size[1]}' for d in (dir_truth, dir_input))

# Run preprocessing
for dir_full, dir_pp in ((dir_truth, dir_truth_pp), (dir_input, dir_input_pp)):
    # Check if the directory already exists
    if os.path.isdir(dir_pp):
        print(f'Preprocessed directory already exists: {dir_pp}')
        continue

    print(f'Preprocessing: {dir_full}')

    # Walk though the directory and preprocess each file 
    for root,_,files in  os.walk( dir_full ):
        if len(files) == 0:
            continue

        print(f'Preprocessing sub-directory: {root.replace(dir_full, "")}')

        # Create the directory in the preprocessed set
        root_pp = root.replace(dir_full, dir_pp)
        os.makedirs(root_pp, exist_ok=True)

        for f in files:
            if not f.endswith('.png'):
                continue

            # Resize and save PNG image
            path_original = os.path.join(root,f)
            img_resized = Image.open(path_original).resize(sample_size, Image.NEAREST)
            img_resized.save(path_original.replace(dir_full, dir_pp), 'png', quality=100)

print(f'Preprocessing done')

In [None]:
from torchvision.datasets import Cityscapes 


# train_dataset = Cityscapes('/home/ubuntu/Cityscapes/Cityscapes', split='train', mode='fine',
                    #  target_type='semantic')

# val_dataset  = Cityscapes('/home/ubuntu/Cityscapes/Cityscapes',split='val', mode='fine', target_type='semantic')


# test_dataset = Cityscapes('/home/ubuntu/Cityscapes/Cityscapes',split='test', mode='fine', target_type='semantic')

print(val_dataset)

Visualizing functions

In [None]:
custom_colormap = mcolors.ListedColormap(['#000000', '#FF0000'])

# Define a function to display an image and its mask

def overlay_transparent_mask(image, mask, transparency=0.8):
    # Convert the image and mask to float for accurate calculations
    image_float = image.astype(np.float32) / 255.0
    mask_float = mask.astype(np.float32)

    # Set the green channel to 1 where the road is present
    image_float[:, :, 1] = np.maximum(image_float[:, :, 1], mask_float * transparency)

    # Clip values to the valid range [0, 1]
    image_float = np.clip(image_float, 0.0, 1.0)

    # Convert back to uint8 for display
    overlay_image = (image_float * 255).astype(np.uint8)

    return overlay_image




def display_image_with_masks(image, gt_mask, predicted_mask, colormap):
    # Convert the predicted mask to integer values (0 or 1)
    predicted_mask = predicted_mask.astype(np.uint8)

    overlay_image_gt = overlay_transparent_mask(image, gt_mask)
    # Overlay transparent green mask on the original image
    overlay_image = overlay_transparent_mask(image, predicted_mask)

    plt.figure(figsize=(17, 10))
    
    # Display the original image
    plt.subplot(1, 3, 1)
    plt.imshow(image)
    plt.title('Image')
    

    plt.imshow(overlay_image_gt)
    plt.title('Ground truth Mask')

    plt.subplot(1, 3, 3)
    plt.imshow(overlay_image)
    plt.title('Predicted Mask')
    
    plt.tight_layout()
    plt.show()


def visualize(**images):
    """PLot images in one row."""

    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()



def apply_colormap(mask, colormap):
    # Create an RGB image from the mask using the colormap
    height, width, num_classes = mask.shape
    colorized_mask = np.zeros((height, width, 3), dtype=np.uint8)
    
    for class_id, color in colormap.items():
        class_pixels = (mask == class_id)
        colorized_mask[class_pixels] = color
    
    return colorized_mask


Cityscapes class

In [None]:
class CityscapesSearchDataset(torchvision.datasets.Cityscapes):
    def __init__(self, *args,augmentation=None, preprocessing = None,**kwargs):
        super().__init__(*args, **kwargs)
        self.semantic_target_type_index = [i for i, t in enumerate(self.target_type) if t == "semantic"][0]
        self.colormap = self._generate_colormap()
        self.augmentation = augmentation
        self.preprocessing = preprocessing

    def _generate_colormap(self):
        colormap = {}
        for class_ in self.classes:
            if class_.train_id in (-1, 255):
                continue
            colormap[class_.train_id] = class_.id
        return colormap

    # def _convert_to_segmentation_mask(self, mask):
    #     height, width = mask.shape[:2]
    #     segmentation_mask = np.zeros((height, width, len(self.colormap)), dtype=np.float32)
    #     for label_index, label in self.colormap.items():
    #         segmentation_mask[:, :, label_index] = (mask == label).astype(float)
    #     return segmentation_mask

    
    def _convert_to_segmentation_mask(self, mask):
        height, width = mask.shape[:2]
        road_class_id = 7  
        road_mask = (mask == road_class_id).astype(np.float32) 
        road_mask = np.expand_dims(road_mask, axis=-1)
        return road_mask
    
    def __getitem__(self, index):
        image = cv2.imread(self.images[index])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.targets[index][self.semantic_target_type_index], cv2.IMREAD_UNCHANGED)
        # print("image",image.shape)
        mask = self._convert_to_segmentation_mask(mask)
        
        # print("mask",mask.shape)

        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

        return image, mask

Abx custom class

In [None]:
from PIL import Image




class CustomDataset(Dataset):
    def __init__(self, image_dir, mask_dir, augmentation=None, preprocessing=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_list = os.listdir(self.image_dir)
        # print(f"Total number of images: {len(self.image_list)}")
        
        # Find indices of images with valid masks
        self.valid_indices = []
        for i, image_id in enumerate(self.image_list):
            mask_path = os.path.join(self.mask_dir, image_id)
            if os.path.exists(mask_path):
                self.valid_indices.append(i)

        # print(f"Number of images with valid masks: {len(self.valid_indices)}")

        # Use valid indices to generate file paths
        self.images_fps = [os.path.join(self.image_dir, self.image_list[i]) for i in self.valid_indices]
        self.masks_fps = [os.path.join(self.mask_dir, self.image_list[i]) for i in self.valid_indices]
        
        self.augmentation = augmentation
        self.preprocessing = preprocessing
     

    def __len__(self):
        return len(self.valid_indices)

    def __getitem__(self, i):
        index = self.valid_indices[i]
        # print(f"Index: {index}, Length of valid indices: {len(self.valid_indices)}")
        # print(f"Length of images_fps: {len(self.images_fps)}, Length of masks_fps: {len(self.masks_fps)}")
        
        image_path = self.images_fps[i]
        mask_path = self.masks_fps[i]
        # print(f"Image path: {image_path}, Mask path: {mask_path}")

        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        mask = cv2.imread(mask_path, 0).astype(np.uint8)
        mask = np.expand_dims(mask, axis=-1)
        mask = mask / 255.0
        print(image)
        # print(f"Image shape: {image.shape}, Mask shape: {mask.shape}")

        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        
        # Apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
       
             
           

        return image, mask


In [None]:
dataset = CustomDataset(dir_input, 
                        dir_truth,
                        
                        )


for i in range(1):  # Change the range to the number of samples you want to visualize
    image, mask = dataset[12]
    visualize(image = image,
              mask = mask)

In [None]:
custom_dataset = CityscapesSearchDataset(
    root='/home/ubuntu/Cityscapes/Cityscapes',  # Replace with the actual path to your dataset
    split='train',
    mode='fine',
    target_type='semantic',
)


for i in range(5):  # Change the range to the number of samples you want to visualize
    image, mask = custom_dataset[i]
    display_image_with_masks(image, mask,custom_colormap)


Data augmentation (albumentations)

In [None]:
import albumentations as albu


def get_training_augmentation():


    train_transform = [

        albu.HorizontalFlip(p=0.5),

        albu.ShiftScaleRotate(scale_limit=0.5, rotate_limit=0, shift_limit=0.1, p=1, border_mode=0),

        albu.PadIfNeeded(min_height=320, min_width=320, always_apply=True, border_mode=0),
        albu.RandomCrop(height=320, width=320, always_apply=True),
        albu.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.3, always_apply=False, p=0.5),

        albu.GaussNoise(p=0.4),
        albu.Perspective(p=0.35),
            

        albu.OneOf(
            [
                albu.CLAHE(p=1),
                albu.RandomBrightness(p=1),
                albu.RandomGamma(p=1),
            ],
            p=0.9,
        ),

        albu.OneOf(
            [
                albu.Sharpen(p=1),
                albu.Blur(blur_limit=3, p=1),
                albu.MotionBlur(blur_limit=3, p=1),
            ],
            p=0.9,
        ),

        albu.OneOf(
            [
                albu.RandomBrightnessContrast(p=1),
                albu.HueSaturationValue(p=1),
            ],
            p=0.9,
        ),
        # albu.Lambda(image=print_shape, mask= print_shape),
    ]
    return albu.Compose(train_transform)

    



def get_validation_augmentation():
    """Add paddings to make image shape divisible by 32"""
    test_transform = [
        albu.PadIfNeeded(384, 480),
        # albu.Lambda(image=print_shape, mask= print_shape)
    ]
    return albu.Compose(test_transform)






def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform
    
    Args:
        preprocessing_fn (callable): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    
    """
    
    _transform = [
        albu.Lambda(image=preprocessing_fn),
        # albu.Lambda(image=print_shape, mask= print_shape),
        albu.Lambda(image=to_tensor, mask=to_tensor),
        
    ]
    return albu.Compose(_transform)


Data augmentation (torchvision transforms)

In [None]:
from torchvision.transforms import functional as F
from torchvision import transforms as T
import random





def pad_if_smaller(img, size, fill=0):
    min_size = min(img.size)
    if min_size < size:
        ow, oh = img.size
        padh = size - oh if oh < size else 0
        padw = size - ow if ow < size else 0
        img = F.pad(img, (0, 0, padw, padh), fill=fill)
    return img


class Compose:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target


class RandomResize:
    def __init__(self, min_size, max_size=None):
        self.min_size = min_size
        if max_size is None:
            max_size = min_size
        self.max_size = max_size

    def __call__(self, image, target):
        size = random.randint(self.min_size, self.max_size)
        image = F.resize(image, size, antialias=True)
        target = F.resize(target, size, interpolation=T.InterpolationMode.NEAREST)
        return image, target


class RandomHorizontalFlip:
    def __init__(self, flip_prob):
        self.flip_prob = flip_prob

    def __call__(self, image, target):
        if random.random() < self.flip_prob:
            image = F.hflip(image)
            target = F.hflip(target)
        return image, target


class RandomCrop:
    def __init__(self, size):
        self.size = size

    def __call__(self, image, target):
        image = pad_if_smaller(image, self.size)
        target = pad_if_smaller(target, self.size, fill=255)
        crop_params = T.RandomCrop.get_params(image, (self.size, self.size))
        image = F.crop(image, *crop_params)
        target = F.crop(target, *crop_params)
        return image, target


class CenterCrop:
    def __init__(self, size):
        self.size = size

    def __call__(self, image, target):
        image = F.center_crop(image, self.size)
        target = F.center_crop(target, self.size)
        return image, target


class PILToTensor:
    def __call__(self, image, target):
        image = F.pil_to_tensor(image)
        target = torch.as_tensor(np.array(target), dtype=torch.int64)
        return image, target


class ToDtype:
    def __init__(self, dtype, scale=False):
        self.dtype = dtype
        self.scale = scale

    def __call__(self, image, target):
        if not self.scale:
            return image.to(dtype=self.dtype), target
        image = F.convert_image_dtype(image, self.dtype)
        return image, target


class Normalize:
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, image, target):
        image = F.normalize(image, mean=self.mean, std=self.std)
        return image, target
    



class ColorJitter:
    def __init__(self, brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2):
        self.transform = T.ColorJitter(
            brightness=brightness, contrast=contrast, saturation=saturation, hue=hue
        )

    def __call__(self, image, target):
        image = self.transform(image)
        return image, target





class GaussianBlur:
    def __init__(self, kernel_size, sigma=(0.1, 2.0)):
        self.kernel_size = kernel_size
        self.sigma = sigma

    def __call__(self, image, target):
        # Convert PIL image to a PyTorch tensor
        image = F.to_tensor(image)

        # Apply Gaussian blur
        image = F.gaussian_blur(image, self.kernel_size, self.sigma)

        # Convert the tensor back to a PIL image
        image = F.to_pil_image(image)

        return image, target


class RandomPerspective:
    def __init__(self, distortion_scale=0.5, p=0.5):
        self.distortion_scale = distortion_scale
        self.p = p

    def __call__(self, image, target):
        if torch.rand(1) < self.p:
            # Generate random perspective transformation parameters
            perspective_params = T.RandomPerspective.get_params(
                width=image.width,
                height=image.height,
                distortion_scale=self.distortion_scale
            )

            # Apply perspective transformation to the image
            transformed_image = F.perspective(image, *perspective_params)

            # Apply the same perspective transformation to the mask
            transformed_target = F.perspective(target, *perspective_params)

            return transformed_image, transformed_target
        else:
            return image, target

class CustomRandomApply:
    def __init__(self, transforms, p):
        self.transforms = transforms
        self.p = p

    def __call__(self, image, target):
        if random.random() < self.p:
            for transform in self.transforms:
                image, target = transform(image, target)
        return image, target


Torchvision augmentation pipeline (Used only for Experiment 1)

In [None]:

def get_training_augmentation():
  
   
    random_resize = RandomResize(min_size=320, max_size=320)
    random_horizontal_flip = RandomHorizontalFlip(flip_prob=0.5)
    random_crop = RandomCrop(size=(320))
    color_jitter = ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2)
    gaussian_blur = [GaussianBlur(kernel_size=3, sigma=0.5)]
    gaussblur = CustomRandomApply(gaussian_blur, p=0.4)
    custom_perspective = [RandomPerspective(distortion_scale=0.7)]
    perspective = CustomRandomApply(custom_perspective, p=0.5)
    
    custom_augmentations = [
        random_resize,
        random_horizontal_flip,
        random_crop,
        color_jitter,
        gaussblur,
        perspective,
       
    ]


    custom_augmentation_pipeline = Compose(custom_augmentations)

    return custom_augmentation_pipeline








def get_validation_augmentation():
    """Add paddings to make image shape divisible by 32"""
    test_transform = [
        albu.PadIfNeeded(384, 480),
        # albu.Lambda(image=print_shape, mask= print_shape)
    ]
    return albu.Compose(test_transform)


def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')



def get_preprocessing_image(preprocessing_fn):
    print("IN image  Preprocess")

    image_transform = transforms.Compose([
        transforms.Lambda(lambda x: preprocessing_fn(x)),
        transforms.Lambda(lambda x: to_tensor(x)),
    ])

    
    return image_transform



def get_preprocessing_mask(preprocessing_fn):
    print("IN mask  Preprocess")

    mask_transform = transforms.Compose([
        # transforms.Lambda(lambda x: preprocessing_fn(x)),
        transforms.Lambda(lambda x: to_tensor(x)),
    ])

    
    return mask_transform


In [None]:
dataset = CustomDataset(dir_input, 
                        dir_truth,
                        augmentation=get_training_augmentation(),
                        )


for i in range(1):  # Change the range to the number of samples you want to visualize
    image, mask = dataset[7]
    visualize(image = image,
              mask = mask)

In [None]:
custom_dataset = CityscapesSearchDataset(
    root='/home/ubuntu/Cityscapes/Cityscapes',  # Replace with the actual path to your dataset
    split='train',
    mode='fine',
    target_type='semantic',
    augmentation= get_training_augmentation()
)


for i in range(5):  # Change the range to the number of samples you want to visualize
    image, mask = custom_dataset[i]
    display_image_mask(image, mask,custom_colormap)

Model training

In [None]:
import torch
import numpy as np
import segmentation_models_pytorch as smp

In [None]:
ENCODER = 'resnet50'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['road']
ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multiclass segmentation
DEVICE = 'cuda'

# # @PipelineDecorator.component(return_values=['Training'])
# aux_params=dict(
#     pooling='avg',             # one of 'avg', 'max'
#     dropout=0.5,               # dropout ratio, default is None
#     activation=ACTIVATION,      # activation function, default is None
#     classes=len(CLASSES),
#                                    # define number of output labels
# )

# create segmentation model with pretrained encoder
model = smp.Unet(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=len(CLASSES), 
    # activation=ACTIVATION,
    # aux_params=aux_params
    
)
# Load the pre-trained weights into the model
# model.load_state_dict(torch.load("/home/ubuntu/Cityscapes_noaug_weights.pth"))

# pretrained_model = torch.load("/home/ubuntu/Cityscapes_all_Aug.pth")

# Freeze the initial layers
# freeze_layers = 20
# for param in pretrained_model.parameters():
#     param.requires_grad = False

# for param in pretrained_model.encoder[:freeze_layers].parameters():
#     param.requires_grad = True




preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

Dataloaders

In [None]:
from torch.utils.data import DataLoader


# train_dataset = CityscapesSearchDataset(
#     root='/home/ubuntu/Cityscapes/Cityscapes',  # Replace with the actual path to your dataset
#     split='train',
#     mode='fine',
#     target_type='semantic',
#     # augmentation=get_training_augmentation(), 
#     preprocessing=get_preprocessing(preprocessing_fn),
    
# )

# valid_dataset = CityscapesSearchDataset(
#     root='/home/ubuntu/Cityscapes/Cityscapes',  # Replace with the actual path to your dataset
#     split='val',
#     mode='fine',
#     target_type='semantic',
#     preprocessing=get_preprocessing(preprocessing_fn),
    
# )


train_dataset = CustomDataset(dir_input, 
                        dir_truth,
                        augmentation = get_training_augmentation(),
                        preprocessing=get_preprocessing(preprocessing_fn),
                        )


valid_dataset = CustomDataset(dir_input_val,
                              dir_truth_val,
                              preprocessing=get_preprocessing(preprocessing_fn),
                              )



train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=4, shuffle=False)

Hyperparameters

In [None]:
from segmentation_models_pytorch import utils


loss = smp.utils.losses.DiceLoss()

# loss = smp.losses.SoftBCEWithLogitsLoss()
# loss.__name__ = 'soft_bce'

# loss = smp.losses.TverskyLoss(mode = 'binary', from_logits= True)   
# loss.__name__ = 'Twersky_Loss'

# loss = smp.losses.LovaszLoss(mode = 'binary', from_logits=True)   
# loss.__name__ = 'Lovasz_Loss'

metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
]

optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.0001),
    # torch.optim.SGD( dict(params=model.parameters(), lr=0.1,momentum=0.9,weight_decay=0.0005))
])


In [None]:
task.connect(mutable = model,name='Model')

In [None]:
# create epoch runners 

train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)


# torch.save(model.state_dict(), "/home/ubuntu/all_Augment_finetunedabx.pth")

Training

In [None]:
max_score = 0

# logger = task.get_logger()
for i in range(0, 30):
    
    # print('\nEpoch: {}'.format(i))
    train_logs = train_epoch.run(train_loader)
    valid_logs = valid_epoch.run(valid_loader)
    print(train_logs)

    # logger.report_scalar('Dice_loss train', 'model_output', iteration=i, value=train_logs['dice_loss'])
    # logger.report_scalar('IoU score train', 'model_output', iteration=i, value=train_logs['iou_score'])
    # logger.report_scalar('Dice_loss valid', 'model_output', iteration=i, value=valid_logs['dice_loss'])
    # logger.report_scalar('IoU score valid', 'model_output', iteration=i, value=valid_logs['iou_score'])
    
    
    # do something (save model, change lr, etc.)
    
    if max_score < valid_logs['iou_score']:
        max_score = valid_logs['iou_score']
        torch.save(model, '/home/ubuntu/NO_FT_ABXaug_FM_750_BCE.pth')
        # torch.save(model.state_dict(), "/home/ubuntu/NO_FT_ABXaug_weights_50.pth")

        print('Model saved!')
        
    # if i == 40:
    #     optimizer.param_groups[0]['lr'] = 1e-5
    #     print('Decrease decoder learning rate to 1e-5!')

Load model

In [None]:
# load best saved checkpoint
best_model = torch.load('/home/ubuntu/NO_FT_ABXaug_FM_750_BCE.pth')

In [None]:
test_dataset = CustomDataset(dir_input_test, 
                        dir_truth_test,
                        preprocessing=get_preprocessing(preprocessing_fn),
                        )


In [None]:
# test_dataset = CityscapesSearchDataset(
#     root='/home/ubuntu/Cityscapes/Cityscapes',  # Replace with the actual path to your dataset
#     split='val',
#     mode='fine',
#     target_type='semantic',
#     preprocessing=get_preprocessing(preprocessing_fn),
    
# )

test_dataloader = DataLoader(test_dataset)

Testing

In [None]:
test_epoch = smp.utils.train.ValidEpoch(
    model=best_model,
    loss=loss,
    metrics=metrics,
    device=DEVICE,
)

logs = test_epoch.run(test_dataloader)

In [None]:
print(logs)

Display results

In [None]:
import random
from skimage.metrics import structural_similarity as ssim    

test_dataset_vis = CustomDataset(dir_input_test, 
                        dir_truth_test,
                        
                        )


indices = [770,78,909,295]

# for n in indices:
for i in range(4):
    n = np.random.choice(len(test_dataset))
   
    
    image_vis = test_dataset_vis[n][0].astype('uint8')
    # image_vis = np.array(test_dataset_vis[n][0]).astype('uint8')
    image, gt_mask = test_dataset[n]
    
    gt_mask = gt_mask.squeeze()
    
    x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
    pr_mask = best_model.predict(x_tensor)
    pr_mask = (pr_mask.squeeze().cpu().numpy().round())
    
    ssim_score = ssim(gt_mask, pr_mask, data_range = 1,multichannel=True)
    print(f"SSIM Score for Image {i + 1}: {ssim_score}")

    image, mask = test_dataset_vis[i]
    display_image_with_masks(image_vis, gt_mask, pr_mask, custom_colormap)

Results

In [None]:
# Initialize variables to accumulate predictions and ground truth masks
all_pr_masks = []
all_gt_masks = []
SSIM_scores = []
comparisons = []

# Iterate through the entire test dataset
for i in range(len(test_dataset)):
    image, gt_mask = test_dataset[i]

    gt_mask = gt_mask.squeeze()

    x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
    pr_mask = best_model.predict(x_tensor)
    pr_mask = (pr_mask.squeeze().cpu().numpy().round())
    ssim_score = ssim(gt_mask, pr_mask, data_range=1, multichannel=True)

    
    all_pr_masks.append(pr_mask)
    all_gt_masks.append(gt_mask)
    SSIM_scores.append(ssim_score)
    comparisons.append((image_vis, gt_mask, pr_mask, ssim_score))

all_pr_masks = torch.from_numpy(np.array(all_pr_masks)).float()

all_gt_masks = torch.from_numpy(np.array(all_gt_masks)).long()
print(all_gt_masks.shape)
print(all_pr_masks.shape)
print(np.mean(SSIM_scores))

In [None]:
from torchmetrics import ConfusionMatrix


tp, fp, fn, tn = smp.metrics.get_stats(all_pr_masks, all_gt_masks, mode='binary', threshold=0.5)


conf_mat = ConfusionMatrix(task="binary", num_classes=2)
conf_matrix = conf_mat(all_pr_masks, all_gt_masks)
conf_mat.plot()
plt.show()

In [None]:
from torchmetrics import Recall, Accuracy,F1Score, FBetaScore,JaccardIndex


# Calculate Intersection over Union (IoU)
iou_score = JaccardIndex(task = "binary",num_classes=2)  
iou_score.update(all_pr_masks,all_gt_masks)
iou = iou_score.compute().item()



# Calculate F1 score
f1_metric = F1Score(task="binary",num_classes=2)  
f1_metric.update(all_pr_masks,all_gt_masks)
f1_score = f1_metric.compute().item()

# Calculate F2 score
f2_metric = FBetaScore(task="binary",beta=2.0, num_classes=2)  
f2_metric.update(all_pr_masks,all_gt_masks)
f2_score = f2_metric.compute().item()

# Calculate Recall
recall_metric = Recall(task = "binary",num_classes=2)  
recall_metric.update(all_pr_masks,all_gt_masks)
recall = recall_metric.compute().item()

# Calculate Accuracy
accuracy_metric = Accuracy(task = "binary",num_classes=2)  
accuracy_metric.update(all_pr_masks,all_gt_masks)
accuracy = accuracy_metric.compute().item()

print("IoU:", iou)
print("F1 Score:", f1_score)
print("F2 Score:", f2_score)
print("Recall:", recall)
print("Accuracy:", accuracy)

Metrics for generalization

In [None]:
from sklearn.metrics import roc_curve, roc_auc_score,precision_recall_curve,auc
import matplotlib.pyplot as plt


fpr, tpr, thresholds = roc_curve(all_gt_masks.flatten(), all_pr_masks.flatten())
auroc = roc_auc_score(all_gt_masks.flatten(), all_pr_masks.flatten())


desired_tpr = 0.95  # Set the desired TPR value
index = next(i for i, value in enumerate(tpr) if value >= desired_tpr)

fpr_at_desired_tpr = fpr[index]*100

print(f"FPR at {desired_tpr * 100}% TPR: {fpr_at_desired_tpr}")



# Plot ROC Curve
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'AUROC = {auroc:.2f}')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate (FPR)')
plt.ylabel('True Positive Rate (TPR)')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc="lower right")
plt.show()

print(f'AUROC Score: {auroc:.2f}')

In [None]:
precision,recall_val, thresholds = precision_recall_curve(all_gt_masks.flatten(), all_pr_masks.flatten())
aupr = auc(recall_val,precision)
print("AUPR: " , aupr)
plt.figure(figsize=(8, 6))
plt.plot(recall_val, precision, color='blue', lw=2, label='Precision-Recall Curve (AUPR = {:.2f})'.format(aupr))
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.grid()
plt.legend(loc='lower right')
plt.xlim([0.0, 1.01])
plt.ylim([0.0, 1.05])
plt.show()

In [None]:
train_size = [0, 50, 100, 250, 500, 750]

Test_IoU_FT_aug = [0, 91, 92.8, 95, 95.9, 95.7]
Test_IoU_FT_noaug = [0, 93, 94, 95, 94.8, 95.6]
Test_IoU_noFT = [0, 92.9, 94.4, 95, 95.5, 94.8]
Test_IoU_noFTaug = [0, 93.5, 94.8, 94.65, 94.7, 94.3]

# Plotting
bar_width = 0.25
index = np.arange(len(train_size))

plt.bar(index, Test_IoU_FT_aug, color='b', width=bar_width, label='Test_IoU_FT_aug')
plt.bar(index + bar_width, Test_IoU_FT_noaug, color='g', width=bar_width, label='Test_IoU_FT_noaug')
plt.bar(index + 2*bar_width, Test_IoU_noFT, color='r', width=bar_width, label='Test_IoU_noFT')
plt.bar(index + 3*bar_width, Test_IoU_noFTaug, color='y', width=bar_width, label='Test_IoU_noFTaug')


plt.xlabel('Train Size')
plt.ylabel('Test IoU')
plt.title('Test IoU for Different Scenarios')
plt.xticks(index + 2*bar_width, train_size)
plt.legend()

plt.show()

In [None]:
task.mark_completed()