# CAMP II Surgical Workflow Analysis Exercise

It is highly recommended to complete the CAMP II Exercises on Classification before this one!  
Make sure to activate the GPU in colab before you start this exercise.

In this exercise, we will segment the liver, using as ground truth annotation provided in the datase. 

This notebook has many codeblocks already in place to help you get started. Places where you have to add your own code are clearly marked with "TASK" and lines ("-----"). When a variable you have to implement is used later on, we placed a name and description in the task bracket (see example below). These markings are only there to guide you toward what you have to implement to complete the exercise, feel free to experiment beyond them.

---




In [1]:
# TASK: description of the task you need to do ---------------------------------
# my_variable_name: a variable that is used later on, so the name should be right

# ------------------------------------------------------------------------------

Install prerequisites

In [None]:
!pip -qq install pytorch_lightning==1.6.2
!pip -qq install -U segmentation-models-pytorch
!pip -qq install opencv-python

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
scvi-tools 0.16.0 requires pytorch-lightning<1.6,>=1.5, but you have pytorch-lightning 1.6.2 which is incompatible.[0m


Download the dataset

In [None]:
# If you get an error here about access being denied - just try again until it works
!gdown --id 1MHp64mCt2m8NxCW3-4kjD39m3Rry0ekA
!unzip -qq liver_endoscopy_dataset
!rm liver_endoscopy_dataset.zip

Install packages

In [2]:
from collections import defaultdict
import cv2
import json
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from PIL import Image
import pytorch_lightning as pl
import random
from re import T
import segmentation_models_pytorch as pytorch_models
from skimage.color import label2rgb
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import torchvision
from torchvision import transforms
from typing import Optional

%load_ext tensorboard

## 0. Dataset
The LiverEndoscopy class processes the dataset so that it can be used for our segementation task

In [3]:
video_splits = {'train': ['01', '09', '17', '20', '25', '27', '28', '35', '37', '43', '55'], 'val': ['12', '24', '26'], 'test': ['18', '48', '52']}


class LiverEndoscopy(Dataset):
    def __init__(self,
                 task_type: str = 'classification', split: str = 'train', balance_data: bool = False, temporal: bool = False,
                 pil_transform: Optional[transforms.Compose] = None, tensor_transform: Optional[transforms.Compose] = None):
        assert split in ['train', 'val', 'test']
        self.split = split
        self.balance_data = balance_data
        self.task_type = task_type
        self.temporal = temporal
        self.pil_transform = pil_transform
        self.tensor_transform = tensor_transform

        export_dataset_path = Path('data')
        self.images_path = export_dataset_path / 'images'
        self.seg_masks_path = export_dataset_path / 'seg_masks'
        with open(export_dataset_path / 'classification_annotations.json', 'r') as f:
            self.classification_annotations = json.load(f)
        with open(export_dataset_path / 'phase_annotations.json', 'r') as f:
            self.workflow_phase_annotations = json.load(f)
        with open(export_dataset_path / 'has_liver.json', 'r') as f:
            self.has_liver = json.load(f)

        if task_type == 'classification' or task_type == 'segmentation' or (task_type == 'workflow' and not temporal):
            self.image_names = []
            for image_path in sorted(self.images_path.glob('*.png')):
                video_id = image_path.name.split('_')[0].replace('video', '')
                if video_id in video_splits[split]:
                    self.image_names.append(image_path.name.replace('.png', ''))
            self.image_names = sorted(self.image_names)

        if balance_data:
            self.do_balance_data(task_type, temporal)

    def do_balance_data(self, task_type, temporal):
        print('Balancing data by oversampling under-represented classes...')
        class_to_samples = defaultdict(list)
        if not temporal:
            for image_name in self.image_names:
                if task_type == 'segmentation':
                    label = self.has_liver[image_name]
                class_to_samples[label].append(image_name)
            max_number = max([len(elem) for elem in class_to_samples.values()])
            self.image_names = []
            for key, value in class_to_samples.items():
                if len(value) < max_number:
                    self.image_names += random.choices(value, k=max_number)
                else:
                    self.image_names += value
            random.shuffle(self.image_names)
        else:
            for video_id, window in self.windows:
                label = self.workflow_phase_annotations[f'video{video_id}_{str(window[-1]).zfill(6)}']
                class_to_samples[label].append((video_id, window))
            max_number = max([len(elem) for elem in class_to_samples.values()])
            self.windows = []
            for key, value in class_to_samples.items():
                if len(value) < max_number:
                    self.windows += random.choices(value, k=max_number)
                else:
                    self.windows += value
            random.shuffle(self.windows)

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, index):
        image_name = self.image_names[index]
        image_path = self.images_path / f'{image_name}.png'
        seg_mask_path = self.seg_masks_path / f'{image_name}_liver_mask.png'
        image = Image.open(image_path)
        if self.pil_transform is not None:
            image = self.pil_transform(image)
        image_tensor = transforms.ToTensor()(image)
        if self.tensor_transform is not None:
            image_tensor = self.tensor_transform(image_tensor)
        seg_mask = Image.open(seg_mask_path)
        if self.pil_transform is not None:
            seg_mask = self.pil_transform(seg_mask)
        seg_mask_tensor = transforms.ToTensor()(seg_mask)[0].float()
        if self.tensor_transform is not None:
            seg_mask_tensor = self.tensor_transform(seg_mask_tensor)
        intrument_exists = int(self.classification_annotations[image_name])

        if self.task_type == 'segmentation':
            return image_tensor, seg_mask_tensor

