In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from albumentations.pytorch import ToTensorV2
import numpy as np
from tqdm import tqdm
from torch.utils.data.dataset import Dataset
from PIL import Image
import glob
import os
import torch.nn.functional as F
import albumentations as A
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report
from torchmetrics.classification import MulticlassJaccardIndex

In [None]:
torch.cuda.is_available()

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, num_heads)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim * mlp_ratio)),
            nn.GELU(),
            nn.Linear(int(dim * mlp_ratio), dim),
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        x = x + self.mlp(self.norm2(x))
        return x

In [None]:
class TransformerUNet(nn.Module):
    def __init__(self, in_channels, out_channels, features=[64, 128, 256, 512]):
        super().__init__()
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Down part of UNet
        for feature in features:
            self.downs.append(self._double_conv(in_channels, feature))
            in_channels = feature

        # Up part of UNet
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2)
            )
            self.ups.append(self._double_conv(feature * 2, feature))

        self.bottleneck = self._double_conv(features[-1], features[-1] * 2)
        self.transformer = TransformerBlock(features[-1] * 2, num_heads=8)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def _double_conv(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        skip_connections = []

        # Down part
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        # Bottleneck
        x = self.bottleneck(x)

        # Apply transformer to the bottleneck
        b, c, h, w = x.shape
        x = x.flatten(2).permute(2, 0, 1)  # (H*W, B, C)
        x = self.transformer(x)
        x = x.permute(1, 2, 0).view(b, c, h, w)

        # Up part
        skip_connections = skip_connections[::-1]
        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx // 2]
            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx + 1](concat_skip)

        return self.final_conv(x)

In [None]:
# # Taken from https://stackoverflow.com/questions/43884463/how-to-convert-rgb-image-to-one-hot-encoded-3d-array-based-on-color-using-numpy

# color_dict = {
#         0: (0, 0, 0),
#         1: (1, 1, 1),
#         2: (2, 2, 2),
#         3: (3, 3, 3),
#         4: (4, 4, 4),
# }


# def rgb_to_onehot(rgb_arr, color_dict):
#     num_classes = len(color_dict)
#     shape = rgb_arr.shape[:2]+(num_classes,)
#     arr = np.zeros( shape, dtype=np.int8 )
#     for i, cls in enumerate(color_dict):
#         arr[:,:,i] = np.all(rgb_arr.reshape( (-1,3) ) == color_dict[i], axis=1).reshape(shape[:2])
#     return arr


# def onehot_to_rgb(onehot, color_dict):
#     single_layer = np.argmax(onehot, axis=-1)
#     output = np.zeros( onehot.shape[:2]+(3,) )
#     for k in color_dict.keys():
#         output[single_layer==k] = color_dict[k]
#     return np.uint8(output)

In [None]:
# Taken from https://github.com/hubutui/DiceLoss-PyTorch/blob/master/loss.py
def make_one_hot(input, num_classes):
    """Convert class index tensor to one hot encoding tensor.

    Args:
         input: A tensor of shape [N, 1, *]
         num_classes: An int of number of class
    Returns:
        A tensor of shape [N, num_classes, *]
    """
    shape = np.array(input.shape)
    shape[1] = num_classes
    shape = tuple(shape)
    result = torch.zeros(shape)
    result = result.scatter_(1, input.cpu(), 1)

    return result


class BinaryDiceLoss(nn.Module):
    """Dice loss of binary class
    Args:
        smooth: A float number to smooth loss, and avoid NaN error, default: 1
        p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2
        predict: A tensor of shape [N, *]
        target: A tensor of shape same with predict
        reduction: Reduction method to apply, return mean over batch if 'mean',
            return sum if 'sum', return a tensor of shape [N,] if 'none'
    Returns:
        Loss tensor according to arg reduction
    Raise:
        Exception if unexpected reduction
    """
    def __init__(self, smooth=0.00001, p=2, reduction='mean'):
        super(BinaryDiceLoss, self).__init__()
        self.smooth = smooth
        self.p = p
        self.reduction = reduction

    def forward(self, predict, target):
        assert predict.shape[0] == target.shape[0], "predict & target batch size don't match"
        predict = predict.contiguous().view(predict.shape[0], -1)
        target = target.contiguous().view(target.shape[0], -1)

        num = torch.sum(torch.mul(predict, target), dim=1) + self.smooth
        den = torch.sum(predict.pow(self.p) + target.pow(self.p), dim=1) + self.smooth

        loss = 1 - num / den

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        elif self.reduction == 'none':
            return loss
        else:
            raise Exception('Unexpected reduction {}'.format(self.reduction))


