In [None]:
!pip install segmentation_models_pytorch
!pip install git+https://github.com/albumentations-team/albumentations
!pip install wandb

In [None]:
import pickle
import gzip
import numpy as np
import os
import matplotlib.pyplot as plt
import time
import random

import cv2
import albumentations as A
from albumentations.augmentations.geometric.transforms import ElasticTransform, Affine, GridDistortion, PadIfNeeded
from albumentations.augmentations.crops.transforms import CenterCrop
from albumentations.pytorch import ToTensorV2

import torchvision.transforms as T
from PIL import Image
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

from segmentation_models_pytorch.losses import DiceLoss, JaccardLoss
import segmentation_models_pytorch as smp

from google.colab import drive
drive.mount('/content/drive')
PATH = "drive/My Drive/aml_task3/data"

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
if USE_WANDB:
  !wandb login
  import wandb
  wandb.init(project="heart_valve_segmentation", name=PRED_FILE_NAME)

In [1]:
# PARAMETERS --------------
USE_WANDB = True
PRED_FILE_NAME = 'prediction'

# Preprocessing
NBR_OVERSAMPLINGS = 1
PAD_EXPERT_IMAGES = False
RESIZE_SIZE = 448
BATCH_SIZE = 16
PROBABILITY_AUGMENTATION = 0.6

# Model
ENCODER_NAME = 'resnet18'
ENCODER_WEIGHTS = None
ENCODER_DEPTH = 5
DECODER_ATTENTION_TYPE = None
MODEL_TYPE = 'Unet++'

In [None]:
# helper functions --------------
def load_zipped_pickle(filename):
    with gzip.open(filename, 'rb') as f:
        loaded_object = pickle.load(f)
        return loaded_object

  
def save_zipped_pickle(obj, filename):
    with gzip.open(filename, 'wb') as f:
        pickle.dump(obj, f, 2)


def show_side_by_side(video, mask):
    fig = plt.figure()
    fig.add_subplot(1, 2, 1)
    plt.imshow(video)
    fig.add_subplot(1, 2, 2)
    plt.imshow(mask)


def show_on_top(image1, image2, alpha1=0.5, alpha2=0.5):
  plt.imshow(image1, alpha = alpha1)
  plt.imshow(image2, alpha = alpha2)

In [None]:
class HeartData(Dataset):
    """ 
    Custom pytorch dataset for heart data:
    - pad_expert_images = 'if True the expert images are padded to a square shape to preserve the original aspect-ratio after resizing';
    - oversample = 'if True the expert images are triplicated';
    - idx = 'list of video indices to read. This is useful to make a train-validation split.'
    """
    def __init__(self, file, transform=None, pad_expert_images=False, nbr_oversamplings=0, idx=list(range(65))):
        # load data into memory
        data = load_zipped_pickle(file)
        # extract labelled frames and masks
        self.images = []
        self.masks = []
        padder = PadIfNeeded(min_height=863, min_width=863, border_mode=cv2.BORDER_CONSTANT, value=0)
        for i in range(len(data)):
          if i in idx: # read only provided indices
            sample = data[i]
            for frame_idx, frame in enumerate(sample["frames"]):
                video_frame = sample["video"][:, :, frame]
                mask_frame = sample["label"][:, :, frame]
                # padding
                if pad_expert_images and sample['dataset'] == 'expert':
                  padded_frame = padder(image=video_frame, mask=mask_frame.astype(int))
                  video_frame = padded_frame['image']
                  mask_frame = padded_frame['mask']
                # first append
                self.images.append(video_frame)
                self.masks.append(mask_frame)
                # oversampling
                if nbr_oversamplings and sample["dataset"] == "expert":
                  for _ in range(nbr_oversamplings):
                    self.images.append(video_frame)
                    self.masks.append(mask_frame)
        # store the transformation
        self.transform = transform
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        # fetch
        image = self.images[idx]
        mask = self.masks[idx]
        # transform
        if self.transform:
            transformed = self.transform(image=image, mask=mask.astype(int))
            image = transformed['image']
            mask = transformed['mask']
        # move to GPU
        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        image, mask = image.to(device), mask.to(device)
        return image, mask

