In [None]:
import os
import torch
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import segmentation_models_pytorch as smp

from pprint import pprint
from torch.utils.data import DataLoader

In [None]:
import pandas as pd
from glob import glob

img_fnames = sorted([f[21:] for f in glob('fashion_segmentation/png_images/IMAGES/*.png')])

mask_fnames = [f'png_masks/MASKS/seg_{f[-8:-4]}.png' for f in img_fnames]
for f in mask_fnames:
    if not os.path.exists(os.path.join('fashion_segmentation', f)):
        print(f)

fnames = pd.DataFrame(data={'img': img_fnames, 'mask': mask_fnames})

seed = 1337

from sklearn.model_selection import train_test_split

train, val = train_test_split(fnames, test_size=0.1)

train.to_csv('fashion_segmentation/train.csv', header=None, index=None)
val.to_csv('fashion_segmentation/val.csv', header=None, index=None)


In [None]:
import numpy as np
import cv2


class FashionSegmentationDataset(torch.utils.data.Dataset):
    def __init__(self, root, mode="train", transform=None):
        assert mode in {"train", "val"}
        self.root = root
        self.mode = mode
        self.transform = transform
        self.filenames = self._read_split()  # read train/val split

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        image_filename, mask_filename = self.filenames[idx]
        image = cv2.imread(os.path.join(self.root, image_filename))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(os.path.join(self.root, mask_filename), cv2.IMREAD_UNCHANGED)
        # mask = preprocess_mask(mask)
        if self.transform is not None:
            transformed = self.transform(image=image, mask=mask)
            image = transformed["image"]
            mask = transformed["mask"]
        return dict(image=image, mask=mask)

    @staticmethod
    def _preprocess_mask(mask):
        mask = mask.astype(np.float32)
        mask[mask == 2.0] = 0.0
        mask[(mask == 1.0) | (mask == 3.0)] = 1.0
        return mask

    def _read_split(self):
        split_filename = "val.csv" if self.mode == "val" else "train.csv"
        split_path = os.path.join(self.root, split_filename)
        with open(split_path) as f:
            split_data = f.read().strip("\n").split("\n")
        filenames = [tuple(x.split(",")) for x in split_data]
        return filenames


Dataset checklist:
[x] H and W are divisible by 32
[x] CHW axes order
[x] no leaks

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

train_transform = A.Compose(
    [
        A.SmallestMaxSize(256),
        A.RandomCrop(256, 256),
        A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
        A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
        A.RandomBrightnessContrast(p=0.5),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ]
)

train_dataset = FashionSegmentationDataset('fashion_segmentation', mode='train', transform=train_transform)

val_transform = A.Compose(
    [
        A.SmallestMaxSize(256),
        A.CenterCrop(256, 256),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ]
)

val_dataset = FashionSegmentationDataset('fashion_segmentation', mode='val', transform=val_transform)

In [None]:
import copy

def visualize_augmentations(dataset, idx=0, samples=5):
    dataset = copy.deepcopy(dataset)
    dataset.transform = A.Compose([t for t in dataset.transform if not isinstance(t, (A.Normalize, ToTensorV2))])
    figure, ax = plt.subplots(nrows=samples, ncols=2, figsize=(10, 24))
    for i in range(samples):
        sample = dataset[idx]
        ax[i, 0].imshow(sample["image"])
        ax[i, 1].imshow(sample["mask"], interpolation="nearest")
        ax[i, 0].set_title("Augmented image")
        ax[i, 1].set_title("Augmented mask")
        ax[i, 0].set_axis_off()
        ax[i, 1].set_axis_off()
    plt.tight_layout()
    plt.show()

In [None]:
# visualize_augmentations(val_dataset)

In [None]:
assert set(train_dataset.filenames).isdisjoint(set(val_dataset.filenames))

print(f"Train size: {len(train_dataset)}")
print(f"Val size: {len(val_dataset)}")

n_cpu = os.cpu_count()
batch_size = 16
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=n_cpu)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=n_cpu)

