In [1]:
%cd /content/drive/MyDrive/Colab\ Notebooks

/content/drive/MyDrive/Colab Notebooks


In [2]:
%cd Thesis/eyediease/

/content/drive/MyDrive/Colab Notebooks/Thesis/eyediease


In [3]:
!pip install -q catalyst==20.12
!pip install -q pytorch-toolbelt==0.4.2
!pip install -q segmentation-models-pytorch==0.1.3
!pip install -q albumentations==0.4.6

In [4]:
from pathlib import Path
import os
import torch
from torch.optim import SGD, Adam, RMSprop, AdamW
from torch.optim.lr_scheduler import (
    ExponentialLR,
    CyclicLR,
    MultiStepLR,
    CosineAnnealingLR,
    CosineAnnealingWarmRestarts,
    ReduceLROnPlateau
)
from torch.utils.data import Dataset, DataLoader
from catalyst import dl
from catalyst.dl import SupervisedRunner, CriterionCallback, EarlyStoppingCallback, SchedulerCallback, MetricAggregationCallback, IouCallback, DiceCallback
from catalyst import utils
from pytorch_toolbelt.losses import *
from catalyst.core import Callback, CallbackOrder
from sklearn.metrics import average_precision_score, roc_auc_score, roc_curve, precision_recall_curve, auc
from torch import Tensor

import albumentations as A
import segmentation_models_pytorch as smp

from PIL import Image
import numpy as np
from sklearn.model_selection import train_test_split

from collections import OrderedDict

In [5]:
img_path = Path(os.getcwd()) / 'data/raw/FGADR/Seg-set/Original_Images'
ex_path = Path(os.getcwd()) / 'data/raw/FGADR/Seg-set/HardExudate_Masks'

In [6]:
print(len(list(img_path.glob('*.png'))))
print(len(list(ex_path.glob('*.png'))))

1842
1842


In [7]:
train_idx, test_idx = train_test_split(range(1842), test_size=0.2, random_state=1999)

In [8]:
train_idx, val_idx = train_test_split(range(len(train_idx)), test_size=0.2, random_state=1999)

In [9]:
class Lesion(Dataset):
  def __init__(self, img_path, ex_path, transform, preprocessing_fn):
    self.img_path = img_path
    self.ex_path = ex_path
    self.transform = transform
    self.preprocessing_fn = preprocessing_fn

  def __len__(self):
    return len(self.img_path)

  def __getitem__(self, idx):
    img = Image.open(self.img_path[idx])
    mask = Image.open(self.ex_path[idx])
    mask = mask.convert('1')
    img = np.asarray(img)
    mask = np.asarray(mask).astype(np.float32)

    if self.transform:
      result = self.transform(image=img, mask=mask)
      img = result['image']
      mask = result['mask']
    if self.preprocessing_fn:
      img = self.preprocessing_fn(img)
    
    img = torch.FloatTensor(img).permute(2, 0, 1)
    mask = torch.FloatTensor(mask)
    mask = torch.unsqueeze(mask, dim=0)
    filename = self.img_path[idx].name

    return {
        'image': img,
        'target': mask,
        'image_id': filename
    }

In [10]:
images = sorted(list(img_path.glob('*.png')))
mask = sorted(list(ex_path.glob('*.png')))

images = np.array(images)
mask = np.array(mask)

train_imgs = images[train_idx]
train_mask = mask[train_idx]

val_imgs = images[val_idx]
val_mask = mask[val_idx]

test_imgs = images[test_idx]
test_mask = mask[test_idx]

In [11]:
print(len(train_imgs), len(train_mask))
print(len(val_imgs), len(val_mask))
print(len(test_imgs), len(test_mask))

1178 1178
295 295
369 369


In [12]:
base_train_transform = A.Compose([
                                  A.HorizontalFlip(),
                                  A.RandomRotate90(),
                                  A.RandomResizedCrop(1280, 1280),
                                  A.Resize(1024, 1024)
])

base_val_transform = A.Compose([A.Resize(1024, 1024)])

In [13]:
def get_preprocessing_fn(dataset_name: str):
    if dataset_name == "IDRiD":
        mean = [0.44976714,0.2186806,0.06459363]
        std = [0.33224553,0.17116262,0.086509705]
    elif dataset_name == 'FGADR':
        mean = [0.4554011,0.2591345,0.13285689]
        std = [0.28593522,0.185085,0.13528904]
    else:
        mean, std = None, None

    def preprocessing(x, mean=mean, std=std, **kwargs):
        x = x / 255.0
        if mean is not None:
            mean = np.array(mean)
            x = x - mean

        if std is not None:
            std = np.array(std)
            x = x / std
        return x

    return preprocessing

In [14]:
preprocessing_fn = get_preprocessing_fn('FGADR')

In [15]:
train_ds = Lesion(list(train_imgs), list(train_mask), base_train_transform, preprocessing_fn)
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)
val_ds = Lesion(list(val_imgs), list(val_mask), base_val_transform, preprocessing_fn)
val_loader = DataLoader(val_ds, batch_size=4, shuffle=False, num_workers=2, pin_memory=True)