## A. Segmentation of the liver in cholec80 dataset




In [4]:
from torchgeometry.losses import DiceLoss
from scipy.spatial import distance

def l1_loss(ground_truth, prediction):
    # TASK: compute the L1 loss ---------------------------------------------------------------
    l1 = torch.nn.L1Loss()
    loss = l1(ground_truth, prediction)
    # -----------------------------------------------------------------------------------------
    return loss

def dice_loss(ground_truth, prediction):
    # TASK: compute the dice loss --------------------------------------------------------------
    dice = DiceLoss()
    loss = dice(ground_truth, prediction)
    # ------------------------------------------------------------------------------------------
    return loss

def dice_score(ground_truth, prediction):
  # TASK: compute the dice score -------------------------------------------------------------
  score = distance.dice(ground_truth, prediction)
  # -----------------------------------------------------------------------------------------
  return score

### A.1 Network
Here we define the network, including it's behaviour during training and validation phase

In [5]:
class LiverSegmentation(pl.LightningModule):
    def __init__(self, model,lr,loss):
        super().__init__()
        self.backbone = model
        self.lr = lr
        self.loss = loss
        self.metric = dice_score
        self.writer = SummaryWriter()

    def forward(self, x):
        y = self.backbone(x)
        return y

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        y = y.unsqueeze(dim=1)
        # TASK: perform the forward pass, compute the loss and the metric of each step --------------
        y_pred = self.forward(x)
        loss = self.loss(y_pred, y)
        metric = self.metric(y, y_pred)
        # -------------------------------------------------------------------------------------------
        return {"loss": loss, "metric": metric}

    def training_epoch_end(self,output):
        loss = 0
        metric = 0
        for o in output:
            # TASK: compute the loss and metric of the epoch -------------------------------------------
            loss = loss + o["loss"]
            metric = metric + o["metric"]
            # ------------------------------------------------------------------------------------------
        loss = loss / len(output)
        metric = metric / len(output)
        self.writer.add_scalar('Epoch_loss/training', loss, self.current_epoch)
        self.writer.add_scalar('Epoch_metric/training', metric, self.current_epoch)

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        y = y.unsqueeze(dim=1)
        # TASK: perform the forward pass, compute the loss and the metric of each step --------------
        y_pred = self.forward(x)
        loss = self.loss(y_pred, y)
        metric = self.metric(y, y_pred)
        # -------------------------------------------------------------------------------------------
        y_pred_plot = np.array(y_pred.cpu(), dtype=float)
        image = self.prepare_visualization(y[0, 0, :, :].cpu().numpy(), y_pred_plot[0, 0, :, :], x[0, 0, :, :].cpu().numpy())
        fig, ax = plt.subplots(nrows=1, ncols=4)
        ax[0].imshow(x[0, 0, :, :].cpu(),cmap="gray")
        ax[0].set_title('Image')
        ax[1].imshow(y[0, 0, :, :].cpu().numpy().astype("uint8"), cmap="gray")
        ax[1].set_title('GT Segm')
        ax[2].imshow(y_pred_plot[0, 0, :, :], cmap="gray")
        ax[2].set_title('Pred Segm')
        ax[3].imshow(image)
        ax[3].set_title('Overlay')
        self.writer.add_figure("Validation/"+str(batch_idx), fig, self.current_epoch)
        plt.close()
        return {"loss": loss, "metric": metric}

    def prepare_visualization(self,y,y_pred,image):
        annotation_pred = (y_pred>0.5).astype("uint8") # It will evaluate the logical expression y_predict>0.25 and return True or False 
        annotation_pred = np.uint8(annotation_pred)
        annotation_gt= np.uint8(y)

        overlay = np.copy(image) 
        image_label_overlay = label2rgb(annotation_pred, image=overlay, bg_label=0, alpha=0.5, colors=["red"])    

        image = cv2.cvtColor(image,cv2.COLOR_GRAY2BGR)
        redImg = np.zeros(image.shape, image.dtype)
        redImg[:,:] = (0, 255, 0)
        redMask = cv2.bitwise_and(redImg, redImg, mask=annotation_gt)
        image_mask1 = np.float32(image_label_overlay)
        image = cv2.addWeighted(redMask, 0.05, image_mask1, 0.95,0.0)
        return image


    def validation_epoch_end(self,output):
        loss = 0
        metric = 0
        for o in output:
            loss = loss + o["loss"]
            metric = metric + o["metric"]

          # TASK: compute the loss and metric of the epoch -------------------------------------------
          # metric
          # loss
          # ------------------------------------------------------------------------------------------
        loss = loss / len(output)
        metric = metric / len(output)
        self.log('val_dice', metric)
        self.writer.add_scalar('Epoch_loss/validation', loss, self.current_epoch)
        self.writer.add_scalar('Epoch_metric/validation', metric, self.current_epoch)

