In [None]:
import matplotlib.pyplot as plt
from PIL import Image

import random
from pathlib import Path

# Loading the dataset

In [None]:
# Load the DRIVE dataset
dataset = Path('data/DRIVE')

# Load the training dataset
train_images = sorted(dataset.glob('training/images/*.tif'))
train_labels = sorted(dataset.glob('training/1st_manual/*.gif'))
train_mask = sorted(dataset.glob('training/mask/*.gif'))

In [None]:
len(train_images), train_images[:3], train_labels[:3], train_mask[:3]

In [None]:
# Load the test dataset
test_images = sorted(dataset.glob('test/images/*.tif'))
test_mask = sorted(dataset.glob('test/mask/*.gif'))
len(test_images), test_images[:3], test_mask[:3]

# Visualizing the dataset

In [None]:
# Display a sample image, mask and label
image = Image.open(train_images[0])
mask = Image.open(train_mask[0])
label = Image.open(train_labels[0])

plt.figure(figsize=(12, 8))
plt.subplot(131)
plt.imshow(image)
plt.title(f'Image {image.size}, {image.mode}')
plt.subplot(132)
plt.imshow(mask, cmap='gray')
plt.title(f'Mask {mask.size}, {mask.mode}')
plt.subplot(133)
plt.imshow(label, cmap='gray')
plt.title(f'Label {label.size}, {label.mode}')
plt.show()


In [None]:
# Displaying the image in R,G,B channels
red, green, blue = image.split()

plt.figure(figsize=(16, 8))
plt.subplot(141)
plt.imshow(image)
plt.subplot(142)
plt.imshow(red)
plt.title('red')
plt.subplot(143)
plt.imshow(green)
plt.title('green')
plt.subplot(144)
plt.imshow(blue)
plt.title('blue')
plt.show()

# Implementing the preprocessing method

In [None]:
# Preprocessing training images

# 1. take only the green channel
# 2. apply morphological opening with a three-pixel diameter disk structuring element
# 3. The local background gray level is computed by applying a 69×69 mean filter to the image. The
# background is then subtracted and the resulting gray levels are scaled from 0 to 1.
# 4. a constant is added to the image gray levels so the mode gray level value in image is set to 0.5
# 5.  top-hat transformation on the complement of the image using an eight-pixel radius
# disk as the structuring element
# 

import numpy as np
from skimage import morphology
from skimage import exposure
from skimage import filters
from skimage import img_as_float
from skimage import transform

def preprocess(image, mask):
    # Convert the image and mask to float32 tensors
    image = img_as_float(image)
    mask = img_as_float(mask)
    
    # Take only the green channel
    image = image[:, :, 1]
    
    # Apply morphological opening with a 3-pixel disk structuring element
    selem = morphology.disk(3)
    image = morphology.opening(image, selem)
    
    # Compute the local mean of the image
    local_mean = filters.rank.mean(image, selem)
    
    # Subtract the local mean from the image
    image = image - local_mean
    
    # Scale the image so that its values range from 0 to 1
    image = exposure.rescale_intensity(image)
    
    # Add a constant to the image so that its minimum value is 0
    image = image - image.min()
    
    # Normalize the image so its values sum to 1
    image = image / image.sum()
    
    # Apply a top-hat transformation to the image
    selem = morphology.disk(8)
    image = morphology.white_tophat(image, selem)
    
    # Normalize the image so its values sum to 1
    image = image / image.sum()
    
    # Apply the mask to the image
    image = image * mask
    
    # # Convert the image and the mask to PyTorch tensors
    # image = torch.from_numpy(image).unsqueeze(0)
    # mask = torch.from_numpy(mask).unsqueeze(0)
    
    return image

In [None]:
p_image = preprocess(image, mask)
print(p_image.shape)

plt.figure(figsize=(16, 8))
plt.subplot(121)
plt.title(f'Preprocessed image {p_image.shape}')
plt.imshow(p_image, cmap='gray')
plt.subplot(122)
plt.title(f'ground truth {label.size}')
plt.imshow(label, cmap='gray')
plt.show()

In [None]:
# Feature extraction from preprocessed image
from numpy.lib.stride_tricks import sliding_window_view

windows = sliding_window_view(p_image, (9, 9))
windows.shape



In [None]:
label_windows = sliding_window_view(label, (9, 9))
label_windows.shape

In [None]:
# visualizing the windows randomly

