# Part-1: Exploratory Data Analysis of Satellite Imagery

In [None]:
import torch
import torchvision
import torch.nn as nn
import torchvision.transforms as T
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision.io import read_image
from torchvision.transforms.functional import to_pil_image, center_crop
from torchvision.utils import draw_segmentation_masks, make_grid
from sklearn.model_selection import train_test_split, KFold, StratifiedKFold
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os
from PIL import Image
from webcolors import rgb_to_name
from tqdm import tqdm
%matplotlib inline

In [None]:
!pip install webcolors -qq

In [None]:
base_dir = "/kaggle/input/deepglobe-land-cover-classification-dataset"
train_dir = os.path.join(base_dir, "train")
test_dir = os.path.join(base_dir, "test")
valid_dir = os.path.join(base_dir, "valid")

## There data cosists of 3 dirs(train, test, valid) along with a .csv with the classes and another with the annotations.

In [None]:
classes = pd.read_csv(os.path.join(base_dir, "class_dict.csv"))
annotations = pd.read_csv(os.path.join(base_dir, "metadata.csv"))

In [None]:
image_ids = annotations[annotations["split"] == "train"]["image_id"].values
transformed_masks_path = "/kaggle/input/processed-masks/train_masks"

In [None]:
plt.rcParams["savefig.bbox"] = "tight"
def show(imgs):
    """Helper function to show images in the torch format"""
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(100, 100))
    for i, img in enumerate(imgs):
        img = img.detach()
        img = to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

In [None]:
# mapping from class indices to RGB
class_to_rgb = {}
for idx, row in classes.iterrows():
    class_to_rgb[row[0]] = row[1:].to_list()
class_colors = [tuple(x) for x in class_to_rgb.values()]
num_classes = len(class_colors)

In [None]:
# number of files
print(f"Number of files in the train dir: {len(os.listdir(train_dir))}")
print(f"Number of files in the test dir : {len(os.listdir(test_dir))}")
print(f"Number of files in the valid dir: {len(os.listdir(valid_dir))}")

The satellite images are in `.jpg` format while the mask in `.png`

In [None]:
# there are png and jpg files
print(os.listdir(train_dir)[:10])

The test and valid directories don't contain masks:

In [None]:
def list_ext(dir_path, ext):
    n_files = len([name for name in os.listdir(dir_path) if name.endswith(ext)])
    print(f"Files with '{ext}' extension in {dir_path} dir: {n_files}")


list_ext(train_dir, "jpg")
list_ext(train_dir, "png")
list_ext(test_dir, "jpg")
list_ext(test_dir, "png")
list_ext(valid_dir, "jpg")
list_ext(valid_dir, "png")

The `annotations` file contains the annotations of the images, including the sattelite and mask paths and the splits

In [None]:
annotations

The `classes` file contains the mappings between the land types and the pixel colors

In [None]:
classes

In [None]:
# map classes to RGB values
class_to_rgb = {}
for idx, row in classes.iterrows():
    class_to_rgb[row[0]] = row[1:].to_list()
class_colors = [tuple(x) for x in class_to_rgb.values()]
class_colors

In [None]:
# mapping from land type to colors
{name: rgb_to_name(class_to_rgb[name]) for name in class_to_rgb.keys()}

In [None]:
# dimmensions of the images
sat_img = read_image(os.path.join(train_dir, f"{image_ids[0]}_sat.jpg"))
mask = read_image(os.path.join(train_dir, f"{image_ids[0]}_mask.png"))
print(f"satellite image dimensions: {sat_img.shape}")
print(f"mask image dimensions: {mask.shape}")

The pixels of the segmentation mask are represented in RGB colors, this is not ideal so we will transform it into values between 0 to the number of classes.

In [None]:
def transform_mask(mask, class_colors):
    """
    Transforms a mask from RGB format to a tensor of shape (H, W)
    where a pixel represents the class label
    """
    num_classes = len(class_colors)
    h, w = mask.shape[1:]  # shape expected to be (C, H, W)
    semantic_map = torch.zeros((h, w), dtype=torch.uint8)
    # iterate over all the classes
    for idx, color in enumerate(class_colors):
        color = torch.tensor(color).view(3, 1, 1)  # rgb value
        class_map = torch.all(torch.eq(mask, color), 0)
        semantic_map[class_map] = idx
    return semantic_map


def ohe_mask(mask, num_classes):
    """Turns a label tensor of shape(H, W) to (num_classes, H, W)"""
    return torch.permute(
        F.one_hot(mask.type(torch.long), num_classes=num_classes).type(torch.bool),
        (2, 0, 1),
    )

In [None]:
def process_masks(image_ids, class_colors, in_dir, out_dir):
    for image_id in image_ids:
        mask = read_image(os.path.join(in_dir, f"{image_id}_mask.png"))
        to_pil_image(transform_mask(mask, class_colors)).save(
            os.path.join(out_dir, f"{image_id}_mask.png")
        )

In [None]:
def plot_sample(n=5, class_idx=None):
    # lets plot a few samples
    sample = []
    for i in range(n):
        while True:
            id = np.random.choice(image_ids, 1)[0]
            classes_img = np.unique(
                read_image(os.path.join(transformed_masks_path, f"{id}_mask.png"))
            )
            if class_idx is None:
                sample.append(id)
                break
            elif class_idx in classes_img:
                sample.append(id)
                break
    imgs = []
    for img_id in sample:
        sat_img = read_image(os.path.join(train_dir, f"{img_id}_sat.jpg"))
        masks = ohe_mask(
            read_image(
                os.path.join(transformed_masks_path, f"{img_id}_mask.png")
            ).squeeze(),
            len(class_colors),
        )
        mask_over_image = draw_segmentation_masks(
            sat_img, masks=masks, alpha=0.5, colors=class_colors
        )
        imgs.extend([sat_img, mask_over_image])

    display({name: rgb_to_name(class_to_rgb[name]) for name in class_to_rgb.keys()})
    show(make_grid(imgs, nrow=2))

