In [1]:
import os
from typing import Dict, List
import numpy as np
import cv2
import torch
import segmentation_models_pytorch as smp
import pytorch_lightning as pl
import torch.nn as nn
from pprint import pprint
from torch.utils.data import DataLoader
from glob import glob
from torch.utils.data import Dataset, DataLoader

torch.multiprocessing.set_sharing_strategy('file_system')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.environ["CUDA_VISIBLE_DEVICES"] = '0, 1, 2, 3'


class CODIPAIDataSet(Dataset):
    def __init__(self, imgs_dir: str, masks_dir: str):
        self.imgs_dir = imgs_dir
        self.masks_dir = masks_dir

        self.ids = [
            os.path.splitext(fi)[0]
            for fi in os.listdir(self.masks_dir)
            if not fi.startswith(".")
        ]

    def __len__(self) -> int:
        return len(self.ids)

    def __getitem__(self, i) -> Dict[str, torch.tensor]:
        idx = self.ids[i]

        img_path = os.path.join(self.imgs_dir, f"{idx}.npy")
        mask_path = os.path.join(self.masks_dir, f"{idx}.npy")

        img_file = glob(img_path)
        mask_file = glob(mask_path)

        assert (
                len(mask_file) == 1
        ), f"Either no mask or multiple masks found for the ID {idx}: {mask_file}"
        assert (
                len(img_file) == 1
        ), f"Either no image or multiple images found for the ID {idx}: {img_file}"

        image = np.load(img_file[0], allow_pickle=True)
#         print('1: ', image.shape)
        mask = np.load(mask_file[0], allow_pickle=True)
        mask = mask.squeeze()

        image = np.transpose(image, (2, 0, 1))
#         mask = np.transpose(mask, (2, 0, 1))

        return {
            "image": torch.tensor(image, dtype=torch.float32),
            "mask": torch.tensor(mask, dtype=torch.float32),
        }

class CODIPAIModel(pl.LightningModule):
    def __init__(
        self,
        arch="unetplusplus",
        encoder_name="efficientnet-b3",
        in_channels=3,
        out_classes=1,
        **kwargs,
    ):
        super().__init__()
        self.model = smp.create_model(
            arch,
            encoder_name=encoder_name,
            in_channels=in_channels,
            classes=out_classes,
            **kwargs,
        )
        
        # Add attention layers
#         self.self_attention = nn.MultiheadAttention(embed_dim=256, num_heads=8)
#         self.encoder_attention = nn.MultiheadAttention(embed_dim=256, num_heads=8)

        

        params = smp.encoders.get_preprocessing_params(encoder_name)
        self.register_buffer("std", torch.tensor(params["std"]).view(1, 3, 1, 1))
        self.register_buffer("mean", torch.tensor(params["mean"]).view(1, 3, 1, 1))

        self.loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)

    def forward(self, image):
        image = (image - self.mean) / self.std
        mask = self.model(image)
        return mask

    def shared_step(self, batch, stage):
        image = batch["image"]

        # Shape of the image should be (batch_size, num_channels, height, width)
        # if you work with grayscale images, expand channels dim to have [batch_size, 1, height, width]
        assert image.ndim == 4

        # Check that image dimensions are divisible by 32,
        # encoder and decoder connected by `skip connections` and usually encoder have 5 stages of
        # downsampling by factor 2 (2 ^ 5 = 32); e.g. if we have image with shape 65x65 we will have
        # following shapes of features in encoder and decoder: 84, 42, 21, 10, 5 -> 5, 10, 20, 40, 80
        # and we will get an error trying to concat these features
        h, w = image.shape[2:]
        assert h % 32 == 0 and w % 32 == 0

        mask = batch["mask"]        
        mask = mask.unsqueeze(1)