class DiceLoss(nn.Module):
    """Dice loss, need one hot encode input
    Args:
        weight: An array of shape [num_classes,]
        ignore_index: class index to ignore
        predict: A tensor of shape [N, C, *]
        target: A tensor of same shape with predict
        other args pass to BinaryDiceLoss
    Return:
        same as BinaryDiceLoss
    """
    def __init__(self, weight=None, ignore_index=None, **kwargs):
        super(DiceLoss, self).__init__()
        self.kwargs = kwargs
        self.weight = weight
        self.ignore_index = ignore_index

    def forward(self, predict, target):
        assert predict.shape == target.shape, 'predict & target shape do not match'
        dice = BinaryDiceLoss(**self.kwargs)
        total_loss = 0
        predict = F.softmax(predict, dim=1)

        for i in range(target.shape[1]):
            if i != self.ignore_index:
                dice_loss = dice(predict[:, i], target[:, i])
                if self.weight is not None:
                    assert self.weight.shape[0] == target.shape[1], \
                        'Expect weight shape [{}], get[{}]'.format(target.shape[1], self.weight.shape[0])
                    dice_loss *= self.weights[i]
                total_loss += dice_loss

        return total_loss/target.shape[1]

In [None]:
# Taken from https://stackoverflow.com/questions/72195156/correct-implementation-of-dice-loss-in-tensorflow-keras

# def dice_coef(y_true, y_pred, smooth):
#     y_true_f = K.flatten(y_true)
#     y_pred_f = K.flatten(y_pred)
#     intersection = K.sum(y_true_f * y_pred_f)
#     dice = (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
#     return dice

# def dice_coef_loss(y_true, y_pred, smooth):
#     return 1 - dice_coef(y_true, y_pred, smooth)

# def dice_coef_multilabel(y_true, y_pred, M, smooth=0.00001):
#     dice = 0
#     for index in range(M):
#         dice += dice_coef(y_true[:,:,:,index], y_pred[:,:,:,index], smooth)
#     return dice

In [None]:
def read_rgb_mask(img, num_classes=5):
    img_array = np.array(img)
    height, width, _ = img_array.shape
    output = np.zeros((num_classes, height, width), dtype=np.uint8)
    class_map = {
        (0, 0, 0): 0,
        (1, 1, 1): 1,
        (2, 2, 2): 2,
        (3, 3, 3): 3,
        (4, 4, 4): 4,
    }
    for class_value, class_index in class_map.items():
        mask = np.all(img_array == class_value, axis=-1)
        output[class_index][mask] = 1
    return output

In [None]:
def preprocess_data(image, mask, train):
    if train:
        image_transforms = A.Compose(
            [
                A.OneOf([
                    A.ToGray(),
                    A.HueSaturationValue(hue_shift_limit=3, sat_shift_limit=3, val_shift_limit=3),
                    A.RandomBrightnessContrast(brightness_limit=0.01, contrast_limit=0.01, brightness_by_max=False)
                ], p=0.2),
                ToTensorV2()
            ]
        )
    else:
        image_transforms = A.Compose(
            [
                ToTensorV2()
            ]
        )
    
    image = image_transforms(image=image)
    image = image['image']
    image = image / 255
    
    mask = read_rgb_mask(mask)
    mask = torch.Tensor(mask)
    
    return image, mask

In [None]:
def transform_mask(mask):
    class_indices = torch.argmax(mask, dim=1)
    return class_indices

