In [None]:
import os

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

import imageio

import torch
from torch.utils import data

In [None]:
root = '../input/'

In [None]:
print(os.listdir(root))

In [None]:
class TGSSaltDataset(data.Dataset):
    
    def __init__(self, root, files):
        self.root = root
        self.files = files
        
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, i):
        file = self.files[i]
        
        image_path = os.path.join(self.root, 'images', '{}.png'.format(file))
        image = np.array(imageio.imread(image_path), dtype=np.uint8)
        
        mask_path = os.path.join(self.root, 'masks', '{}.png'.format(file))
        mask = np.array(imageio.imread(mask_path), dtype=np.uint8)
        
        return image, mask

In [None]:
train_masks = pd.read_csv(os.path.join(root, 'train.csv'))
train_masks.head()

In [None]:
depths = pd.read_csv(os.path.join(root, 'depths.csv'))
depths.head()

In [None]:
files = list(train_masks['id'].values)
dataset = TGSSaltDataset(os.path.join(root, 'train'), files)

In [None]:
def plot2x2array(image, mask):
    fig, axes = plt.subplots(1, 2)
    
    axes[0].imshow(image)
    axes[1].imshow(mask)
    
    axes[0].grid()
    axes[1].grid()
    
    axes[0].set_title('Image')
    axes[1].set_title('Mask')

In [None]:
for i in range(5):
    image, mask = dataset[np.random.randint(0, len(dataset))]
    plot2x2array(image, mask)

In [None]:
ax = depths.plot.hist(bins=50)
ax.set_title('Distribution of Depth');

In [None]:
def rle_to_mask(rle, width, height):
    try:
        parts = [int(item) for part in rle.split(' ')]
        pairs = np.array(parts).reshape(-1, 2)
        
        image = np.zeros(width * height, dtype=np.uint8)
        
        for index, length in paris:
            index -= 1
            image[index:index + length] = 255
            
        image = image.reshape(width, height)
        image.transpose()
    
    except Exception:
        # When either rle is None or empty.
        image = np.zeros((width * height))
    
    return image

In [None]:
def salty(image):
    try:
        unique, counts = np.unique(image, return_counts=True)
        return counts[0] / 10201.0
    except Exception as e:
        return 0.0

In [None]:
train_masks['mask'] = train_masks['rle_mask'].apply(lambda x: rle_to_mask(x, 101, 101))
train_masks['salty'] = train_masks['mask'].apply(lambda x: salty(x))

In [None]:
merged = train_masks.merge(depths, how='left')
merged.head()

In [None]:
ax = merged.plot.scatter(x='salty', y='z')