In [1]:

# imports
import os
import sys
import types
import json

# figure size/format
fig_width = 7
fig_height = 5
fig_format = 'retina'
fig_dpi = 96

# matplotlib defaults / format
try:
  import matplotlib.pyplot as plt
  plt.rcParams['figure.figsize'] = (fig_width, fig_height)
  plt.rcParams['figure.dpi'] = fig_dpi
  plt.rcParams['savefig.dpi'] = fig_dpi
  from IPython.display import set_matplotlib_formats
  set_matplotlib_formats(fig_format)
except Exception:
  pass

# plotly use connected mode
try:
  import plotly.io as pio
  pio.renderers.default = "notebook_connected"
except Exception:
  pass

# enable pandas latex repr when targeting pdfs
try:
  import pandas as pd
  if fig_format == 'pdf':
    pd.set_option('display.latex.repr', True)
except Exception:
  pass



# output kernel dependencies
kernel_deps = dict()
for module in list(sys.modules.values()):
  # Some modules play games with sys.modules (e.g. email/__init__.py
  # in the standard library), and occasionally this can cause strange
  # failures in getattr.  Just ignore anything that's not an ordinary
  # module.
  if not isinstance(module, types.ModuleType):
    continue
  path = getattr(module, "__file__", None)
  if not path:
    continue
  if path.endswith(".pyc") or path.endswith(".pyo"):
    path = path[:-1]
  if not os.path.exists(path):
    continue
  kernel_deps[path] = os.stat(path).st_mtime
print(json.dumps(kernel_deps))

# set run_path if requested
if r'/home/borea17/GIT/borea17.github.io/blog/paper_summaries/u_net':
  os.chdir(r'/home/borea17/GIT/borea17.github.io/blog/paper_summaries/u_net')

# reset state
%reset

def ojs_define(**kwargs):
  import json
  try:
    # IPython 7.14 preferred import
    from IPython.display import display, HTML
  except:
    from IPython.core.display import display, HTML

  # do some minor magic for convenience when handling pandas
  # dataframes
  def convert(v):
    try:
      import pandas as pd
    except ModuleNotFoundError: # don't do the magic when pandas is not available
      return v
    if type(v) == pd.Series:
      v = pd.DataFrame(v)
    if type(v) == pd.DataFrame:
      j = json.loads(v.T.to_json(orient='split'))
      return dict((k,v) for (k,v) in zip(j["index"], j["data"]))
    else:
      return v
  
  v = dict(contents=list(dict(name=key, value=convert(value)) for (key, value) in kwargs.items()))
  display(HTML('<script type="ojs-define">' + json.dumps(v) + '</script>'), metadata=dict(ojs_define = True))
globals()["ojs_define"] = ojs_define


In [2]:
from PIL import Image
import torch
import torchvision.transforms as transforms


def load_dataset():
    num_img, img_size = 30, 512
    # initialize
    imgs = torch.zeros(num_img, 1, img_size, img_size)
    labels = torch.zeros(num_img, 1, img_size, img_size)
    # fill tensors with data
    for index in range(num_img):
        cur_name = str(index) + '.png'

        img_frame = Image.open('./Dataset/train/image/' + cur_name)
        label_frame = Image.open('./Dataset/train/label/' + cur_name)

        imgs[index] = transforms.ToTensor()(img_frame).type(torch.float32)
        labels[index] = transforms.ToTensor()(label_frame).type(torch.float32)
    return imgs, labels


imgs, labels = load_dataset()

In [3]:
from torch.utils.data import Dataset
import numpy as np
from scipy.ndimage.interpolation import map_coordinates
from scipy.signal import convolve2d
import torchvision.transforms.functional as TF


