## 0. Imports

In [1]:

# Trains an MNIST digit recognizer using PyTorch Lightning,
# and uses Mlflow to log metrics, params and artifacts
# NOTE: This example requires you to first install
# pytorch-lightning (using pip install pytorch-lightning)
#       and mlflow (using pip install mlflow).
#
# pylint: disable=arguments-differ
# pylint: disable=unused-argument
# pylint: disable=abstract-method
import pytorch_lightning as pl
import cv2
import random
import os
import torch
import urllib
import tarfile
import numpy as np
import albumentations as A
import albumentations.augmentations.functional as F
from albumentations.pytorch import ToTensorV2
from argparse import ArgumentParser
from collections import OrderedDict
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import LearningRateMonitor
from torchmetrics.functional import accuracy
from torch.nn import functional as F
from torch import nn
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import datasets, transforms

IMAGE_SIZE = 512
TRAIN = True #False
from pathlib import Path

data_path = Path('data/')
download_path = data_path/'BBBC05.tar.gz'
image_dir = data_path/'BBBC005_v1_images'
mask_dir = data_path/'BBBC005_v1_ground_truth'

## 1. Data

In [2]:
class CellDataset(Dataset):
    def __init__(self, images_directory, masks_directory, mask_filenames, transform=None):
        self.images_directory = images_directory
        self.masks_directory = masks_directory
        self.filenames = mask_filenames
        self.transform = transform

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):        
        filename = self.filenames[idx]
        
        image = cv2.imread(os.path.join(self.images_directory, filename))
        if image is None or image.size==0:
           while True:
             newidx=random.randint(0,len(self.filenames)-1)
             filename = self.filenames[newidx]
             image = cv2.imread(os.path.join(self.images_directory, filename))
             if image is not None:
               break
             if image.size!=0:
               break
          
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)

        mask = cv2.imread(os.path.join(self.masks_directory, filename), -1)#cv2.IMREAD_UNCHANGED)
        mask = mask.astype(np.float32)
        mask /= 255.0

        if self.transform is not None:
            transformed = self.transform(image=image, mask=mask)
            image = transformed["image"]
            mask = transformed["mask"]
            mask = torch.unsqueeze(mask,0)
           # mask=np.expand_dims(mask,axis=0)

        return image, mask

