Importing libraries

In [1]:
import segmentation_models_pytorch as smp
from torchvision import transforms
import torch
from torch.utils.data import Dataset, DataLoader
import wandb
from PIL import Image
import os
import torchmetrics
import numpy as np
import torchvision.transforms.functional as TF
from ptflops import get_model_complexity_info

Creating classes

In [2]:
# With this class we create a wrapper to the DeepLab one, allowing us to normalize the output
# using sigmoid function
class NormalizedDecoder(smp.DeepLabV3Plus):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, x):
        x = super().forward(x)
        x = torch.sigmoid(x)
        return x
    
# Dataset custom class, used to create the datasets that we will use during train loop and test loop
class CustomDataset(Dataset):
    def __init__(self, img_dir, transform_image=None, transform_annotation=None):
        self.img_dir = img_dir
        self.transform_image = transform_image
        self.transform_annotation = transform_annotation

    def __len__(self):
        return len(self.img_dir[0])

    def __getitem__(self, idx):
        normalized_path = self.img_dir[0][idx]
        image = Image.open(normalized_path).convert('RGB')
        image = TF.to_tensor(image)
        annotation_path = self.img_dir[1][idx]
        annotation = Image.open(annotation_path)
        annotation = TF.to_tensor(annotation)
        if self.transform_image:
            image = self.transform_image(image)
        if self.transform_annotation:
            annotation = self.transform_annotation(annotation)
        return image, annotation

Defining functions

In [3]:
def get_groundtruth_images(image_folder, folders):
    path_annotations = []
    for folder_name in folders:
        folder_path = os.path.join(image_folder, folder_name)
        if os.path.isdir(folder_path):
            tile_folder_path = os.path.join(folder_path, 'groundtruth')
            for filename in os.listdir(tile_folder_path):
                if filename.endswith('iMap.png'):
                    image_path = os.path.join(tile_folder_path, filename)
                    path_annotations.append(image_path)
    return path_annotations

# With this function we read all the paths for the images from the input path
def get_normalized_paths(image_folder, folders):
    train_paths = []
    for folder_name in folders:
        folder_path = os.path.join(image_folder, folder_name)
        if os.path.isdir(folder_path):
            tile_folder_path = os.path.join(folder_path, 'tile', 'RGB')
            for filename in os.listdir(tile_folder_path):
                image_path = os.path.join(tile_folder_path, filename)
                train_paths.append(image_path)
    return train_paths


# With this function we read all the paths for the images from the input path and we split them using the
# % in input
def get_normalized_paths_split(image_folder, folders, train_validation_split):
    train_paths = []
    validation_paths = []
    for folder_name in folders:
        folder_path = os.path.join(image_folder, folder_name)
        if os.path.isdir(folder_path):
            tile_folder_path = os.path.join(folder_path, 'tile', 'RGB')
            image_filenames = os.listdir(tile_folder_path)
            num_train = int(len(image_filenames) * train_validation_split)
            train_filenames = image_filenames[:num_train]
            validation_filenames = image_filenames[num_train:]
            for filename in train_filenames:
                image_path = os.path.join(tile_folder_path, filename)
                train_paths.append(image_path)
            for filename in validation_filenames:
                image_path = os.path.join(tile_folder_path, filename)
                validation_paths.append(image_path)
    return train_paths, validation_paths


# With this function we read all the annotations from the input path
def get_annotation_paths(annotation_folder: str, folders):
    path_annotations = []
    for folder_name in folders:
        folder_path = os.path.join(annotation_folder, folder_name)
        if os.path.isdir(folder_path):
            ndvi_folder = os.path.join(folder_path, 'tile', 'NDVI')
            if os.path.exists(ndvi_folder):
                for filename in os.listdir(ndvi_folder):
                    if filename.endswith('.png'):
                        image_path = os.path.join(ndvi_folder, filename)
                        path_annotations.append(image_path)
    return path_annotations

