In [56]:
import sys
from pathlib import Path
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import wandb
from easydict import EasyDict
from pytorch_lightning.loggers import WandbLogger
from skimage import color
from torch import Generator
from torch.optim import SGD, RMSprop, Adam
from torch.utils.data import DataLoader, Dataset, random_split
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
from pytorch_toolbelt import losses as L

from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler, PopulationBasedTraining
from ray.tune.integration.pytorch_lightning import TuneReportCallback, TuneReportCheckpointCallback
pl.seed_everything(1)

Global seed set to 1


1

In [57]:
! wandb offline

W&B offline, running your script from this directory will only write metadata locally.


In [58]:
conf = EasyDict(
    dataset_folder=Path("/home/shamil/PycharmProjects/covid_segmentation/data/preprocessed"),
    batch_size=2,
    val_size=0.1,
)

max_epochs = 20


In [59]:
# Utils

def get_loss(cfg: dict):
    if cfg['loss_fn'] == 'cross_entropy':
        return nn.CrossEntropyLoss()
    elif cfg['loss_fn'] == "focal_loss":
        return smp.losses.FocalLoss(mode='binary')
    elif cfg['loss_fn'] == 'dice_loss':
        return smp.losses.DiceLoss(mode='binary')
    elif cfg['loss_fn'] == 'tversky_loss':
        return smp.losses.TverskyLoss(mode='binary')
    else:
        raise NotImplemented("Loss not found!")


def get_optimizer(cfg: dict, params, lr: float):
    if cfg['optim'] == 'SGD':
        return SGD(params, lr)
    elif cfg['optim'] == 'Adam':
        return Adam(params, lr)
    elif cfg['optim'] == 'RMSprop':
        return RMSprop(params, lr)
    else:
        raise NotImplemented("Optim not found!")


def dict_to_str(data: dict):
    result = []
    for key, value in data.items():
        line = f"{key}={value}"
        result.append(line)

    return ", ".join(result)


In [60]:
# Utils for visualising

def visualize_prediction(image: np.array, prediction: np.array, true: np.array):
    image = (image - image.min()) / (image.max() - image.min())
    image = color.gray2rgb(image)
    new_image = image.copy()

    red, green = (1, 0, 0), (0, 1, 0)
    image[true == 1] = red
    new_image[prediction == 1] = green
    return np.concatenate([image, new_image], axis=1)



In [61]:
# Augmentations

train_transform = A.Compose([
    A.OneOf([
        A.Blur(),
        A.MotionBlur(),
    ]),
    A.ShiftScaleRotate(shift_limit=0.1, rotate_limit=10),
    A.OneOf([
        A.HorizontalFlip(),
        A.VerticalFlip(),
    ]),
    A.Normalize(mean=(0.485,), std=(0.229,)),
    ToTensorV2(),
])

valid_transform = A.Compose([
    A.Normalize(mean=(0.485,), std=(0.229,)),
    ToTensorV2(),
])

In [62]:
class CTDataset(Dataset):
    """
    Returns
        - train - images with shape (1, 512, 512), labels with shape (512, 512) and frame name
        - test - images with shpae (1, 512, 512) and frame name
    """

    def __init__(self, transform: A.Compose, train: bool):
        self.train = train
        self.transform = transform

        if train:
            self.images = np.load(conf.dataset_folder / "training_images.npy")
            self.labels = np.load(conf.dataset_folder / "training_labels.npy")
            self.frame_names = (conf.dataset_folder / "training_frame_names.txt").read_text().split()

        else:
            self.images = np.load(conf.dataset_folder / "testing_images.npy")
            self.frame_names = (conf.dataset_folder / "testing_frame_names.txt").read_text().split()

    def __getitem__(self, item):
        image = self.images[item]
        image = np.expand_dims(image, axis=2)
        assert image.shape == (512, 512, 1)

        frame_name = self.frame_names[item]

        if self.train:
            label = self.labels[item]
            # label = np.expand_dims(label, axis=0)
            transformed = self.transform(image=image, mask=label)
            return transformed['image'], transformed['mask'], frame_name

        else:
            transformed = self.transform(image=image)
            return transformed['image'], frame_name

    def __len__(self):
        return len(self.images)