In [16]:
loader = OrderedDict()
loader['train'] = train_loader
loader['valid'] = val_loader

In [17]:
model = smp.Unet(encoder_name='resnet50', encoder_weights='imagenet', encoder_depth=5, decoder_use_batchnorm=True, decoder_attention_type='scse', in_channels=3, classes=1)

In [18]:
param_group = []
if hasattr(model, 'encoder'):
    encoder_params = filter(lambda p: p.requires_grad, model.encoder.parameters())
    param_group += [{'params': encoder_params, 'lr': 1e-5}]        
if hasattr(model, 'decoder'):
    decoder_params = filter(lambda p: p.requires_grad, model.decoder.parameters())
    param_group += [{'params': decoder_params}]        
if hasattr(model, 'segmentation_head'):
    head_params = filter(lambda p: p.requires_grad, model.segmentation_head.parameters())
    param_group += [{'params': head_params}]        

optimizer = AdamW(param_group, 1e-3, weight_decay=1e-5)

In [19]:
scheduler = MultiStepLR(optimizer, milestones=[int(60 * 0.4), int(60 * 0.7)], gamma=0.1)

In [20]:
from pytorch_toolbelt.utils import to_numpy
from pytorch_toolbelt.utils.distributed import all_gather
from typing import Callable, Optional

class AucPRMetricCallback(Callback):
    """
    Auc Precision-Recall score metric
    """

    def __init__(
        self,
        outputs_to_probas: Callable[[Tensor], Tensor] = torch.sigmoid,
        input_key: str = "targets",
        output_key: str = "logits",
        prefix: str = "auc_pr",
        average="macro",
        ignore_index: Optional[int] = None,
    ):
        """
        Args:
            input_key: input key to use for accuracy calculation;
                specifies our `y_true`
            output_key: output key to use for accuracy calculation;
                specifies our `y_pred`
            prefix: key for the metric's name
        """

        super().__init__(CallbackOrder.Metric)
        self.prefix = prefix
        self.output_key = output_key
        self.input_key = input_key
        self.ignore_index = ignore_index
        self.outputs_to_probas = outputs_to_probas
        self.y_trues = []
        self.y_preds = []
        self.average = average

    def on_loader_start(self, state):
        self.y_trues = []
        self.y_preds = []

    @torch.no_grad()
    def on_batch_end(self, runner):
        pred_probas = self.outputs_to_probas(runner.output[self.output_key])
        true_labels = runner.input[self.input_key]

        self.y_trues.extend(to_numpy(true_labels))
        self.y_preds.extend(to_numpy(pred_probas))

    def on_loader_end(self, runner):
        y_trues = np.concatenate(all_gather(self.y_trues))
        y_preds = np.concatenate(all_gather(self.y_preds))
        precision, recall, _ = precision_recall_curve(y_trues.reshape(-1), y_preds.reshape(-1))
        score = auc(recall, precision)
        runner.loader_metrics[self.prefix] = float(score)

In [21]:
criterion = {}
criterion['bce'] = SoftBCEWithLogitsLoss()
criterion['log_dice'] = DiceLoss(mode="binary", log_loss=True)

In [22]:
callbacks = []
criterion_config = {'bce':0.8, 'log_dice':0.2}
losses = []
for loss_name, loss_weight in criterion_config.items():
    criterion_callback = CriterionCallback(
        input_key="target",
        output_key="logits",
        criterion_key=loss_name,
        prefix="loss_"+loss_name,
        multiplier=float(loss_weight)
    )

    callbacks.append(criterion_callback)
    losses.append(criterion_callback.prefix)

callbacks += [MetricAggregationCallback(
    prefix="loss",
    mode="sum",
    metrics=losses
)]

early_stopping = EarlyStoppingCallback(
        patience=10, metric='dice', minimize=False)

# iou_scores = IouCallback(
#     input_key="target",
#     activation="Sigmoid",
#     threshold=0.5
# )

dice_scores = DiceCallback(
    input_key="target",
    activation="Sigmoid",
threshold=0.5
)

# aucpr_scores = AucPRMetricCallback(
#     input_key="target",
# )

In [23]:
callbacks += [ early_stopping,
               dice_scores]

In [26]:
log_dir = 'test_fgadr_2'
os.makedirs(log_dir, exist_ok=False)

In [27]:
runner = SupervisedRunner(
    device=utils.get_device(), input_key="image", input_target_key="target")

runner.train( 
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    callbacks=callbacks,
    logdir=log_dir,
    loaders=loader,
    num_epochs=60,
    scheduler=scheduler,
    main_metric='dice',
    minimize_metric=False,
    timeit=True,
    fp16=dict(amp=True),
    resume=None,
    verbose=True
)

