# Preprocess mask

Combine the multiple labels into 10 categories (for now):
- hard coral
- hard coral bleached
- dead coral
- other invertebrates
- sand/rubble
- other
- algae
- seagrass
- unknown
- no label

## Imports

In [1]:
# load custom scripts
from preprocess_mask import *

# import the necessary packages
from imutils import paths
from skimage import io
import cv2
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import plotly.express as px
import time
import warnings
warnings.filterwarnings('ignore')

## Load masks

In [2]:
# load mask filepaths in a sorted manner
maskPaths = sorted(list(paths.list_images("/data/jantina/CoralNet/dataset/pre_labels_dense")))

## New label classes

In [4]:
classes = {1: 'hard coral',
           2: 'hard coral bleached',
           3: 'dead coral',
           4: 'other invertebrates',
           5: 'sand/rubble',
           6: 'other',
           7: 'algae',
           8: 'seagrass',
           9: 'unknown',
           0: 'no label'
          }

### Merging labels

In [5]:
hard_coral = [1,3,5,7,8,9,13,16,19,29,32,34,36,40,42,
              46,49,52,54,55,57,59,64,66,70,74,76,81,
              83,89,90,91,93,95,97,99,100,101,103,105,
              106,108,110,111,113,114,122,135,136,138,
              140,141,142,143,144,148,150,155,156,157,
              158,159,160,161,162,163,164,165,166,167,
              168,169,170,171,172,173,174,175,176,177,
              178,179,180,181,182,183]

hard_coral_bleached = [2,4,6,10,14,15,17,30,33,41,43,47,50,
                       53,56,58,60,65,67,71,75,77,82,92,94,
                       96,98,102,104,107,109,112,115,123,139]

dead_coral = [37,38,116,117,152]

other_invertebrates = [11,12,18,28,44,45,63,68,69,72,73,80,
                       88,119,124,125,127,128,129,132,134,137,
                       184,185,186,187,188,189,190,191,192,193,
                       194,195,196,199,200]

sand_rubble = [79,118,120,145,149,151,201]

other = [35,48,51,86,126,153]

algae = [20,21,22,23,24,25,26,27,31,39,
         61,62,78,84,87,121,130,131,133,
        146,147,154,197,198,202,203]

seagrass = [204,205,206,207]

unknown = [85]

no_label = [255]

## Creating new masks and saving to disk

In [6]:
startTime = time.time()

for masks in maskPaths:
    mask = io.imread(masks)
    new_mask = merge_mask(mask)
    maskPath = '/data/jantina/CoralNet/dataset/labels_dense/' + masks.split(os.path.sep)[-1]
    cv2.imwrite(maskPath, new_mask)
    
endTime = time.time()
print("[INFO] total time taken to write the new masks: {:.2f}s".format(endTime - startTime))

[INFO] total time taken to write the new masks: 1775.07s


## Class representation

### Number of labels per class

In [None]:
labels_per_class = {'name': list(classes.values()),
                    'number of labels': [len(hard_coral),len(hard_coral_bleached),
                                         len(dead_coral),len(other_invertebrates),
                                         len(sand_rubble),len(other),len(algae),
                                         len(seagrass),len(unknown),len(no_label)]}

In [None]:
fig = px.bar(pd.DataFrame(labels_per_class), x ='name', y ='number of labels')
fig.show()

### Number of pixels per class

In [None]:
count = pd.DataFrame()

new_maskPaths = sorted(list(paths.list_images("/data/jantina/CoralNet/dataset/small/labels")))
for masks in new_maskPaths:
    mask = io.imread(masks)
    unique, counts = np.unique(mask, return_counts=True)
    count = count.append(pd.DataFrame(dict(zip(unique, counts)).items(),
                                      columns=['label', 'pixel count']),
                         ignore_index = True)
    

In [None]:
df = pd.DataFrame(count.groupby('label')['pixel count'].sum()).reset_index()

In [None]:
for old, new in classes.items():
    df.label[df.label == old] = new

In [None]:
fig = px.bar(df, x = 'label', y = 'pixel count')
fig.show()

## Class weighting

In [None]:
weights = 1. / df["pixel count"]
weights

## Expand GT pixels to 3x3 and make the masks boolean

In [None]:
maskPaths = sorted(list(paths.list_images("/data/jantina/CoralNet/dataset/small/masks")))
startTime = time.time()

for masks in maskPaths:
    mask = io.imread(masks)
    new_mask = expand_pixels(mask, masks)
    maskPath = '/data/jantina/CoralNet/dataset/small/masks_3x3/' + masks.split(os.path.sep)[-1]
    cv2.imwrite(maskPath, new_mask)
    
endTime = time.time()
print("[INFO] total time taken to write the new masks: {:.2f}s".format(endTime - startTime))