In [1]:
import hydra
import monai
import torch

with hydra.initialize(config_path="configs", version_base=None, job_name="training"):
    cfg = hydra.compose(config_name="config.yaml")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
for key in cfg:
    print(key)
    print(cfg[key])
    print()

task_name
retina_segmentation

datamodule
{'_target_': 'src.datamodule.RetinaDataModule', 'train_data_path': 'data/tiles/train_data.npz', 'val_data_path': 'data/tiles/val_data.npz', 'test_data_path': 'data/tiles/test_data.npz', 'batch_size': 32, 'train_transforms': '${transforms.train_transforms}', 'test_transforms': '${transforms.test_transforms}'}

transforms
{'norm_mean': [0, 0, 0], 'norm_std': [1, 1, 1], 'train_transforms': {'_target_': 'monai.transforms.Compose', 'transforms': [{'_target_': 'monai.transforms.RandRotateD', 'keys': ['image', 'mask'], 'range_x': [0, 180], 'prob': 0.5}, {'_target_': 'monai.transforms.RandAxisFlipD', 'keys': ['image', 'mask'], 'prob': 0.5}, {'_target_': 'monai.transforms.ToTensorD', 'keys': ['image', 'mask']}]}, 'test_transforms': {'_target_': 'monai.transforms.Compose', 'transforms': [{'_target_': 'monai.transforms.ToTensorD', 'keys': ['image', 'mask']}]}}