# With this function we read all the paths for the annotations from the input path and we split them using the
# % in input
def get_annotation_paths_split(annotation_folder: str, folders, train_validation_split):
    path_annotations = []
    validation_path_annotations = []
    for folder_name in folders:
        folder_path = os.path.join(annotation_folder, folder_name)
        if os.path.isdir(folder_path):
            ndvi_folder = os.path.join(folder_path, 'tile', 'NDVI')
            if os.path.exists(ndvi_folder):
                image_filenames = os.listdir(ndvi_folder)
                num_train = int(len(image_filenames) * train_validation_split)
                train_filenames = image_filenames[:num_train]
                validation_filenames = image_filenames[num_train:]
                for filename in train_filenames:
                    if filename.endswith('.png'):
                        image_path = os.path.join(ndvi_folder, filename)
                        path_annotations.append(image_path)
                for filename in validation_filenames:
                    if filename.endswith('.png'):
                        image_path = os.path.join(ndvi_folder, filename)
                        validation_path_annotations.append(image_path)
    return path_annotations, validation_path_annotations


# We use this function to validate the model after an epoch during the train loop to check if the model has improved
def validate_model(model, validation_loader):
    model.eval()
    mse_validation = torchmetrics.MeanSquaredError().to(device)

    with torch.no_grad():
        for images, annotations in validation_loader:
            images = images.to(device)
            annotations = annotations.to(device)

            outputs = model(images)
            mse_validation.update(outputs,annotations)

    mse_validation = mse_validation.compute()
    wandb.log({'mse_validation': mse_validation})
    return mse_validation.item()


# We use this function to create new tensor to test the error for each class, the value in input define
# which class we are analyzing at the moment
# The tensor image and the tensor annotations will have all their pixel that are not in the class we are analyzing
# removed, and after they will be returned as output out the function
def get_selected_pixels_tensor(image, annotations, imap, value):
    # Crea una maschera booleana in base al valore desiderato
    mask = imap == value
    assert imap.shape == image.shape
    new_image = image[mask]
    new_annotations = annotations[mask]

    return new_image, new_annotations


# We use this function to return the correct grountruth tensor to calculate the error for each class
def get_grountruth(groundtruth_input, idx):
    groundtruth_return = None
    j = 0
    for other, groundtruth in groundtruth_input:
        groundtruth_return = groundtruth
        if j == idx:
            break
        j += 1
    return groundtruth_return

Defining data trasformation functions

In [4]:
# We define transformations for the images that will be applied during the execution of the program
data_transforms= transforms.Compose([
    transforms.Resize((224, 224), interpolation=Image.NEAREST, antialias=False),
    transforms.ConvertImageDtype(torch.float32),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[
                         0.229, 0.224, 0.225])
])

annotation_transforms = transforms.Compose([
    transforms.Resize((224, 224), interpolation=Image.NEAREST)
])

data_transforms_augmentation = transforms.Compose([
    transforms.Resize((224, 224), interpolation=Image.NEAREST, antialias=False),
    transforms.ConvertImageDtype(torch.float32),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[
                         0.229, 0.224, 0.225]),
    #Augmentation transform
    transforms.RandomHorizontalFlip(p=0.5)
])

annotation_transforms_augmentation = transforms.Compose([
    transforms.Resize((224, 224), interpolation=Image.NEAREST),
    # Augmentation transform
    transforms.RandomHorizontalFlip(p=0.5)
])

  transforms.Resize((224, 224), interpolation=Image.NEAREST, antialias=False),
  transforms.Resize((224, 224), interpolation=Image.NEAREST)
  transforms.Resize((224, 224), interpolation=Image.NEAREST, antialias=False),
  transforms.Resize((224, 224), interpolation=Image.NEAREST),


Configuring wandb config and general parameters

In [5]:
# Authenticate API key
wandb.login(key="59882b3dbe33bfba4a2007aee502f5b9803bb409")
# Initialize Wandb
wandb.init()

# Setting batch size
batch_size = 30

# Define a configuration object
config = wandb.config
config.encoder_name = 'efficientnet-b2'
config.encoder_weights = 'imagenet'
config.classes = 1
config.lr = 0.0002
config.batch_size = batch_size
config.num_epochs = 1000

# Log the configuration to Wandb
wandb.config.update(config)