class EM_Dataset(Dataset):
    """EM Dataset (from ISBI 2012) to train U-Net on including data
    augmentation as proposed by Ronneberger et al. (2015)

    Args:
        imgs (tensor): torch tensor containing input images [1, 512, 512]
        labels (tensor): torch tensor containing segmented images [1, 512, 512]
        stride (int): stride that is used for overlap-tile strategy,
            Note: stride must be chosen such that all labels are retrieved
        transformation (bool): transform should be applied (True) or not (False)
    ------- transformation related -------
        probability (float): probability that transformation is applied
        alpha (float): intensity of elastic deformation
        sigma (float): std dev. of Gaussian kernel, i.e., smoothing parameter
        kernel dim (int): kernel size is [kernel_dim, kernel_dim]
    """

    def __init__(self, imgs, labels, stride, transformation=False,
                 probability=None, alpha=None, sigma=None, kernel_dim=None):
        super().__init__()
        assert isinstance(stride, int) and stride <= 124 and \
          round((512-388)/stride) == (512-388)/stride
        self.orig_imgs = imgs
        self.imgs = EM_Dataset._extrapolate_by_mirroring(imgs)
        self.labels = labels
        self.stride = stride
        self.transformation = transformation
        if transformation:
            assert 0 <= probability <= 1
            self.probability = probability
            self.alpha = alpha
            self.kernel = EM_Dataset._create_gaussian_kernel(kernel_dim, sigma)
        return

    def __getitem__(self, index):
        """images and labels are divided into several overlaping parts using the
        overlap-tile strategy
        """
        number_of_tiles_1D = (1 + int((512 - 388)/self.stride))
        number_of_tiles_2D = number_of_tiles_1D**2

        img_index = int(index/number_of_tiles_2D)
        # tile indexes of image
        tile_index = (index % number_of_tiles_2D)
        tile_index_x = (tile_index % number_of_tiles_1D) * self.stride
        tile_index_y = int(tile_index / number_of_tiles_1D) * self.stride

        img = self.imgs[img_index, :,
                        tile_index_y:tile_index_y + 572,
                        tile_index_x:tile_index_x + 572]
        label = self.labels[img_index, :,
                            tile_index_y: tile_index_y + 388,
                            tile_index_x: tile_index_x + 388]
        if self.transformation:
            if np.random.random() > 1 - self.probability:
                img, label = EM_Dataset.elastic_deform(img, label, self.alpha,
                                                       self.kernel)
        return (img, label)

    def __len__(self):
        number_of_imgs = len(self.imgs)
        number_of_tiles = (1 + int((512 - 388)/self.stride))**2
        return number_of_imgs * number_of_tiles

    @staticmethod
    def gray_value_variations(image, sigma):
        """applies gray value variations by adding Gaussian noise

        Args:
            image (torch tensor): extrapolated image tensor [1, 572, 572]
            sigma (float): std. dev. of Gaussian distribution

        Returns:
            image (torch tensor): image tensor w. gray value var. [1, 572, 572]
        """
        # see https://stats.stackexchange.com/a/383976
        noise = torch.randn(image.shape, dtype=torch.float32) * sigma
        return image + noise

    @staticmethod
    def affine_transform(image, label, angle, translate):
        """applies random affine translations and rotation on image and label

        Args:
            image (torch tensor): extrapolated image tensor [1, 572, 572]
            label (torch tensor): label tensor [1, 388, 388]
            angle (float): rotation angle
            translate (list): entries correspond to horizontal and vertical shift

        Returns:
            image (torch tensor): transformed image tensor [1, 572, 572]
            label (torch tensor): transformed label tensor [1, 388, 388]
        """
        # transform to PIL
        image = transforms.ToPILImage()(image[0])
        label = transforms.ToPILImage()(label[0])
        # apply affine transformation
        image = TF.affine(image, angle=angle, translate=translate,
                          scale=1, shear=0)
        label = TF.affine(label, angle=angle, translate=translate,
                          scale=1, shear=0)
        # transform back to tensor
        image = transforms.ToTensor()(np.array(image))
        label = transforms.ToTensor()(np.array(label))
        return image, label

    @staticmethod
    def elastic_deform(image, label, alpha, gaussian_kernel):
        """apply smooth elastic deformation on image and label data as
        described in

        [Simard2003] "Best Practices for Convolutional Neural Networks applied
        to Visual Document Analysis"

        Args:
            image (torch tensor): extrapolated image tensor [1, 572, 572]
            label (torch tensor): label tensor [1, 388, 388]
            alpha (float): intensity of transformation
            gaussian_kernel (np array): gaussian kernel used for smoothing

        Returns:
            deformed_img (torch tensor): deformed image tensor [1, 572, 572]
            deformed_label (torch tensor): deformed label tensor [1, 388, 388]

        code is adapted from https://github.com/vsvinayak/mnist-helper
        """
        # generate standard coordinate grids
        x_i, y_i = np.meshgrid(np.arange(572), np.arange(572))
        x_l, y_l = np.meshgrid(np.arange(388), np.arange(388))
        # generate random displacement fields (uniform distribution [-1, 1])
        dx = 2*np.random.rand(*x_i.shape) - 1
        dy = 2*np.random.rand(*y_i.shape) - 1
        # smooth by convolving with gaussian kernel
        dx = alpha * convolve2d(dx, gaussian_kernel, mode='same')
        dy = alpha * convolve2d(dy, gaussian_kernel, mode='same')
        # one dimensional coordinates (neccessary for map_coordinates)
        x_img = np.reshape(x_i + dx, (-1, 1))
        y_img = np.reshape(y_i + dy, (-1, 1))
        x_label = np.reshape(x_l + dx[92:480, 92:480], (-1, 1))
        y_label = np.reshape(y_l + dy[92:480, 92:480], (-1, 1))
        # deformation using map_coordinates interpolation (spline not bicubic)
        deformed_img = map_coordinates(image[0], [y_img, x_img], order=1,
                                       mode='reflect')
        deformed_label = map_coordinates(label[0], [y_label, x_label], order=1,
                                         mode='reflect')
        # reshape into desired shape and cast to tensor
        deformed_img = torch.from_numpy(deformed_img.reshape(image.shape))
        deformed_label = torch.from_numpy(deformed_label.reshape(label.shape))
        return deformed_img, deformed_label

    @staticmethod
    def _extrapolate_by_mirroring(data):
        """increase data by mirroring (needed for overlap-tile strategy)

        Args:
            data (torch tensor): shape [num_samples, 1, 512, 512]

        Returns:
            extrapol_data (torch tensor): shape [num_samples, 1, 696, 696]
        """
        num_samples = len(data)
        extrapol_data = torch.zeros(num_samples, 1, 696, 696)

        # put data into center of extrapol data
        extrapol_data[:,:, 92:92+512, 92:92+512] = data
        # mirror left
        extrapol_data[:,:, 92:92+512, 0:92] = data[:,:,:,0:92].flip(3)
        # mirror right
        extrapol_data[:,:, 92:92+512, 92+512::] = data[:,:,:,-92::].flip(3)
        # mirror top
        extrapol_data[:,:, 0:92,:] = extrapol_data[:,:,92:92+92,:].flip(2)
        # mirror buttom
        extrapol_data[:,:, 92+512::,:] = extrapol_data[:,:, 512:512+92,:].flip(2)
        return extrapol_data

    @staticmethod
    def _create_gaussian_kernel(kernel_dim, sigma):
        """returns a 2D Gaussian kernel with the standard deviation
        denoted by sigma

        Args:
            kernel_dim (int): kernel size will be [kernel_dim, kernel_dim]
            sigma (float): std dev of Gaussian (smoothing parameter)

        Returns:
            gaussian_kernel (numpy array): centered gaussian kernel

        code is adapted from https://github.com/vsvinayak/mnist-helper
        """
        # check if the dimension is odd
        if kernel_dim % 2 == 0:
            raise ValueError("Kernel dimension should be odd")
        # initialize the kernel
        kernel = np.zeros((kernel_dim, kernel_dim), dtype=np.float16)
        # calculate the center point
        center = kernel_dim/2
        # calculate the variance
        variance = sigma ** 2
        # calculate the normalization coefficeint
        coeff = 1. / (2 * variance)
        # create the kernel
        for x in range(0, kernel_dim):
            for y in range(0, kernel_dim):
                x_val = abs(x - center)
                y_val = abs(y - center)
                numerator = x_val**2 + y_val**2
                denom = 2*variance

                kernel[x,y] = coeff * np.exp(-1. * numerator/denom)
        # normalise it
        return kernel/sum(sum(kernel))


# generate datasets
stride = 124
whole_dataset = EM_Dataset(imgs, labels, stride=stride,
                           transformation=True, probability=0.5, alpha=50,
                           sigma=5, kernel_dim=25)