In [None]:
import torch 
import torch.nn as nn
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import cv2
import numpy as np
from typing import Any
import lightning.pytorch as pl
import segmentation_models_pytorch as smp
import os 
import random
from torch.utils.data import Dataset, DataLoader

In [None]:
class cfg:
    device = "cpu"
    model_name = "Unet"
    encoder_name = "resnet34"
    dataset_path =  "../dataset"
    train_batch_size = 8
    valid_batch_size = 8
    num_epochs = 5

In [None]:
class HandDataset:
    def __init__(self, image_paths, transforms = None):
        self.image_paths = image_paths
        self.transforms = transforms
        self.post_transforms = A.Compose([
                                A.Normalize(),
                                ToTensorV2(),
                            ])
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = cv2.imread(image_path + "/" + image_path.split("/")[-1] + ".jpg")
        mask = cv2.imread(image_path + "/" + image_path.split("/")[-1] + "_mask.jpg", cv2.IMREAD_GRAYSCALE)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        transformed = self.transforms(image=image, mask=mask)
        transformed_image = transformed['image']
        transformed_mask = transformed['mask']

        transformed_image = self.post_transforms(image = transformed_image)["image"]

        mask = (transformed_mask > 0).astype(np.uint8)
        mask = torch.from_numpy(mask).unsqueeze(0).float()

        return {"image": transformed_image, "mask": mask}

In [None]:
class HandModel(pl.LightningModule):
    def __init__(
            self,
            model_name: str = "FPN",
            encoder_name: str = "resnet34", 
            in_channels: int = 3 , 
            out_classes: int = 1,
            **kwargs) -> None:
        super().__init__()
        if model_name == "FPN":
            self.model = smp.FPN(encoder_name, in_channels = in_channels, classes = out_classes, **kwargs)
        elif model_name == "Unet":
            self.model = smp.Unet(encoder_name, in_channels = in_channels, classes = out_classes, **kwargs)

        # for image segmentation dice loss could be the best first choice
        self.loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)

        self.training_step_ouputs = []
        self.validation_step_outputs = []
    
    def forward(self, image):
        return self.model(image)
    
    def shared_step(self, batch, stage):
        
        image = batch["image"]

        # Shape of the image should be (batch_size, num_channels, height, width)
        # if you work with grayscale images, expand channels dim to have [batch_size, 1, height, width]
        assert image.ndim == 4

        # Check that image dimensions are divisible by 32, 
        # encoder and decoder connected by `skip connections` and usually encoder have 5 stages of 
        # downsampling by factor 2 (2 ^ 5 = 32); e.g. if we have image with shape 65x65 we will have 
        # following shapes of features in encoder and decoder: 84, 42, 21, 10, 5 -> 5, 10, 20, 40, 80
        # and we will get an error trying to concat these features
        h, w = image.shape[2:]
        assert h % 32 == 0 and w % 32 == 0

        mask = batch["mask"]

        # Shape of the mask should be [batch_size, num_classes, height, width]
        # for binary segmentation num_classes = 1
        assert mask.ndim == 4

        # Check that mask values in between 0 and 1, NOT 0 and 255 for binary segmentation
        assert mask.max() <= 1.0 and mask.min() >= 0

        logits_mask = self.forward(image)
        
        # Predicted mask contains logits, and loss_fn param `from_logits` is set to True
        loss = self.loss_fn(logits_mask, mask)

        # Lets compute metrics for some threshold
        # first convert mask values to probabilities, then 
        # apply thresholding
        prob_mask = logits_mask.sigmoid()
        pred_mask = (prob_mask > 0.5).float()

        # We will compute IoU metric by two ways
        #   1. dataset-wise
        #   2. image-wise
        # but for now we just compute true positive, false positive, false negative and
        # true negative 'pixels' for each image and class
        # these values will be aggregated in the end of an epoch
        tp, fp, fn, tn = smp.metrics.get_stats(pred_mask.long(), mask.long(), mode="binary")

        if stage == "train":
            self.training_step_ouputs.append({
                "loss": loss,
                "tp": tp,
                "fp": fp,
                "fn": fn,
                "tn": tn,
            })
        elif stage == "valid":
            self.validation_step_outputs.append({
                "loss": loss,
                "tp": tp,
                "fp": fp,
                "fn": fn,
                "tn": tn,
            })
        return {
            "loss": loss,
            "tp": tp,
            "fp": fp,
            "fn": fn,
            "tn": tn,
        }

    def shared_epoch_end(self, outputs, stage):
        # aggregate step metics
        tp = torch.cat([x["tp"] for x in outputs])
        fp = torch.cat([x["fp"] for x in outputs])
        fn = torch.cat([x["fn"] for x in outputs])
        tn = torch.cat([x["tn"] for x in outputs])

        # per image IoU means that we first calculate IoU score for each image 
        # and then compute mean over these scores
        per_image_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro-imagewise")
        
        # dataset IoU means that we aggregate intersection and union over whole dataset
        # and then compute IoU score. The difference between dataset_iou and per_image_iou scores
        # in this particular case will not be much, however for dataset 
        # with "empty" images (images without target class) a large gap could be observed. 
        # Empty images influence a lot on per_image_iou and much less on dataset_iou.
        dataset_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")

        metrics = {
            f"{stage}_per_image_iou": per_image_iou,
            f"{stage}_dataset_iou": dataset_iou,
        }
        
        self.log_dict(metrics, prog_bar=True)

    def training_step(self, batch, batch_idx):
        return self.shared_step(batch, "train")            

    def on_training_epoch_end(self):
        ret_val = self.shared_epoch_end(self.training_step_ouputs, "train")
        self.training_step_ouputs.clear()
        return ret_val

    def validation_step(self, batch, batch_idx):
        return self.shared_step(batch, "valid")

    def on_validation_epoch_end(self):
        ret_val = self.shared_epoch_end(self.validation_step_outputs, "valid")
        self.validation_step_outputs.clear()
        return ret_val

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.0001)

In [None]:
device = cfg.device
model = HandModel(cfg.model_name, cfg.encoder_name)
dataset_path = cfg.dataset_path
image_paths = [f"{dataset_path}/{x}" for x in os.listdir(dataset_path)]
random.shuffle(image_paths)

train_image_paths = image_paths[:int(0.8 * len(image_paths))]
val_image_paths = image_paths[int(0.8 * len(image_paths)):]
print(f"Training on {len(train_image_paths)} images")
print(f"Validating on {len(val_image_paths)} images")


train_transform = A.Compose([
    A.Resize(width=256, height=256),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.2)
])

val_transform = A.Compose([
    A.Resize(width=256, height=256),
])

train_dataset = HandDataset(train_image_paths, transforms=train_transform)
val_dataset = HandDataset(val_image_paths, transforms=val_transform)

train_loader = DataLoader(train_dataset, batch_size=cfg.train_batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=cfg.valid_batch_size, shuffle=False, num_workers=0)

trainer = pl.Trainer(accelerator=device, max_epochs=cfg.num_epochs, default_root_dir="../models/")
trainer.fit(model, train_loader, val_loader)