### A.2 Models
Here we define 4 different NN models. We will compare them and find the one that is best performing for our task.

In [6]:
pl.seed_everything(42) # Fix the seed

unet3 = pytorch_models.Unet(
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                      # model output channels (number of classes in your dataset)
    encoder_depth=3,                # Amount of down- and upsampling of the Unet
    decoder_channels=(64, 32,16),   # Amount of channels
    encoder_weights = None,         # Model does not download pretrained weights
    activation = 'sigmoid'            # Activation function to apply after final convolution       
    )


unet5 = pytorch_models.Unet(
    in_channels=3,                  
    classes=1,                      
    encoder_depth=5,                
    decoder_channels=(256, 128, 64, 32, 16), 
    encoder_weights = None,
    activation = 'sigmoid'                    
    )


unet7 = pytorch_models.Unet(
    in_channels=3,                  
    classes=1,                      
    encoder_depth=7,                
    decoder_channels=(1024, 512, 256, 128, 64, 32,16), 
    encoder_weights = None,
    activation = 'sigmoid'                
    )


resnet34 = pytorch_models.Unet(
    in_channels = 3,
    classes=1,
    encoder_name='resnet34', 
    encoder_depth=5, encoder_weights=None,
    decoder_channels=(256, 128, 64, 32, 16),
    activation = 'sigmoid'              
    )


Global seed set to 42


### A.3 Metrics & Losses
Here we code the L1-loss the dice loss and the dice score functions. After that we test their implementation.

In [7]:
def l1_loss(ground_truth, prediction):
    # TASK: compute the L1 loss ---------------------------------------------------------------
    l1 = torch.nn.L1Loss()
    loss = l1(ground_truth, prediction)
    # -----------------------------------------------------------------------------------------
    return loss

def dice_loss(ground_truth, prediction):
    smooth = 1.

    iflat = prediction.view(-1)
    tflat = ground_truth.view(-1)
    intersection = (iflat * tflat).sum()
    
    return 1 - ((2. * intersection + smooth) /
              (iflat.sum() + tflat.sum() + smooth))

def dice_score(ground_truth, prediction):
    return -1 * dice_loss(ground_truth, prediction) + 1
  # -----------------------------------------------------------------------------------------


Here we test the implementation of the losses and the metric

In [9]:
pil_transform = transforms.Compose([transforms.Resize((224, 224))])
train_dataset = LiverEndoscopy(split='train', balance_data=True, task_type='segmentation', temporal=False, pil_transform=pil_transform)
x_0, y_0 = train_dataset[0]
y_0 = y_0.unsqueeze(dim=1)

