## Clinoforms Identification using Image Segmentation (UNet)

In [1]:
import torch
import torch.utils.data as data
from torch.utils.data import DataLoader
from torch.nn import functional as F
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import cv2
import glob
%matplotlib inline
import matplotlib.pyplot as plt
import re
import os
import random
from random import randint

In [2]:
def sp_noise(image,prob):
    '''
    Add salt and pepper noise to image
    prob: Probability of the noise
    '''
    output = np.zeros(image.shape,np.uint8)
    thres = 1 - prob 
    for i in range(image.shape[0]):
        for j in range(image.shape[1]):
            rdn = random.random()
            if rdn < prob:
                output[i][j] = 0
            elif rdn > thres:
                output[i][j] = 255
            else:
                output[i][j] = image[i][j]
    return output


def gaussian_noise(image):
    ''' 
    Add gaussian noise to image
    '''
    row, col = image.shape
    mean = 0
    gauss = np.random.normal(mean,1,(row,col))
    gauss = gauss.reshape(row,col)
    noisy = image+gauss
    return noisy

Extend Dataset abstract class:

In [3]:
class SeismicData(data.Dataset):
    """ Load Seismic Dataset.
    Args:
        image_path(str): the path where the image is located
        mask_path(str): the path where the mask is located
        option(str): decide which dataset to import
    """
    def __init__(self, image_path, mask_path):
        self.image_path = image_path
        self.mask_path = mask_path
        self.mask_arr = glob.glob(str(mask_path)+"/*")
        self.image_arr = glob.glob(str(image_path)+str("/*"))
        self.data_len = len(self.mask_arr)
        
    def __getitem__(self, index):
        """Get specific data corresponding to the index
        Args:
            index: index of the data
        Returns:
            Tensor: specific data on index which is converted to Tensor
        """
        single_image_name = self.image_arr[index]
        
        imgID = fname = re.findall("Line\d+",single_image_name)[0]
        single_mask_name = os.path.join(self.mask_path, f"{imgID}_mask.png")
        
        # Read image and mask
        img = cv2.imread(single_image_name)
        mask = cv2.imread(single_mask_name)
        
        # convert image and mask to grayscale
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)

        # Resize to 256*1216 for U-net
        img = img[:,:1216]
        mask = mask[:,:1216]
        
        # Normalize mask to 0 and 1
        mask = mask/255
#         plt.imshow(img)
#         plt.show()
#         plt.imshow(mask)
#         plt.show()
        
        # Data augmentation on sampling
        ## Image flip
        if randint(0,1) == 1:
            # we flip horizontally the image and its mask
            img = cv2.flip(img,1)
            mask = cv2.flip(mask,1)
        
        addNoise = randint(0,2)
        if addNoise == 1:  ## Add some gaussian noise
            img = gaussian_noise(img)
        elif addNoise == 2:  # Add salt/pepper noise
            img = sp_noise(img,0.05)
                 
        
        img_as_tensor = torch.from_numpy(img).int()
        mask_as_tensor = torch.from_numpy(mask).int()
        
        # Reshape to (ch,h,w)
        img_as_tensor = torch.reshape(img_as_tensor,(1,img_as_tensor.shape[0],img_as_tensor.shape[1]))
        mask_as_tensor = torch.reshape(mask_as_tensor,(1,mask_as_tensor.shape[0],mask_as_tensor.shape[1]))
        
        # Reshape
        
#         plt.imshow(img)
#         plt.show()
#         plt.imshow(mask)
#         plt.show()
        return (img_as_tensor, mask_as_tensor)
        
    def __len__(self):
        return self.data_len   


Define train and test data

In [4]:
train_dataset = SeismicData(image_path='./seismic/train/images/', mask_path='./seismic/train/masks/')
test_dataset = SeismicData(image_path='./seismic/test/images/', mask_path='./seismic/test/masks/')

Create data loaders. We will have data augmentation on sampling

In [5]:
train_data_load = DataLoader(dataset=train_dataset, batch_size=5, shuffle=True)
test_data_load = DataLoader(dataset=test_dataset, batch_size=1, shuffle=True)

### Build Model (U-net)

#### Define loss function (Jaccard's loss)