In [63]:
class CTDataLoader(pl.LightningDataModule):
    def __init__(self, train_aug: bool):
        super(CTDataLoader, self).__init__()

        self.train_dataset = CTDataset(train=True, transform=train_transform if train_aug else valid_transform)
        self.test_dataset = CTDataset(train=False, transform=valid_transform)

        self.val_images = int(conf.val_size * len(self.train_dataset))
        self.train_images = len(self.train_dataset) - self.val_images

    def train_dataloader(self):
        train, _ = random_split(self.train_dataset, [self.train_images, self.val_images],
                                generator=Generator().manual_seed(0))

        return DataLoader(dataset=train,
                          batch_size=conf.batch_size,
                          shuffle=True,
                          num_workers=2)

    def val_dataloader(self):
        _, val = random_split(self.train_dataset, [self.train_images, self.val_images],
                              generator=Generator().manual_seed(0))
        return DataLoader(dataset=val,
                          batch_size=conf.batch_size,
                          shuffle=False,
                          num_workers=2)

    def test_dataloader(self):
        return DataLoader(dataset=self.test_dataset,
                          batch_size=conf.batch_size,
                          shuffle=False,
                          num_workers=2)

In [64]:
def to_numpy(x):
    return x.detach().cpu().numpy()


class ImageLoggerCallback(pl.Callback):
    def __init__(self, train_loader: DataLoader, val_loader: DataLoader, log_batches: int = 2, threshold: float = 0.5):
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.log_batches = log_batches
        self.threshold = threshold

    def make_predictions(self, pl_module, loader):
        images, pred_masks, true_masks = [], [], []

        for ii, batch in enumerate(loader):

            if ii >= self.log_batches:
                break

            x, y, label = batch

            with torch.no_grad():
                x = x.cuda()
                output = pl_module(x)
                output = output.cpu()

            image = torch.squeeze(x)
            pred_mask = (output >= self.threshold).int().reshape(-1, 512, 512)
            true_mask = y.reshape(-1, 512, 512)

            images.extend(list(to_numpy(image)))
            pred_masks.extend(list(to_numpy(pred_mask)))
            true_masks.extend(list(to_numpy(true_mask)))

        return images, pred_masks, true_masks

    @staticmethod
    def make_grid(images, pred_masks, true_masks):
        grid = []
        for image, pred, true in zip(images, pred_masks, true_masks):
            visualization = visualize_prediction(image, pred, true)

            grid.append(visualization)

        grid = np.concatenate(grid, axis=0)
        return grid

    def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        images, pred_masks, true_masks = self.make_predictions(pl_module, self.val_loader)
        grid = self.make_grid(images, pred_masks, true_masks)

        trainer.logger.experiment.log({
            "val/predictions": wandb.Image(grid, caption="Red are ground truth, Green are predictions"),
            "global_step": trainer.global_step
        })

    def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        images, pred_masks, true_masks = self.make_predictions(pl_module, self.train_loader)
        grid = self.make_grid(images, pred_masks, true_masks)

        trainer.logger.experiment.log({
            "train/predictions": wandb.Image(grid, caption="Red are ground truth, Green are predictions"),
            "global_step": trainer.global_step
        })

