# Completing mid-ocean ridge faults detection using semi-supervised learning

![](content/image_faults.jpg)

In [None]:
import numpy as np
import scipy.io
import matplotlib.pyplot as plt
import holoviews as hv
hv.extension('bokeh')
from skimage.morphology import binary_dilation
from utils import dataSplit
from model import unet, mean_iou
from ipywidgets import interact, widgets

## Given dataset

In [None]:
bat = scipy.io.loadmat("Bathy.mat")['IMG2']
fault = scipy.io.loadmat("Fault_Bool.mat")['Fault_Bool']
fault = binary_dilation(fault, selem=np.ones((10,10)))

In [None]:
bat_mean = np.nanmean(bat)
bat_std = np.nanstd(bat)
    
x = dataSplit(bat, 256)
y = dataSplit(fault, 256)
x_all = np.copy(x)
y_all = np.copy(y)

idx = list(map(lambda x: not np.all(x), np.isnan(x).sum(axis=(1,2,3)) > (256**2) / 2))
x = x[idx]
y = y[idx]

y_mean = y.mean(axis=(1, 2, 3))
x = x[y_mean>0.05]
y = y[y_mean>0.05]

x_mean = np.nanmean(x, axis=(1, 2))[:, None, None, :]
x_std = np.nanstd(x, axis=(1, 2))[:, None, None, :]
x = (x - x_mean)/x_std
x[np.isnan(x)] = 0

x_all_mean = np.nanmean(x_all, axis=(1, 2))[:, None, None, :]
x_all_std = np.nanstd(x_all, axis=(1, 2))[:, None, None, :]
x_all = (x_all - x_all_mean)/x_all_std
x_all[np.isnan(x_all)] = 0

In [None]:
@interact(index=widgets.IntSlider(min=0, max=len(x), step=1, value=2))
def display(index):
    fig, [ax1, ax2, ax3] = plt.subplots(nrows=1, ncols=3, figsize=(18, 5))
    cm1 = ax1.pcolormesh(np.squeeze(x[index]), vmin=-2, vmax=2)
    fig.colorbar(cm1, ax=ax1)
    ax1.set_title("Sea levels")

    cm2 = ax2.pcolormesh(np.squeeze(y[index]), vmin=0, vmax=1)
    fig.colorbar(cm2, ax=ax2)
    ax2.set_title("Labeled fails")

    ax3.hist(bat.ravel())
    ax3.set_title("Sea levels distribution")

    plt.show()

## Training the model

![](content/u-net-architecture.png)

In [None]:
input_shape = (256, 256, 1)

model = unet(input_shape=input_shape)
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=[mean_iou])
model.fit(x, y, epochs=30, batch_size=4, validation_split=0.2, verbose=0)
y_pred = model.predict(x_all)

## Evaluating predictions

In [None]:
y_pred_all = np.squeeze(y_pred.reshape((8, 15, 256, 256, 1)))

y_pred_all = np.block([[y_pred_all[i, j] for i in range(8)] for j in range(15)])

In [None]:
y_all = np.squeeze(y_all.reshape((8, 15, 256, 256, 1)))

y_all = np.block([[y_all[i, j] for i in range(8)] for j in range(15)])

In [None]:
hv.Image(bat.T)

In [None]:
hv.Image(np.squeeze(y_all).T) + hv.Image(np.squeeze(y_pred_all).T) + hv.Image(np.squeeze(y_pred_all).T>0.5)