plt.figure(figsize=(80, 16))
for i in range(40):
    xidx = random.randint(0, windows.shape[0])
    yidx = random.randint(0, windows.shape[1])

    if i < 20:
        plt.subplot(4, 20, i+1)
        plt.imshow(windows[xidx, yidx], cmap='gray')

        plt.subplot(4, 20, i+21)
        plt.imshow(label_windows[xidx, yidx], cmap='gray')

        if i == 0:
            plt.title(f'window {xidx}, {yidx}')
            plt.title(f'label {xidx}, {yidx}')

    if i >= 20:
        plt.subplot(4, 20, i-20+41)
        plt.imshow(windows[xidx, yidx], cmap='gray')

        plt.subplot(4, 20, i-20+61)
        plt.imshow(label_windows[xidx, yidx], cmap='gray')

        if i == 20:
            plt.title(f'window {xidx}, {yidx}')
            plt.title(f'label {xidx}, {yidx}')

plt.show()

# Preparing the dataset

In [None]:
# Generating the features of the image
# for each image in the dataset, implement the below steps
# 1. take the image and mask and preprocess the image
# 2. Compute the following features on the preprocessed image, in a 9x9 window around each pixel in the image.
#   a. intensity of center pixel
#  b. absolute difference between the intensity of center pixel and min, max and mean intensity of the window
# c. standard deviation of the window

def extract_features(p_image):
    # Compute the window
    windows = sliding_window_view(p_image, (9, 9)).copy()

    # pixel centers
    centers = windows[:, :, 4, 4]

    # min, max, mean, and standard deviation of each window
    w_min, w_max, w_mean, w_std = windows.min(axis=(2, 3)), windows.max(axis=(2, 3)), windows.mean(axis=(2, 3)), windows.std(axis=(2, 3))

    # features
    features = np.stack([centers, np.abs(centers - w_min), np.abs(centers - w_max), np.abs(centers - w_mean), w_std], axis=2)
    
    return features


In [None]:
def extract_labels(label):
    # Compute the window
    windows = sliding_window_view(label, (9, 9)).copy()

    # pixel centers
    centers = windows[:, :, 4, 4]

    # labels
    labels = (centers > 0).astype(np.uint8)
    
    return labels

In [None]:
label_values = extract_labels(label)
label.size, label_values.shape

In [None]:
for image_path, mask_path, label_path in zip(train_images, train_mask, train_labels):
    print(image_path, mask_path, label_path)

In [None]:
train_features = []

for image_path, mask_path in zip(train_images, train_mask):
    image = Image.open(image_path)
    mask = Image.open(mask_path)

    p_image = preprocess(image, mask)
    print(p_image.shape)
    features = extract_features(p_image)
    print(features.shape)
    train_features.append(features)
    print(f'Extracted {features.shape} features from {image_path.name}')


In [None]:
train_y = []
for label_path in train_labels:
    labels = extract_labels(label)
    print(labels.shape)
    train_y.append(labels)
    print(f'Extracted {labels.shape} labels from {label_path.name}')

In [None]:
train_set = np.stack(train_features), np.stack(train_y)
train_set[0].shape, train_set[1].shape

In [None]:
# for all images in test set, extract the features
test_features = []
for image_path, mask_path in zip(test_images, test_mask):
    image = Image.open(image_path)
    mask = Image.open(mask_path)

    p_image = preprocess(image, mask)
    features = extract_features(p_image)
    test_features.append(features)
    print(f'Extracted {features.shape} features from {image_path.name}')

test_set = np.stack(test_features)
test_set.shape

In [None]:
# Save and load the train_set and test_set

np.savez_compressed('data/DRIVE/train_set.npz', X=train_set[0], y=train_set[1])
np.savez_compressed('data/DRIVE/test_set.npz', X=test_set)

# Sampling postive and negative examples from train_set

In [None]:
# load the train_set and test_set
train_set = np.load('data/DRIVE/train_set.npz', allow_pickle=True)
test_set = np.load('data/DRIVE/test_set.npz', allow_pickle=True)

In [None]:
train_set['X'].shape, train_set['y'].shape

In [None]:
np.unique(train_set['y'], return_counts=True)

In [None]:
plt.hist(train_set['y'].ravel())

In [None]:
# Positive samples
pos_samples = train_set['X'][train_set['y'] == 1]

In [None]:
# Variance of negative samples

neg_samples = train_set['X'][train_set['y'] == 0]
neg_samples.shape


In [None]:
# plotting the std of negative samples
plt.figure(figsize=(12, 8))
plt.hist(neg_samples[:, -1].ravel(), bins=100)
plt.title('std of negative samples')
plt.show()

In [None]:
# sampling the negative samples equal to the number of positive samples
sample_neg_idx = np.random.choice(np.arange(neg_samples.shape[0]), size=train_set['y'].sum())
sample_neg = neg_samples[sample_neg_idx]
sample_neg.shape


In [None]:
# plotting the std of negative samples
plt.figure(figsize=(12, 8))
plt.hist(sample_neg[:, -1].ravel(), bins=100)
plt.title('std of negative samples')
plt.show()