In [65]:
class CTSemanticSegmentation(pl.LightningModule):

    def __init__(self, cfg: dict, threshold: float = 0.5):
        super(CTSemanticSegmentation, self).__init__()
        self.model = smp.Unet('mobilenet_v2', encoder_weights="imagenet", classes=1, activation='sigmoid',
                              encoder_depth=5,
                              decoder_channels=[256, 128, 64, 32, 16], in_channels=1)

        self.cfg = cfg
        self.loss_fn = get_loss(cfg)
        self.threshold = threshold

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        optimizer = get_optimizer(self.cfg, self.model.parameters(), lr=self.cfg.lr)
        return optimizer

    def training_step(self, batch, batch_idx):
        x, y, label = batch
        output = self(x)
        # print(output)
        loss = self.loss_fn(output, y)

        self.log("train/loss_step", loss.item())

        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        x, y, label = batch
        output = self(x)
        output = (output > self.threshold).int()
        loss = self.loss_fn(output, y)

        self.log("val/loss_step", loss.item())

        return {"loss": loss}

    def training_epoch_end(self, outputs) -> None:
        avg_losses = torch.hstack([loss["loss"] for loss in outputs]).mean()
        self.log("train/loss_epoch", avg_losses)

    def validation_epoch_end(self, outputs) -> None:
        avg_losses = torch.hstack([loss["loss"] for loss in outputs]).mean()
        self.log("val/loss_epoch", avg_losses)

In [66]:
## Tuning parameters
def train_num_epoch(hparam, num_epochs, num_gpus):
    full_config = hparam | conf.__dict__
    dataloader = CTDataLoader(train_aug=full_config['train_aug'])
    train_dataloader, val_dataloader = dataloader.train_dataloader(), dataloader.val_dataloader()

    logger = WandbLogger(name=f"{dict_to_str(hparam)}", project="CTSegmentationTuning", config=full_config)

    tune_report_callback = TuneReportCallback({
        'loss': 'val/loss_epoch'
    }, on="validation_end")

    trainer = pl.Trainer(
        logger=logger,
        gpus=num_gpus,
        max_epochs=num_epochs,
        callbacks=[tune_report_callback],
        progress_bar_refresh_rate=0
    )

    model = CTSemanticSegmentation(cfg=full_config)
    trainer.fit(model, train_dataloader, val_dataloader)


def tuning_ct_segmentation():
    tuning_conf = {
        "loss_fn": tune.choice(['cross_entropy', 'focal_loss', 'dice_loss', 'tversky_loss']),
        "optim": tune.choice(['SGD', "Adam", "RMSprop"]),
        "train_aug": tune.choice([True, False]),
        "lr": tune.choice([0.01, 0.001, 0.0001]),
    }

    scheduler = ASHAScheduler(
        max_t=max_epochs,
        grace_period=5,  # 5 train epochs before stopping some samples
        reduction_factor=4
    )

    reporter = CLIReporter(
        parameter_columns=list(tuning_conf.keys()),
        metric_columns=['loss', 'training_iteration']
    )

    analysis = tune.run(
        tune.with_parameters(
            train_num_epoch,
            num_epochs=max_epochs,
            num_gpus=0),
        resources_per_trial={
            "cpu": 1,
            "gpu": 0
        },
        metric="loss",
        mode='min',
        num_samples=72,
        config=tuning_conf,
        scheduler=scheduler,
        progress_reporter=reporter,
        name="tune_ct_segmentation")

    print(f"Best parameters found: {analysis.best_config}")
    return analysis.best_config


tuning_example = tuning_ct_segmentation()

== Status ==
Current time: 2021-11-16 10:49:04 (running for 00:00:00.37)
Memory usage on this node: 8.3/15.5 GiB
Using AsyncHyperBand: num_stopped=0
Bracket: Iter 20.000: None | Iter 5.000: None
Resources requested: 0/8 CPUs, 0/0 GPUs, 0.0/5.41 GiB heap, 0.0/2.71 GiB objects
Result logdir: /home/shamil/ray_results/tune_ct_segmentation
Number of trials: 16/72 (16 PENDING)
+-----------------------------+----------+-------+---------------+---------+-------------+--------+
| Trial name                  | status   | loc   | loss_fn       | optim   | train_aug   |     lr |
|-----------------------------+----------+-------+---------------+---------+-------------+--------|
| train_num_epoch_ab933_00000 | PENDING  |       | focal_loss    | SGD     | True        | 0.001  |
| train_num_epoch_ab933_00001 | PENDING  |       | tversky_loss  | SGD     | True        | 0.001  |
| train_num_epoch_ab933_00002 | PENDING  |       | focal_loss    | SGD     | True        | 0.001  |
| train_num_epoch_ab933_00