In [6]:
def jaccard_loss(true, logits, eps=1e-7):
    """Computes the Jaccard loss, a.k.a the IoU loss.
    Note that PyTorch optimizers minimize a loss. In this
    case, we would like to maximize the jaccard loss so we
    return the negated jaccard loss.
    Args:
        true: a tensor of shape [B, H, W] or [B, 1, H, W].
        logits: a tensor of shape [B, C, H, W]. Corresponds to
            the raw output or logits of the model.
        eps: added to the denominator for numerical stability.
    Returns:
        jacc_loss: the Jaccard loss.
    """
    num_classes = logits.shape[1]
    if num_classes == 1:
        true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)]
        true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
        true_1_hot_f = true_1_hot[:, 0:1, :, :]
        true_1_hot_s = true_1_hot[:, 1:2, :, :]
        true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1)
        pos_prob = torch.sigmoid(logits)
        neg_prob = 1 - pos_prob
        probas = torch.cat([pos_prob, neg_prob], dim=1)
    else:
        true_1_hot = torch.eye(num_classes)[true.squeeze(1)]
        true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
        probas = F.softmax(probas, dim=1)
    true_1_hot = true_1_hot.type(logits.type())
    dims = (0,) + tuple(range(2, true.ndimension()))
    intersection = torch.sum(probas * true_1_hot, dims)
    cardinality = torch.sum(probas + true_1_hot, dims)
    union = cardinality - intersection
    jacc_loss = (intersection / (union + eps)).mean()
    return (1 - jacc_loss)

#### Build Neural Network