In [None]:
# Define the path and id
transformed_masks_path = '/kaggle/input/processed-masks/train_masks'  # replace with your actual path
id = '103730'  # replace with your actual id

# Construct the file path
file_path = os.path.join(transformed_masks_path, f"{id}_mask.png")
print(file_path)  # print to check the constructed file path

# Check if file exists before reading
if os.path.exists(file_path):
    classes_img = np.unique(read_image(file_path))
    print(classes_img)  # print to check the unique classes in the image
else:
    print(f"File not found: {file_path}")

In [None]:
plot_sample(class_idx=0)

In [None]:
# let's try to count the number of pixels of each class
counts = torch.zeros((7,))
for image_id in image_ids:
    mask = read_image(os.path.join(transformed_masks_path, f"{image_id}_mask.png"))
    counts += torch.bincount(mask.view(-1), minlength=7)

In [None]:
(counts / counts.sum()).tolist()

In [None]:
plt.figure(figsize=(10, 5))
plt.bar(class_to_rgb.keys(), counts / counts.sum())
plt.title("Percentage of pixels per class in the dataset")
plt.show()

# Part-2: Model Training and Validation

In [None]:
#using [DeepLabV3+](https://arxiv.org/abs/1802.02611) for Land Cover Classfication from Satellite Imagery using [DeepGlobe Land Cover Classification Dataset](https://www.kaggle.com/balraj98/deepglobe-land-cover-classification-dataset).

In [None]:
import os, cv2
import numpy as np
import pandas as pd
import random, tqdm
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline

import warnings
warnings.filterwarnings("ignore")

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import albumentations as album

In [None]:
pip install -U segmentation-models-pytorch==0.2.0

In [None]:
import segmentation_models_pytorch as smp
from segmentation_models_pytorch import utils, deeplabv3

In [None]:
!pip show segmentation_models_pytorch
!pip show albumentations

### Read Data & Create train / valid splits 📁

In [None]:
DATA_DIR = '../input/deepglobe-land-cover-classification-dataset'

metadata_df = pd.read_csv(os.path.join(DATA_DIR, 'metadata.csv'))
metadata_df = metadata_df[metadata_df['split']=='train']
metadata_df = metadata_df[['image_id', 'sat_image_path', 'mask_path']]
metadata_df['sat_image_path'] = metadata_df['sat_image_path'].apply(lambda img_pth: os.path.join(DATA_DIR, img_pth))
metadata_df['mask_path'] = metadata_df['mask_path'].apply(lambda img_pth: os.path.join(DATA_DIR, img_pth))
# Shuffle DataFrame
metadata_df = metadata_df.sample(frac=1).reset_index(drop=True)

# Perform 90/10 split for train / val
valid_df = metadata_df.sample(frac=0.1, random_state=42)
train_df = metadata_df.drop(valid_df.index)
len(train_df), len(valid_df)

In [None]:
class_dict = pd.read_csv(os.path.join(DATA_DIR, 'class_dict.csv'))
# Get class names
class_names = class_dict['name'].tolist()
# Get class RGB values
class_rgb_values = class_dict[['r','g','b']].values.tolist()

print('All dataset classes and their corresponding RGB values in labels:')
print('Class Names: ', class_names)
print('Class RGB values: ', class_rgb_values)

#### Shortlist specific classes to segment

In [None]:
# Useful to shortlist specific classes in datasets with large number of classes
select_classes = ['urban_land', 'agriculture_land', 'rangeland', 'forest_land', 'water', 'barren_land', 'unknown']

# Get RGB values of required classes
select_class_indices = [class_names.index(cls.lower()) for cls in select_classes]
select_class_rgb_values =  np.array(class_rgb_values)[select_class_indices]

print('Selected classes and their corresponding RGB values in labels:')
print('Class Names: ', class_names)
print('Class RGB values: ', class_rgb_values)

### Helper functions for viz. & one-hot encoding/decoding

In [None]:
# helper function for data visualization
def visualize(**images):
    """
    Plot images in one row
    """
    n_images = len(images)
    plt.figure(figsize=(20,8))
    for idx, (name, image) in enumerate(images.items()):
        plt.subplot(1, n_images, idx + 1)
        plt.xticks([]); 
        plt.yticks([])
        # get title from the parameter names
        plt.title(name.replace('_',' ').title(), fontsize=20)
        plt.imshow(image)
    plt.show()

# Perform one hot encoding on label
def one_hot_encode(label, label_values):
    """
    Convert a segmentation image label array to one-hot format
    by replacing each pixel value with a vector of length num_classes
    # Arguments
        label: The 2D array segmentation image label
        label_values
        
    # Returns
        A 2D array with the same width and hieght as the input, but
        with a depth size of num_classes
    """
    semantic_map = []
    for colour in label_values:
        equality = np.equal(label, colour)
        class_map = np.all(equality, axis = -1)
        semantic_map.append(class_map)
    semantic_map = np.stack(semantic_map, axis=-1)

    return semantic_map
    
# Perform reverse one-hot-encoding on labels / preds
def reverse_one_hot(image):
    """
    Transform a 2D array in one-hot format (depth is num_classes),
    to a 2D array with only 1 channel, where each pixel value is
    the classified class key.
    # Arguments
        image: The one-hot format image 
        
    # Returns
        A 2D array with the same width and hieght as the input, but
        with a depth size of 1, where each pixel value is the classified 
        class key.
    """
    x = np.argmax(image, axis = -1)
    return x

# Perform colour coding on the reverse-one-hot outputs
def colour_code_segmentation(image, label_values):
    """
    Given a 1-channel array of class keys, colour code the segmentation results.
    # Arguments
        image: single channel array where each value represents the class key.
        label_values

    # Returns
        Colour coded image for segmentation visualization
    """
    colour_codes = np.array(label_values)
    x = colour_codes[image.astype(int)]

    return x

