# Imports

In [None]:
import numpy as np 
import pandas as pd
import os
import cv2

import matplotlib.pyplot as plt
from matplotlib import patches as patches

# Config

In [None]:
class_names = ["Fish", "Flower", "Gravel", "Sugar"]
colors = ['Blue', 'Red', 'Gray', 'Purple']

In [None]:
path = '../input/understanding_cloud_organization/'
path_train = path + 'train_images/'
path_test = path + 'test_images/'

# Image and mask functions

In [None]:
def np_resize(img, input_shape):
    """
    Reshape a numpy array, which is input_shape=(height, width), 
    as opposed to input_shape=(width, height) for cv2
    """
    height, width = input_shape
    return cv2.resize(img, (width, height))

def rle2mask(rle, input_shape):
    width, height = input_shape[:2]
    
    mask= np.zeros( width*height ).astype(np.uint8)
    
    array = np.asarray([int(x) for x in rle.split()])
    starts = array[0::2]
    lengths = array[1::2]

    current_position = 0
    for index, start in enumerate(starts):
        mask[int(start):int(start+lengths[index])] = 1
        current_position += lengths[index]
        
    return mask.reshape(height, width).T

def build_masks(rles, input_shape, reshape=None):
    depth = len(rles)
    if reshape is None:
        masks = np.zeros((*input_shape, depth))
    else:
        masks = np.zeros((*reshape, depth))
    
    for i, rle in enumerate(rles):
        if type(rle) is str:
            if reshape is None:
                masks[:, :, i] = rle2mask(rle, input_shape)
            else:
                mask = rle2mask(rle, input_shape)
                reshaped_mask = np_resize(mask, reshape)
                masks[:, :, i] = reshaped_mask
    
    return masks

def get_rles(rles_df, ID):
    image_df = rles_df[rles_df['ImageId'] == ID]
    rles = image_df['EncodedPixels'].values
    return rles

# Visualization functions

In [None]:
def bounding_box(img):
    rows = np.any(img, axis=1)
    cols = np.any(img, axis=0)
    rmin, rmax = np.where(rows)[0][[0, -1]]
    cmin, cmax = np.where(cols)[0][[0, -1]]

    return rmin, rmax, cmin, cmax

def plot_cloud(img_id):
    img = cv2.imread(os.path.join(path_train, img_id))
    img = cv2.resize(img, (525, 350))
    
    rles = get_rles(train_df, img_id)
    masks = build_masks(rles, (1400, 2100), (350,525))
    
    fig, ax = plt.subplots(nrows=1, ncols=1, sharey=True, figsize=(8,4))
    ax.imshow(img)
    
    for i in range(4):
        mask = masks[...,i]
        label = class_names[i]

        if np.sum(mask) != 0:
            kernel = np.ones((4,4),np.uint8)
            mask = cv2.dilate(mask,kernel,iterations = 2)
            mask = cv2.erode(mask,kernel,iterations = 2)
            num_component, component = cv2.connectedComponents(mask.astype(np.uint8))
            
            for j in range(1, num_component):
                rmin, rmax, cmin, cmax = bounding_box(component == j)
                bbox = patches.Rectangle((cmin,rmin),cmax-cmin,rmax-rmin,linewidth=3,edgecolor=colors[i],facecolor='none')
                ax.add_patch(bbox)
                ax.text(cmin, rmin, label, bbox=dict(fill=True, color=colors[i]))
                ax.text(cmin, rmin, label, bbox=dict(fill=True, color=colors[i]))
                ax.axis('off')
                
def plot_simple(img_id):
    img = cv2.imread(os.path.join(path_train, img_id))
    img = cv2.resize(img, (525, 350))
    
    rles = get_rles(train_df, img_id)
    masks = build_masks(rles, (1400, 2100), (350,525))
    
    fig, ax = plt.subplots(nrows=1, ncols=5, sharey=True, figsize=(20,4))
    ax[0].imshow(img)
    ax[0].axis('off')
    
    for i in range(4):
        mask = masks[...,i]
        ax[i+1].imshow(mask)
        ax[i+1].set_title(class_names[i])
        ax[i+1].axis('off')

# Make some visualizations

## Load data

In [None]:
train_df = pd.read_csv(path + 'train.csv')
train_df['ImageId'] = train_df['Image_Label'].apply(lambda x: x.split('_')[0])
train_df['ClassId'] = train_df['Image_Label'].apply(lambda x: x.split('_')[1])
train_df['hasMask'] = ~ train_df['EncodedPixels'].isna()

print(train_df.shape)
train_df.head()

## Simple visualization

In [None]:
for image_id in train_df['ImageId'].unique()[0:30]:
    plot_simple(image_id)

## Mask box visualization

In [None]:
for image_id in train_df['ImageId'].unique()[0:30]:
    plot_cloud(image_id)