backbone
{'_target_': 'monai.networks.nets.UNet', 'spatial_dims': 2, 'in_channels': 3, 'out_cha

In [4]:
# DataModule
datamodule = hydra.utils.instantiate(cfg['datamodule'])

2023-03-06 08:40:43,971 - Created a temporary directory at /tmp/tmp3hbss6yv
2023-03-06 08:40:43,973 - Writing /tmp/tmp3hbss6yv/_remote_module_non_scriptable.py


In [21]:
print('Instantiating backbone \n')
backbone = torch.nn.Sequential(
    hydra.utils.instantiate(cfg['backbone']),
    torch.nn.Sigmoid()
)

print('Instantiating optimizer \n')
optimizer = hydra.utils.instantiate(cfg['optimizer'], params=backbone.parameters())

print('Instantiating loss and metric functions \n')
loss_function = hydra.utils.instantiate(cfg['losses'])
metric = hydra.utils.instantiate(cfg['metrics'])

print('Instantiating model \n')
model = hydra.utils.instantiate(
    cfg['model'], 
    backbone=backbone, 
    loss_function=loss_function, 
    metric=metric, 
    optimizer=optimizer
)

Instantiating backbone 

Instantiating optimizer 

Instantiating loss and metric functions 

Instantiating model 



In [22]:
print('Instantiating trainer \n')
trainer = hydra.utils.instantiate(cfg['trainer'])

Instantiating trainer 

2023-03-06 08:42:55,723 - GPU available: True (cuda), used: True
2023-03-06 08:42:55,725 - TPU available: False, using: 0 TPU cores
2023-03-06 08:42:55,726 - IPU available: False, using: 0 IPUs
2023-03-06 08:42:55,727 - HPU available: False, using: 0 HPUs




In [23]:
trainer.fit(model, datamodule)

2023-03-06 08:42:58,372 - LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
2023-03-06 08:42:58,376 - 
  | Name          | Type       | Params
---------------------------------------------
0 | backbone      | Sequential | 597 K 
1 | loss_function | DiceLoss   | 0     
---------------------------------------------
597 K     Trainable params
0         Non-trainable params
597 K     Total params
2.392     Total estimated model params size (MB)
Epoch 0:   6%|▌         | 12/209 [00:01<00:24,  8.06it/s, loss=0.821, v_num=12, train_loss=0.813]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [6]:
datamodule.prepare_data()

In [7]:
len(datamodule.train_datadict)

5173

In [8]:
datamodule.setup(stage='fit')

In [9]:
len(datamodule.train_dataset)

5173

In [12]:
datamodule.train_dataset[0]

{'original_image': array([[[0.584, 0.592, 0.584, ..., 0.639, 0.616, 0.616],
         [0.569, 0.588, 0.588, ..., 0.62 , 0.631, 0.627],
         [0.576, 0.584, 0.588, ..., 0.635, 0.635, 0.631],
         ...,
         [0.525, 0.51 , 0.51 , ..., 0.6  , 0.588, 0.573],
         [0.529, 0.525, 0.522, ..., 0.565, 0.569, 0.576],
         [0.518, 0.525, 0.533, ..., 0.58 , 0.573, 0.557]],
 
        [[0.451, 0.459, 0.451, ..., 0.467, 0.439, 0.439],
         [0.435, 0.447, 0.447, ..., 0.439, 0.443, 0.439],
         [0.435, 0.443, 0.447, ..., 0.447, 0.447, 0.435],
         ...,
         [0.435, 0.416, 0.416, ..., 0.451, 0.431, 0.416],
         [0.443, 0.431, 0.427, ..., 0.424, 0.427, 0.427],
         [0.427, 0.427, 0.435, ..., 0.431, 0.424, 0.416]],
 
        [[0.314, 0.322, 0.314, ..., 0.314, 0.286, 0.286],
         [0.298, 0.314, 0.314, ..., 0.298, 0.302, 0.298],
         [0.302, 0.31 , 0.314, ..., 0.298, 0.298, 0.29 ],
         ...,
         [0.31 , 0.282, 0.282, ..., 0.31 , 0.294, 0.278],
      

In [13]:
len(datamodule.train_dataset[0])

3

In [10]:
dl = datamodule.train_dataloader()

In [30]:
type(dl)

monai.data.dataloader.DataLoader

In [31]:
dir(dl)

['_DataLoader__initialized',
 '_DataLoader__multiprocessing_context',
 '_IterableDataset_len_called',
 '__annotations__',
 '__class__',
 '__class_getitem__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__iter__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__orig_bases__',
 '__parameters__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__slots__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_auto_collation',
 '_dataset_kind',
 '_get_iterator',
 '_index_sampler',
 '_is_protocol',
 '_iterator',
 'batch_sampler',
 'batch_size',
 'check_worker_number_rationality',
 'collate_fn',
 'dataset',
 'drop_last',
 'generator',
 'multiprocessing_context',
 'num_workers',
 'persistent_workers',
 'pin_memory',
 'pin_memory_device',
 'prefetch_factor',
 'sampler',
 'timeout',
 'worker_init_fn']

In [19]:
dl.batch_size

32

In [11]:
batch = next(iter(dl))

In [18]:
len(batch)

3

In [12]:
type(batch)

dict

In [13]:
batch.keys()

dict_keys(['original_image', 'image', 'mask'])

In [14]:
type(batch['image']), batch['image'].shape

(monai.data.meta_tensor.MetaTensor, torch.Size([32, 3, 64, 64]))

In [18]:
ds = datamodule.train_dataset

In [22]:
type(ds[0]), len(ds[0]), ds[0].keys()

NameError: name 'ds' is not defined

In [15]:
loss_function = monai.losses.DiceLoss(sigmoid=True)
dice_metric = monai.metrics.DiceMetric(include_background=False, reduction="mean")

In [24]:
images, gts = batch['image'], batch['mask']

In [25]:
images = images.to('cuda')
gts = gts.to('cuda')

In [41]:
type(images), type(gts)

(monai.data.meta_tensor.MetaTensor, monai.data.meta_tensor.MetaTensor)

In [26]:
preds = model(images)

In [27]:
loss_function(preds, gts)

tensor(0.7927, device='cuda:0', grad_fn=<AliasBackward0>)

In [54]:
preds_binary = torch.where(preds > 0.5, True, False)

In [55]:
preds_binary.shape, gts.shape

(torch.Size([32, 1, 64, 64]), torch.Size([32, 1, 64, 64]))

In [56]:
dice_metric(preds_binary, gts)

tensor([[0.4921],
        [0.3212],
        [0.1916],
        [0.5040],
        [0.7364],
        [0.7457],
        [0.2917],
        [0.1364],
        [0.4073],
        [0.2401],
        [0.7744],
        [0.4841],
        [0.2787],
        [0.3616],
        [0.6299],
        [0.7564],
        [0.2275],
        [0.4470],
        [0.4311],
        [0.3476],
        [0.0904],
        [0.8150],
        [0.2240],
        [0.4211],
        [0.5540],
        [0.4000],
        [0.4669],
        [0.1304],
        [0.2746],
        [0.5170],
        [0.1955],
        [0.7615]], device='cuda:0')

In [59]:
v = preds_binary[:3]
v.shape

torch.Size([3, 1, 64, 64])

In [60]:
a = preds_binary[0]
a.shape

torch.Size([1, 64, 64])

In [65]:
a0 = torch.where(a==False, True, False)[None, ...]
a1 = torch.where(a==True, True, False)[None, ...]

In [66]:
a0.shape, a1.shape

(torch.Size([1, 1, 64, 64]), torch.Size([1, 1, 64, 64]))

In [75]:
c = torch.cat([a0, a1], dim=0)
c.shape

torch.Size([2, 1, 64, 64])

PROVARE SE A DARGLI UN TENSOR ONE-HOT-ENCODED LA METRICA FUNZIONA BENE

In [76]:
dice_metric(c, c)

tensor([[1.],
        [1.]], device='cuda:0')

NO :(

## DICE

- Dice metric: 2TP/(2TP+FP+FN) AKA 2*intersection/union (in the union, the overlap should be added).
- Dice Loss: 1 - Dice metric

In [159]:
# Dummy example
grnd = torch.zeros(1, 4, 4)
pred = torch.zeros(1, 4, 4)
grnd[..., 1, 1] = grnd[..., 1, 2] = grnd[..., 2, 1] = grnd[..., 2, 2] = 1
pred[..., 1, 1] = pred[..., 1, 2] = pred[..., 2, 0] = pred[..., 2, 1] = 1

print(grnd)
print()
print(pred)

tensor([[[0., 0., 0., 0.],
         [0., 1., 1., 0.],
         [0., 1., 1., 0.],
         [0., 0., 0., 0.]]])

tensor([[[0., 0., 0., 0.],
         [0., 1., 1., 0.],
         [1., 1., 0., 0.],
         [0., 0., 0., 0.]]])


Dice metric = (2 * 3) / (2 * 3 + 1 + 1) = 0.75

In [161]:
# PyTorch Dice
import torchmetrics
dice_metric_torch = torchmetrics.Dice()

dice_metric_torch(pred, torch.tensor(grnd, dtype=torch.int8))

  dice_metric_torch(pred, torch.tensor(grnd, dtype=torch.int8))


tensor(0.7500)

In [162]:
# PyTorch Dice adding the batch dimension
predn = pred[None]
grndn = grnd[None]
print(predn.shape)

dice_metric_torch(predn, torch.tensor(grndn, dtype=torch.int8))

torch.Size([1, 1, 4, 4])


  dice_metric_torch(predn, torch.tensor(grndn, dtype=torch.int8))


tensor(0.7500)

In [163]:
# PyTorch Dice simulating a real batch
batch_pred = torch.cat([predn, predn, predn])
batch_grnd = torch.cat([grndn, grndn, grndn])
print(batch_pred.shape)

dice_metric_torch(batch_pred, torch.tensor(batch_grnd, dtype=torch.int8))

torch.Size([3, 1, 4, 4])


  dice_metric_torch(batch_pred, torch.tensor(batch_grnd, dtype=torch.int8))


tensor(0.7500)

In [175]:
# Monai's
dice_metric_monai = monai.metrics.DiceMetric()
dice_metric_monai(pred, grnd)

tensor([[   nan, 1.0000, 0.5000,    nan]])

In [176]:
# Monai's with batch channel
dice_metric_monai = monai.metrics.DiceMetric()
dice_metric_monai(predn, grndn)

tensor([[0.7500]])

In [194]:
# Monai's with proper batch
dice_metric_monai(batch_pred, batch_grnd)

tensor([[0.7500],
        [0.7500],
        [0.7500]])

It returns a freaking vector!

In [190]:
# Mine
class RetinaDiceMetric():
    def __init__(self, threshold=0.5):
        self.threshold = threshold
    
    def __call__(self, preds, targets):
        preds = preds.clone().detach()
        targets = targets.clone().detach()
        preds, targets = self.binarize(preds), self.binarize(targets)
        n = len(preds)
        dices = []
        for i in range(n):
            img, mask = preds[i], targets[i]
            intersection = torch.sum(img * mask)
            union = img.sum() + mask.sum() + 1e-12
            dice = intersection / union
            dices.append(dice.unsqueeze(0))
        batch_dice = 2*torch.mean(torch.cat(dices))
        return batch_dice

    def binarize (self, a):
        return torch.where(a > self.threshold, True, False)

In [191]:
dice_metric_retina = RetinaDiceMetric()

dice_metric_retina(pred, grnd)

tensor(0.7500)

In [179]:
dice_metric_retina(predn, grndn)

tensor(0.7500)

In [181]:
datamodule.prepare_data()
datamodule.setup(stage='fit')
dl = datamodule.train_dataloader()
batch = next(iter(dl))
images, gts = batch['image'].to('cuda'), batch['mask'].to('cuda')
preds = model(images)

In [188]:
# Pytorch With retina predictions
dice_metric_torch(
    preds.detach().cpu(),
    torch.tensor(gts, dtype=torch.int8).detach().cpu()
)

# tensor(0.1781)

  torch.tensor(gts, dtype=torch.int8).detach().cpu()


tensor(0.1781)

In [192]:
# Mine With retina predictions
dice_metric_retina(
    preds.detach().cpu(),
    torch.tensor(gts, dtype=torch.int8).detach().cpu()
)

# tensor(0.1610)
# A bit different from PyTorch's implementation

  torch.tensor(gts, dtype=torch.int8).detach().cpu()


tensor(0.1610)

In [195]:
# Monai's with retina predictions
result = dice_metric_monai(preds, gts)
result.shape



torch.Size([32, 1])

In [201]:
# Monai dice loss (not metric!) with and without sigmoid
dice_loss_monai = monai.losses.DiceLoss(sigmoid=False)
print(dice_loss_monai(batch_pred, batch_grnd))

dice_loss_monai = monai.losses.DiceLoss(sigmoid=True)
print(dice_loss_monai(batch_pred, batch_grnd))

tensor(0.2500)
tensor(0.5832)


In [203]:
preds.max()

tensor(0.9969, device='cuda:0', grad_fn=<AliasBackward0>)

## CALLBACKS

In [8]:
cfg['callbacks']

{'early_stopping': {'_target_': 'pytorch_lightning.callbacks.EarlyStopping', 'monitor': 'val_loss', 'mode': 'min', 'min_delta': 0.001, 'patient': 10}, 'model_checkpoint': {'_target_': 'pytorch_lightning.callbacks.ModelCheckpoint', 'dirpath': 'model_checkpoints', 'filename': 'checkpoint', 'monitor': 'val_loss', 'mode': 'min', 'save_top_k': 1}}

In [25]:
from src.utils import instantiate_callbacks

In [26]:
l = instantiate_callbacks(cfg['callbacks'])
l

[<pytorch_lightning.callbacks.early_stopping.EarlyStopping at 0x7f5185a84d90>,
 <pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint at 0x7f5185b20c10>]