In [None]:
class LandCoverDataset(torch.utils.data.Dataset):

    """DeepGlobe Land Cover Classification Challenge Dataset. Read images, apply augmentation and preprocessing transformations.
    
    Args:
        df (str): DataFrame containing images / labels paths
        class_rgb_values (list): RGB values of select classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline 
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing 
            (e.g. noralization, shape manipulation, etc.)
    
    """
    def __init__(
            self, 
            df,
            class_rgb_values=None, 
            augmentation=None, 
            preprocessing=None,
    ):
        self.image_paths = df['sat_image_path'].tolist()
        self.mask_paths = df['mask_path'].tolist()
        
        self.class_rgb_values = class_rgb_values
        self.augmentation = augmentation
        self.preprocessing = preprocessing
    
    def __getitem__(self, i):
        
        # read images and masks
        image = cv2.cvtColor(cv2.imread(self.image_paths[i]), cv2.COLOR_BGR2RGB)
        mask = cv2.cvtColor(cv2.imread(self.mask_paths[i]), cv2.COLOR_BGR2RGB)
        
        # one-hot-encode the mask
        mask = one_hot_encode(mask, self.class_rgb_values).astype('float')
        
        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
            
        return image, mask
        
    def __len__(self):
        # return length of 
        return len(self.image_paths)

#### Visualize Sample Image and Mask 📈

In [None]:
dataset = LandCoverDataset(train_df, class_rgb_values=select_class_rgb_values)
random_idx = random.randint(0, len(dataset)-1)
image, mask = dataset[2]

visualize(
    original_image = image,
    ground_truth_mask = colour_code_segmentation(reverse_one_hot(mask), select_class_rgb_values),
    one_hot_encoded_mask = reverse_one_hot(mask)
)

### Defining Augmentations 🙃

In [None]:
def get_training_augmentation():
    train_transform = [
        album.RandomCrop(height=1024, width=1024, always_apply=True),
        album.HorizontalFlip(p=0.5),
        album.VerticalFlip(p=0.5),
    ]
    return album.Compose(train_transform)


def get_validation_augmentation():
    train_transform = [
        album.CenterCrop(height=1024, width=1024, always_apply=True),
    ]
    return album.Compose(train_transform)


def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing(preprocessing_fn=None):
    """Construct preprocessing transform    
    Args:
        preprocessing_fn (callable): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    """
    _transform = []
    if preprocessing_fn:
        _transform.append(album.Lambda(image=preprocessing_fn))
    _transform.append(album.Lambda(image=to_tensor, mask=to_tensor))
        
    return album.Compose(_transform)

#### Visualize Augmented Images & Masks

In [None]:
augmented_dataset = LandCoverDataset(
    train_df, 
    augmentation=get_training_augmentation(),
    class_rgb_values=select_class_rgb_values,
)

random_idx = random.randint(0, len(augmented_dataset)-1)

# Different augmentations on image/mask pairs
for idx in range(3):
    image, mask = augmented_dataset[idx]
    visualize(
        original_image = image,
        ground_truth_mask = colour_code_segmentation(reverse_one_hot(mask), select_class_rgb_values),
        one_hot_encoded_mask = reverse_one_hot(mask)
    )

## Training DeepLabV3+

<h3><center>DeepLabV3+ Model Architecture</center></h3>
<img src="https://miro.medium.com/max/1000/1*2mYfKnsX1IqCCSItxpXSGA.png" width="750" height="750"/>
<h4></h4>
<h4><center><a href="https://arxiv.org/abs/1802.02611">Image Source: DeepLabV3+ [Liang-Chieh Chen et al.]</a></center></h4>

### Model Definition

In [None]:
import torch
import segmentation_models_pytorch as smp

ENCODER = 'resnet50'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = select_classes
ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multiclass segmentation

# create segmentation model with pretrained encoder
model = smp.DeepLabV3Plus(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=len(CLASSES), 
    activation=ACTIVATION,
)

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

# Print model 
print(model)

#### Get Train / Val DataLoaders

In [None]:
# Get train and val dataset instances
train_dataset = LandCoverDataset(
    train_df, 
    augmentation=get_training_augmentation(),
    preprocessing=get_preprocessing(preprocessing_fn),
    class_rgb_values=select_class_rgb_values,
)

valid_dataset = LandCoverDataset(
    valid_df, 
    augmentation=get_validation_augmentation(),
    preprocessing=get_preprocessing(preprocessing_fn),
    class_rgb_values=select_class_rgb_values,
)

# Get train and val data loaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=2)
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=4)

#### Set Hyperparams

In [None]:
# Set flag to train the model or not. If set to 'False', only prediction is performed (using an older model checkpoint)
TRAINING = True

# Set num of epochs
EPOCHS = 5

# Set device: `cuda` or `cpu`
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# define loss function
loss = smp.utils.losses.DiceLoss()

# define metrics
metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
]

# define optimizer
optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.00008),
])

# define learning rate scheduler (not used in this NB)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, T_0=1, T_mult=2, eta_min=5e-5,
)

# load best saved model checkpoint from previous commit (if present)
if os.path.exists('../input/deepglobe-land-cover-classification-deeplabv3/best_model.pth'):
    model = torch.load('../input/deepglobe-land-cover-classification-deeplabv3/best_model.pth', map_location=DEVICE)
    print('Loaded pre-trained DeepLabV3+ model!')

In [None]:
train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

### Training DeepLabV3+

In [None]:
%%time

if TRAINING:

    best_iou_score = 0.0
    train_logs_list, valid_logs_list = [], []

    for i in range(0, EPOCHS):

        # Perform training & validation
        print('\nEpoch: {}'.format(i))
        train_logs = train_epoch.run(train_loader)
        valid_logs = valid_epoch.run(valid_loader)
        train_logs_list.append(train_logs)
        valid_logs_list.append(valid_logs)

        # Save model if a better val IoU score is obtained
        if best_iou_score < valid_logs['iou_score']:
            best_iou_score = valid_logs['iou_score']
            torch.save(model, './best_model.pth')
            print('Model saved!')

