In [None]:
! pip install torchshow

In [None]:
import zipfile
import os
import torch
import tifffile as tiff
import numpy as np
import matplotlib.pyplot as plt
import torchshow as ts
import torchvision.transforms.functional as F
import time
import random

from tqdm import tqdm
from sklearn.model_selection import train_test_split
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torch.optim import Adam
from torchsummary import summary
from torch.nn import BCEWithLogitsLoss

print(torch.__version__)

In [None]:
# see what GPU is currently being used

print(torch.cuda.device_count())
print(torch.cuda.current_device())
print(torch.cuda.get_device_name())

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)


### PREPARE DATASET

The dataset consists of 19,200 images of cells but only 1,200 masks. The first step is to match up the images with their ground truth masks.


In [None]:

!wget --no-check-certificate \
    "https://data.broadinstitute.org/bbbc/BBBC005/BBBC005_v1_images.zip" \
    -O "/tmp/BBBC005_v1_images.zip"


zip_ref = zipfile.ZipFile('/tmp/BBBC005_v1_images.zip', 'r') #Opens the zip file in read mode
zip_ref.extractall('/tmp') #Extracts the files into the /tmp folder
zip_ref.close()

In [None]:
!wget --no-check-certificate \
    "https://data.broadinstitute.org/bbbc/BBBC005/BBBC005_v1_ground_truth.zip" \
    -O "/tmp/BBBC005_v1_ground_truth.zip"


zip_ref = zipfile.ZipFile('/tmp/BBBC005_v1_ground_truth.zip', 'r') #Opens the zip file in read mode
zip_ref.extractall('/tmp') #Extracts the files into the /tmp folder
zip_ref.close()

In [None]:
image_path = '/tmp/BBBC005_v1_images/'
mask_path = '/tmp/BBBC005_v1_ground_truth/'

image_paths = sorted(os.listdir(image_path))
mask_paths = sorted(os.listdir(mask_path))

len(image_paths), len(mask_paths)

In [None]:
image_paths[:5], mask_paths[:5]

We want to remove the '.htaccess' file because it is of no use to us.

In [None]:
image_paths.remove('.htaccess')
mask_paths.remove('.htaccess')

len(image_paths), len(mask_paths)

In [None]:
image_paths[:5], mask_paths[:5]

Now we want to create a list of all images that have an associated mask:

In [None]:
new_list = sorted(list(set(image_paths) & set(mask_paths)))
len(new_list), type(new_list)

In [None]:
new_list[:5]

Now that we have our list of 1200 images and their associated masks, lets split this into train and test sets:

In [None]:
train_list, test_list = train_test_split(new_list, test_size=0.1)

len(train_list), len(test_list)

Now lets visualize a couple of images and their ground truth segmentations:

In [None]:
img = tiff.imread(image_path + train_list[150])
mask = tiff.imread(mask_path + train_list[150])

img.shape, mask.shape

In [None]:
fig = plt.figure(figsize=(15,15))

fig.add_subplot(1,2,1)
plt.title('image')
plt.imshow(img, cmap='gray')

fig.add_subplot(1,2,2)
plt.title('ground truth mask')
plt.imshow(mask, cmap='gray')
plt.show()

Here we can see a good example of an image and the ground truth mask. This is what we will be using to train our image segmentation model. The next step is to create a dataset and a dataloader.

In [None]:
class SegmentationDataset(Dataset):
    def __init__(self, imagePaths, maskPaths, transforms):
        # store the image and mask filepaths, and augmentation
        self.imagePaths = imagePaths
        self.maskPaths = maskPaths
        self.transforms = transforms

    def __len__(self):
        # return the number of total samples contained in the dataset
        return len(self.imagePaths)
    
    def __getitem__(self, i):
        # grab the image path from the current index
        imagePath = self.imagePaths[i]
        maskPath = self.maskPaths[i]

        image = tiff.imread(image_path + imagePath)
        mask = tiff.imread(mask_path + maskPath)
        
        if self.transforms is not None:
            # apply the transformations to both image and its mask
            image = self.transforms(image)
            mask = self.transforms(mask)
        
        return image, mask

In [None]:
train_batch_size = 64
test_batch_size = 64
img_size = 128