In [7]:
class U_net(nn.Module):
    def __init__(self):
        super(U_net, self).__init__()
        
        ### Start downward path:
        # Conv Block 1 - Down 1
        self.conv1_block = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1, stride=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1, stride=1),
            nn.ReLU(inplace=True),
        )
        self.max1 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Conv Block 2 - Down 2
        self.conv2_block = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1, stride=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1, stride=1),
            nn.ReLU(inplace=True),
        )
        self.max2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Conv Block 3 - Down 3
        self.conv3_block = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1, stride=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1, stride=1),
            nn.ReLU(inplace=True),
        )
        self.max3 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Conv Block 4 - Down 4
        self.conv4_block = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1, stride=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1, stride=1),
            nn.ReLU(inplace=True),
        )
        self.max4 = nn.MaxPool2d(kernel_size=2, stride=2)

         # Conv Block 5 - Down 5 ((Bottom of the network))
        self.conv5_block = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1, stride=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1, stride=1),
            nn.ReLU(inplace=True),
        )
        
        ## Start upwards path
        
        # Up 1
        self.up_1 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2)
        
        # Upconvolution Block 1
        self.conv_up_1 = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1, stride=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1, stride=1),
            nn.ReLU(inplace=True),
        )
        
        # Up 2
        self.up_2 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2)
        
        # Upconvolution Block 2
        self.conv_up_2 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1, stride=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1, stride=1),
            nn.ReLU(inplace=True),
        )

        # Up 3
        self.up_3 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2)
        
        # Upconvolution Block 3
        self.conv_up_3 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1, stride=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1, stride=1),
            nn.ReLU(inplace=True),
        )
        
        # Up 4
        self.up_4 = nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=2, stride=2)
        
        # Upconvolution Block 4
        self.conv_up_4 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, padding=1, stride=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1, stride=1),
            nn.ReLU(inplace=True),
        )

        # Final output
        self.conv_final = nn.Conv2d(in_channels=32, out_channels=2, kernel_size=1, padding=0, stride=1)
        
    def forward(self,x):
        #print('input', x.shape)
        # Conv1 block (Down)
        x = self.conv1_block(x)
        #print('after conv1', x.shape)
        # Save output for future concatenation
        conv1_out = x
        conv1_dim_h = x.shape[2]
        conv1_dim_w = x.shape[3]
        # Max pooling
        x = self.max1(x)
        #print('before conv2', x.shape)
        
        # Conv2 block (Down)
        x = self.conv2_block(x)
        #print('after conv2', x.shape)
        # Save output for future concatenation
        conv2_out = x
        conv2_dim_h = x.shape[2]
        conv2_dim_w = x.shape[3]
        # Max pooling
        x = self.max2(x)
        #print('before conv3', x.shape)

        # Conv3 block (Down)
        x = self.conv3_block(x)
        #print('after conv3', x.shape)
        # Save output for future concatenation
        conv3_out = x
        conv3_dim_h = x.shape[2]
        conv3_dim_w = x.shape[3]
        # Max pooling
        x = self.max3(x)
        #print('before conv4', x.shape)

         # Conv4 block (Down)
        x = self.conv4_block(x)
        #print('after conv5', x.shape)
        # Save output for future concatenation
        conv4_out = x
        conv4_dim_h = x.shape[2]
        conv4_dim_w = x.shape[3]
        # Max pooling
        x = self.max4(x)
        #print('before conv6', x.shape)
        
        # Bottom of the network
        x = self.conv5_block(x)
        #print("At bottom of the network",x.shape)
        
        # Conv1 block (Up)
        x = self.up_1(x)
        #print('up_1', x.shape)
        lower_h = int((conv4_dim_h - x.shape[2])/2)
        upper_h = int((conv4_dim_h - lower_h))
        lower_w = int((conv4_dim_w - x.shape[3])/2)
        upper_w = int((conv4_dim_w - lower_w))
        conv4_out_modified = conv4_out[:,:,lower_h:upper_h,lower_w:upper_w]
        #print("Shape of conv4-out-mod",conv4_out_modified.shape)
        x = torch.cat([x,conv4_out_modified], dim=1)
        #print('after cat_1', x.shape)
        x = self.conv_up_1(x)
        #print('after conv_1', x.shape)

        # Conv2 block (Up)
        x = self.up_2(x)
        #print('up_2', x.shape)
        lower_h = int((conv3_dim_h - x.shape[2])/2)
        upper_h = int((conv3_dim_h - lower_h))
        lower_w = int((conv3_dim_w - x.shape[3])/2)
        upper_w = int((conv3_dim_w - lower_w))
        conv3_out_modified = conv3_out[:,:,lower_h:upper_h,lower_w:upper_w]
        #print("Shape of conv3-out-mod",conv3_out_modified.shape)
        x = torch.cat([x,conv3_out_modified], dim=1)
        #print('after cat_2', x.shape)
        x = self.conv_up_2(x)
        #print('after conv_2', x.shape)
        
        # Conv3 block (Up)
        x = self.up_3(x)
        #print('up_3', x.shape)
        lower_h = int((conv2_dim_h - x.shape[2])/2)
        upper_h = int((conv2_dim_h - lower_h))
        lower_w = int((conv2_dim_w - x.shape[3])/2)
        upper_w = int((conv2_dim_w - lower_w))
        conv2_out_modified = conv2_out[:,:,lower_h:upper_h,lower_w:upper_w]
        x = torch.cat([x,conv2_out_modified], dim=1)
        #print('after cat_3', x.shape)
        x = self.conv_up_3(x)
        #print('after conv_3', x.shape)
        
        # Conv4 block (Up)
        x = self.up_4(x)
        #print('up_4', x.shape)
        lower_h = int((conv1_dim_h - x.shape[2])/2)
        upper_h = int((conv1_dim_h - lower_h))
        lower_w = int((conv1_dim_w - x.shape[3])/2)
        upper_w = int((conv1_dim_w - lower_w))
        conv1_out_modified = conv1_out[:,:,lower_h:upper_h,lower_w:upper_w]
        x = torch.cat([x,conv1_out_modified], dim=1)
        #print('after cat_4', x.shape)
        x = self.conv_up_4(x)
        #print('after conv_4', x.shape)
               
        # Final 
        x = self.conv_final(x)
        #print('after final', x.shape)
        
        return x
        

   



#### Train Neural Network

In [13]:
# Load model
model = U_net()
model = torch.nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count()))).cuda()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.RMSprop(model.module.parameters(), lr=0.001)

In [None]:
for epoch in range(2):
    running_loss = 0.0
    for i, data in enumerate(train_data_load,0):
        inputs, labels = data
        optimizer.zero_grad()
        
        outputs = model(inputs.float())
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        print(outputs.shape)