In [None]:
class FashionModel(pl.LightningModule):

    def __init__(self, arch, encoder_name, in_channels, out_classes, **kwargs):
        super().__init__()
        self.model = smp.create_model(arch, encoder_name=encoder_name, in_channels=in_channels, classes=out_classes, **kwargs)
        params = smp.encoders.get_preprocessing_params(encoder_name)
        self.loss_fn = smp.losses.DiceLoss(smp.losses.MULTICLASS_MODE, from_logits=True)

    def forward(self, x):
        x = self.model(x)
        return x

    def common_step(self, batch, stage):
        image = batch["image"]
        assert image.ndim == 4 # Shape of the image should be (batch_size, num_channels, height, width)
        assert image.shape[2] % 32 == 0 and image.shape[3] % 32 == 0 # Check that image dimensions are divisible by 32 to comply with network's downscaling factor

        mask = batch["mask"].long()
        assert mask.ndim == 3 # Shape of the mask should be [batch_size, num_classes, height, width]

        logits_mask = self.forward(image)
        loss = self.loss_fn(logits_mask, mask) # Predicted mask contains logits, and loss_fn param `from_logits` is set to True

        prob_mask = logits_mask.sigmoid() # convert mask values to probabilities
        pred_mask = (prob_mask > 0.5).float() # apply thresholding

        tp, fp, fn, tn = smp.metrics.get_stats(torch.argmax(pred_mask, dim=1).long(), mask.long(), mode="multiclass", num_classes=59)
        return {"loss": loss, "tp": tp, "fp": fp, "fn": fn, "tn": tn}

    def common_epoch_end(self, outputs, stage):
        tp = torch.cat([x["tp"] for x in outputs])
        fp = torch.cat([x["fp"] for x in outputs])
        fn = torch.cat([x["fn"] for x in outputs])
        tn = torch.cat([x["tn"] for x in outputs])

        per_image_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro-imagewise") # calculate IoU for each image and then compute mean over these scores
        dataset_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro") # aggregate intersection and union over whole dataset and then compute IoU score

        metrics = {f"{stage}_per_image_iou": per_image_iou,
                   f"{stage}_dataset_iou": dataset_iou}
        self.log_dict(metrics, prog_bar=True)

    def training_step(self, batch, batch_idx):
        return self.common_step(batch, "train")

    def training_epoch_end(self, outputs):
        return self.common_epoch_end(outputs, "train")

    def validation_step(self, batch, batch_idx):
        return self.common_step(batch, "valid")

    def validation_epoch_end(self, outputs):
        return self.common_epoch_end(outputs, "valid")

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.0001)

The difference between dataset_iou and per_image_iou scores in this particular case will not be much, however for dataset with "empty" images (images without target class) a large gap could be observed. Empty images influence a lot on per_image_iou and much less on dataset_iou.

In [None]:
model = FashionModel("FPN", "resnet34", in_channels=3, out_classes=59)

In [None]:
trainer = pl.Trainer(
    gpus=1,
    max_epochs=10,
)

trainer.fit(
    model,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader,
)

In [None]:
# run validation dataset
valid_metrics = trainer.validate(model, dataloaders=val_dataloader, verbose=False)
pprint(valid_metrics)

In [None]:
batch = next(iter(val_dataloader))
with torch.no_grad():
    model.eval()
    logits = model(batch["image"])
pr_masks = torch.argmax(logits.sigmoid(), dim=1)

figure, ax = plt.subplots(nrows=batch_size, ncols=3, figsize=(10, 24))
for i, (image, gt_mask, pr_mask) in enumerate(zip(batch["image"], batch["mask"], pr_masks)):
    ax[i, 0].imshow(image.numpy().transpose(1, 2, 0))
    ax[i, 1].imshow(gt_mask.numpy(), interpolation="nearest")
    ax[i, 2].imshow(pr_mask.numpy(), interpolation="nearest")
    ax[i, 0].set_title("image")
    ax[i, 1].set_title("gt mask")
    ax[i, 2].set_title("pr mask")
    ax[i, 0].set_axis_off()
    ax[i, 1].set_axis_off()
    ax[i, 2].set_axis_off()
plt.tight_layout()
plt.show()