transforms = transforms.Compose([transforms.ToPILImage(),
                                 transforms.Resize((img_size, img_size)),
                                 transforms.ToTensor()])

In [None]:
## TRAIN DATASET
train_dataset = SegmentationDataset(train_list, train_list, transforms)

train_loader = DataLoader(train_dataset, 
                          batch_size = train_batch_size, 
                          shuffle = True)

## TEST DATASET
test_dataset = SegmentationDataset(test_list, test_list, transforms)

test_loader = DataLoader(test_dataset, 
                         batch_size = test_batch_size, 
                         shuffle = True)

In [None]:
trainImg, trainLabels = next(iter(train_loader))
print('trainImg shape: ', trainImg.shape)
print('trainLabels shape: ', trainLabels.shape)

print("\n")
print('----------------------------------------------------')

testImg, testLabels = next(iter(test_loader))
print('testImg shape: ', testImg.shape)
print('testLabels shape: ', testLabels.shape)



In [None]:
def display_batch(image, label, batch_num):
    image = image[batch_num,:,:,:].numpy()[0,:,:]
    label = label[batch_num,:,:,:].numpy()[0,:,:]

    figure, ax = plt.subplots(nrows=1, ncols=2, figsize=(7, 7))
    
    # plot the original image, its mask, and the predicted mask
    ax[0].imshow(image, cmap='gray')
    ax[1].imshow(label, cmap='gray')
    
    # set the titles of the subplots
    ax[0].set_title(f"Image # {batch_num}")
    ax[1].set_title(f"Original Mask # {batch_num}")
    
    # set the layout of the figure and display it
    figure.tight_layout()
    figure.show()


In [None]:
for i in range(3):
    display_batch(trainImg, trainLabels, i)

Here we can see a good example of an image and the ground truth mask. This is what we will be using to train our image segmentation model. The next step is to create a dataset and a dataloader.


### Define UNet model

In [None]:
import torch.nn as nn

class double_conv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(double_conv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class inconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(inconv, self).__init__()
        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x):
        x = self.conv(x)
        return x


class down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(down, self).__init__()
        self.mpconv = nn.Sequential(nn.MaxPool2d(2),
                                    double_conv(in_ch, out_ch))

    def forward(self, x):
        x = self.mpconv(x)
        return x


class up(nn.Module):
    def __init__(self, in_ch, out_ch, bilinear=True):
        super(up, self).__init__()

        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode="bilinear",
                                  align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2,
                                         2, stride=2)

        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x1, x2):
        x1 = self.up(x1)

        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2))
        
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class outconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(outconv, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 1)

    def forward(self, x):
        x = self.conv(x)
        return x


class UNet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super(UNet, self).__init__()
        self.inc = inconv(n_channels, 64)
        self.down1 = down(64, 128)
        self.down2 = down(128, 256)
        self.down3 = down(256, 512)
        self.down4 = down(512, 512)
        self.up1 = up(1024, 256, False)
        self.up2 = up(512, 128, False)
        self.up3 = up(256, 64, False)
        self.up4 = up(128, 64, False)
        self.outc = outconv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        return torch.sigmoid(x)

In [None]:
model = UNet(n_channels=1, n_classes=1).float()
model.to(device)

summary(model, (1, img_size, img_size))

In [None]:
def f_score(pr, gt, beta=1, eps=1e-7, threshold=None, activation='sigmoid'):
    """
    Args:
        pr (torch.Tensor): A list of predicted elements
        gt (torch.Tensor):  A list of elements that are to be predicted
        eps (float): epsilon to avoid zero division
        threshold: threshold for outputs binarization
    Returns:
        float: IoU (Jaccard) score
    """

    if activation is None or activation == "none":
        activation_fn = lambda x: x
    elif activation == "sigmoid":
        activation_fn = torch.nn.Sigmoid()
    elif activation == "softmax2d":
        activation_fn = torch.nn.Softmax2d()
    else:
        raise NotImplementedError(
            "Activation implemented for sigmoid and softmax2d"
        )

    pr = activation_fn(pr)

    if threshold is not None:
        pr = (pr > threshold).float()


    tp = torch.sum(gt * pr)
    fp = torch.sum(pr) - tp
    fn = torch.sum(gt) - tp

    score = ((1 + beta ** 2) * tp + eps) \
            / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + eps)

    return score