#         print(mask.shape)

        # Shape of the mask should be [batch_size, num_classes, height, width]
        # for binary segmentation num_classes = 1
        assert mask.ndim == 4
        print('aaa: ',mask.max())
        # Check that mask values in between 0 and 1, NOT 0 and 255 for binary segmentation
        assert mask.max() <= 1.0 and mask.min() >= 0

        logits_mask = self.forward(image)

        # Predicted mask contains logits, and loss_fn param `from_logits` is set to True
        loss = self.loss_fn(logits_mask, mask)

        # Lets compute metrics for some threshold
        # first convert mask values to probabilities, then
        # apply thresholding
        prob_mask = logits_mask.sigmoid()
        pred_mask = (prob_mask > 0.5).float()

        # We will compute IoU metric by two ways
        #   1. dataset-wise
        #   2. image-wise
        # but for now we just compute true positive, false positive, false negative and
        # true negative 'pixels' for each image and class
        # these values will be aggregated in the end of an epoch
        tp, fp, fn, tn = smp.metrics.get_stats(
            pred_mask.long(), mask.long(), mode="binary"
        )

        return {
            "loss": loss,
            "tp": tp,
            "fp": fp,
            "fn": fn,
            "tn": tn,
        }

    def shared_epoch_end(self, outputs, stage):
        # aggregate step metics
        tp = torch.cat([x["tp"] for x in outputs])
        fp = torch.cat([x["fp"] for x in outputs])
        fn = torch.cat([x["fn"] for x in outputs])
        tn = torch.cat([x["tn"] for x in outputs])

        # per image IoU means that we first calculate IoU score for each image
        # and then compute mean over these scores
        per_image_iou = smp.metrics.iou_score(
            tp, fp, fn, tn, reduction="micro-imagewise"
        )
        per_image_accuracy = smp.metrics.accuracy(
            tp, fp, fn, tn, reduction="micro-imagewise"
        )
        per_image_precision = smp.metrics.precision(
            tp, fp, fn, tn, reduction="micro-imagewise"
        )
        per_image_recall = smp.metrics.recall(
            tp, fp, fn, tn, reduction="micro-imagewise"
        )
        per_image_f1 = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro-imagewise")

        # dataset IoU means that we aggregate intersection and union over whole dataset
        # and then compute IoU score. The difference between dataset_iou and per_image_iou scores
        # in this particular case will not be much, however for dataset
        # with "empty" images (images without target class) a large gap could be observed.
        # Empty images influence a lot on per_image_iou and much less on dataset_iou.
        dataset_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
        dataset_accuracy = smp.metrics.accuracy(tp, fp, fn, tn, reduction="micro")
        dataset_precision = smp.metrics.precision(tp, fp, fn, tn, reduction="micro")
        dataset_recall = smp.metrics.recall(tp, fp, fn, tn, reduction="micro")
        dataset_f1 = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro")

        metrics = {
            f"{stage}_per_image_iou": per_image_iou,
            f"{stage}_per_image_accuracy": per_image_accuracy,
            f"{stage}_per_image_precision": per_image_precision,
            f"{stage}_per_image_recall": per_image_recall,
            f"{stage}_per_image_f1": per_image_f1,
            f"{stage}_dataset_iou": dataset_iou,
            f"{stage}_dataset_accuracy": dataset_accuracy,
            f"{stage}_dataset_precision": dataset_precision,
            f"{stage}_dataset_recall": dataset_recall,
            f"{stage}_dataset_f1": dataset_f1,
        }

        self.log_dict(metrics, prog_bar=True)
        pprint(metrics)

    def training_step(self, batch, batch_idx):
        return self.shared_step(batch, "train")

    def training_epoch_end(self, outputs):
        return self.shared_epoch_end(outputs, "train")

    def validation_step(self, batch, batch_idx):
        return self.shared_step(batch, "valid")

    def validation_epoch_end(self, outputs):
        return self.shared_epoch_end(outputs, "valid")

    def test_step(self, batch, batch_idx):
        return self.shared_step(batch, "test")

    def test_epoch_end(self, outputs):
        return self.shared_epoch_end(outputs, "test")

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.0001)


if __name__ == "__main__":
    train_dataloaders = DataLoader(
        CODIPAIDataSet(
            imgs_dir='./data/data2/train_512/image/',
            masks_dir='./data/data2/train_512/mask/'),
        batch_size=16,
        num_workers=4,
    )

#     val_dataloaders = DataLoader(
#         CODIPAIDataSet(
#             imgs_dir='../Image_Segmentation_2/dataset_1/valid/',
#             masks_dir='../Image_Segmentation_2/dataset_1/valid_GT'),
#         batch_size=10,
#         num_workers=8,    
#     )

    trainer = pl.Trainer(
        accelerator="gpu",
        devices=1,
        max_epochs=50,
        log_every_n_steps=1,
        default_root_dir="./new_codipai/pan",
    )

model = CODIPAIModel("pan", "efficientnet-b3", in_channels=3, out_classes=1)
trainer.fit(model, train_dataloaders, train_dataloaders)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0, 1, 2, 3]

  | Name    | Type     | Params
-------------------------------------
0 | model   | PAN      | 21.5 M
1 | loss_fn | DiceLoss | 0     
-------------------------------------
21.5 M    Trainable params
0         Non-trainable params
21.5 M    Total params
85.903    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

aaa:  tensor(1., device='cuda:0')


  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
