### Settings

In [None]:
from IPython.core.display import display, HTML
import sys,cv2,gc
sys.path.append('../')
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import pandas
from Utils.utils import *
from ipywidgets import interact
import deepdish as dd
from skimage import io, transform

%matplotlib inline
display(HTML("<style>.container { width:100% !important; }</style>"))

%load_ext autoreload
%autoreload 2

from jupyterthemes import jtplot
# set "context" (paper, notebook, talk, poster)
jtplot.style(theme='grade3',context='talk', fscale=2.5, spines=True, gridlines='',ticks=True, grid=True, figsize=(6, 4.5))
plotcolor = (0, 0.6, 1.0)

# Load augmentation stuff
from albumentations import (
    IAAFliplr,IAAFlipud, OneOf, RandomContrast, RandomBrightness, RandomRotate90,
    IAASharpen, HueSaturationValue,IAAAdditiveGaussianNoise, MedianBlur, GaussNoise,
    IAAPiecewiseAffine, OneOf, Compose, Transpose,MotionBlur, Blur
)
def augment(p=1):
    return Compose([RandomRotate90(),IAAFliplr(),IAAFlipud(),Transpose(),
        OneOf([GaussNoise(),], p=0.25),
        OneOf([MotionBlur(p=.2),MedianBlur(blur_limit=3, p=.1),Blur(blur_limit=3, p=.1),], p=0.25),
        OneOf([IAASharpen(),RandomContrast(),RandomBrightness(),], p=0.5)
    ], p=p)

data_folder = 'D:/data/HPA/all/'
target_count = 500 #set this lower than actual target count. will increase anyways
USE_ALL_CHANNELS = True
RESIZE = False
print('Done.')

### EDA

In [None]:
%%time
#Read Labels
label_csv = pandas.read_csv(data_folder+'train.csv')
samplecount = label_csv['Id'].size
filenames = label_csv["Id"].values
labels = np.zeros([samplecount,28],dtype = np.bool)

for i, row in label_csv.iterrows():
    labelNr = list(map(int,row['Target'].split(' ')))
    labels[i,labelNr] = True #Convert labels to bool, where entry is true if class is present
    
frequencies = labels.sum(axis=0)
classes = np.argsort(frequencies)
frequencies.sort()
print(frequencies)

#shuffle labels and filenames
idx = np.arange(samplecount)
np.random.shuffle(idx)
filenames = filenames[idx].tolist()
labels = labels[idx]

In [None]:
#for some reason we need to load theme twice...
# set "context" (paper, notebook, talk, poster)
jtplot.style(theme='grade3',context='talk', fscale=2.5, spines=True, gridlines='',ticks=True, grid=True, figsize=(6, 4.5))
plotcolor = (0, 0.6, 1.0)

plt.figure(figsize=(20,4))
plt.bar(range(28),frequencies,width = 0.5)
xtick_labels = list(map(str, classes))
plt.yscale('log')
plt.xlabel("Class Index")
plt.ylabel("Frequency")
_ = plt.xticks(range(28),xtick_labels)

### Sample and augment data
We sample in two steps:
- Get as many samples as possible or necessary from the present data
- Augment underrepresented classes

In [None]:
#Setup
print("Collecting garbage...")
gc.collect()

X = []
selected_labels = []
class_representation = np.zeros(28)
print("Done.")

