In [16]:
import os
import numpy as np
from glob import glob
from PIL import Image
from torch.utils.data import Dataset
from typing import Optional, Callable
from torchvision import transforms
import torch
from torch.utils.data import Dataset
from torchvision import transforms
import cv2
import torch.nn as nn
import numpy as np
import albumentations as A
import os
import time
import torch.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader


class CropSegmentationDataset(Dataset):
    ROOT_PATH: str = "/net/ens/am4ip/datasets/project-dataset/"
    id2cls: dict = {0: "background",
                    1: "crop",
                    2: "weed",
                    3: "partial-crop",
                    4: "partial-weed"}

    cls2id: dict = {"background": 0,
                    "crop": 1,
                    "weed": 2,
                    "partial-crop": 3,
                    "partial-weed": 4}

    def __init__(self, set_type: str = "train", transform: Optional[Callable] = None,
                 target_transform: Optional[Callable] = None,
                 merge_small_items: bool = True,
                 remove_small_items: bool = False):
        """Class to load datasets for the Project.

        Remark: `target_transform` is applied before merging items (this eases data augmentation).

        :param set_type: Define if you load training, validation or testing sets. Should be either "train", "val" or "test".
        :param transform: Callable to be applied on inputs.
        :param target_transform: Callable to be applied on labels.
        :param merge_small_items: Boolean to either merge classes of small or occluded objects.
        :param remove_small_items: Boolean to consider as background class small or occluded objects. If `merge_small_items` is set to `True`, then this parameter is ignored.
        """
        super(CropSegmentationDataset, self).__init__()
        self.transform = transform
        self.target_transform = target_transform
        self.merge_small_items = merge_small_items
        self.remove_small_items = remove_small_items

        if set_type not in ["train", "val", "test"]:
            raise ValueError("'set_type has an unknown value. "
                             f"Got '{set_type}' but expected something in ['train', 'val', 'test'].")

        self.set_type = set_type
        images = glob(os.path.join(self.ROOT_PATH, set_type, "images/*"))
        images.sort()
        self.images = np.array(images)

        labels = glob(os.path.join(self.ROOT_PATH, set_type, "labels/*"))
        labels.sort()
        self.labels = np.array(labels)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index: int):
        input_img = Image.open(self.images[index], "r")
        target = Image.open(self.labels[index], "r")

        if self.transform is not None:
            input_img = self.transform(input_img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        target_np = np.array(target)

        if self.merge_small_items:
            target_np[target_np == self.cls2id["partial-crop"]] = self.cls2id["crop"]
            target_np[target_np == self.cls2id["partial-weed"]] = self.cls2id["weed"]
        elif self.remove_small_items:
            target_np[target_np == self.cls2id["partial-crop"]] = self.cls2id["background"]
            target_np[target_np == self.cls2id["partial-weed"]] = self.cls2id["background"]

        # Convert back to PIL image
        target = Image.fromarray(target_np)

        return input_img, target

    def get_class_number(self):
        if self.merge_small_items or self.remove_small_items:
            return 3
        else:
            return 5



In [17]:
def preprocess_image(path):

    img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
    img = img.astype('float32')
    mx = np.max(img)
    if mx:
        img/=mx

    img = np.transpose(img, (2, 0, 1))
    img_ten = torch.tensor(img)
    return img_ten

def preprocess_mask(path):

    msk = cv2.imread(path, cv2.IMREAD_UNCHANGED)
    msk = msk.astype('float32')
    msk/=255.0
    msk = np.repeat(msk[:, :, np.newaxis], 3, axis=-1)
    msk_ten = torch.tensor(msk)
    msk_ten= np.transpose(msk_ten, (2, 0, 1))

    return msk_ten



In [18]:
class CustomDataset(Dataset):
    def __init__(self, image_files, mask_files, input_size=(256, 256), augmentation_transforms=None):
        self.image_files = image_files
        self.mask_files = mask_files
        self.input_size = input_size
        self.augmentation_transforms = augmentation_transforms

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):

        image_path = self.image_files[idx]
        mask_path = self.mask_files[idx]

        image = preprocess_image(image_path)
        mask = preprocess_mask(mask_path)

        if self.augmentation_transforms:
            image, mask = self.augmentation_transforms(image, mask)

        return image, mask