# Defining folders
folder = 'RedEdge'
train_folders = ['000', '001', '002', '004']
test_folders = ['003']

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33mi-decosmis[0m ([33moctopus-canneller[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: C:\Users\Ivan/.netrc


Creating test set, validation set, training set and groundtruth set

In [6]:
groundtruth = get_groundtruth_images(folder, test_folders)

groundtruth_train = get_groundtruth_images(folder, train_folders)

# Getting normalized images from folders, splitted in 2 lists, one for train(80%) and one for validation(20%)
normalized_train_paths = get_normalized_paths_split(folder, train_folders, 0.8)

# Getting normalized images for test
normalized_test_paths = get_normalized_paths(folder, test_folders)

# Getting classes annotations from folders, splitted in 2 lists, one for train(80%) and one for validation(20%)
train_annotations_paths = get_annotation_paths_split(folder, train_folders, 0.8)

# Getting classes annotations from folders for test
test_annotations_paths = get_annotation_paths(folder, test_folders)

# Creating each different set with images and annotations
train_matrix = np.array([normalized_train_paths[0], train_annotations_paths[0]])
validation_matrix = np.array([normalized_train_paths[1], train_annotations_paths[1]])
test_matrix = np.array([normalized_test_paths, test_annotations_paths])
groundtruth_matrix = np.array([groundtruth, groundtruth])

groundtruth_train_matrix = np.array([groundtruth_train, groundtruth_train])

# Creating data loader for train, validation and test set
train_data = CustomDataset(train_matrix, data_transforms_augmentation, annotation_transforms_augmentation)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=False)

validation_data = CustomDataset(validation_matrix, data_transforms_augmentation, annotation_transforms_augmentation)
validation_loader = DataLoader(train_data, batch_size=batch_size, shuffle=False)

test_data = CustomDataset(test_matrix, data_transforms,annotation_transforms)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

groundtruth_data = CustomDataset(groundtruth_matrix, annotation_transforms, annotation_transforms)
groundtruth_loader = DataLoader(groundtruth_data, batch_size=batch_size, shuffle=False)

groundtruth_train_data = CustomDataset(groundtruth_train_matrix, annotation_transforms, annotation_transforms)
groundtruth_train_loader = DataLoader(groundtruth_train_data, batch_size=batch_size, shuffle=False)

Creating torch device, model, and metrics

In [7]:
# Setting the device to gpu if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Creating the model
model = NormalizedDecoder(
    encoder_name='efficientnet-b2',
    encoder_weights='imagenet',
    classes=1
).to(device)

# Defining loss function for the train loop to check if model is improving each epoch
mse_train = torchmetrics.MeanSquaredError().to(device)
mae_train = torchmetrics.MeanAbsoluteError().to(device)

# Defining function to calculate mean error after the inference loop
mse_test = torchmetrics.MeanSquaredError().to(device)
mae_test = torchmetrics.MeanAbsoluteError().to(device)

mse_test_one = torchmetrics.MeanSquaredError().to(device)
mse_test_two = torchmetrics.MeanSquaredError().to(device)
mse_test_three = torchmetrics.MeanSquaredError().to(device)

mae_test_one = torchmetrics.MeanAbsoluteError().to(device)
mae_test_two = torchmetrics.MeanAbsoluteError().to(device)
mae_test_three = torchmetrics.MeanAbsoluteError().to(device)

mse_class_average = torchmetrics.MeanSquaredError().to(device)
mae_class_average = torchmetrics.MeanAbsoluteError().to(device)

mse_train_one = torchmetrics.MeanSquaredError().to(device)
mse_train_two = torchmetrics.MeanSquaredError().to(device)
mse_train_three = torchmetrics.MeanSquaredError().to(device)

mae_train_one = torchmetrics.MeanAbsoluteError().to(device)
mae_train_two = torchmetrics.MeanAbsoluteError().to(device)
mae_train_three = torchmetrics.MeanAbsoluteError().to(device)


# Wrappo il modello e l'optimizer con wandb
wandb.watch(model)

[]

Printing set elements and total number of parameters

In [8]:
# Showing data in console
print("Number of elements in the train set:", len(train_data))
print("Number of elements in the validation set:", len(validation_data))
print("Number of elements in the test set:", len(test_data))


total_params = sum(p.numel() for p in model.parameters())
print(f'Total number of parameters: {total_params}')

Number of elements in the train set: 362
Number of elements in the validation set: 93
Number of elements in the test set: 102
Total number of parameters: 8642739


Defining hyperparameters and training loop for the model

In [9]:
# Variables for training loop
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0002)
num_epochs = 1000
best_loss = np.inf
patience = 40
delta_threshold = 0.0001
counter = 0
saved_epoch = 0
# Training the model
for epoch in range(num_epochs):
    cicle = enumerate(train_loader)
    j = 0
    for i, (images, annotations) in cicle:
        images = images.to(device)
        annotations = annotations.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, annotations)
        loss.backward()
        optimizer.step()
        mse_train.update(outputs, annotations)
        mae_train.update(outputs, annotations)
        groundtruth_test = get_grountruth(groundtruth_loader, j)
        for img, annotation, img_groundtruth in zip(outputs, annotations, groundtruth_test):
            new_img, new_groundtruth = get_selected_pixels_tensor(img, annotation, img_groundtruth, 0)
            mse_train_one.update(new_img, new_groundtruth)
            mae_train_one.update(new_img, new_groundtruth)
        for img, annotation, img_groundtruth in zip(outputs, annotations, groundtruth_test):
            new_img, new_groundtruth = get_selected_pixels_tensor(img, annotation, img_groundtruth, 2)
            mse_train_two.update(new_img, new_groundtruth)
            mae_train_two.update(new_img, new_groundtruth)
        for img, annotation, img_groundtruth in zip(outputs, annotations, groundtruth_test):
            new_img, new_groundtruth = get_selected_pixels_tensor(img, annotation, img_groundtruth, 10000)
            mse_train_three.update(new_img, new_groundtruth)
            mae_train_three.update(new_img, new_groundtruth)
        j += 1
    # Verifica se la loss corrente è migliore della migliore loss finora
    mse_train = mse_train.compute()
    mae_train = mae_train.compute()
    mse_train_one = mse_train_one.compute()
    mse_train_two = mse_train_two.compute()
    mse_train_three = mse_train_three.compute()

    mae_train_one = mae_train_one.compute()
    mae_train_two = mae_train_two.compute()
    mae_train_three = mae_train_three.compute()
    print(f"Epoch: {epoch + 1}/{num_epochs} | Mse loss: {mse_train.item():.4f}")
    wandb.log({'mse_training': mse_train, 'mae_training': mae_train, 'mse_train_class_one': mse_train_one,
               'mse_train_class_two': mse_train_two, 'mse_train_class_thress': mse_train_three,
               'mae_train_class_one': mae_train_one, 'mae_train_class_two': mae_train_two,
               'mae_train_class_three': mae_train_three})
    mse_train = torchmetrics.MeanSquaredError().to(device)
    mae_train = torchmetrics.MeanAbsoluteError().to(device)
    mse_train_one = torchmetrics.MeanSquaredError().to(device)
    mse_train_two = torchmetrics.MeanSquaredError().to(device)
    mse_train_three = torchmetrics.MeanSquaredError().to(device)

    mae_train_one = torchmetrics.MeanAbsoluteError().to(device)
    mae_train_two = torchmetrics.MeanAbsoluteError().to(device)
    mae_train_three = torchmetrics.MeanAbsoluteError().to(device)
    mse_validation = validate_model(model, validation_loader)
    if best_loss - mse_validation > delta_threshold:
        print("entro:", best_loss)
        saved_epoch = epoch + 1
        best_loss = mse_validation
        counter = 0  # Reimposta il contatore
        torch.save(model.state_dict(), 'model_backup')
    else:
        counter += 1
    if counter >= patience:
        print(f"Stopping at epoch: {epoch + 1}")
        print("Using the model saved from epoch:", saved_epoch)
        break

Epoch: 1/1000 | Mse loss: 0.1247
entro: inf
Epoch: 2/1000 | Mse loss: 0.0639
entro: 0.10802991688251495
Epoch: 3/1000 | Mse loss: 0.0163
entro: 0.019754493609070778
Epoch: 4/1000 | Mse loss: 0.0146
entro: 0.017978545278310776
Epoch: 5/1000 | Mse loss: 0.0155
entro: 0.016914425417780876
Epoch: 6/1000 | Mse loss: 0.0155
entro: 0.01580314338207245
Epoch: 7/1000 | Mse loss: 0.0127
entro: 0.013095824979245663
Epoch: 8/1000 | Mse loss: 0.0124
entro: 0.012529371306300163
Epoch: 9/1000 | Mse loss: 0.0120
Epoch: 10/1000 | Mse loss: 0.0121
entro: 0.01208651252090931
Epoch: 11/1000 | Mse loss: 0.0118
Epoch: 12/1000 | Mse loss: 0.0121
entro: 0.011679112911224365
Epoch: 13/1000 | Mse loss: 0.0117
Epoch: 14/1000 | Mse loss: 0.0113
entro: 0.01147305965423584
Epoch: 15/1000 | Mse loss: 0.0110
Epoch: 16/1000 | Mse loss: 0.0112
entro: 0.011099523864686489
Epoch: 17/1000 | Mse loss: 0.0111
entro: 0.01088991854339838
Epoch: 18/1000 | Mse loss: 0.0109
Epoch: 19/1000 | Mse loss: 0.0108
Epoch: 20/1000 | Mse 

Calculating the complexity of the model

In [10]:
with torch.cuda.device(0):
    macs, params = get_model_complexity_info(model, (3, 224, 224), as_strings=True,
                                           print_per_layer_stat=False, verbose=False)
    print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
    print('{:<30}  {:<8}'.format('Number of parameters: ', params))

Computational complexity:       436.09 MMac
Number of parameters:           8.64 M  


Inference loop with testing data

In [11]:
# Loading the last better model and setting it in evaluation mode
model.load_state_dict(torch.load('model_backup'))
model.eval()

# Starting inference with test data
with torch.no_grad():
    i = 0
    counter = 1
    for images, annotations in test_loader:
        images = images.to(device)
        annotations = annotations.to(device)
        # Passaggio Forward
        outputs = model.predict(images)
        # Updating error
        mse_test.update(outputs,annotations)
        mae_test.update(outputs,annotations)
        groundtruth_test = get_grountruth(groundtruth_loader, i)
        for img, annotation, img_groundtruth in zip(outputs, annotations, groundtruth_test):
            new_img, new_groundtruth = get_selected_pixels_tensor(img, annotation, img_groundtruth, 0)
            mse_test_one.update(new_img, new_groundtruth)
            mae_test_one.update(new_img, new_groundtruth)
        for img, annotation, img_groundtruth in zip(outputs, annotations, groundtruth_test):
            new_img, new_groundtruth = get_selected_pixels_tensor(img, annotation, img_groundtruth, 2)
            mse_test_two.update(new_img, new_groundtruth)
            mae_test_two.update(new_img, new_groundtruth)
        for img, annotation, img_groundtruth in zip(outputs, annotations, groundtruth_test):
            new_img, new_groundtruth = get_selected_pixels_tensor(img, annotation, img_groundtruth, 10000)
            mse_test_three.update(new_img, new_groundtruth)
            mae_test_three.update(new_img, new_groundtruth)
        i += 1

        # Creating wandb table for logs images for easy comparison
        table = wandb.Table(columns=['Type', 'Image'])
        input_images = wandb.Image(
            images
        )
        annotations_images = wandb.Image(
            annotations
        )
        output_images = wandb.Image(
            outputs
        )
        table.add_data("Input", input_images)
        table.add_data("Ground truth", annotations_images)
        table.add_data("Output", output_images)
        wandb.log({'Table_{}'.format(counter): table})
        counter += 1
    # Computing total error
    mse_test = mse_test.compute()
    mae_test = mae_test.compute()
    mse_test_one = mse_test_one.compute()
    mse_test_two = mse_test_two.compute()
    mse_test_three = mse_test_three.compute()
    mae_test_one = mae_test_one.compute()
    mae_test_two = mae_test_two.compute()
    mae_test_three = mae_test_three.compute()
    mae_test_average = (mae_test_one.item() + mae_test_two.item() + mae_test_three.item()) / 3
    mse_test_average = (mse_test_one.item() + mse_test_two.item() + mse_test_three.item()) / 3

Logging last information on wandb

In [12]:
# Logging average error for classes after inference
wandb.summary["MSE microaveraged"] = mse_test
wandb.summary["MAE microaveraged"] = mae_test
wandb.summary["mae_class_one"] = mae_test_one
wandb.summary["mae_class_two"] = mae_test_two
wandb.summary["mae_class_three"] = mae_test_three
wandb.summary["mse_class_one"] = mse_test_one
wandb.summary["mse_class_two"] = mse_test_two
wandb.summary["mse_class_three"] = mse_test_three
wandb.summary["MAE macroaveraged"] = mae_test_average
wandb.summary["MSE macroaveraged"] = mse_test_average


# Saving logs locally
wandb.save('model.pth')

[]