In [None]:
def output_to_rgb(output, color_map=None):
    if color_map is None:
        color_map = {
            0: [0, 0, 0],
            1: [201,0,118],
            2: [34,97,38],
            3: [41,134,204],
            4: [116,71,0]
        }
    
    output = output.detach().cpu().numpy()
    output = np.argmax(output, axis=1)[0]
    
    height, width = output.shape
    rgb_image = np.zeros((height, width, 3), dtype=np.uint8)
    
    for class_idx, color in color_map.items():
        rgb_image[output == class_idx] = color
    
    return rgb_image

In [None]:
def visualize_model_output(model, input_image):
    model.eval()
    with torch.no_grad():
        output = model(input_image)
    
    rgb_image = output_to_rgb(output)
    
    plt.figure(figsize=(5, 5))
    plt.imshow(rgb_image)
    plt.axis('off')
    plt.show()

In [None]:
class CustomDataset(Dataset):
    def __init__(self, image_paths, target_paths, preprocess_fn, train=False):
        self.image_paths = image_paths
        self.target_paths = target_paths
        self.preprocess_fn = preprocess_fn
        self.train = train

    def __getitem__(self, index):
        image = Image.open(self.image_paths[index])
        mask = Image.open(self.target_paths[index])
        
        return self.preprocess_fn(np.array(image), np.array(mask, dtype=np.int8), train=self.train)

    def __len__(self):
        return len(self.image_paths)

In [None]:
batch_size = 8
num_epochs = 100
learning_rate = 0.001
num_classes = 5  # Including background
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
train_img_paths = glob.glob(os.path.join("landcover.ai.v1/train/image", "*.tif"))
train_img_paths = sorted(train_img_paths)

train_mask_paths = glob.glob(os.path.join("landcover.ai.v1/train/label", "*.tif"))
train_mask_paths = sorted(train_mask_paths)

val_img_paths = glob.glob(os.path.join("landcover.ai.v1/val/image", "*.tif"))
val_img_paths = sorted(val_img_paths)

val_mask_paths = glob.glob(os.path.join("landcover.ai.v1/val/label", "*.tif"))
val_mask_paths = sorted(val_mask_paths)

In [None]:
train_dataset = CustomDataset(
    image_paths=train_img_paths,
    target_paths=train_mask_paths,
    preprocess_fn=preprocess_data,
    train=True
)
val_dataset = CustomDataset(
    image_paths=val_img_paths,
    target_paths=val_mask_paths,
    preprocess_fn=preprocess_data,
)

train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True
)
val_loader = DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False
)

In [None]:
model = TransformerUNet(in_channels=3, out_channels=num_classes).to(device)
loss_fn = DiceLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
epoch_train_loss = []
epoch_val_loss = []

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}:")
    
    model.train()
    train_loss = 0
    
    for images, masks in tqdm(train_loader, desc="Training"):
        images, masks = images.to(device), masks.to(device)
        outputs = model(images)
        loss = loss_fn(outputs, masks)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
        train_loss += loss.item()

    train_loss = train_loss / len(train_loader)

    epoch_train_loss.append(train_loss)
    
    print(f"Train Loss: {train_loss:.4f}")

    torch.save(model.state_dict(), "landcover_seg_model.pth")
    
    # model.eval()
    # val_loss = 0
    
    # with torch.inference_mode():
    #     for images, masks in tqdm(val_loader, desc="Validation"):
    #         images, masks = images.to(device), masks.to(device)
    
    #         outputs = model(images)
    #         loss = loss_fn(outputs, masks)
    
    #         val_loss += loss.item()
    
    # val_loss = val_loss / len(val_loader)

    # epoch_val_loss.append(val_loss)
    
    # print(f"Validation Loss: {val_loss:.4f}")
    print("-----------------------------")

In [None]:
model.load_state_dict(torch.load("landcover_seg_model.pth"))

In [None]:
metric = MulticlassJaccardIndex(num_classes=num_classes).to(device)

In [None]:
model.eval()
val_iou = 0
with torch.inference_mode():
    for images, masks in tqdm(val_loader, desc="Validation"):
        images, masks = images.to(device), masks.to(device)

        outputs = model(images)

        iou = metric(outputs, transform_mask(masks))
        val_iou += iou
print(f"Validation IoU: {val_iou / len(val_loader)}")