In [None]:
#Step 1
for classNr in classes: #Iterate over classes starting with least represented ones
    print("######################## Sampling class: ", classNr, "Current class samples = ", class_representation[classNr], '######################## ')
    current_class_representation = class_representation[classNr]
    
    i = 0
    while(True): #Iterate over files
        #if enough samples collected for class, move on
        if current_class_representation >= target_count:
            break
        #iterating over changing stuff is nasty, this is our abort.
        if i >= len(filenames):
            break
            
        #check sample contains target class
        if labels[i][classNr]:
            #if sample present, add to dataset 
            fn = data_folder+'train/'+filenames[i]
            blue,green,red,yellow = cv2.imread(fn+'_blue.png',0),cv2.imread(fn+'_green.png',0),cv2.imread(fn+'_red.png',0),cv2.imread(fn+'_yellow.png',0)
            
            selected_labels.append(labels[i]) #store labels
            class_representation += labels[i] #update representation
            current_class_representation += 1
            
            # Store image, already handle resizing and channels
            if USE_ALL_CHANNELS: 
                all_img = np.asarray([green,red,blue,yellow]).transpose(1,2,0)
                if RESIZE:
                    all_img = transform.resize(all_img.squeeze(), (224, 224), preserve_range=True)
                X.append(all_img.astype(np.uint8))
            else:
                if RESIZE:
                    green = transform.resize(green, (224, 224), preserve_range=True)
                X.append(np.expand_dims(green,axis=2).astype(np.uint8))

            #Drop this sample from remaining data
            filenames = np.delete(filenames,i, axis=0)
            labels = np.delete(labels, i, axis=0)
            
            printProgressBar (current_class_representation, target_count, prefix = 'Sampling class...', suffix = '(' + str(current_class_representation) + '/' + str(target_count) + ')')
            
        i += 1
        
print_horizontal_divider()
print("Class representation after sampling: ", class_representation)
print("Done.")
            

In [None]:
#we update the target number now, because oversampled the overrepresented classes for sure
#we aim for the middle ground between the overpresentation (the max) and the mean representation
# target_count = np.round((target_count + np.max(class_representation) - np.mean(class_representation)) / 2.0)
# print("New target count =", target_count)
# print(X[0].shape)
#store augmentations separately, we don't want to augment already augmented images
augmentedImages = [] 
augmentedLabels = []

In [None]:
#Step 2
for classNr in classes: #Iterate over classes starting with least represented ones
    current_class_representation = class_representation[classNr]
    i = 0
    while current_class_representation < target_count: # iterate as long as needed to get enough samples
        label = selected_labels[i]
        
        #check sample contains target class
        if label[classNr]:
            #augment image
            aug = augment()
            img = aug(image=X[i].squeeze().astype(np.uint8))['image']
            if not USE_ALL_CHANNELS:
                img = np.expand_dims(img,axis=2)
#             fig = plt.figure(figsize=(10,5))
#             f, axarr = plt.subplots(1,2)
#             axarr[0].imshow(X[i].squeeze())
#             axarr[1].imshow(img.squeeze())
#             plt.show()
#             wait

            #add image
            augmentedImages.append(img)
            
            #add label
            augmentedLabels.append(label)
            
            #update class representation)
            class_representation += selected_labels[i] #update representation
            current_class_representation += 1
            
            if  i % 25 == 0 or (current_class_representation == target_count):
                printProgressBar (current_class_representation, target_count, prefix = 'Augmenting class ' + str(classNr) + '...', suffix = '(' + str(current_class_representation) + '/' + str(target_count) + ')')
            
        i = (i + 1) % len(selected_labels)


#add augmented images
X.extend(augmentedImages)
selected_labels.extend(augmentedLabels)

print_horizontal_divider()
print("Final class representation: ", class_representation)
print("Total samples = ", len(X))
print("Done.")

### Store the created dataset

In [None]:
%%time
# store data
data = {'X': np.asarray(X), 'labels': np.asarray(selected_labels)}
if USE_ALL_CHANNELS: 
    if RESIZE:
        dd.io.save(data_folder+'all_channel_augmented_small.h5', data,compression=('blosc', 8))
    else:
        dd.io.save(data_folder+'all_channel_augmented.h5', data,compression=('blosc', 8))
else:
    if RESIZE:
        dd.io.save(data_folder+'poi_augmented_small.h5', data,compression=('blosc', 8))
    else:
        dd.io.save(data_folder+'poi_augmented.h5', data,compression=('blosc', 8))
print("Done.")

### Show sample data

In [None]:
d = dd.io.load(data_folder+'poi_0_small.h5')
print("Done.")

In [None]:
X = d['X']
print(X.shape)
plt.imshow(X[42].squeeze(),cmap='gray')