In [None]:
# balanced dataset = positive samples + negative samples
balanced_x = np.concatenate([pos_samples, sample_neg], axis=0)
balanced_y = np.concatenate([np.ones(pos_samples.shape[0]), np.zeros(sample_neg.shape[0])], axis=0)
balanced_x.shape, balanced_y.shape

In [None]:
# feature scaling of balanced_x
from sklearn.preprocessing import StandardScaler

scaler = StandardScaler()
balanced_x = scaler.fit_transform(balanced_x.reshape(-1, 5)).reshape(balanced_x.shape)
balanced_x.shape

In [None]:
balanced_y.sum(), np.unique(balanced_y, return_counts=True) 

In [None]:
np.savez_compressed('data/DRIVE/balanced_train_set.npz', X=balanced_x, y=balanced_y)

# Training the Neural Network using MLP

In [31]:
import torch
import numpy as np

In [32]:
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
device = 'cpu'

In [33]:
# Pytorch dataloader for balanced data
from torch.utils.data import Dataset

class DRIVE(Dataset):
    def __init__(self, M=30000):
        super().__init__()
        balanced_set = np.load('data/DRIVE/balanced_train_set.npz')
        samples = np.random.choice(np.arange(balanced_set['X'].shape[0]), size=M)
        self.X = balanced_set['X'][samples].astype(np.float32)
        self.y = balanced_set['y'].reshape(-1,1)[samples].astype(np.float32)

        print('Loaded the dataset', self.X.shape, self.X.dtype, self.y.shape, self.y.dtype)
        
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [34]:
# MLP with 5 inputs, three hidden layers with 15 nodes each, and one output

import torch.nn as nn
import torch.nn.functional as F

class MLP(nn.Module):
    def __init__(self, input_size=5, hidden_size=[15,15,15], output_size=1, dropout=0.5):
        super().__init__()
        self.p_dropout = dropout
        self.fc1 = nn.Linear(input_size, hidden_size[0])
        for i in range(1, len(hidden_size)):
            setattr(self, f'fc{i+1}', nn.Linear(hidden_size[i-1], hidden_size[i]))
        self.fc4 = nn.Linear(hidden_size[-1], output_size)
    
    def forward(self, x):
        x = x.view(x.shape[0], -1)
        x = F.relu6(self.fc1(x))
        x = F.relu6(self.fc2(x))
        x = F.relu6(self.fc3(x))
        x = torch.sigmoid(self.fc4(x))
        return x

In [35]:
model = MLP().to(device)
optimizer = torch.optim.Adam(model.parameters())
criterion = nn.BCELoss()

print(model)

MLP(
  (fc1): Linear(in_features=5, out_features=15, bias=True)
  (fc2): Linear(in_features=15, out_features=15, bias=True)
  (fc3): Linear(in_features=15, out_features=15, bias=True)
  (fc4): Linear(in_features=15, out_features=1, bias=True)
)


In [36]:
from torch.utils.data import DataLoader

epochs = 10
BATCH_SIZE = 1024
lr = 0.01

train_set = DRIVE()
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)

model.train()
for epoch in range(epochs):
    losses = []
    for batch_num, input_data in enumerate(train_loader):
        optimizer.zero_grad()
        x, y = input_data
        x = x.to(device).float()
        y = y.to(device)

        output = model(x)
        loss = criterion(output, y)
        loss.backward()
        losses.append(loss.item())

        optimizer.step()

        if batch_num % 40 == 0:
            print('\tEpoch %d | Batch %d | Loss %6.2f' % (epoch, batch_num, loss.item()))
    print('Epoch %d | Loss %6.2f' % (epoch, sum(losses)/len(losses)))

Loaded the dataset (30000, 5) float32 (30000, 1) float32
	Epoch 0 | Batch 0 | Loss   0.69
Epoch 0 | Loss   0.69
	Epoch 1 | Batch 0 | Loss   0.68
Epoch 1 | Loss   0.68
	Epoch 2 | Batch 0 | Loss   0.67
Epoch 2 | Loss   0.66
	Epoch 3 | Batch 0 | Loss   0.66
Epoch 3 | Loss   0.64
	Epoch 4 | Batch 0 | Loss   0.64
Epoch 4 | Loss   0.63
	Epoch 5 | Batch 0 | Loss   0.62
Epoch 5 | Loss   0.62
	Epoch 6 | Batch 0 | Loss   0.60
Epoch 6 | Loss   0.61
	Epoch 7 | Batch 0 | Loss   0.63
Epoch 7 | Loss   0.60
	Epoch 8 | Batch 0 | Loss   0.60
Epoch 8 | Loss   0.60
	Epoch 9 | Batch 0 | Loss   0.59
Epoch 9 | Loss   0.59
