Some dependencies...

In [None]:
from sklearn.model_selection import train_test_split
import torch
from torch import nn
from torch import cat
import os
import warnings
import numpy as np
import pandas as pd
import random
!pip install rasterio
import rasterio
!pip install pytorch_lightning
import pytorch_lightning
import logging

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Create util functions

In [3]:
def extract_paths(directory):
    """
    The function reads all files names in the given root directory and returns two
    dataframe: contents and masks.

    :param directory: str
    :return:
    contents:
            dataframe containing info of content image paths and content image id
    masks:
            dataframe contain info of mask image path and mask image id
    """

    masks = []
    contents = []
    # walk through all files in the given directory
    for (root, dirs, files) in os.walk(directory, topdown=False):
        # extract patient id from root path
        patient_id = root.split("/")[-1]

        # skip the main file
        if "TCGA" not in patient_id:
            continue

        for file_name in files:
            # full directory string for the file
            full_dir = os.path.join(root, file_name)
            if "mask" in full_dir:
                mask_id = int(full_dir[77:-9])
                masks.extend([patient_id, full_dir, mask_id])
            else:
                content_id = int(full_dir[77:-4])
                contents.extend([patient_id, full_dir, content_id])

    # split patient_id and full_dir
    mask_patient_ids = masks[::3]
    mask_full_dirs = masks[1::3]
    mask_ids = masks[2::3]
    content_patient_ids = contents[::3]
    content_full_dirs = contents[1::3]
    content_ids = contents[2::3]

    # combine two list (patient_id and full_dir) into a dataframe
    # 1. convert two list into a dictionary
    masks = {"patient_id": mask_patient_ids,
             "full_mask_dir": mask_full_dirs,
             "mask_id": mask_ids}
    contents = {"patient_id": content_patient_ids,
                "full_content_dir": content_full_dirs,
                "content_id": content_ids}
    # 2. convert the dict into a dataframe
    masks = pd.DataFrame(masks)
    contents = pd.DataFrame(contents)

    return masks, contents


def sort_combine_paths(masks, contents):
    """
    The function takes path dataframe for mask images and path dataframe for content images,
    sorts and combines the two dataframes so that the mask file path and the content file
    path are matched in a single row in a single dataframe

    :param masks: dataframe
    :param contents: dataframe
    :return: dataframe
    """

    # sort dataframes with a key function, making sure that mask path and
    # content path is corresponding to each other at each row index
    contents = contents.sort_values(by=["patient_id", "content_id"], ignore_index=True)
    masks = masks.sort_values(by=["patient_id", "mask_id"], ignore_index=True)
    # contents = sorted(original_contents["full_content_dir"].values, key=lambda x: int(x[56:-4]))
    # masks = sorted(masks["full_mask_dir"].values, key=lambda x: int(x[56:-9]))

    # combine two dataframe together
    content_paths = contents.iloc[:, 1]
    mask_paths = masks.iloc[:, 1]
    patient_ids = contents.iloc[:, 0]
    dir_df = pd.DataFrame({
        "patient_id": patient_ids,
        "content_path": content_paths,
        "mask_path": mask_paths
    })

    # randomly select a row, check if paths at the same row are not matched
    idx = random.randint(0, len(dir_df) - 1)
    content_path = dir_df.iloc[idx, 1]
    mask_path = dir_df.iloc[idx, 2]
    if content_path[:-4] != mask_path[:-9]:
        raise Exception("Something failed for matching process")

    return dir_df.sample(frac=1)  # shuffle dataframe


def load_data(df_dir: pd.DataFrame, start_idx: int = 0, shuffle: bool = False, batch_size: int = 1):
    """
    The function read images by paths in df_dir and return content image arrays and mask image arrays.
    :param start_idx: int
    :param batch_size: int
    :param shuffle: boolean
    :param df_dir: dataframe
    :return: ndarray, ndarray
    """
    # check if random shuffle
    if shuffle:
        df_dir.sample(frac=1).reset_index(drop=True)

    # check if batch size is valid
    if batch_size > df_dir.shape[0]:
        raise IndexError("batch size is too large")

    # ignore warning
    warnings.filterwarnings("ignore", category=rasterio.errors.NotGeoreferencedWarning)

    # read images an convert into tensor
    content_images = []
    mask_images = []
    for i in range(batch_size):
        # read content image
        content_path = df_dir.iloc[start_idx + i, 1]
        mask_path = df_dir.iloc[start_idx + i, 2]

        # read image by rasterio
        content_image = rasterio.open(content_path).read()
        mask_image = rasterio.open(mask_path).read()
        content_images.append(content_image)
        mask_images.append(mask_image)

    return np.array(content_images), np.array(mask_images)