In [19]:
def augment_image(image, mask):

    image_np = image.permute(1, 2, 0).numpy()
    mask_np = mask.permute(1, 2, 0).numpy()

    transform = A.Compose([
        A.Resize(256,256, interpolation=cv2.INTER_NEAREST),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.ShiftScaleRotate(scale_limit=0.5, rotate_limit=0, shift_limit=0.1, p=1, border_mode=0),
        A.RandomCrop(height=256, width=256, always_apply=True),
        A.RandomBrightness(p=1),
        A.OneOf(
            [
                A.Blur(blur_limit=3, p=1),
                A.MotionBlur(blur_limit=3, p=1),
            ],
            p=0.9,
        ),

    ])

    augmented = transform(image=image_np, mask=mask_np)
    augmented_image, augmented_mask = augmented['image'], augmented['mask']

    augmented_image = torch.tensor(augmented_image, dtype=torch.float32).permute(2, 0, 1)
    augmented_mask = torch.tensor(augmented_mask, dtype=torch.float32).permute(2, 0, 1)

    return augmented_image, augmented_mask

Training the model

In [20]:
base_path = '/net/ens/am4ip/datasets/project-dataset/'

dataset_train = '/net/ens/am4ip/datasets/project-dataset/train'
dataset_val = '/net/ens/am4ip/datasets/project-dataset/val'


images_path_train = os.path.join(base_path, dataset_train, 'images')
labels_path_train = os.path.join(base_path, dataset_train, 'labels')

images_path_val = os.path.join(base_path, dataset_val, 'images')
labels_path_val = os.path.join(base_path, dataset_val, 'labels')

images_train = sorted([os.path.join(images_path_train, f) for f in os.listdir(images_path_train) if f.endswith('.png')])
label_train = sorted([os.path.join(labels_path_train, f) for f in os.listdir(labels_path_train) if f.endswith('.png')])

images_val = sorted([os.path.join(images_path_val, f) for f in os.listdir(images_path_val) if f.endswith('.png')])
label_val = sorted([os.path.join(labels_path_val, f) for f in os.listdir(labels_path_val) if f.endswith('.png')])


In [21]:
train_dataset = CustomDataset(images_train, label_train, augmentation_transforms=augment_image)
val_dataset = CustomDataset(images_val, label_val, augmentation_transforms=augment_image)

In [22]:
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=False)

In [23]:
# Assuming you have already created train_dataloader and val_dataloader
train_dataset_size = len(train_dataloader.dataset)
val_dataset_size = len(val_dataloader.dataset)

print("Train dataset size:", train_dataset_size)
print("Validation dataset size:", val_dataset_size)


Train dataset size: 1407
Validation dataset size: 422


In [24]:
# for batch_idx, (batch_images, batch_masks) in enumerate(train_dataloader):
#     print("Batch", batch_idx + 1)
#     print("Image batch shape:", batch_images.shape)
#     print("Mask batch shape:", batch_masks.shape)
    
#     for image, mask, image_path, mask_path in zip(batch_images, batch_masks, images_train, label_train):
       
#         image = image.permute((1, 2, 0)).numpy()*255.0
#         image = image.astype('uint8')
#         mask = (mask*255).numpy().astype('uint8')
        
#         image_filename = os.path.basename(image_path)
#         mask_filename = os.path.basename(mask_path)
        
#         plt.figure(figsize=(15, 10))
        
#         plt.subplot(2, 4, 1)
#         plt.imshow(image, cmap='gray')
#         plt.title(f"Original Image - {image_filename}")
        
#         plt.subplot(2, 4, 2)
#         plt.imshow(mask, cmap='gray')
#         plt.title(f"Mask Image - {mask_filename}")
        
#         plt.tight_layout()
#         plt.show()
#     break

In [25]:
# for batch_idx, (batch_images, batch_masks) in enumerate(train_dataloader):
#     print("Batch", batch_idx + 1)
#     print("Image batch shape:", batch_images.shape)
#     print("Mask batch shape:", batch_masks.shape)

In [26]:
# for batch_idx, (batch_images, batch_masks) in enumerate(val_dataloader):
#     print("Batch", batch_idx + 1)
#     print("Image batch shape:", batch_images.shape)
#     print("Mask batch shape:", batch_masks.shape)

In [27]:
def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')
device = get_default_device()
device

device(type='cuda')