x_1, y_1 = train_dataset[0]
y_1 = y_1.unsqueeze(dim=1)

print("Loss test 1", l1_loss(y_0, y_0), dice_loss(y_0, y_0))
print("Loss test 2", l1_loss(y_0, y_1), dice_loss(y_0, y_1))
print("Metric test 1", dice_loss(y_0, y_0), "; Metric test 2", dice_loss(y_0, y_1))

Balancing data by oversampling under-represented classes...
Loss test 1 tensor(0.) tensor(0.)
Loss test 2 tensor(0.) tensor(0.)
Metric test 1 tensor(0.) ; Metric test 2 tensor(0.)


### A.4 Training, finally!


Run this cell if you want to remove all the saved tensorboard logs (`runs`) and/or the checkpoints (`lightning_logs`)

In [10]:
!rm -rf runs
!rm -rf lightning_logs

In this cell we will test different combinations of models, learning rates and losses. You will modify the line of code `net = LiverSegmentation(unet3, lr=0.002, loss=dice_loss)` by including different models (see above), learning rates in [0.1, 1e-5] and L1 or dice loss. You will also need to change `max_epochs` and evaluate its impact on the performance of the network.

In [12]:
pl.seed_everything(42) # Fix the seed

# Split the dataset in train and validation
pil_transform = transforms.Compose([transforms.Resize((224, 224))])
train_dataset = LiverEndoscopy(split='train', balance_data=True, task_type='segmentation', temporal=False, pil_transform=pil_transform)
val_dataset = LiverEndoscopy(split='val', balance_data=True, task_type='segmentation', temporal=False, pil_transform=pil_transform)

# Create the dataloader
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=2)

# Define the network
net = LiverSegmentation(unet3, lr=0.002, loss=dice_loss)

# Define how and when to save your models during training
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor='val_dice',
    save_top_k=1,
    mode='max',
    every_n_epochs=1,
    save_last=True
)

trainer = pl.Trainer(gpus=1, 
                    precision=16, 
                    callbacks=checkpoint_callback,
                    check_val_every_n_epoch=1,
                    log_every_n_steps=5,
                     max_epochs=10
                     )

# Train!
trainer.fit(net, train_loader, val_loader)

Global seed set to 42


Balancing data by oversampling under-represented classes...
Balancing data by oversampling under-represented classes...


Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type | Params
----------------------------------
0 | backbone | Unet | 21.5 M
----------------------------------
21.5 M    Trainable params
0         Non-trainable params
21.5 M    Total params
42.971    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

Validation: 0it [00:00, ?it/s]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

Validation: 0it [00:00, ?it/s]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

Validation: 0it [00:00, ?it/s]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

Validation: 0it [00:00, ?it/s]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

Validation: 0it [00:00, ?it/s]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

Validation: 0it [00:00, ?it/s]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

Validation: 0it [00:00, ?it/s]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

Validation: 0it [00:00, ?it/s]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

Validation: 0it [00:00, ?it/s]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

In [13]:
%tensorboard --logdir runs

Reusing TensorBoard on port 6006 (pid 42124), started 0:19:27 ago. (Use '!kill 42124' to kill it.)

In [None]:
pl.seed_everything(42) # Fix the seed

# Split the dataset in train and validation
pil_transform = transforms.Compose([transforms.Resize((224, 224))])
train_dataset = LiverEndoscopy(split='train', balance_data=True, task_type='segmentation', temporal=False, pil_transform=pil_transform)
val_dataset = LiverEndoscopy(split='val', balance_data=True, task_type='segmentation', temporal=False, pil_transform=pil_transform)

# Create the dataloader
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=2)

# Define the network
net = LiverSegmentation(unet5, lr=0.002, loss=dice_loss)

# Define how and when to save your models during training
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor='val_dice',
    save_top_k=1,
    mode='max',
    every_n_epochs=1,
    save_last=True
)

trainer = pl.Trainer(gpus=1, 
                    precision=16, 
                    callbacks=checkpoint_callback,
                    check_val_every_n_epoch=1,
                    log_every_n_steps=5,
                     max_epochs=5
                     )

# Train!
trainer.fit(net, train_loader, val_loader)