def adjust_data(img, mask):
    """
    The function rescaled data in img to [0, 1] and convert data in mask to binary
    :param img: dataframe
    :param mask: dataframe
    :return: dataframes
    """
    img = img / 255
    mask = mask / 255
    mask[mask > 0.5] = 1
    mask[mask <= 0.5] = 0

    return img, mask


Build my customized nn layers

In [4]:
class DoubleConv(nn.Module):
    """
    double_conv layer contains two convolutional layer
    Conv2d -> BatchNorm2d -> ReLU -> Conv2d -> BatchNorm2d -> ReLU
    the layer does not contain a max pooling layer because its output is useful for up-sampling
    """
    def __init__(self, in_channel, out_channel):
        # inherited properties passed from nn.Module
        super(DoubleConv, self).__init__()
        # define mid channel
        self.mid_channel = out_channel
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.double_conv = nn.Sequential(
            # out_channel means how many convolutional kernels are we using
            # padding = 'same' keeps image size (HxW)
            nn.Conv2d(self.in_channel, self.mid_channel, kernel_size=3, padding='same'),
            # batch normalization making sure values in feature map follows normal distribution
            nn.BatchNorm2d(self.mid_channel),
            # an activation layer after each convolutional layer. Turn on inplace to save memory
            nn.ReLU(inplace=True),

            nn.Conv2d(self.mid_channel, self.out_channel, kernel_size=3, padding='same'),
            nn.BatchNorm2d(self.out_channel),
            nn.ReLU(inplace=True),
        )

    def forward(self, input_layer):
        """
        The forward function passes input_layer through the double convolutional layer.
        :param input_layer: tensor [batch size, input channel size, H, W]
        :return: tensor [batch size, output channel size, H, W]
        """
        output_layer = self.double_conv(input_layer)
        return output_layer


class Up(nn.Module):
    def __init__(self, in_channel, out_channel):
        # out_channel is 2 * in_channel
        super(Up, self).__init__()
        self.mid_channel = out_channel
        # an up-conv layer
        # H_out = (H_in - 1) * stride - 2 * padding + kernel_size
        self.up = nn.ConvTranspose2d(in_channel, self.mid_channel, stride=2, kernel_size=2)
        # a DoubleConv layer
        # input channel size is 2 * mid channel size because we need to concatenate
        self.double_conv = DoubleConv(self.mid_channel * 2, out_channel)

    def forward(self, input_layer1, input_layer2):
        input_layer1 = self.up(input_layer1)
        input_layer = cat((input_layer1, input_layer2), dim=1)  # axis: channel
        output_layer = self.double_conv(input_layer)
        return output_layer


class MyLoss(nn.Module):
    def __init__(self, dice_loss_mode: bool = True, smooth: float = 0.01):
        super(MyLoss, self).__init__()
        self.dice_loss_mode = dice_loss_mode
        self.smooth = smooth
        return

    def forward(self, pred, target):

        if self.dice_loss_mode:
            # flatten pred and target
            pred_flattened = pred.reshape(-1)
            target_flattened = target.reshape(-1)

            # intersection
            intersect = torch.dot(pred_flattened, target_flattened)

            # sum
            sum_two = torch.sum(pred_flattened) + torch.sum(target_flattened)

            # dice_score = 2 * |A∩B| / (|A| + |B|)
            loss = - (2 * intersect + self.smooth) / (sum_two + self.smooth)
        else:
            pred = torch.round(pred)  # convert into binary mask

            # intersection
            intersect = torch.sum(pred * target)

            # union
            union = torch.ceil((pred + target)/2)

            # Jaccard = |A∩B| / |A∪B|
            loss = - (intersect + self.smooth) / (torch.sum(union) + self.smooth)
        return loss

