### Settings

In [None]:
from IPython.core.display import display, HTML
import sys,cv2,gc
sys.path.append('../')
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import pandas
from Utils.utils import *
from ipywidgets import interact
import deepdish as dd
from skimage import io, transform

%matplotlib inline
display(HTML("<style>.container { width:100% !important; }</style>"))

%load_ext autoreload
%autoreload 2

from jupyterthemes import jtplot
# set "context" (paper, notebook, talk, poster)
jtplot.style(theme='grade3',context='talk', fscale=2.5, spines=True, gridlines='',ticks=True, grid=True, figsize=(6, 4.5))
plotcolor = (0, 0.6, 1.0)

# Load augmentation stuff
from albumentations import (
    Flip, OneOf, RandomContrast, RandomBrightness, RandomRotate90, CLAHE,
    IAASharpen, HueSaturationValue,IAAAdditiveGaussianNoise, MedianBlur, GaussNoise,
    IAAPiecewiseAffine, OneOf, Compose, Transpose,MotionBlur, Blur
)
def augment(p=1):
    return Compose([RandomRotate90(),Flip(),Transpose(),
        OneOf([IAAAdditiveGaussianNoise(),GaussNoise(),], p=0.25),
        OneOf([MotionBlur(p=.2),MedianBlur(blur_limit=3, p=.1),Blur(blur_limit=3, p=.1),], p=0.25),
        OneOf([CLAHE(clip_limit=2),IAASharpen(),RandomContrast(),RandomBrightness(),], p=0.5),
        HueSaturationValue(p=0.25),
    ], p=p)

data_folder = 'D:/data/HPA/all/'
target_count = 10
USE_ALL_CHANNELS = False
RESIZE = True
print('Done.')

### EDA

In [None]:
%%time
#Read Labels
label_csv = pandas.read_csv(data_folder+'train.csv')
samplecount = label_csv['Id'].size
labels = np.zeros([samplecount,28],dtype = np.bool)

for i, row in label_csv.iterrows():
    labelNr = list(map(int,row['Target'].split(' ')))
    labels[i,labelNr] = True #Convert labels to bool, where entry is true if class is present
    
frequencies = labels.sum(axis=0)
classes = np.argsort(frequencies)
frequencies.sort()

In [None]:
#for some reason we need to load theme twice...
# set "context" (paper, notebook, talk, poster)
jtplot.style(theme='grade3',context='talk', fscale=2.5, spines=True, gridlines='',ticks=True, grid=True, figsize=(6, 4.5))
plotcolor = (0, 0.6, 1.0)

plt.figure(figsize=(15,6))
plt.bar(range(28),frequencies,width = 0.5)
xtick_labels = list(map(str, classes))
plt.yscale('log')
plt.xticks(range(28),xtick_labels)

print(frequencies)

### Sample and augment data
We sample in two steps:
- Get as many samples as possible or necessary from the present data
- Augment underrepresented classes

In [None]:
#Setup
print("Collecting garbage...")
gc.collect()

X = []
selected_labels = []
class_representation = np.zeros(28)
print("Done.")

In [None]:
#Step 1
for classNr in classes: #Iterate over classes starting with least represented ones
    print("######################## Sampling class: ", classNr, '######################## ')
    print("Current class representation: ", class_representation)
    current_class_representation = class_representation[classNr]
    
    for i, row in label_csv.iterrows(): #Iterate over files
        #if enough samples collected for class, move on
        if current_class_representation >= target_count:
            break
        
        #check sample contains target class
        if labels[i][classNr]:
            #if sample present, add to dataset 
            fn = data_folder+'train/'+row['Id']
            blue,green,red,yellow = cv2.imread(fn+'_blue.png',0),cv2.imread(fn+'_green.png',0),cv2.imread(fn+'_red.png',0),cv2.imread(fn+'_yellow.png',0)
            
            selected_labels.append(labels[i]) #store labels
            class_representation += labels[i] #update representation
            current_class_representation += 1
            
            # Store image, already handle resizing and channels
            if USE_ALL_CHANNELS: 
                all_img = [green,red,blue,yellow].tranpose(1,2,0)
                if RESIZE:
                    all_img = transform.resize(all_img.squeeze(), (224, 224), preserve_range=True)
                X.append(all_img)
            else:
                if RESIZE:
                    green = np.expand_dims(transform.resize(green, (224, 224), preserve_range=True),axis=2)
                X.append(np.expand_dims(green,axis=2))

            #Drop this sample from remaining data
            label_csv.drop(i,inplace=True)
            labels = np.delete(labels, i, axis=0)
            
            printProgressBar (current_class_representation, target_count, prefix = 'Sampling class...', suffix = '(' + str(current_class_representation) + '/' + str(target_count) + ')')

print_horizontal_divider()
print("Class representation after sampling: ", class_representation)
print("Done.")
            

In [None]:
#we update the target number now, because oversampled the overrepresented classes for sure
#we aim for the middle ground between the overpresentation (the max) and the mean representation
target_count = np.round((target_count + np.max(class_representation) - np.mean(class_representation)) / 2.0)
print("New target count =", target_count)

#store augmentations separately, we don't want to augment already augmented images
augmentedImages = [] 
augmentedLabels = []

In [None]:
#Step 2
for classNr in classes: #Iterate over classes starting with least represented ones
    print("######################## Augmenting class: ", classNr, '######################## ')
    print("Current class representation: ", class_representation)
    current_class_representation = class_representation[classNr]
    i = 0
    while current_class_representation < target_count: # iterate as long as needed to get enough samples
        label = selected_labels[i]
        
        #check sample contains target class
        if label[classNr]:
            #augment image
            augmentedImages.append(augment(X[i]))
            
            #add label
            augmentedLabels.append(label)
            
            #update class representation
            class_representation += labels[i] #update representation
            current_class_representation += 1
            
            printProgressBar (current_class_representation, target_count, prefix = 'Sampling class...', suffix = '(' + str(current_class_representation) + '/' + str(target_count) + ')')
            
        i = (i + 1) % len(selected_labels)
    print("Done with class ", classNr)

#add augmented images
X.extend(augmentedImages)
selected_labels.extend(augmentedLabels)

print_horizontal_divider()
print("Final class representation: ", class_representation)
print("Done.")

### Store the created dataset

In [None]:
# store data
data = {'X': X, 'labels': selected_labels}
if USE_ALL_CHANNELS: 
    if RESIZE:
        dd.io.save(data_folder+'all_channel_augmented_small.h5', data,compression=('blosc', 8))
    else:
        dd.io.save(data_folder+'all_channel_augmented.h5', data,compression=('blosc', 8))
else:
    if RESIZE:
        dd.io.save(data_folder+'poi_augmented_small.h5', data,compression=('blosc', 8))
    else:
        dd.io.save(data_folder+'poi_augmented.h5', data,compression=('blosc', 8))
print("Done.")

### Show sample data

In [None]:
d = dd.io.load(data_folder+'poi_0_small.h5')
print("Done.")

In [None]:
X = d['X']
print(X.shape)
plt.imshow(X[42].squeeze(),cmap='gray')