class DiceLoss(nn.Module):
    __name__ = 'dice_loss'

    def __init__(self, eps=1e-7, activation='sigmoid'):
        super().__init__()
        self.activation = activation
        self.eps = eps

    def forward(self, y_pr, y_gt):
        return 1 - f_score(y_pr, y_gt, beta=1., 
                           eps=self.eps, threshold=None, 
                           activation=self.activation)


class BCEDiceLoss(DiceLoss):
    __name__ = 'bce_dice_loss'

    def __init__(self, eps=1e-7, activation='sigmoid', lambda_dice=1.0, lambda_bce=1.0):
        super().__init__(eps, activation)
        if activation == None:
            self.bce = nn.BCELoss(reduction='mean')
        else:
            self.bce = nn.BCEWithLogitsLoss(reduction='mean')
        self.lambda_dice=lambda_dice
        self.lambda_bce=lambda_bce

    def forward(self, y_pr, y_gt):
        dice = super().forward(y_pr, y_gt)
        bce = self.bce(y_pr, y_gt)
        return (self.lambda_dice*dice) + (self.lambda_bce* bce)
    

def dice(img1, img2):
    img1 = np.asarray(img1).astype(np.bool)
    img2 = np.asarray(img2).astype(np.bool)

    intersection = np.logical_and(img1, img2)

    return 2.0 * intersection.sum() / (img1.sum() + img2.sum())


def dice_no_threshold(
    outputs: torch.Tensor,
    targets: torch.Tensor,
    eps: float = 1e-7,
    threshold: float = None,
):
    """
    Reference:
    https://catalyst-team.github.io/catalyst/_modules/catalyst/dl/utils/criterion/dice.html
    """
    if threshold is not None:
        outputs = (outputs > threshold).float()

    intersection = torch.sum(targets * outputs)
    union = torch.sum(targets) + torch.sum(outputs)
    dice = 2 * intersection / (union + eps)

    return dice

In [None]:
criterion = BCEDiceLoss(eps=1.0, activation=None)
optimizer = Adam(model.parameters(), lr = 0.005)
current_lr = [param_group['lr'] for param_group in optimizer.param_groups][0]
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.2,
                                                       patience=2, cooldown=2)

In [None]:
# number of epochs to train the model
n_epochs = 4
train_loss_list = []
valid_loss_list = []
train_dice_list = []
val_dice_list = []
lr_rate_list = []

valid_loss_min = np.Inf # track change in validation loss
for epoch in range(1, n_epochs+1):

    # keep track of training and validation loss
    train_loss = 0.0
    valid_loss = 0.0
    train_dice_score = 0.0
    val_dice_score = 0.0
    ###################
    # train the model #
    ###################
    model.train()
    bar = tqdm(train_loader, postfix={"train_loss":0.0})
    for data, target in bar:
        # move tensors to GPU
        data, target = data.cuda(), target.cuda()
        # clear the gradients of all optimized variables
        optimizer.zero_grad()
        # forward pass: compute predicted outputs by passing inputs to the model
        output = model(data)
        # calculate the batch loss
        loss = criterion(output, target)
        #print(loss)
        # backward pass: compute gradient of the loss with respect to model parameters
        loss.backward()
        # perform a single optimization step (parameter update)
        optimizer.step()
        # update training loss
        train_loss += loss.item()*data.size(0)
        
        train_dice_cof = dice_no_threshold(output.cpu(), target.cpu()).item() ## train dice cof
        train_dice_score +=  train_dice_cof * data.size(0)
            
        bar.set_postfix(ordered_dict={"train_loss":loss.item()})
    ######################    
    # validate the model #
    ######################
    model.eval()
    del data, target
    with torch.no_grad():
        bar = tqdm(test_loader, postfix={"test_loss":0.0, "dice_score":0.0})
        for data, target in bar:
            # move tensors to GPU
            data, target = data.cuda(), target.cuda()
            # forward pass: compute predicted outputs by passing inputs to the model
            output = model(data)
            # calculate the batch loss
            loss = criterion(output, target)
            # update average validation loss 
            valid_loss += loss.item()*data.size(0)
            
            val_dice_cof = dice_no_threshold(output.cpu(), target.cpu()).item()
            val_dice_score +=  val_dice_cof * data.size(0)
            
            bar.set_postfix(ordered_dict={"valid_loss":loss.item(), "dice_score":val_dice_cof})
    
    # calculate average losses
    train_loss = train_loss/len(train_loader.dataset)
    valid_loss = valid_loss/len(test_loader.dataset)
    train_dice_score = train_dice_score/len(train_loader.dataset)
    val_dice_score = val_dice_score/len(test_loader.dataset)
    train_loss_list.append(train_loss)
    valid_loss_list.append(valid_loss)
    train_dice_list.append(train_dice_score)
    val_dice_list.append(val_dice_score)
    lr_rate_list.append([param_group['lr'] for param_group in optimizer.param_groups])
    
    # print training/validation statistics 
    print('Epoch: {}  Training Loss: {:.6f}  Validation Loss: {:.6f} Traning Dice Score: {:.6f} Valid Dice Score: {:.6f}'.format(
        epoch, train_loss, valid_loss, train_dice_score, val_dice_score))
    
    # save model if validation loss has decreased
    if valid_loss <= valid_loss_min:
        print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(
        valid_loss_min,
        valid_loss))
        torch.save(model.state_dict(), 'model_capstone.pt')
        valid_loss_min = valid_loss
    
    scheduler.step(valid_loss)

