In [None]:
pip install -Uq segmentation-models-pytorch

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.encoders import get_preprocessing_fn
import cv2
import numpy as np
import albumentations as A
import matplotlib.pyplot as plt

In [None]:
cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)


class CityscapesDataset(torchvision.datasets.Cityscapes):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs, target_type="semantic")
        self.semantic_target_type_index = [i for i, t in enumerate(self.target_type) if t == "semantic"][0]
        self.colormap = self._generate_colormap()

    def _generate_colormap(self):
        colormap = {}
        for class_ in self.classes:
            if class_.train_id in (-1, 255):
                continue
            colormap[class_.train_id] = class_.id
        return colormap

    def _convert_to_segmentation_mask(self, mask):
        height, width = mask.shape[:2]
        segmentation_mask = np.full((height, width), len(self.colormap))
        for label_index, label in self.colormap.items():
            segmentation_mask[mask == label] = label_index
        return segmentation_mask

    def to_color_mask(self, segmentation_mask):
        height, width = segmentation_mask.shape[-2:]
        color_mask = np.zeros((height, width, 3), dtype=np.uint8)
        for label_index, label in self.colormap.items():
            color_mask[segmentation_mask == label_index] = self.classes[label].color
        return color_mask    
    
    def __getitem__(self, index):
        image = cv2.imread(self.images[index])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.targets[index][self.semantic_target_type_index], cv2.IMREAD_UNCHANGED)

        mask = self._convert_to_segmentation_mask(mask)

        if self.transform is not None:
            transformed = self.transform(image=image, mask=mask)
            image = transformed["image"]
            mask = transformed["mask"]

        return image, mask

In [None]:
model = smp.Unet(
    encoder_name="efficientnet-b0",
    encoder_weights="imagenet",
    in_channels=3,
    classes=20
)

preprocess_input = get_preprocessing_fn("efficientnet-b0", pretrained="imagenet")

In [None]:
def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing(preprocessing_fn):
    _transform = [
        A.Lambda(name="image_preprocessing", image=preprocessing_fn),
        A.Lambda(name="to_tensor", image=to_tensor),
    ]
    return A.Compose(_transform)

In [None]:
transform = A.Compose([
    A.RandomCrop(384, 384),
    get_preprocessing(preprocess_input)
])

train_dataset = CityscapesDataset("../input/cityscapes/cityscapes", 
                                  split="train", mode="fine", transform=transform)

valid_dataset = CityscapesDataset("../input/cityscapes/cityscapes", 
                                  split="val", mode="fine", transform=transform)

img, smnt = train_dataset[1]

In [None]:
for i in range(3):
    _, mask = train_dataset[i]
    plt.imshow(train_dataset.to_color_mask(mask))
    plt.show()

In [None]:
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=2)
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=2)

In [None]:
loss = nn.CrossEntropyLoss()
loss.__name__ = 'ce_loss'

metrics = []

optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.001),
])

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=device,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=device,
    verbose=True,
)

In [None]:
for i in range(0, 3):
    print('\nEpoch: {}'.format(i))
    train_logs = train_epoch.run(train_loader)
    valid_logs = valid_epoch.run(valid_loader)

In [None]:
img, smnt = next(iter(valid_loader))
pred = model(img.to(device))

tp, fp, fn, tn = smp.metrics.get_stats(pred.argmax(axis=1), smnt.to(device), mode='multiclass', num_classes=19, ignore_index=19)
smp.metrics.functional.iou_score(tp, fp, fn, tn, reduction='micro-imagewise')

In [None]:
flattened = torch.argmax(pred, axis=1).cpu().numpy()[0]
plt.imshow(valid_dataset.to_color_mask(flattened))
plt.show()
plt.imshow(valid_dataset.to_color_mask(smnt.cpu().numpy()[0]))
plt.show()