In [None]:
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!unzip -qq '/content/drive/My Drive/Colab Notebooks/Glaucoma detection/Data/BEH.zip'

In [None]:
!pip install git+https://github.com/karolzak/keras-unet

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import glob
import os
import sys
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.models as models
from torchvision.utils import make_grid
from torch.utils.data import Dataset, random_split, DataLoader
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
import random
from torchvision.utils import make_grid
from PIL import Image
from keras_unet.utils import plot_imgs
from sklearn.model_selection import train_test_split
from keras_unet.models import custom_unet
from keras.callbacks import ModelCheckpoint
from keras.optimizers import Adam, SGD
from keras_unet.metrics import iou, iou_thresholded
from keras_unet.losses import jaccard_distance
from keras_unet.utils import plot_imgs, plot_segm_history

In [None]:
# Load FAU dataset

orgs = glob.glob("/content/FAU/training/original/*")
masks = glob.glob("/content/FAU/training/mask/*")
size = 512

imgs_list = []
masks_list = []
for image, mask in zip(orgs, masks):
    imgs_list.append(np.array(Image.open(image).resize((size,size)))[:,:,1])
    im = Image.open(mask).resize((512,512))
    masks_list.append(np.array(im))

imgs_np = np.asarray(imgs_list)
masks_np = np.asarray(masks_list)

print('Original Images:', imgs_np.shape, ' Ground Truth images:', masks_np.shape)
# plot_imgs(org_imgs=imgs_np, mask_imgs=masks_np, nm_img_to_plot=10, figsize=6)

In [None]:
dataset_glaucoma = glob.glob("/content/BEH/Train/glaucoma/*.jpg")
dataset_normal = glob.glob("/content/BEH/Train/normal/*.jpg")

dataset = []
for image in dataset_glaucoma:
    dataset.append(np.array(Image.open(image).resize((size,size)))[:,:,1])
for image in dataset_normal:
    dataset.append(np.array(Image.open(image).resize((size,size)))[:,:,1])

dataset_np = np.asarray(dataset)
dataset_x = np.asarray(dataset_np, dtype=np.float32)/255
dataset_x = dataset_x.reshape(dataset_x.shape[0], dataset_x.shape[1], dataset_x.shape[2], 1)
print('Dataset:', dataset_x.shape)
plot_imgs(org_imgs=dataset_np, mask_imgs=masks_np, nm_img_to_plot=10, figsize=6)

In [None]:
# Get data into correct shape, dtype and range (0.0-1.0)
print(imgs_np.max(), masks_np.max())
x = np.asarray(imgs_np, dtype=np.float32)/255
y = np.asarray(masks_np, dtype=np.float32)/255
print(x.max(), y.max())
print(x.shape, y.shape)
y = y.reshape(y.shape[0], y.shape[1], y.shape[2], 1)
x = x.reshape(x.shape[0], x.shape[1], x.shape[2], 1)
print(x.shape, y.shape)

x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.1, random_state=0)

print("x_train: ", x_train.shape)
print("y_train: ", y_train.shape)
print("x_val: ", x_val.shape)
print("y_val: ", y_val.shape)

from keras_unet.utils import get_augmented

train_gen = get_augmented(
    x_train, y_train, batch_size=8,
    data_gen_args = dict(
        rotation_range=5.,
        width_shift_range=0.05,
        height_shift_range=0.05,
        shear_range=40,
        zoom_range=0.2,
        horizontal_flip=True,
        vertical_flip=False,
        fill_mode='constant'
    ))

sample_batch = next(train_gen)
xx, yy = sample_batch
print(xx.shape, yy.shape)
from keras_unet.utils import plot_imgs

# Plot Dataset and Masks
plot_imgs(org_imgs=xx, mask_imgs=yy, nm_img_to_plot=3, figsize=6)

# Initialize network
input_shape = x_train[0].shape

model = custom_unet(
    input_shape,
    filters=32,
    use_batch_norm=True,
    dropout=0.3,
    dropout_change_per_layer=0.0,
    num_layers=4
)

model_filename = 'segm_model_v3.h5'
callback_checkpoint = ModelCheckpoint(
    model_filename, 
    verbose=1, 
    monitor='val_loss', 
    save_best_only=True,
)

model.compile(
    optimizer=Adam(), 
    # optimizer=SGD(lr=0.01, momentum=0.99),
    loss='binary_crossentropy',
    #loss=jaccard_distance,
    metrics=[iou, iou_thresholded]
)


In [None]:
history = model.fit_generator(
    train_gen,
    steps_per_epoch=200,
    epochs=3,
    validation_data=(x_val, y_val),
    callbacks=[callback_checkpoint]
)

In [None]:
plot_segm_history(history)

In [None]:
# Segment Training data
model.load_weights(model_filename)
y_pred = model.predict(x_val)
y_pred = np.moveaxis(y_pred, -1, 1)
plot_imgs(org_imgs=x_val, mask_imgs=y_val, pred_imgs=y_pred, nm_img_to_plot=8)

# Segment dataset
dataset_y_pred = model.predict(dataset_x)
plot_imgs(org_imgs=dataset_x, mask_imgs=dataset_y_pred, pred_imgs=dataset_y_pred, nm_img_to_plot=8)
dataset_x = np.moveaxis(dataset_x, -1, 1)
dataset_y_pred = np.moveaxis(dataset_y_pred, -1, 1)
print(dataset_x.shape, dataset_y_pred.shape)

import torch
x = torch.Tensor(dataset_y_pred)

In [None]:
from torchvision.utils import save_image
from pathlib import Path

for i in range(len(dataset_glaucoma)):
    output = x[i][0]
    out_dir = Path('/content/ORIGA_af/glaucoma')
    out_filename = str(i) + '_BEH.jpg'
    output_name = out_dir.joinpath(out_filename)
    save_image(output, output_name, padding=0)

for i in range(len(dataset_glaucoma), len(x)):
    output = x[i][0]
    out_dir = Path('/content/ORIGA_af/normal')
    out_filename = str(i) + '_BEH.jpg'
    output_name = out_dir.joinpath(out_filename)
    save_image(output, output_name, padding=0)

In [None]:
# Zip segmented dataset
!zip -r -j BEH '/content/BEH/'