### Prediction on Test Data

In [None]:
# load best saved model checkpoint from the current run
if os.path.exists('./best_model.pth'):
    best_model = torch.load('./best_model.pth', map_location=DEVICE)
    print('Loaded DeepLabV3+ model from this run.')

# load best saved model checkpoint from previous commit (if present)
elif os.path.exists('../input/deepglobe-land-cover-classification-deeplabv3/best_model.pth'):
    best_model = torch.load('../input/deepglobe-land-cover-classification-deeplabv3/best_model.pth', map_location=DEVICE)
    print('Loaded DeepLabV3+ model from a previous commit.')

In [None]:
# create test dataloader to be used with DeepLabV3+ model (with preprocessing operation: to_tensor(...))
test_dataset = LandCoverDataset(
    valid_df, 
    augmentation=get_validation_augmentation(),
    preprocessing=get_preprocessing(preprocessing_fn),
    class_rgb_values=select_class_rgb_values,
)

test_dataloader = DataLoader(test_dataset)

# test dataset for visualization (without preprocessing augmentations & transformations)
test_dataset_vis = LandCoverDataset(
    train_df,
    augmentation=get_validation_augmentation(),
    class_rgb_values=select_class_rgb_values,
)

# get a random test image/mask index
random_idx = random.randint(0, len(test_dataset_vis)-1)
image, mask = test_dataset_vis[random_idx]

visualize(
    original_image = image,
    ground_truth_mask = colour_code_segmentation(reverse_one_hot(mask), select_class_rgb_values),
    one_hot_encoded_mask = reverse_one_hot(mask)
)


In [None]:
sample_preds_folder = 'sample_predictions/'
if not os.path.exists(sample_preds_folder):
    os.makedirs(sample_preds_folder)

In [None]:
best_model = model

In [None]:
for idx in range(len(test_dataset)):

    image, gt_mask = test_dataset[idx]
    image_vis = test_dataset_vis[idx][0].astype('uint8')
    x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
    # Predict test image
    pred_mask = best_model(x_tensor)
    pred_mask = pred_mask.detach().squeeze().cpu().numpy()
    # Convert pred_mask from `CHW` format to `HWC` format
    pred_mask = np.transpose(pred_mask,(1,2,0))
    # Get prediction channel corresponding to foreground
    pred_urban_land_heatmap = pred_mask[:,:,select_classes.index('urban_land')]
    pred_mask = colour_code_segmentation(reverse_one_hot(pred_mask), select_class_rgb_values)
    # Convert gt_mask from `CHW` format to `HWC` format
    gt_mask = np.transpose(gt_mask,(1,2,0))
    gt_mask = colour_code_segmentation(reverse_one_hot(gt_mask), select_class_rgb_values)
    cv2.imwrite(os.path.join(sample_preds_folder, f"sample_pred_{idx}.png"), np.hstack([image_vis, gt_mask, pred_mask])[:,:,::-1])
    
    visualize(
        original_image = image_vis,
        ground_truth_mask = gt_mask,
        predicted_mask = pred_mask,
        pred_urban_land_heatmap = pred_urban_land_heatmap
    )

### Model Evaluation on Test Dataset

In [None]:
test_epoch = smp.utils.train.ValidEpoch(
    model,
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

valid_logs = test_epoch.run(test_dataloader)
print("Evaluation on Test Data: ")
print(f"Mean IoU Score: {valid_logs['iou_score']:.4f}")
print(f"Mean Dice Loss: {valid_logs['dice_loss']:.4f}")

### Plot Dice Loss & IoU Metric for Train vs. Val

In [None]:
train_logs_df = pd.DataFrame(train_logs_list)
valid_logs_df = pd.DataFrame(valid_logs_list)
train_logs_df.T

In [None]:
plt.figure(figsize=(20,8))
plt.plot(train_logs_df.index.tolist(), train_logs_df.iou_score.tolist(), lw=3, label = 'Train')
plt.plot(valid_logs_df.index.tolist(), valid_logs_df.iou_score.tolist(), lw=3, label = 'Valid')
plt.xlabel('Epochs', fontsize=20)
plt.ylabel('IoU Score', fontsize=20)
plt.title('IoU Score Plot', fontsize=20)
plt.legend(loc='best', fontsize=16)
plt.grid()
plt.savefig('iou_score_plot.png')
plt.show()

In [None]:
plt.figure(figsize=(20,8))
plt.plot(train_logs_df.index.tolist(), train_logs_df.dice_loss.tolist(), lw=3, label = 'Train')
plt.plot(valid_logs_df.index.tolist(), valid_logs_df.dice_loss.tolist(), lw=3, label = 'Valid')
plt.xlabel('Epochs', fontsize=20)
plt.ylabel('Dice Loss', fontsize=20)
plt.title('Dice Loss Plot', fontsize=20)
plt.legend(loc='best', fontsize=16)
plt.grid()
plt.savefig('dice_loss_plot.png')
plt.show()

# method 02: Landcover Classification on DeepGlobe with NestedUnet

## Install & Import Dependencies

In [None]:
!pip install -q git+https://github.com/PyTorchLightning/pytorch-lightning
!pip install -q git+https://github.com/qubvel/segmentation_models.pytorch
!pip install -q git+https://github.com/albumentations-team/albumentations
!pip install -q torchinfo

In [None]:
import os
import cv2
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import pytorch_lightning as pl
from torch import nn
from tqdm.notebook import tqdm

## Extract Data

In [None]:
IMAGE_SIZE = 320
BATCH_SIZE = 16
EPOCHS = 5

color_dict = pd.read_csv('../input/deepglobe-land-cover-classification-dataset/class_dict.csv')
CLASSES = color_dict['name']
print(color_dict)

In [None]:
from glob import glob
from sklearn.utils import shuffle

pd_dataset = pd.DataFrame({
    'IMAGES': sorted(glob("../input/deepglobe-land-cover-classification-dataset/train/*.jpg")), 
    'MASKS': sorted(glob("../input/deepglobe-land-cover-classification-dataset/train/*.png"))
})
pd_dataset = shuffle(pd_dataset)
pd_dataset.reset_index(inplace=True, drop=True)
pd_dataset.head()

In [None]:
from sklearn.model_selection import train_test_split

pd_train, pd_test = train_test_split(pd_dataset, test_size=0.25, random_state=0)
pd_train, pd_val = train_test_split(pd_train, test_size=0.2, random_state=0)

print("Training set size:", len(pd_train))
print("Validation set size:", len(pd_val))
print("Testing set size:", len(pd_test))

In [None]:
index = 200

sample_img = cv2.imread(pd_train.iloc[index].IMAGES)
sample_img = cv2.cvtColor(sample_img, cv2.COLOR_BGR2RGB)

sample_msk = cv2.imread(pd_train.iloc[index].MASKS)
sample_msk = cv2.cvtColor(sample_msk, cv2.COLOR_BGR2RGB)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 10))