In [5]:
def iou_score(pred, target, smooth: int = 0.001):
    pred = torch.round(pred)  # convert into binary mask

    # intersection
    intersect = torch.sum(pred * target)

    # union
    union = torch.ceil((pred + target)/2)

    # Jaccard = |A∩B| / |A∪B|
    iou_score = (intersect + smooth) / (torch.sum(union) + smooth)

    return iou_score

def dice_score(pred, target, smooth: int=0.001):
    # flatten pred and target
    pred_flattened = pred.reshape(-1)
    target_flattened = target.reshape(-1)

    # intersection
    intersect = torch.dot(pred_flattened, target_flattened)

    # sum
    sum_two = torch.sum(pred_flattened) + torch.sum(target_flattened)

    # dice_score = 2 * |A∩B| / (|A| + |B|)
    dice_score = (2 * intersect + smooth) / (sum_two + smooth)
    return dice_score


Build my CNN model

In [6]:
class Unet(nn.Module):
    def __init__(self, in_channel):
        super(Unet, self).__init__()
        self.max_pool = nn.MaxPool2d(kernel_size=2)
        self.down1 = DoubleConv(in_channel, 32)
        self.down2 = DoubleConv(32, 64)
        self.down3 = DoubleConv(64, 128)
        self.down4 = DoubleConv(128, 256)
        self.down5 = DoubleConv(256, 512)
        self.up6 = Up(512, 256)
        self.up7 = Up(256, 128)
        self.up8 = Up(128, 64)
        self.up9 = Up(64, 32)
        # output a single channel (binary)
        self.up10 = nn.ConvTranspose2d(32, 1, stride=1, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, im_input):
        im_down1 = self.down1(im_input)
        im_down1_pooled = self.max_pool(im_down1)
        im_down2 = self.down2(im_down1_pooled)
        im_down2_pooled = self.max_pool(im_down2)
        im_down3 = self.down3(im_down2_pooled)
        im_down3_pooled = self.max_pool(im_down3)
        im_down4 = self.down4(im_down3_pooled)
        im_down4_pooled = self.max_pool(im_down4)
        im_down5 = self.down5(im_down4_pooled)  # max pooling is not followed
        im_up6 = self.up6(im_down5, im_down4)
        im_up7 = self.up7(im_up6, im_down3)
        im_up8 = self.up8(im_up7, im_down2)
        im_up9 = self.up9(im_up8, im_down1)
        im_output = self.up10(im_up9)
        return self.sigmoid(im_output)

Create a validation function

In [7]:
def validation(unet: Unet, validation_img: pd.DataFrame, device, batch_size: int = 10):
    # load data into device
    validation_content, validation_mask = load_data(validation_img, shuffle=True, batch_size=batch_size)  #validation_img.shape[0]
    validation_content, validation_mask = adjust_data(validation_content, validation_mask)
    validation_content_tensor = torch.from_numpy(validation_content).to(device).float()
    validation_mask_tensor = torch.from_numpy(validation_mask).to(device).float()

    unet.eval()  # close BatchNorm2d during validation
    # calculate prediction of validation image with no autograd mechanism
    with torch.no_grad():
        pred_mask_tensor = unet.forward(validation_content_tensor)
        pred_mask_tensor = torch.round(pred_mask_tensor)
        pred_dice_score = dice_score(pred_mask_tensor, validation_mask_tensor)
        pred_iou_score = iou_score(pred_mask_tensor, validation_mask_tensor)

    unet.train()  # convert back to training mode
    return pred_dice_score, pred_iou_score

Create a training function