In [3]:
class CellDataModule(pl.LightningDataModule):
    def __init__(self, image_dir, mask_dir, max_images=None, **kwargs):
        """
        Initialization of inherited lightning data module
        """
        super().__init__()
        self.df_train = None
        self.df_val = None
        self.df_test = None
        self.train_data_loader = None
        self.val_data_loader = None
        self.test_data_loader = None
        self.args = kwargs
        self.max_images = max_images

        self.image_dir = image_dir 
        self.mask_dir = mask_dir
        self.filenames = None
        self.batch_size = 10
        self.num_workers = kwargs['num_workers'] if 'num_workers' in kwargs else 0

        self.transform_train = A.Compose([
            A.Resize(IMAGE_SIZE, IMAGE_SIZE,  always_apply=True),
            A.VerticalFlip(p=0.2),
            A.Blur(p=0.2),
            A.RandomBrightnessContrast(p=0.2),
            A.RandomSunFlare(p=0.2, src_radius=200),
            A.RandomShadow(p=0.2),
            A.RandomFog(p=0.2),
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            A.pytorch.ToTensorV2(transpose_mask=True)
            ])
        
        self.transform_valid = A.Compose([
            A.Resize(IMAGE_SIZE, IMAGE_SIZE, always_apply=True),
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 
            ToTensorV2(transpose_mask=True)
            ])
              
        self.transform_predict = A.Compose([
            A.Resize(IMAGE_SIZE, IMAGE_SIZE, always_apply=True),
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ToTensorV2()
            ])
  
    def prepare_data(self):

        # dset_url = "https://www.kaggle.com/datasets/vbookshelf/synthetic-cell-images-and-masks-bbbc005-v1"
        dset_url = "https://www.googleapis.com/drive/v3/files/1UJMyQmI8bZCWB2UKWqa0F00I7f_nz-2q?alt=media&key=AIzaSyDhmuR1Oj_myOqYXEXBQ0J3FN1-cwvR9zI"
        
        if not image_dir.exists() or not mask_dir.exists():
            if not data_path.exists():
                data_path.mkdir()
            if not download_path.exists():
                urllib.request.urlretrieve(dset_url , filename = download_path)
            with tarfile.open(download_path, 'r') as zip_ref:
                zip_ref.extractall(data_path)  

        import glob
        fnames=[]
        fnames_nofilter=[]
        widths=[]
        count = 0
        for file in os.listdir(self.mask_dir):
            if file.endswith(".TIF"): #
              fnames_nofilter.append(file)
              image = cv2.imread(os.path.join(self.image_dir, file))
              if image is not None and image.size!=0 and 0.9 <=image.shape[1]/image.shape[0]<1.4:
                  fnames.append(file)
                  widths.append(image.shape[1])
                  count += 1 
                  if self.max_images:
                    if count > self.max_images:
                        break
        print("Initial images: {}, keeping {}".format(len(fnames_nofilter), 
                                                      len(fnames)))          
        self.filenames = fnames #list(sorted(fnames))

        random.seed(43)
        random.shuffle(self.filenames)
        n_val = int(len(self.filenames)*0.2)
        self.filenames_valid = self.filenames[:n_val]
        self.filenames_train = self.filenames[n_val:]
        print("{} train, {} validation".format(len(self.filenames_train),
                                               len(self.filenames_valid)))


    def setup(self, stage:str):
        """
        Create train and valid datasets
        """

        print(f"In CellDataModule.setup; stage = {stage}")
        if stage == "fit":
            self.dataset_train = CellDataset(self.image_dir, self.mask_dir, 
                                             self.filenames_train, 
                                             transform=self.transform_train)
            
            self.dataset_valid = CellDataset(self.image_dir, self.mask_dir, 
                                             self.filenames_valid, 
                                             transform=self.transform_valid)
                        
        if stage == "test":
            raise NotImplemented
        
        if stage == "predict":
            raise NotImplemented

    def create_data_loader(self, dataset:Dataset):
        """
        Generic data loader function
        """
        return DataLoader(dataset, batch_size=self.batch_size, drop_last=False, 
                          shuffle=False, num_workers=self.num_workers,
                          worker_init_fn=self.worker_init)

    def train_dataloader(self):
        """
        :return: output - Train data loader for the given input
        """
        return self.create_data_loader(self.dataset_train)

    def val_dataloader(self):
        """
        :return: output - Validation data loader for the given input
        """
        return self.create_data_loader(self.dataset_valid)

    def test_dataloader(self):
        """
        :return: output - Test data loader for the given input
        """
        raise NotImplemented

    @staticmethod
    def worker_init(worker_id):
        np.random.seed(42 + worker_id)


In [4]:
#Test Data Loader

celldm = CellDataModule(image_dir, mask_dir)
celldm.prepare_data()
celldm.setup(stage="fit")

Initial images: 1200, keeping 1200
960 train, 240 validation
In CellDataModule.setup; stage = fit


## 2. Define Model Architecture

