### Cross validation Notebook

- [ ] Cora
- [ ] Dennis
- [ ] Manju
- [ ] Corinna
- [ ] GT
- [ ] Cross Coder?

### Colab options

In [None]:
try:
    from google.colab import drive
    drive.mount('/content/drive')
    !git clone https://github.com/matjesg/DeepFLaSH2.git /content/drive/My\ Drive/DeepFLaSH2
    %cd /content/drive/My\ Drive/DeepFLaSH2
except:
    pass

Import packages

In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
import pandas as pd
from skimage import io
import os
from deepflash import unet, preproc, metrics, lr_finder, utils
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedKFold
import skimage
from skimage.measure import label
from skimage.color import label2rgb
%matplotlib inline

## Global Settings

In [None]:
SKIP = 0
CODER = ['manju', 'corinna', 'rohini', 'dennis'] #cora'
NAME_PREFIX = 'falk'
MASK = 'cFOS'
IMAGE = 'red'
CHANNELS_IMG = 100
DATA_PATH = "data/images"
MASK_PATH = "data/labels"
ASSIGNMENT_PATH = 'samples_36_final.csv'
TILE_SHAPE = (540,540)
PADDING = (184,184)
SEED = 0
EL_SIZE = [635.9, 635.9] #micrometers
CHECKPOINTS = 'checkpoints_cv'
LOGDIR = 'logs_cv'

### Training params

In [None]:
PRETAINED = None #'caffe/caffe_weights.5' #None
BATCH_NORM = False
EPOCHS = 100
CYCLIC_LR = None #'triangular'
SNAPSHOT_INTERVAL = 5
N_SPLITS = 4

### Weighting params

In [None]:
LAMBDA = 50 #50
V_BAL = 0.1 #0.1
SIGMA_BAL = 10 #10 
SIGMA_SEP = 6 #6

## Load Data

Excel list with assignments

In [None]:
assignment = pd.read_csv(ASSIGNMENT_PATH, converters={'Nummer': lambda x: str(x).zfill(4)})
assignment['Group_ID'] = assignment.groupby(['Kondition', 'Area']).ngroup()
file_ids = assignment['Nummer'].tolist()

Images

In [None]:
image_list = [io.imread(os.path.join(DATA_PATH, img_name), as_gray=True) for 
              img_name in [s + '_' + IMAGE + '.tif' for s in file_ids]]

image_list = [np.expand_dims(img, axis=2) for img in image_list]

### Loop over coder and folds

In [None]:
for coder in CODER:
    mask_list = [io.imread(os.path.join(MASK_PATH, coder, x), as_gray=True).astype('int')
             for x in [s + '_' + MASK + '.png' for s in file_ids]]
    
    name = NAME_PREFIX + '_' + MASK + '_' + str(CYCLIC_LR) + '_' + coder 
    output_path = os.path.join('output', name)
    os.makedirs(output_path)
    skf = StratifiedKFold(n_splits=N_SPLITS, random_state=SEED)
    fold = 0
    
    for train_index, test_index in skf.split(assignment['Nummer'], assignment['Group_ID']):
        if fold < SKIP:
            continue
        fold += 1
        X_train, X_test = np.array(image_list)[train_index], np.array(image_list)[test_index]
        y_train, y_test = np.array(mask_list)[train_index], np.array(mask_list)[test_index]

        data_train = [{'rawdata': img, 'element_size_um': EL_SIZE} for img in X_train]
        data_test = [{'rawdata': img, 'element_size_um': EL_SIZE} for img in X_test]

        ## Generators
        train_generator = preproc.DataAugmentationGenerator(data = data_train, 
                                                        classlabels=y_train,
                                                        instancelabels=None,
                                                        tile_shape = TILE_SHAPE, 
                                                        padding= PADDING,
                                                        batch_size = 4,
                                                        n_classes=2,
                                                        ignore=None,
                                                        weights=None,
                                                        element_size_um=None,
                                                        rotation_range_deg=(0, 360),
                                                        flip=False,
                                                        deformation_grid=(150, 150),
                                                        deformation_magnitude=(10, 10),
                                                        value_minimum_range=(0, 0),
                                                        value_maximum_range=(0.0, 1),
                                                        value_slope_range=(1, 1),
                                                        shuffle=True,
                                                        foreground_dist_sigma_px=SIGMA_BAL,
                                                        border_weight_sigma_px=SIGMA_SEP,
                                                        border_weight_factor=LAMBDA,
                                                        foreground_background_ratio=V_BAL
                                                       )
        test_generator = preproc.TileGenerator(data = data_test,
                                           classlabels=y_test,
                                           instancelabels=None,
                                           tile_shape = TILE_SHAPE, 
                                           padding= PADDING,
                                           n_classes=2,
                                           ignore=None,
                                           weights=None,
                                           element_size_um=EL_SIZE,                                       
                                           foreground_dist_sigma_px=SIGMA_BAL,
                                           border_weight_sigma_px=SIGMA_SEP,
                                           border_weight_factor=LAMBDA,
                                           foreground_background_ratio=V_BAL)

        ## Model
        name_helper = name + '_' + str(fold)

        print(name_helper)

        model = unet.Unet2D(snapshot=None, 
                        n_channels=1, 
                        n_classes=2, 
                        n_levels=4,
                        batch_norm = BATCH_NORM,
                        upsample=False,
                        relu_alpha=0.1,
                        n_features=64, name=name_helper)

        if PRETAINED is not None: 
            model.trainModel.load_weights(PRETAINED,reshape=True, by_name=True)

        model.train(train_generator, 
                validation_generator=test_generator, 
                n_epochs=EPOCHS, 
                snapshot_interval= SNAPSHOT_INTERVAL,
                snapshot_dir = CHECKPOINTS,
                snapshot_prefix=name_helper,
                log_dir = LOGDIR,
                cyclic_lr= CYCLIC_LR)

        ## Predict
        pred_model = unet.Unet2D(snapshot= CHECKPOINTS + '/' +name_helper+'.0100.h5',
                        n_channels=1, 
                        n_classes=2, 
                        n_levels=4, 
                        batch_norm =  BATCH_NORM,
                        upsample=False,
                        relu_alpha=0.1,
                        n_features=64,name="U-Net")

        tile_generator = preproc.TileGenerator(data_test, TILE_SHAPE, PADDING)

        predictions = pred_model.predict(test_generator)

        ## Save
        for i in range(len(predictions[1])):
            idx = np.array(file_ids)[test_index][i]
            file_name = idx + '_' + MASK + '.png'
            io.imsave(os.path.join(output_path, file_name), predictions[1][i])