In [8]:
def train(train_img: pd.DataFrame,
          validation_img: pd.DataFrame,
          epoch: int = 1,
          lrate: float = 0.0001,
          shuffle: bool = False,
          batch_size: int = 32,
          device_: str = 'cpu',
          use_cel: bool = True):
    # define device
    if device_ != 'cuda' and device_ != 'cpu':
        raise ValueError("invalid device")
    if not torch.cuda.is_available() and device_ == 'cuda':
        device = torch.device('cpu')
    else:
        device = torch.device(device_)
    print("using", device.type)

    iteration = train_img.shape[0] // batch_size  # how many iterations do we need to go through all data

    # customize training process
    interval = 20
    # logging.basicConfig(level=logging.INFO, filename='./content/drive/MyDrive/my_training_logs/test.log', filemode='w')

    # initialize model and model parameters
    print("initialize model")
    unet = Unet(3).to(device)  # 3 RGB channels
    optimizer = torch.optim.Adam(unet.parameters(), lr=0.0001)
    criterion = nn.CrossEntropyLoss()
    dice_loss = MyLoss(dice_loss_mode=True, smooth=0.01)

    for e in range(epoch):
        for i in range(iteration):
            # get train images and masks (we drop last few samples)
            train_content, train_mask = load_data(train_img,
                                                  start_idx=i * batch_size,
                                                  shuffle=shuffle,
                                                  batch_size=batch_size)

            # rescaled content data and convert mask data into binary digit
            train_content, train_mask = adjust_data(train_content, train_mask)

            # create tensor matrix for model to train
            # u cannot use torchvision.transforms.ToTensor() here because the function accept
            # a three-dimensional input (CxHxW), but we have four dimensions (NxCxHxW)
            train_content_tensor = torch.from_numpy(train_content).to(device)
            train_mask_tensor = torch.from_numpy(train_mask).to(device)

            # now our data is double(float64 in pytorch) while the weights in conv are float
            # convert our data to float32
            train_content_tensor = train_content_tensor.float()
            train_mask_tensor = train_mask_tensor.float()

            # train data
            # set the model in the training mode
            unet.train()
            optimizer.zero_grad()  # clear grad, avoid accumulation
            pred_mask_tensor = unet.forward(train_content_tensor)  # get model prediction
            
            # get prediction loss and accuracy
            if use_cel:
              loss = criterion(pred_mask_tensor, train_mask_tensor) + dice_loss(pred_mask_tensor, train_mask_tensor)
            else:
              loss = dice_loss(pred_mask_tensor, train_mask_tensor)
            accuracy = dice_score(pred_mask_tensor, train_mask_tensor)

            loss.backward()  # backpropagation
            optimizer.step()  # update model weight
            
            # print training log
            print('epoch-%s-iteration-%s: loss %s accuracy %s' % (e + 1, i + 1, loss.item(), accuracy.item()))  # print loss & accuracy
            # calculate and print validation accuracy every other 10 iterations
            if (i + 1) % interval == 0:
                valid_dice_score, valid_iou_score = validation(unet, validation_img, device)
                print('valid iou score: %s valid dice score: %s' % (valid_iou_score.item(), valid_dice_score.item()))
                # save model if valid dice exceed 0.75
                if valid_dice_score.item() > 0.85:
                  torch.save({'model': unet.state_dict()}, 'unet_epoche%s_iter%s.pth' % (e + 1, i + 1))

Next we could train our model

In [None]:
rewrite = False
if rewrite:
  path = "/content/drive/MyDrive/kaggle_3m"
  # select all file paths into two dataframes
  masks, contents = extract_paths(path)
  # sort paths and combine dataframes
  dir_df = sort_combine_paths(masks, contents)
  # split training set, validation set and test set, not depend on patient id
  train_dirs, test_dirs = train_test_split(dir_df, test_size=0.1)
  test_dirs.to_csv("/content/drive/MyDrive/test.csv")
  train_dirs.to_csv("/content/drive/MyDrive/train.csv")
  train_dirs, validation_dirs = train_test_split(train_dirs, test_size=0.2, random_state=40)
else:
  # read csv
  train_dirs = pd.read_csv('/content/drive/MyDrive/My model 2/train.csv', index_col = 0)
  train_dirs, validation_dirs = train_test_split(train_dirs, test_size=0.2, random_state=40)

train(train_img=train_dirs, validation_img=validation_dirs, lrate=0.001, use_cel=True, epoch=10, batch_size=32, device_='cuda')