In [5]:
class UNet(nn.Module):

    def __init__(self, in_channels=3, out_channels=1, init_features=32):
        super(UNet, self).__init__()

        features = init_features
        self.encoder1 = UNet._block(in_channels, features, name="enc1")
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder2 = UNet._block(features, features * 2, name="enc2")
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")

        self.upconv4 = nn.ConvTranspose2d(
            features * 16, features * 8, kernel_size=2, stride=2
        )
        self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4")
        self.upconv3 = nn.ConvTranspose2d(
            features * 8, features * 4, kernel_size=2, stride=2
        )
        self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3")
        self.upconv2 = nn.ConvTranspose2d(
            features * 4, features * 2, kernel_size=2, stride=2
        )
        self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2")
        self.upconv1 = nn.ConvTranspose2d(
            features * 2, features, kernel_size=2, stride=2
        )
        self.decoder1 = UNet._block(features * 2, features, name="dec1")

        self.conv = nn.Conv2d(
            in_channels=features, out_channels=out_channels, kernel_size=1
        )

    def forward(self, x):
        enc1 = self.encoder1(x)        
        enc2 = self.encoder2(self.pool1(enc1))        
        enc3 = self.encoder3(self.pool2(enc2))
        enc4 = self.encoder4(self.pool3(enc3))
        bottleneck = self.bottleneck(self.pool4(enc4))

        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)

        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)

        return torch.sigmoid(self.conv(dec1))

    @staticmethod
    def _block(in_channels, features, name):
        return nn.Sequential(
            OrderedDict(
                [
                    (
                        name + "conv1",
                        nn.Conv2d(
                            in_channels=in_channels,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm1", nn.BatchNorm2d(num_features=features)),
                    (name + "relu1", nn.ReLU(inplace=True)),
                    (
                        name + "conv2",
                        nn.Conv2d(
                            in_channels=features,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm2", nn.BatchNorm2d(num_features=features)),
                    (name + "relu2", nn.ReLU(inplace=True)),
                ]
            )
        )

In [7]:
#dice loss function
class DiceLoss(torch.nn.Module):

    def __init__(self):
        super(DiceLoss, self).__init__()
        self.smooth = 1.0

    def forward(self, y_pred, y_true):
        assert y_pred.size() == y_true.size()
        intersection = (y_pred * y_true).sum()
        dsc = (2. * intersection + self.smooth) / (y_pred.sum() + y_true.sum() + self.smooth)
        return 1. - dsc

In [8]:
class BCELoss(torch.nn.Module):
      def __init__(self):
        super().__init__()
        
      def forward(self, y_pred, y_true):
        assert y_pred.size() == y_true.size()
        return torch.mean(y_true * torch.log(y_pred) + (1-y_true)*torch.log(1-y_pred))


In [9]:
class CellClassifier(pl.LightningModule):
    def __init__(self, in_channels=3, out_channels=1, init_features=32, lr = 0.0001):
        super().__init__()
        self.optimizer = None
        self.scheduler = None
        self.lr = lr
        self.unet = UNet()
        self.diceloss = DiceLoss()

    def forward(self, x):
        return self.unet(x)

    def training_step(self, train_batch, batch_idx):
        input, labels = train_batch
        pred = self.forward(input)
        loss = self.diceloss(pred, labels)
        self.log("train_step_loss", loss, 
                 on_step=True, on_epoch=False, prog_bar=True, logger=True)
        return {"loss": loss}

    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
        self.log("train_loss", avg_loss, 
                 on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)

    def validation_step(self, val_batch, batch_idx):
        input, labels = val_batch
        pred = self.forward(input)
        loss = self.diceloss(pred, labels)
        self.log("valid_step_loss", loss, 
                 on_step=True, on_epoch=False, prog_bar=True, logger=True)
        return {"valid_step_loss": loss}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["valid_step_loss"] for x in outputs]).mean()
        self.log("valid_loss", avg_loss, 
                 on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)

    def configure_optimizers(self):
        self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        self.scheduler = {
            "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer, mode="min", factor=0.2, patience=2, min_lr=self.lr*0.02, verbose=True,
            ),
            "monitor": "valid_loss",
        }
        return [self.optimizer], [self.scheduler]




## 3. Train Model

In [11]:
model = CellClassifier(init_features=2)
dm = CellDataModule(image_dir, mask_dir, max_images = 5, num_workers=0)
trainer = pl.Trainer(max_epochs=1)
# Train the model
trainer.fit(model, dm)


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: /Users/francis/Code/crim2023/lightning_logs

  | Name     | Type     | Params
--------------------------------------
0 | unet     | UNet     | 7.8 M 
1 | diceloss | DiceLoss | 0     
--------------------------------------
7.8 M     Trainable params
0         Non-trainable params
7.8 M     Total params
31.052    Total estimated model params size (MB)


Initial images: 6, keeping 6
5 train, 1 validation
In CellDataModule.setup; stage = fit


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.