[2m[36m(pid=18872)[0m 2021-11-16 10:49:13,437	ERROR function_runner.py:268 -- Runner Thread raised error.
[2m[36m(pid=18872)[0m Traceback (most recent call last):
[2m[36m(pid=18872)[0m   File "/home/shamil/PycharmProjects/covid_segmentation/venv/lib/python3.9/site-packages/ray/tune/function_runner.py", line 262, in run
[2m[36m(pid=18872)[0m     self._entrypoint()
[2m[36m(pid=18872)[0m   File "/home/shamil/PycharmProjects/covid_segmentation/venv/lib/python3.9/site-packages/ray/tune/function_runner.py", line 330, in entrypoint
[2m[36m(pid=18872)[0m     return self._trainable_func(self.config, self._status_reporter,
[2m[36m(pid=18872)[0m   File "/home/shamil/PycharmProjects/covid_segmentation/venv/lib/python3.9/site-packages/ray/util/tracing/tracing_helper.py", line 451, in _resume_span
[2m[36m(pid=18872)[0m     return method(self, *_args, **_kwargs)
[2m[36m(pid=18872)[0m   File "/home/shamil/PycharmProjects/covid_segmentation/venv/lib/python3.9/site-packages/ra

== Status ==
Current time: 2021-11-16 10:49:09 (running for 00:00:05.42)
Memory usage on this node: 9.8/15.5 GiB
Using AsyncHyperBand: num_stopped=0
Bracket: Iter 20.000: None | Iter 5.000: None
Resources requested: 8.0/8 CPUs, 0/0 GPUs, 0.0/5.41 GiB heap, 0.0/2.71 GiB objects
Result logdir: /home/shamil/ray_results/tune_ct_segmentation
Number of trials: 24/72 (16 PENDING, 8 RUNNING)
+-----------------------------+----------+---------------------+---------------+---------+-------------+--------+
| Trial name                  | status   | loc                 | loss_fn       | optim   | train_aug   |     lr |
|-----------------------------+----------+---------------------+---------------+---------+-------------+--------|
| train_num_epoch_ab933_00000 | RUNNING  | 192.168.31.72:18874 | focal_loss    | SGD     | True        | 0.001  |
| train_num_epoch_ab933_00001 | RUNNING  | 192.168.31.72:18875 | tversky_loss  | SGD     | True        | 0.001  |
| train_num_epoch_ab933_00002 | RUNNING  | 

[2m[36m(pid=19120)[0m 2021-11-16 10:49:19,626	ERROR function_runner.py:268 -- Runner Thread raised error.
[2m[36m(pid=19120)[0m Traceback (most recent call last):
[2m[36m(pid=19120)[0m   File "/home/shamil/PycharmProjects/covid_segmentation/venv/lib/python3.9/site-packages/ray/tune/function_runner.py", line 262, in run
[2m[36m(pid=19120)[0m     self._entrypoint()
[2m[36m(pid=19120)[0m   File "/home/shamil/PycharmProjects/covid_segmentation/venv/lib/python3.9/site-packages/ray/tune/function_runner.py", line 330, in entrypoint
[2m[36m(pid=19120)[0m     return self._trainable_func(self.config, self._status_reporter,
[2m[36m(pid=19120)[0m   File "/home/shamil/PycharmProjects/covid_segmentation/venv/lib/python3.9/site-packages/ray/util/tracing/tracing_helper.py", line 451, in _resume_span
[2m[36m(pid=19120)[0m     return method(self, *_args, **_kwargs)
[2m[36m(pid=19120)[0m   File "/home/shamil/PycharmProjects/covid_segmentation/venv/lib/python3.9/site-packages/ra

== Status ==
Current time: 2021-11-16 10:49:14 (running for 00:00:10.44)
Memory usage on this node: 11.9/15.5 GiB
Using AsyncHyperBand: num_stopped=0
Bracket: Iter 20.000: None | Iter 5.000: None
Resources requested: 7.0/8 CPUs, 0/0 GPUs, 0.0/5.41 GiB heap, 0.0/2.71 GiB objects
Result logdir: /home/shamil/ray_results/tune_ct_segmentation
Number of trials: 27/72 (4 ERROR, 16 PENDING, 7 RUNNING)
+-----------------------------+----------+---------------------+---------------+---------+-------------+--------+
| Trial name                  | status   | loc                 | loss_fn       | optim   | train_aug   |     lr |
|-----------------------------+----------+---------------------+---------------+---------+-------------+--------|
| train_num_epoch_ab933_00000 | RUNNING  | 192.168.31.72:18874 | focal_loss    | SGD     | True        | 0.001  |
| train_num_epoch_ab933_00002 | RUNNING  | 192.168.31.72:18877 | focal_loss    | SGD     | True        | 0.001  |
| train_num_epoch_ab933_00005 | R

2021-11-16 10:49:20,143	ERROR worker.py:79 -- Unhandled error (suppress with RAY_IGNORE_UNHANDLED_ERRORS=1): [36mray::ImplicitFunc.train_buffered()[39m (pid=18877, ip=192.168.31.72, repr=<ray.tune.function_runner.ImplicitFunc object at 0x7f6551865d90>)
  File "/home/shamil/PycharmProjects/covid_segmentation/venv/lib/python3.9/site-packages/ray/tune/trainable.py", line 224, in train_buffered
    result = self.train()
  File "/home/shamil/PycharmProjects/covid_segmentation/venv/lib/python3.9/site-packages/ray/tune/trainable.py", line 283, in train
    result = self.step()
  File "/home/shamil/PycharmProjects/covid_segmentation/venv/lib/python3.9/site-packages/ray/tune/function_runner.py", line 381, in step
    self._report_thread_runner_error(block=True)
  File "/home/shamil/PycharmProjects/covid_segmentation/venv/lib/python3.9/site-packages/ray/tune/function_runner.py", line 528, in _report_thread_runner_error
    raise TuneError(
ray.tune.error.TuneError: Trial raised an exception. T

== Status ==
Current time: 2021-11-16 10:49:19 (running for 00:00:15.44)
Memory usage on this node: 12.3/15.5 GiB
Using AsyncHyperBand: num_stopped=0
Bracket: Iter 20.000: None | Iter 5.000: None
Resources requested: 8.0/8 CPUs, 0/0 GPUs, 0.0/5.41 GiB heap, 0.0/2.71 GiB objects
Result logdir: /home/shamil/ray_results/tune_ct_segmentation
Number of trials: 29/72 (6 ERROR, 15 PENDING, 8 RUNNING)
+-----------------------------+----------+---------------------+---------------+---------+-------------+--------+
| Trial name                  | status   | loc                 | loss_fn       | optim   | train_aug   |     lr |
|-----------------------------+----------+---------------------+---------------+---------+-------------+--------|
| train_num_epoch_ab933_00000 | RUNNING  | 192.168.31.72:18874 | focal_loss    | SGD     | True        | 0.001  |
| train_num_epoch_ab933_00002 | RUNNING  | 192.168.31.72:18877 | focal_loss    | SGD     | True        | 0.001  |
| train_num_epoch_ab933_00008 | R

2021-11-16 10:49:25,152	ERROR worker.py:79 -- Unhandled error (suppress with RAY_IGNORE_UNHANDLED_ERRORS=1): [36mray::ImplicitFunc.train_buffered()[39m (pid=19120, ip=192.168.31.72, repr=<ray.tune.function_runner.ImplicitFunc object at 0x7f9678b19d00>)
  File "/home/shamil/PycharmProjects/covid_segmentation/venv/lib/python3.9/site-packages/ray/tune/trainable.py", line 224, in train_buffered
    result = self.train()
  File "/home/shamil/PycharmProjects/covid_segmentation/venv/lib/python3.9/site-packages/ray/tune/trainable.py", line 283, in train
    result = self.step()
  File "/home/shamil/PycharmProjects/covid_segmentation/venv/lib/python3.9/site-packages/ray/tune/function_runner.py", line 381, in step
    self._report_thread_runner_error(block=True)
  File "/home/shamil/PycharmProjects/covid_segmentation/venv/lib/python3.9/site-packages/ray/tune/function_runner.py", line 528, in _report_thread_runner_error
    raise TuneError(
ray.tune.error.TuneError: Trial raised an exception. T

== Status ==
Current time: 2021-11-16 10:49:24 (running for 00:00:20.46)
Memory usage on this node: 10.6/15.5 GiB
Using AsyncHyperBand: num_stopped=0
Bracket: Iter 20.000: None | Iter 5.000: None
Resources requested: 8.0/8 CPUs, 0/0 GPUs, 0.0/5.41 GiB heap, 0.0/2.71 GiB objects
Result logdir: /home/shamil/ray_results/tune_ct_segmentation
Number of trials: 35/72 (12 ERROR, 15 PENDING, 8 RUNNING)
+-----------------------------+----------+---------------------+---------------+---------+-------------+--------+
| Trial name                  | status   | loc                 | loss_fn       | optim   | train_aug   |     lr |
|-----------------------------+----------+---------------------+---------------+---------+-------------+--------|
| train_num_epoch_ab933_00009 | RUNNING  | 192.168.31.72:19120 | tversky_loss  | RMSprop | False       | 0.001  |
| train_num_epoch_ab933_00012 | RUNNING  | 192.168.31.72:19213 | cross_entropy | SGD     | False       | 0.0001 |
| train_num_epoch_ab933_00014 | 

2021-11-16 10:49:31,191	ERROR trial_runner.py:924 -- Trial train_num_epoch_ab933_00019: Error processing event.
Traceback (most recent call last):
  File "/home/shamil/PycharmProjects/covid_segmentation/venv/lib/python3.9/site-packages/ray/tune/trial_runner.py", line 890, in _process_trial
    results = self.trial_executor.fetch_result(trial)
  File "/home/shamil/PycharmProjects/covid_segmentation/venv/lib/python3.9/site-packages/ray/tune/ray_trial_executor.py", line 788, in fetch_result
    result = ray.get(trial_future[0], timeout=DEFAULT_GET_TIMEOUT)
  File "/home/shamil/PycharmProjects/covid_segmentation/venv/lib/python3.9/site-packages/ray/_private/client_mode_hook.py", line 105, in wrapper
    return func(*args, **kwargs)
  File "/home/shamil/PycharmProjects/covid_segmentation/venv/lib/python3.9/site-packages/ray/worker.py", line 1625, in get
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(TuneError): [36mray::ImplicitFunc.train_buffered()[39m (pid=19317, ip=1

Result for train_num_epoch_ab933_00019:
  date: 2021-11-16_10-49-30
  experiment_id: 3a714ec3a7934fbab5e017a99cd68f77
  hostname: shamil-hpprobook430g7
  node_ip: 192.168.31.72
  pid: 19317
  timestamp: 1637048970
  trial_id: ab933_00019
  
Result for train_num_epoch_ab933_00018:
  date: 2021-11-16_10-49-30
  experiment_id: 45e72f8823504f379b9a9762d0a31391
  hostname: shamil-hpprobook430g7
  node_ip: 192.168.31.72
  pid: 19313
  timestamp: 1637048970
  trial_id: ab933_00018
  
Result for train_num_epoch_ab933_00017:
  date: 2021-11-16_10-49-30
  experiment_id: f68b7208f14342629f622cdc8cc93f32
  hostname: shamil-hpprobook430g7
  node_ip: 192.168.31.72
  pid: 19310
  timestamp: 1637048970
  trial_id: ab933_00017
  




KeyboardInterrupt: 