In [None]:
figure, ax = plt.subplots(nrows=1, ncols=3, figsize=(15, 5))
        
# plot the original image, its mask, and the predicted mask
ax[0].plot(lr_rate_list, marker='o', label="learning rate")
ax[1].plot(train_loss_list,  marker='o', label="Training Loss")
ax[1].plot(valid_loss_list,  marker='o', label="Validation Loss")
ax[2].plot(train_dice_list, marker='o', label="Training Dice")
ax[2].plot(val_dice_list, marker='o', label="Validation Dice")
# ax[3].plot(time_list, marker='o', label="Training Time")
    
# set the titles of the subplots
ax[0].set_title("Learning rate during training")
ax[1].set_title("Loss during training")
ax[2].set_title("Dice during training")
# ax[3].set_title("Training Time")

# add legend for loss an dice
ax[1].legend(loc='right')
ax[2].legend(loc='right')

# set the axis labels of the subplots
for i in range(3):
    ax[i].set_xlabel('epoch')
    ax[0].set_ylabel('learning rate')
    ax[1].set_ylabel('loss')
    ax[2].set_ylabel('dice')
# ax[3].set_ylabel('time (s)')
    
# set the layout of the figure and display it
figure.tight_layout()
figure.show()   

In [None]:
# load best model
model.load_state_dict(torch.load('model_capstone.pt'))
model.eval();

In [None]:
outputs = model(testImg.to(device)).cpu().detach()
outputs.shape

In [None]:
def display_pred(image, label, pred, batch_num):
    dice_score = dice(pred[batch_num,:,:], label[batch_num,:,:]).item()
    image = image[batch_num,:,:,:].numpy()[0,:,:]
    label = label[batch_num,:,:,:].numpy()[0,:,:]
    pred = pred[batch_num,:,:,:].cpu().detach().numpy()[0,:,:]
    
    figure, ax = plt.subplots(nrows=1, ncols=3, figsize=(10, 10))
        
    # plot the original image, its mask, and the predicted mask
    ax[0].imshow(image, cmap='gray')
    ax[1].imshow(label, cmap='gray')
    ax[2].imshow(pred, cmap='gray')
    
    # set the titles of the subplots
    ax[0].set_title("Original Image")
    ax[1].set_title("Ground Truth Mask")
    ax[2].set_title("Predicted Mask (dice score: {:.4f})".format(dice_score))
    
    for i in range(3):
        ax[i].set_xticks([])
        ax[i].set_yticks([])
    
    # set the layout of the figure and display it
    figure.tight_layout()
    figure.show() 

In [None]:
for i in random.sample(range(1, testImg.shape[0]), 3):
    display_pred(testImg, testLabels, outputs, i)