In [None]:
# transformation definition ----------------------------------------------
train_transform = A.Compose([
    # augmentation 
    A.OneOf([
        A.ElasticTransform(alpha=15, sigma=8, alpha_affine=5, p=0.5),
        Affine(
           scale=(0.95, 1.05), # zoom in/out
           translate_percent={'x':(-0.05, 0.03), 'y':(-0.03, 0.1)}, # shift x 5% to left 3% to right
           rotate=(-10, 10), 
           shear=(-10, 10), # change in perspective
           keep_ratio=True,
           p=0.5
        )], p=PROBABILITY_AUGMENTATION
    ),    
    GridDistortion(num_steps=5, distort_limit=0.2, p=PROBABILITY_AUGMENTATION),
    # contrast (improving on histogram equalization)
    A.CLAHE(p=1), 
    # downstream compatibility
    A.Resize(RESIZE_SIZE, RESIZE_SIZE),
    A.ToRGB(),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

val_transform = A.Compose([
    # contrast (improving on histogram equalization)
    A.CLAHE(p=1),
    # downstream compatibility
    A.Resize(RESIZE_SIZE, RESIZE_SIZE),
    A.ToRGB(),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

In [None]:
# create train/val split
train_idx = random.sample(range(65), 52)
val_idx = list(set(range(65)) - set(train_idx))

# instantiate data sets
train_dataset = HeartData(
    "{}/train.pkl".format(PATH),
    train_transform,
    pad_expert_images=PAD_EXPERT_IMAGES,
    nbr_oversamplings=NBR_OVERSAMPLINGS,
    idx=train_idx
)

val_dataset = HeartData(
    "{}/train.pkl".format(PATH),
    val_transform,
    pad_expert_images=PAD_EXPERT_IMAGES,
    nbr_oversamplings=False, # don't oversample in the validation set!
    idx=val_idx
)

# instantiate loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
# debug augmentation ----------------------------------------------
image, mask = next(iter(train_loader))
show_side_by_side(image[0,0,:,:].cpu(), mask[0,:,:].cpu())

## Training Loop

In [None]:
def evaluate(model, val_loader, loss_fn):
  with torch.no_grad():
    val_loss_cum = 0.0
    num_samples_epoch = 0
    for batch_idx, batch in enumerate(val_loader):
       images, masks = batch
       # forward pass
       outputs = model(images)
       loss = loss_fn(y_pred=outputs, y_true=masks)
       # compute stats
       num_samples_batch = batch[0].shape[0]
       num_samples_epoch += num_samples_batch
       val_loss_cum += loss.item() * num_samples_batch
    avg_val_loss = val_loss_cum / num_samples_epoch
    return avg_val_loss

In [None]:
# NN
if (MODEL_TYPE == 'Unet'):
  model = smp.Unet(
      encoder_name=ENCODER_NAME,
      encoder_depth=ENCODER_DEPTH, # between 3-5 (default=5)
      encoder_weights=ENCODER_WEIGHTS,
      decoder_attention_type=DECODER_ATTENTION_TYPE,
      activation = None, # we are interested in the logits
      classes=1
  )

if (MODEL_TYPE == 'Unet++'):
  model = smp.UnetPlusPlus(
      encoder_name=ENCODER_NAME,
      encoder_depth=ENCODER_DEPTH, # between 3-5 (default=5)
      encoder_weights=ENCODER_WEIGHTS,
      decoder_attention_type=DECODER_ATTENTION_TYPE,
      activation=None, # we are interested in the logits
      classes=1
  )

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)

# Loss and Optimizer
loss_fn = JaccardLoss(mode='binary', from_logits=True)
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
globaliter = 0

In [None]:
nbr_epochs = 5

for epoch in range(nbr_epochs):
    model.train()
    # reset statistics trackers
    train_loss_cum = 0.0
    num_samples_epoch = 0
    t = time.time()

    print('batch_idx: ', end='')
    for batch_idx, batch in enumerate(train_loader):
        images, masks = batch
       
        # forward pass
        outputs = model(images)
        loss = loss_fn(y_pred=outputs, y_true=masks)
    
        # backward pass and gradient step
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # keep track of train stats
        num_samples_batch = batch[0].shape[0]
        num_samples_epoch += num_samples_batch
        train_loss_cum += loss.item() * num_samples_batch
        print(batch_idx, ' ', end='')

    # after epoch, average the accumulated statistics
    avg_train_loss = train_loss_cum / num_samples_epoch
    # validation:
    model.eval()
    validation_loss = evaluate(model, val_loader, loss_fn)
    epoch_duration = time.time() - t

    if USE_WANDB:
        wandb.log({"train loss": avg_train_loss, "val_loss": validation_loss})

    print()
    print(
        f'Epoch: {globaliter} |'
        f'Train loss: {avg_train_loss:.4f} |'
        f'Validation loss: {validation_loss:.4f}|'
        f'Duration: {epoch_duration:.2f} sec'
    )

    globaliter += 1

In [None]:
if USE_WANDB:
  wandb.finish()

# Prediction

In [None]:
def save_single_model_pred(model, filename, data_test, processing_transform, pad_expert_images=False):
  """
  Apply all the necessary preprocessing steps to a test set and predict with the given model.
  """
  with torch.no_grad():
    if pad_expert_images:
      padder = PadIfNeeded(min_height=1007, min_width=1007, border_mode=cv2.BORDER_CONSTANT, value=0)
    submission = []
    for item in data_test:
      name = item["name"]
      video = item["video"]
      shape = video[:, :, 0].shape
      if pad_expert_images:
        cropper= CenterCrop(shape[0], shape[1])
        resizer2 = A.Resize(1007, 1007)
      else:
        resizer = A.Resize(shape[0], shape[1])
      predictions = []
      for idx in range(video.shape[2]):
        frame = video[:, :, idx]
        if pad_expert_images:
          frame = padder(image = frame)['image']
        processed_frame = processing_transform(image = frame)['image']
        processed_frame = processed_frame[None, :]
        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        processed_frame = processed_frame.to(device)
        prediction = model(processed_frame)
        prediction = prediction.cpu().detach().numpy().squeeze()
        if pad_expert_images:
          prediction = resizer2(image=prediction)["image"]
          prediction = cropper(image=prediction)["image"]
        else:
          prediction = resizer(image = prediction)["image"]
        prediction = prediction > 0
        predictions.append(prediction)
      predictions = np.stack(predictions, axis=2)
      submission.append({"name":name, "prediction":predictions})
  save_zipped_pickle(submission, f"{PATH}/{filename}.pkl")


def save_pred_for_ensembling(model, filename, data_test, processing_transform, pad_expert_images=False):
    """
    Apply all the necessary preprocessing steps to a test set and save
    the prediction in a format suitable for ensembling.
    """
    if pad_expert_images:
      padder = PadIfNeeded(min_height=863, min_width=863, border_mode=cv2.BORDER_CONSTANT, value=0)
    submission = []
    for item in data_test:
        name = item["name"]
        video = item["video"]
        shape = video[:, :, 0].shape
        if pad_expert_images:
          cropper= CenterCrop(shape[0], shape[1])
          resizer2 = A.Resize(1007, 1007)
        else:
          resizer = A.Resize(shape[0], shape[1])
        predictions = torch.empty((video.shape[2], 1, shape[0], shape[1]), dtype=torch.bool)
        for idx in range(video.shape[2]):
            frame = video[:, :, idx]
            if pad_expert_images:
              frame = padder(image = frame)['image']
            processed_frame = processing_transform(image = frame)['image']
            processed_frame = processed_frame[None, :]
            device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
            processed_frame = processed_frame.to(device)
            with torch.no_grad():
                prediction = model(processed_frame)
            prediction = prediction.cpu().detach().numpy().squeeze()
            if pad_expert_images:
              prediction = resizer2(image=prediction)["image"]
              prediction = cropper(image=prediction)["image"]
            else:
              prediction = resizer(image = prediction)["image"]
            prediction = torch.tensor(prediction[None, :])
            prediction = prediction > 0
            predictions[idx] = prediction
        submission.append(predictions)
    save_zipped_pickle(submission, f"{PATH}/{filename}.pkl")

In [None]:
data_test = load_zipped_pickle("{}/test.pkl".format(PATH))

processing_transform = A.Compose([   
    A.CLAHE(p=1),
    A.Resize(RESIZE_SIZE, RESIZE_SIZE),
    A.ToRGB(),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

save_single_model_pred(
    model=model,
    filename=PRED_FILE_NAME + '_individual',
    data_test=data_test,
    processing_transform=processing_transform,
    pad_expert_images=PAD_EXPERT_IMAGES
)

save_pred_for_ensembling(
    model=model,
    filename=PRED_FILE_NAME,
    data_test=data_test,
    processing_transform=processing_transform,
    pad_expert_images=PAD_EXPERT_IMAGES
)