In [None]:
# # setup for the colab

# import os
# os.environ['KAGGLE_USERNAME'] = "kirillfedyanin"
# os.environ['KAGGLE_KEY'] = ""
# !pip install imageio
# !pip install keras 
# !pip install kaggle

# !kaggle competitions download -c tgs-salt-identification-challenge
# !mkdir -p test
# !mkdir -p train
# !unzip test.zip -d test
# !unzip train.zip -d train

In [None]:
import os
import random

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import imageio
from skimage.transform import resize

In [None]:
HEIGHT, WIDTH = 101, 101

In [None]:
root_path = './'
train_path = os.path.join(root_path, "train")

def get_image(file_id):
    image_path = os.path.join(train_path, "images", file_id + '.png')
    image = np.array(imageio.imread(image_path), dtype=np.uint8)
    return image

def get_mask(file_id):
    mask_path = os.path.join(train_path, "masks", file_id + '.png')
    mask = np.array(imageio.imread(mask_path), dtype=np.uint8)
    return mask


In [None]:
train_values = pd.read_csv('train.csv')
file_list = list(train_values['id'])

# Reasearch
Do some data digging

In [None]:
def rle_to_mask(rle_string):
    if isinstance(rle_string, float) and np.isnan(rle_string):
        return np.zeros((HEIGHT, WIDTH)) 
    rle_numbers = [int(num) for num in rle_string.split()] 
    rle_pairs = np.array(rle_numbers).reshape((-1, 2))
    
    mask = np.zeros(HEIGHT*WIDTH)
    for start, length in rle_pairs:
        mask[start-1: start-1+length] = 255
    
    mask = mask.reshape((HEIGHT, WIDTH)).T
        
    return mask


In [None]:
# check if masks correctly oriented
for _ in range(15):
    i = random.randint(0, len(dataset)-1) 
    file_id = file_list[i]
    image, mask = get_image(file_id), get_mask(file_id)
    f, axarr = plt.subplots(1, 3)
    axarr[0].imshow(image)
    axarr[1].imshow(mask, cmap='gray')
    axarr[2].imshow(rle_to_mask(train_values['rle_mask'][i]), cmap='gray')
    print(i, 'is correct: ', (mask==rle_to_mask(train_values['rle_mask'][i])).all())
    

In [None]:
depths = pd.read_csv("depths.csv")

train_values['depths'] = depths['z']

plt.figure(figsize=(6, 6))
plt.hist(train_values['depths'], bins=50)


In [None]:
norm = HEIGHT * WIDTH * 255.0
def salt_concentration(mask):
    return np.sum(mask)/norm

masks = [get_mask(file_id) for file_id in train_values['id']]
train_values['salt_concentration'] = [salt_concentration(mask) for mask in masks]

In [None]:
train_val = train_values.merge(depths, how='left')

In [None]:
plt.figure(figsize=(12, 6))
plt.scatter(train_val['salt_concentration'], train_val['depths'])
plt.title("Depths vs salt concentration")

# Model training itself

**what to do**
- dropout
- model saving
- postprocess
- unet connections


In [None]:
from keras.layers import Input, Dense, Conv2D, UpSampling2D, MaxPooling2D, concatenate, ZeroPadding2D, Cropping2D
from keras.models import Model

def salt_detector():
    common_atr = {'activation': 'relu', 'padding': 'same'}
    
    input_image = Input(shape=(101, 101, 1))
    x = ZeroPadding2D(((0, 27), (0, 27)))(input_image)
    conv1 = Conv2D(16, (3, 3), **common_atr)(x)
    conv1 = Conv2D(16, (3, 3), **common_atr)(conv1)
    max1 = MaxPooling2D((2, 2), padding='same')(conv1)
    
    conv2 = Conv2D(32, (3, 3), **common_atr)(max1)
    conv2 = Conv2D(32, (3, 3), **common_atr)(conv2)
    max2 = MaxPooling2D((2, 2), padding='same')(conv2)
    
    conv3 = Conv2D(64, (3, 3), **common_atr)(max2)
    conv3 = Conv2D(64, (3, 3), **common_atr)(conv3)
    encoded = MaxPooling2D((2, 2), padding='same')(conv3)
    
    conv10 = Conv2D(64, (3, 3), **common_atr)(encoded)
    conv10 = Conv2D(64, (3, 3), **common_atr)(conv10)
    up10 = UpSampling2D((2, 2))(conv10)
    
    conv11 = Conv2D(32, (3, 3), **common_atr)(up10)
    conv11 = Conv2D(32, (3, 3), **common_atr)(conv11)
    up11 = UpSampling2D((2, 2))(conv11)
    
    conv12 = Conv2D(16, (3, 3), **common_atr)(up11)
    conv12 = Conv2D(16, (3, 3), **common_atr)(conv12)
    up12 = UpSampling2D((2, 2))(conv12)
    
    conv13 = conv12 = Conv2D(16, (3, 3), **common_atr)(up12)
    conv13 = conv12 = Conv2D(16, (3, 3), **common_atr)(conv13)
    decoded = Conv2D(1, (3, 3), activation='sigmoid', padding='same')(conv13)
    decoded_cropped = Cropping2D(((0, 27), (0, 27)))(decoded)
    
    autoencoder = Model(input_image, decoded_cropped)
    autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy')
    
    return autoencoder

detector = salt_detector()
detector.summary()


In [None]:
def prepare(images):
    images = np.stack(images)
    images = images[:, :, :, :1].astype('float32') / 255.
    return images

In [None]:
labels = prepare([get_mask(file_id)[:,:,np.newaxis] for file_id in file_list])
images = prepare([get_image(file_id) for file_id in file_list])

In [None]:
val_size = 512
images_train, images_val = images[:-val_size], images[-val_size:]
labels_train, labels_val = labels[:-val_size], labels[-val_size:]

In [None]:
detector.fit(images_train, labels_train, epochs=10, batch_size=64, shuffle=True, validation_data=(images_val, labels_val) )

In [None]:
for i in range(10):
    image = images[i]
    label = labels[i]
    prediction = detector.predict(images[i:i+1])[0]
    _, axarr = plt.subplots(1, 3)
    axarr[0].imshow(image[:, :, 0], cmap='gray')
    axarr[1].imshow(label[:, :, 0], cmap='gray', vmin=0, vmax=1)
    axarr[2].imshow(prediction[:, :, 0], cmap='gray', vmin=0, vmax=1)