ax1.set_title('IMAGE')
ax1.imshow(sample_img)

ax2.set_title('MASK')
ax2.imshow(sample_msk)

## Utility Functions

In [None]:
def rgb2category(rgb_mask):
    category_mask = np.zeros(rgb_mask.shape[:2], dtype=np.int8)
    for i, row in color_dict.iterrows():
        category_mask += (np.all(rgb_mask.reshape((-1, 3)) == (row['r'], row['g'], row['b']), axis=1).reshape(rgb_mask.shape[:2]) * i)
    return category_mask

def category2rgb(category_mask):
    rgb_mask = np.zeros(category_mask.shape[:2] + (3,))
    for i, row in color_dict.iterrows():
        rgb_mask[category_mask==i] = (row['r'], row['g'], row['b'])
    return np.uint8(rgb_mask)

## Data Augmentations & Transformations

In [None]:
import albumentations as aug

train_augment = aug.Compose([
    aug.Resize(IMAGE_SIZE, IMAGE_SIZE),
    aug.HorizontalFlip(p=0.5),
    aug.VerticalFlip(p=0.5),
    aug.RandomBrightnessContrast(p=0.3)
])

test_augment = aug.Compose([
    aug.Resize(IMAGE_SIZE, IMAGE_SIZE),
    aug.RandomBrightnessContrast(p=0.3)
])

## Create PyTorch Dataset

In [None]:
from torch.utils.data import Dataset, DataLoader

class SegmentationDataset(Dataset):
    def __init__(self, df, augmentations=None):
        self.df = df
        self.augmentations = augmentations

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index: int):
        row = self.df.iloc[index]

        image = cv2.imread(row.IMAGES)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        mask = cv2.imread(row.MASKS)
        mask = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.augmentations:
            data = self.augmentations(image=image, mask=mask)
            image = data['image']
            mask = data['mask']
        
        mask = rgb2category(mask)

        image = np.transpose(image, (2, 0, 1)).astype(np.float64)
        mask = np.expand_dims(mask, axis=0)

        image = torch.Tensor(image) / 255.0
        mask = torch.Tensor(mask).long()

        return image, mask

In [None]:
class SegmentationDataModule(pl.LightningDataModule):
    def __init__(self, pd_train, pd_val, pd_test, batch_size=10):
        super().__init__()
        self.pd_train = pd_train
        self.pd_val = pd_val
        self.pd_test = pd_test
        self.batch_size=batch_size

    def setup(self, stage=None):
        self.train_dataset = SegmentationDataset(self.pd_train, train_augment)
        self.val_dataset = SegmentationDataset(self.pd_val, test_augment)
        self.test_dataset = SegmentationDataset(self.pd_test, test_augment)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=2)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size // 2, shuffle=False, num_workers=1)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size // 2, shuffle=False, num_workers=1)

In [None]:
data_module = SegmentationDataModule(pd_train, pd_val, pd_test, batch_size=BATCH_SIZE)
data_module.setup()

In [None]:
image, mask = next(iter(data_module.train_dataloader()))
image.shape, mask.shape

## Build Loss and Model

In [None]:
from segmentation_models_pytorch import UnetPlusPlus
from segmentation_models_pytorch.losses import DiceLoss
from segmentation_models_pytorch.metrics import get_stats, iou_score, accuracy, precision, recall, f1_score

class SegmentationModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = UnetPlusPlus(
            encoder_name="timm-regnety_120",
            encoder_weights="imagenet",
            in_channels=3,
            classes=len(CLASSES),
            activation="softmax"
        )
        self.criterion = DiceLoss(mode="multiclass", from_logits=False)
    
    def forward(self, inputs, targets=None):
        outputs = self.model(inputs)
        if targets is not None:
            loss = self.criterion(outputs, targets)
            tp, fp, fn, tn = get_stats(outputs.argmax(dim=1).unsqueeze(1).type(torch.int64), targets, mode='multiclass', num_classes=len(CLASSES))
            metrics = {
                "Accuracy": accuracy(tp, fp, fn, tn, reduction="micro-imagewise"),
                "IoU": iou_score(tp, fp, fn, tn, reduction="micro-imagewise"),
                "Precision": precision(tp, fp, fn, tn, reduction="micro-imagewise"),
                "Recall": recall(tp, fp, fn, tn, reduction="micro-imagewise"),
                "F1score": f1_score(tp, fp, fn, tn, reduction="micro-imagewise")
            }
            return loss, metrics, outputs
        else: 
            return outputs

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=0.0001)

    def training_step(self, batch, batch_idx):
        images, masks = batch

        loss, metrics, outputs = self(images, masks)
        self.log_dict({
            "train/Loss": loss,
            "train/IoU": metrics['IoU'],
            "train/Accuracy": metrics['Accuracy'],
            "train/Precision": metrics['Precision'],
            "train/Recall": metrics['Recall'],
            "train/F1score": metrics['F1score']
        }, prog_bar=True, logger=True, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images, masks = batch

        loss, metrics, outputs = self(images, masks)
        self.log_dict({
            "val/Loss": loss,
            "val/IoU": metrics['IoU'],
            "val/Accuracy": metrics['Accuracy'],
            "val/Precision": metrics['Precision'],
            "val/Recall": metrics['Recall'],
            "val/F1score": metrics['F1score']
        }, prog_bar=True, logger=True, on_step=False, on_epoch=True)
        return loss

    def test_step(self, batch, batch_idx):
        images, masks = batch

        loss, metrics, outputs = self(images, masks)
        self.log_dict({
            "test/Loss": loss,
            "test/IoU": metrics['IoU'],
            "test/Accuracy": metrics['Accuracy'],
            "test/Precision": metrics['Precision'],
            "test/Recall": metrics['Recall'],
            "test/F1score": metrics['F1score']
        }, prog_bar=True, logger=True, on_step=False, on_epoch=True)
        return loss

In [None]:
from torchinfo import summary

model = SegmentationModel()
summary(model, input_size=(BATCH_SIZE, 3, IMAGE_SIZE, IMAGE_SIZE))

## Train the Model

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import CSVLogger

checkpoint_callback = ModelCheckpoint(
    dirpath="checkpoints",
    filename="best-checkpoint",
    save_top_k=1,
    verbose=True,
    monitor="val/F1score",
    mode="min"
)

logger = CSVLogger("lightning_logs", name="landcover-classification-log")

early_stopping_callback = EarlyStopping(monitor="val/Accuracy", patience=5)

trainer = pl.Trainer(
    logger=logger,
    log_every_n_steps=31,
    callbacks=[checkpoint_callback, early_stopping_callback],
    max_epochs=EPOCHS,
    accelerator="gpu",
    devices=1
)

In [None]:
trainer.fit(model, data_module)

## Test the Model

In [None]:
trainer.test(model, data_module)

In [None]:
metrics = pd.read_csv("./lightning_logs/landcover-classification-log/version_0/metrics.csv")
fig, ((ax1, ax2, ax3), (ax4, ax5, ax6)) = plt.subplots(2, 3, figsize=(18, 10))

axes = [ax1, ax2, ax3, ax4, ax5, ax6]
names = ['Loss', 'IoU', 'Accuracy', 'Precision', 'Recall', 'F1score']

for axis, name in zip(axes, names):
    axis.plot(metrics[f'train/{name}'].dropna())
    axis.plot(metrics[f'val/{name}'].dropna())
    axis.set_title(f'{name}: Train/Val')
    axis.set_ylabel(name)
    axis.set_xlabel('Epoch')
    ax1.legend(['training', 'validation'], loc="upper right")

# Unet++ Inferencing bench marking

Installing the necessary packages

In [None]:
!pip install segmentation-models-pytorch
!pip install pytorch-lightning
!pip install GPUtil

Making the imports

In [None]:
import torch
import os
import ssl
import segmentation_models_pytorch as smp
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
import multiprocessing as mp
from torch import nn
from segmentation_models_pytorch.encoders import get_preprocessing_fn
from GPUtil import showUtilization as gpu_usage
import pandas as pd
from pytorch_lightning.loggers import CSVLogger

Assuring that the GPU has it memory free

In [None]:
torch.cuda.empty_cache()
gpu_usage()

Defining global variables

In [None]:
ENCODER = 'efficientnet-b3'
ENCODER_WEIGHTS = 'imagenet'
CHANNELS = 3
CLASSES = 7
ACTIVATION = 'sigmoid'
LR = 0.001
BATCH_SIZE = 8
IMG_SIZE = (512,512)
DEVICE, NUM_DEVICES = ("cuda", torch.cuda.device_count()) if torch.cuda.is_available() else ("cpu", mp.cpu_count())
WORKERS = mp.cpu_count()
print(f'Running on {NUM_DEVICES} {DEVICE}(s)')

Determining the training, validation and test sets using pandas

In [None]:
DATA_DIR = '../input/deepglobe-land-cover-classification-dataset'

metadata_df = pd.read_csv(os.path.join(DATA_DIR, 'metadata.csv'))
metadata_df = metadata_df[metadata_df['split']=='train']
metadata_df = metadata_df[['image_id', 'sat_image_path', 'mask_path']]
metadata_df['sat_image_path'] = metadata_df['sat_image_path'].apply(lambda img_pth: os.path.join(DATA_DIR, img_pth))
metadata_df['mask_path'] = metadata_df['mask_path'].apply(lambda img_pth: os.path.join(DATA_DIR, img_pth))
# Shuffle DataFrame
metadata_df = metadata_df.sample(frac=1).reset_index(drop=True)

# Get 90% to train, 10% to validate and 10% to test
valid_df = metadata_df.sample(frac=0.1, random_state=42)
minus_valid_df = metadata_df.drop(valid_df.index)
test_df = minus_valid_df.sample(frac=0.1, random_state=42)
train_df = minus_valid_df.drop(test_df.index)

len(train_df), len(valid_df), len(test_df)

Determining class names and rgb related colors

In [None]:
class_df = pd.read_csv(os.path.join(DATA_DIR, 'class_dict.csv'))
# Get class names
class_names = class_df['name'].tolist()
# Get class RGB values
class_rgb_values = class_df[['r','g','b']].values.tolist()

print('All dataset classes and their corresponding RGB values in labels:')
print('Class Names: ', class_names)
print('Class RGB values: ', class_rgb_values)

Helper functions to perform the one hot transformation on the masks. This is needed because the masks are multiclass rgb files.

In [None]:
# Perform one hot encoding on label
def one_hot_encode(label, label_values):
    """
    Convert a segmentation image label array to one-hot format
    by replacing each pixel value with a vector of length num_classes
    # Arguments
        label: The 2D array segmentation image label
        label_values
        
    # Returns
        A 2D array with the same width and hieght as the input, but
        with a depth size of num_classes
    """
    semantic_map = []
    for colour in label_values:
        equality = np.equal(label, colour)
        class_map = np.all(equality, axis = -1)
        semantic_map.append(class_map)
    semantic_map = np.stack(semantic_map, axis=-1)

    return semantic_map
    
# Perform reverse one-hot-encoding on labels / preds
def reverse_one_hot(image):
    """
    Transform a 2D array in one-hot format (depth is num_classes),
    to a 2D array with only 1 channel, where each pixel value is
    the classified class key.
    # Arguments
        image: The one-hot format image 
        
    # Returns
        A 2D array with the same width and hieght as the input, but
        with a depth size of 1, where each pixel value is the classified 
        class key.
    """
    x = np.argmax(image, axis = 0)
    return x

# Perform colour coding on the reverse-one-hot outputs
def colour_code_segmentation(image, label_values):
    """
    Given a 1-channel array of class keys, colour code the segmentation results.
    # Arguments
        image: single channel array where each value represents the class key.
        label_values

    # Returns
        Colour coded image for segmentation visualization
    """
    colour_codes = np.array(label_values)
    x = colour_codes[image.astype(int)]

    return x

Implementing the Segmentation Model Class (subclassing pl.LightningModule)

In [None]:
class SegmentationModel(pl.LightningModule):
    def __init__(self, net, loss, lr):
        super().__init__()
        self.save_hyperparameters()
        self.lr = lr
        self.loss = loss
        self.net = net

    def forward(self, x):
        return self.net(x)

    def shared_step(self, preds, labels):
        # first compute statistics for true positives, false positives, false negative and
        # true negative "pixels"
        tp, fp, fn, tn = smp.metrics.get_stats(preds, labels.long(), mode='binary', threshold=0.5)

        # then compute metrics with required reduction (see metric docs)
        iou_score = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
        f1_score = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro")
        accuracy = smp.metrics.accuracy(tp, fp, fn, tn, reduction="macro")
        recall = smp.metrics.recall(tp, fp, fn, tn, reduction="micro-imagewise")

        return {'iou_score': iou_score,
                'f1_score': f1_score,
                'accuracy': accuracy,
                'recall': recall}

    def training_step(self, batch, batch_idx):
        # "batch" is the output of the training data loader.
        imgs, labels = batch
        preds = self(imgs)

        metrics = self.shared_step(preds, labels)

        self.log('train_iou_score', metrics['iou_score'], on_epoch=True)     
        self.log('train_f1_score', metrics['f1_score'], on_epoch=True)     
        self.log('train_accuracy', metrics['accuracy'], on_epoch=True)     
        self.log('train_recall', metrics['recall'], on_epoch=True)     

        loss = self.loss(preds, labels)
        self.log("train_loss", loss)
        return loss  # Return tensor to call ".backward" on

    def validation_step(self, batch, batch_idx):
        imgs, labels = batch
        preds = self(imgs)

        metrics = self.shared_step(preds, labels)

        self.log('valid_iou_score', metrics['iou_score'], on_epoch=True)     
        self.log('valid_f1_score', metrics['f1_score'], on_epoch=True)     
        self.log('valid_accuracy', metrics['accuracy'], on_epoch=True)     
        self.log('valid_recall', metrics['recall'], on_epoch=True)

        loss = self.loss(preds, labels)
        self.log("valid_loss", loss)             

    def test_step(self, batch, batch_idx):
        imgs, labels = batch
        preds = self(imgs)

        metrics = self.shared_step(preds, labels)

        self.log('test_iou_score', metrics['iou_score'], on_epoch=True)     
        self.log('test_f1_score', metrics['f1_score'], on_epoch=True)     
        self.log('test_accuracy', metrics['accuracy'], on_epoch=True)     
        self.log('test_recall', metrics['recall'], on_epoch=True)

        loss = self.loss(preds, labels)
        self.log("test_loss", loss) 
        
    def configure_optimizers(self):
        opt = torch.optim.Adam(self.net.parameters(), lr=self.lr)
        return [opt]

Implementing the DeepGlobe Dataset Class (subclassing Dataset)

In [None]:
class DeepGlobeSet(Dataset):
    def __init__(self, df, transform=None, preprocess_fn=None, class_rgb_values=None):
        self.image_paths = df['sat_image_path'].tolist()
        self.mask_paths = df['mask_path'].tolist()
        self.transform = transform
        self.preprocess_fn = preprocess_fn
        self.class_rgb_values = class_rgb_values

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_file = self.image_paths[idx]
        mask_file = self.mask_paths[idx]

        img = Image.open(img_file).convert('RGB')
        img = np.array(img)

        mask = Image.open(mask_file).convert('RGB')
        # one-hot-encode the mask
        mask = one_hot_encode(mask, self.class_rgb_values).astype('float')

        if self.preprocess_fn:
            img = self.preprocess_fn(img)
            img = np.array(img, dtype=np.float32)

        if self.transform:
            img = self.transform(img)
            mask = self.transform(mask)

        return img, mask

Implementing the DeepGlobe DataModule Class (subclassing pl.LightningDataModule)

In [None]:
class DeepGlobeDataModule(pl.LightningDataModule):
    def __init__(self, train_df, valid_df, test_df, batch_size=2, img_size=(256,256), preprocess_fn=None, class_rgb_values=None):
        super().__init__()
        self.train_df = train_df
        self.valid_df = valid_df
        self.test_df = test_df
        self.batch_size = batch_size
        self.img_size = img_size
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(size=self.img_size),
        ])
        self.preprocess_fn = preprocess_fn
        self.class_rgb_values = class_rgb_values

    def setup(self, stage=None):
        self.trainset = DeepGlobeSet(self.train_df, transform=self.transform, preprocess_fn=self.preprocess_fn, class_rgb_values=self.class_rgb_values)
        self.validset = DeepGlobeSet(self.valid_df, transform=self.transform, preprocess_fn=self.preprocess_fn, class_rgb_values=self.class_rgb_values)
        self.testset = DeepGlobeSet(self.test_df, transform=self.transform, preprocess_fn=self.preprocess_fn, class_rgb_values=self.class_rgb_values)

    def train_dataloader(self):
        return DataLoader(self.trainset, batch_size=self.batch_size, shuffle=True, num_workers=WORKERS, pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.validset, batch_size=self.batch_size, shuffle=False, num_workers=WORKERS, pin_memory=True)

    def test_dataloader(self):
        return DataLoader(self.testset, batch_size=self.batch_size, shuffle=False, num_workers=WORKERS, pin_memory=True)