In [28]:
class ConvBlock(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class UpConv(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(UpConv, self).__init__()

        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.up(x)
        return x


class AttentionBlock(nn.Module):
    """Attention block with learnable parameters"""

    def __init__(self, F_g, F_l, n_coefficients):
        """
        :param F_g: number of feature maps (channels) in previous layer
        :param F_l: number of feature maps in corresponding encoder layer, transferred via skip connection
        :param n_coefficients: number of learnable multi-dimensional attention coefficients
        """
        super(AttentionBlock, self).__init__()

        self.W_gate = nn.Sequential(
            nn.Conv2d(F_g, n_coefficients, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(n_coefficients)
        )

        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, n_coefficients, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(n_coefficients)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(n_coefficients, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
        )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, gate, skip_connection):
        """
        :param gate: gating signal from previous layer
        :param skip_connection: activation from corresponding encoder layer
        :return: output activations
        """
        g1 = self.W_gate(gate)
        x1 = self.W_x(skip_connection)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        psi = F.softmax(psi, dim=1)  # Apply softmax here

        out = skip_connection * psi
        return out

class AttentionUNet(nn.Module):

    def __init__(self, img_ch=3, output_ch=3):
        super(AttentionUNet, self).__init__()

        self.MaxPool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = ConvBlock(img_ch, 64)
        self.Conv2 = ConvBlock(64, 128)
        self.Conv3 = ConvBlock(128, 256)
        self.Conv4 = ConvBlock(256, 512)
        self.Conv5 = ConvBlock(512, 1024)

        self.Up5 = UpConv(1024, 512)
        self.Att5 = AttentionBlock(F_g=512, F_l=512, n_coefficients=256)
        self.UpConv5 = ConvBlock(1024, 512)

        self.Up4 = UpConv(512, 256)
        self.Att4 = AttentionBlock(F_g=256, F_l=256, n_coefficients=128)
        self.UpConv4 = ConvBlock(512, 256)

        self.Up3 = UpConv(256, 128)
        self.Att3 = AttentionBlock(F_g=128, F_l=128, n_coefficients=64)
        self.UpConv3 = ConvBlock(256, 128)

        self.Up2 = UpConv(128, 64)
        self.Att2 = AttentionBlock(F_g=64, F_l=64, n_coefficients=32)
        self.UpConv2 = ConvBlock(128, 64)

        self.Conv = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        """
        e : encoder layers
        d : decoder layers
        s : skip-connections from encoder layers to decoder layers
        """
        e1 = self.Conv1(x)

        e2 = self.MaxPool(e1)
        e2 = self.Conv2(e2)

        e3 = self.MaxPool(e2)
        e3 = self.Conv3(e3)

        e4 = self.MaxPool(e3)
        e4 = self.Conv4(e4)

        e5 = self.MaxPool(e4)
        e5 = self.Conv5(e5)

        d5 = self.Up5(e5)

        s4 = self.Att5(gate=d5, skip_connection=e4)
        d5 = torch.cat((s4, d5), dim=1)
        d5 = self.UpConv5(d5)

        d4 = self.Up4(d5)
        s3 = self.Att4(gate=d4, skip_connection=e3)
        d4 = torch.cat((s3, d4), dim=1)
        d4 = self.UpConv4(d4)

        d3 = self.Up3(d4)
        s2 = self.Att3(gate=d3, skip_connection=e2)
        d3 = torch.cat((s2, d3), dim=1)
        d3 = self.UpConv3(d3)

        d2 = self.Up2(d3)
        s1 = self.Att2(gate=d2, skip_connection=e1)
        d2 = torch.cat((s1, d2), dim=1)
        d2 = self.UpConv2(d2)

        out = self.Conv(d2)

        return out

In [29]:
import numpy as np

def dice_coeff(prediction, target):
    # Move tensors to CPU
    prediction_cpu = prediction.cpu().numpy()
    target_cpu = target.cpu().numpy()

    # Assuming that prediction and target have the same number of channels
    num_channels = prediction_cpu.shape[1]

    dice_scores = []

    for channel in range(num_channels):
        mask = np.zeros_like(prediction_cpu[:, channel, :, :])
        mask[prediction_cpu[:, channel, :, :] >= 0.5] = 1

        inter = np.sum(mask * target_cpu[:, channel, :, :])
        union = np.sum(mask) + np.sum(target_cpu[:, channel, :, :])

        dice = (2 * inter) / (union + 1e-8)  # Add a small epsilon to avoid division by zero
        dice_scores.append(dice)

    # Return the average dice score across channels
    return np.mean(dice_scores)


In [30]:
import numpy as np

def dice_coeff(prediction, target):
    # Move tensors to CPU
    prediction_cpu = prediction.cpu().numpy()
    target_cpu = target.cpu().numpy()

    # Assuming that prediction and target have the same number of channels
    num_channels = prediction_cpu.shape[1]

    dice_scores = []

    if len(prediction_cpu.shape) == 4:
        # Iterate over channels
        for channel in range(num_channels):
            mask = np.zeros_like(prediction_cpu[:, channel, :, :])
            mask[prediction_cpu[:, channel, :, :] >= 0.5] = 1
            inter = np.sum(mask * target_cpu[:, channel, :, :])
            union = np.sum(mask) + np.sum(target_cpu[:, channel, :, :])
            dice = (2 * inter) / (union + 1e-8)  # Add a small epsilon to avoid division by zero
            dice_scores.append(dice)
    elif len(prediction_cpu.shape) == 3:
        # For binary segmentation
        mask = np.zeros_like(prediction_cpu[:, 0, :, :])
        mask[prediction_cpu[:, 0, :, :] >= 0.5] = 1
        inter = np.sum(mask * target_cpu)


In [31]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
    def __init__(self, gamma=0, size_average=None, ignore_index=-100, reduce=None, balance_param=1.0):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.size_average = size_average
        self.ignore_index = ignore_index
        self.reduce = reduce
        self.balance_param = balance_param

    def forward(self, input, target):
        print(input.shape, target.shape)

        assert len(input.shape) == len(target.shape)

        # Calculate probabilities and log probabilities
        probs = F.softmax(input, dim=1)
        log_probs = F.log_softmax(input, dim=1)

        # Calculate focal loss
        pt = torch.exp(log_probs) * (1 - probs)  # Fixing the one-hot encoding here
        focal_loss = -((1 - pt) ** self.gamma) * log_probs

        # Apply the ignore index if specified
        if self.ignore_index >= 0:
            ignore_mask = target != self.ignore_index
            focal_loss = focal_loss[ignore_mask]
            target = target[ignore_mask]

        # Calculate the balanced focal loss
        balanced_focal_loss = self.balance_param * focal_loss

        # Take the mean over non-ignored samples
        if self.size_average:
            return balanced_focal_loss.mean()
        elif self.reduce == 'sum':
            return balanced_focal_loss.sum()
        else:
            return balanced_focal_loss


In [32]:
dataloaders = {
    'training': train_dataloader,
    'test': val_dataloader
}

In [33]:
import time
import torch
import numpy as np
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

def train_and_test(model, dataloaders, optimizer, criterion, num_epochs=100, show_images=False):
    since = time.time()
    best_loss = float('inf')  # Initialize with a large value
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)

    fieldnames = ['epoch', 'training_loss', 'test_loss', 'training_dice_coeff', 'test_dice_coeff']
    train_epoch_losses = []
    test_epoch_losses = []

    for epoch in range(1, num_epochs + 1):

        print(f'Epoch {epoch}/{num_epochs}')
        print('-' * 10)
        
        batchsummary = {a: [0] for a in fieldnames}
        batch_train_loss = 0.0
        batch_test_loss = 0.0

        for phase in ['training', 'test']:
            if phase == 'training':
                model.train()  
            else:
                model.eval() 

            for sample in iter(dataloaders[phase]):
                inputs = sample[0].to(device)
                masks = sample[1].to(device)
                
                #masks = masks.unsqueeze(1)
                print('pipline mask',masks.shape)
                
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'training'):
                    outputs = model(inputs)

                    loss = criterion(outputs, masks)

                    # Calculate multi-class Dice coefficient directly on the GPU
                    y_pred = torch.argmax(outputs, dim=1)
                    dice_coefficient = dice_coeff(y_pred, masks)

                    batchsummary[f'{phase}_dice_coeff'].append(dice_coefficient.item())  # Convert to Python scalar

                    if phase == 'training':
                        loss.backward()
                        optimizer.step()

                        batch_train_loss += loss.item() * inputs.size(0)

                    else:
                        batch_test_loss += loss.item() * inputs.size(0)

            if phase == 'training':
                epoch_train_loss = batch_train_loss / len(dataloaders['training'])
                train_epoch_losses.append(epoch_train_loss)
            else:
                epoch_test_loss = batch_test_loss / len(dataloaders['test'])
                test_epoch_losses.append(epoch_test_loss)

            batchsummary['epoch'] = epoch
            
            print('{} Loss: {:.4f}'.format(phase, loss))

        best_loss = np.min(batchsummary['test_dice_coeff'])
        for field in fieldnames[3:]:
            batchsummary[field] = np.mean(batchsummary[field])
        print(
            f'\t\t\t train_dice_coeff: {batchsummary["training_dice_coeff"]}, test_dice_coeff: {batchsummary["test_dice_coeff"]}')

    print('Best dice coefficient: {:4f}'.format(best_loss))

    return model, train_epoch_losses, test_epoch_losses


In [34]:
epochs = 25

def train():
    model = AttentionUNet()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    criterion = FocalLoss(gamma=2)

    trained_model, train_epoch_losses, test_epoch_losses = train_and_test(model, dataloaders, optimizer, criterion, num_epochs=epochs)

    return trained_model, train_epoch_losses, test_epoch_losses


trained_model, train_epoch_losses, test_epoch_losses = train()

Epoch 1/25
----------




pipline mask torch.Size([8, 3, 256, 256])
torch.Size([8, 3, 256, 256]) torch.Size([8, 3, 256, 256])


IndexError: too many indices for array: array is 3-dimensional, but 4 were indexed

In [None]:
pth_weigth_file = torch.save(trained_model.state_dict(), 'trained_model.pth')