1/60 * Epoch (train): 100% 294/294 [06:09<00:00,  1.26s/it, _timer/_fps=3.225, dice=0.285, loss=0.324, loss_bce=0.062, loss_log_dice=0.262]
1/60 * Epoch (valid): 100% 74/74 [01:55<00:00,  1.57s/it, _timer/_fps=13.361, dice=0.177, loss=0.366, loss_bce=0.005, loss_log_dice=0.361]
[2021-04-21 16:17:45,046] 
1/60 * Epoch 1 (_base): lr=1.000e-05 | momentum=0.9000
1/60 * Epoch 1 (train): _timer/_fps=3.1943 | _timer/batch_time=1.2530 | _timer/data_time=0.9498 | _timer/model_time=0.3032 | dice=0.2956 | loss=0.4275 | loss_bce=0.0751 | loss_log_dice=0.3524
1/60 * Epoch 1 (valid): _timer/_fps=7.7263 | _timer/batch_time=1.5703 | _timer/data_time=1.2840 | _timer/model_time=0.2864 | dice=0.3647 | loss=0.2856 | loss_bce=0.0350 | loss_log_dice=0.2506
2/60 * Epoch (train): 100% 294/294 [06:16<00:00,  1.28s/it, _timer/_fps=3.199, dice=0.646, loss=0.108, loss_bce=0.014, loss_log_dice=0.094]
2/60 * Epoch (valid): 100% 74/74 [00:36<00:00,  2.04it/s, _timer/_fps=13.751, dice=0.226, loss=0.338, loss_bce=0.00

Traceback (most recent call last):
  File "/usr/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe


5/60 * Epoch (valid): 100% 74/74 [00:36<00:00,  2.03it/s, _timer/_fps=13.942, dice=0.306, loss=0.260, loss_bce=0.004, loss_log_dice=0.256]
[2021-04-21 16:45:52,763] 
5/60 * Epoch 5 (_base): lr=1.000e-05 | momentum=0.9000
5/60 * Epoch 5 (train): _timer/_fps=3.1858 | _timer/batch_time=1.2570 | _timer/data_time=0.9533 | _timer/model_time=0.3037 | dice=0.4202 | loss=0.2443 | loss_bce=0.0419 | loss_log_dice=0.2023
5/60 * Epoch 5 (valid): _timer/_fps=9.8458 | _timer/batch_time=0.4937 | _timer/data_time=0.2213 | _timer/model_time=0.2724 | dice=0.4797 | loss=0.2121 | loss_bce=0.0212 | loss_log_dice=0.1909
6/60 * Epoch (train): 100% 294/294 [06:10<00:00,  1.26s/it, _timer/_fps=3.197, dice=0.431, loss=0.194, loss_bce=0.017, loss_log_dice=0.178]
6/60 * Epoch (valid): 100% 74/74 [00:36<00:00,  2.04it/s, _timer/_fps=13.807, dice=0.294, loss=0.277, loss_bce=0.003, loss_log_dice=0.274]
[2021-04-21 16:52:54,561] 
6/60 * Epoch 6 (_base): lr=1.000e-05 | momentum=0.9000
6/60 * Epoch 6 (train): _timer/_fp

In [28]:
test_ds = Lesion(list(test_imgs), list(test_mask), base_val_transform, preprocessing_fn)
test_loader = DataLoader(test_ds, batch_size=4, num_workers=2)

In [34]:
dice_score = smp.utils.metrics.Fscore(threshold=0.5)

In [32]:
predictions = []
for batch_logits in runner.predict_loader(loader=test_loader):
  logits = batch_logits['logits'].detach().cpu()
  batch_predictions = logits.sigmoid()
  predictions.append(batch_predictions)

In [42]:
test_dice = 0
i = 0
for (batch, prediction) in zip(test_ds, predictions):
  dice_batch = dice_score(prediction, batch['target']).numpy()
  print(dice_batch)
  test_dice += dice_batch
  i +=1

test_dice /= i

4.454343e-11
1.2172855e-11
4.157313e-12
3.9105273e-12
2.1285654e-11
0.014533853
0.029219348
0.0016998027
0.01685484
2.5680534e-11
1.3008977e-11
3.4590109e-12
7.402473e-12
0.10196934
0.036270693
0.0019421247
8.616901e-13
4.2269e-12
4.9531924e-12
1.0411244e-11
1.1623852e-11
4.131378e-12
0.0050251256
4.0165484e-12
5.095022e-12
4.662222e-12
0.0515735
3.0774912e-12
9.443763e-12
0.063821524
2.8129396e-11
1.4146273e-11
0.0032785032
0.0028781632
1.5581179e-11
3.940421e-12
8.1566066e-11
5.0208366e-12
1.9116438e-12
2.266289e-12
6.686727e-12
4.3177892e-11
9.228497e-12
5.062522e-12
0.0074025784
7.902015e-12
1.1997889e-12
2.9359952e-11
0.088656865
0.000733326
8.837045e-12
1.43266475e-11
0.034032684
0.0027820012
1.530222e-11
1.473405e-11
0.0050474284
6.9637885e-12
6.6702243e-12
3.7002775e-12
4.842615e-12
0.0071260273
3.7069987e-12
3.7328756e-12
3.101641e-12
2.7410027e-12
0.017725058
0.005453837
9.057971e-12
4.30552e-12
7.160144e-13
1.8319716e-12
5.201831e-12
1.1818934e-11
0.0012527298
0.0057698474
2

In [41]:
print(test_dice)

0.0065523723230245965