* Defining the net that will be used in our model (Unet++)
* Defining the loss function
* Defining the optimizer

In [None]:
ssl._create_default_https_context = ssl._create_unverified_context

net = smp.UnetPlusPlus(
    encoder_name=ENCODER,                # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights=ENCODER_WEIGHTS,     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=CHANNELS,                       # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=CLASSES,                           # model output channels (number of classes in your dataset)
    activation= ACTIVATION                  
)

preprocess_input = get_preprocessing_fn(ENCODER, pretrained=ENCODER_WEIGHTS)

LOSS = nn.CrossEntropyLoss()

Defining our trainer

In [None]:
LR = 0.001
EPOCHS = 27
OUTPUT_DIR = '/kaggle/working/'

early_stop_callback = EarlyStopping(
    monitor='valid_loss', 
    min_delta=0.00001, 
    patience=5, 
    mode='min')

checkpoint_callback = ModelCheckpoint(
    every_n_epochs=1,
    dirpath=OUTPUT_DIR,
    filename='lightning_trained'
)

logger = CSVLogger(OUTPUT_DIR, name='lightning_logs')

trainer = pl.Trainer(
        accelerator=DEVICE,
        devices=NUM_DEVICES,
        max_epochs=EPOCHS,
        callbacks=[early_stop_callback, checkpoint_callback],
        logger=logger,
    )

Instantiating the model and the data module

In [None]:
segmodel = SegmentationModel(net, LOSS, LR)

deepglobeData = DeepGlobeDataModule(train_df, valid_df, test_df, BATCH_SIZE, IMG_SIZE, preprocess_input, class_rgb_values)

Running our training based on a previously saved checkpoint if available, otherwise we start from scratch.
This is also useful to load our model from a checkpoint.

In [None]:
CHECKPOINT_DIR = '../input/savedckpt/'
checkpoint_file = os.path.join(CHECKPOINT_DIR, 'lightning_trained-v1.ckpt') 

if os.path.isfile(checkpoint_file):
    print('Resuming training from previous checkpoint...')
    trainer.fit(segmodel, datamodule=deepglobeData,
                ckpt_path=checkpoint_file)
else:
    print('Starting training from scratch...')
    trainer.fit(segmodel, datamodule=deepglobeData)

Loading our checkpoint and testing it.

In [None]:
# CHECKPOINT_DIR = '../input/saved-checkpoints/'
# checkpoint_file = os.path.join(CHECKPOINT_DIR, 'lightning_trained-v1.ckpt')

# segmodel = SegmentationModel.load_from_checkpoint(checkpoint_file)

# deepglobeData = DeepGlobeDataModule(train_df, valid_df, test_df, 
#                                     BATCH_SIZE, IMG_SIZE, 
#                                     preprocess_input, class_rgb_values)

trainer.test(segmodel, datamodule=deepglobeData)

Helper function to visualize the data

In [None]:
# helper function for data visualization
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

Instantiating a test dataset for visualization (i.e without preprocessing function)

In [None]:
# Testeset sem o preprocessamento para fins de visualização 
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(size=IMG_SIZE),
])

testset = DeepGlobeSet(test_df, transform=transform, preprocess_fn=preprocess_input, class_rgb_values=class_rgb_values)
testsetvis = DeepGlobeSet(test_df, transform=transform, preprocess_fn=None, class_rgb_values=class_rgb_values)

reverse_transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(size=IMG_SIZE),
        ])

Visualizing the predictions

In [None]:
for i in range(20):
    n = np.random.choice(len(testset))
    
    image, gt_mask = testset[n]
    imagevis = testsetvis[n][0]
    
    # Creating a batch of size 1 (unsqueeze(0)) for image
    pred = segmodel.forward(image.unsqueeze(0))
    
    # Getting pred tensor from gpu to cpu, squeezing it and casting it to numpy
    pred = pred.cpu().detach().squeeze().numpy()
    
    # Getting gt_mask tensor from gpu to cpu and casting to uint8
    gt_mask = gt_mask.cpu().numpy().astype(np.uint8)
    
    # reversing one_hot transformations (made in the Mining class object) of gt_mask and pred
    reversed_gt_mask = reverse_one_hot(gt_mask)
    reversed_pred = reverse_one_hot(pred)
    
    # reapplying the original colors in the reversed one hot images
    # and getting them as uint8
    ground_truth_mask = colour_code_segmentation(reversed_gt_mask, class_rgb_values).astype(np.uint8)
    prediction = colour_code_segmentation(reversed_pred, class_rgb_values).astype(np.uint8)
            
    visualize(
        image=reverse_transform(imagevis),
        prediction=reverse_transform(prediction),
        #one_hot_mask=reverse_transform(reversed_gt_mask.astype(np.uint8)),
        ground_truth_mask=reverse_transform(ground_truth_mask.astype(np.uint8)),
    )