<a href="https://colab.research.google.com/github/matjesg/DeepFLaSH2/blob/master/Deepflash_SC_all.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Cross Coder Notebook

#### Colab options

In [None]:
try:
    from google.colab import drive
    drive.mount('/content/drive')
    !git clone https://github.com/matjesg/DeepFLaSH2.git /content/drive/My\ Drive/DeepFLaSH2
    %cd /content/drive/My\ Drive/DeepFLaSH2
except:
    pass

Import packages

In [None]:
%load_ext autoreload
%autoreload 2
import os
import numpy as np
import pandas as pd
from skimage import io
from deepflash import unet, preproc
import matplotlib.pyplot as plt
from tqdm import tqdm
%matplotlib inline

## Global Settings

In [None]:
CODER = ['rohini', 'dennis', 'cora', 'manju', 'corinna'] 
NAME_PREFIX = 'all_falk'
MASK = 'cFOS'
IMAGE = 'red'
CHANNELS_IMG = 1
DATA_PATH = "data/images"
MASK_PATH = "data/labels"
ASSIGNMENT_PATH = 'samples_36_final.csv'
TILE_SHAPE = (540,540)
PADDING = (184,184)
SEED = 0
EL_SIZE = [635.9, 635.9] #micrometers
CHECKPOINTS = 'checkpoints_sc'
LOGDIR = 'logs_cv'

### Training params

In [None]:
PRETAINED = None
BATCH_NORM = False
EPOCHS = 1
CYCLIC_LR = 'triangular'
SNAPSHOT_INTERVAL = 100

### Weighting params

In [None]:
LAMBDA = 50 #50
V_BAL = 0.1 #0.1
SIGMA_BAL = 10 #10 
SIGMA_SEP = 6 #6

#### Load Data

Excel list with assignments

In [None]:
assignment = pd.read_csv(ASSIGNMENT_PATH, converters={'Nummer': lambda x: str(x).zfill(4)})
file_ids = assignment['Nummer'].tolist()

Load images

In [None]:
image_list = [io.imread(os.path.join(DATA_PATH, img_name), as_gray=True) for 
              img_name in [s + '_' + IMAGE + '.tif' for s in file_ids]]

image_list = [np.expand_dims(img, axis=2) for img in image_list]
data = [{'rawdata': img, 'element_size_um': EL_SIZE} for img in image_list]

Load masks and combine train data

In [None]:
X_train = np.empty(((0,) + image_list[0].shape))
y_train = np.empty(((0,) + image_list[0][...,0].shape))

for coder in CODER:
    mask_list = [io.imread(os.path.join(MASK_PATH, coder, x), as_gray=True).astype('int')
             for x in [s + '_' + MASK + '.png' for s in file_ids]]
    
    X_train = np.append(X_train, np.array(image_list), axis=0)
    y_train = np.append(y_train, np.array(mask_list), axis=0)

data_train = [{'rawdata': img, 'element_size_um': EL_SIZE} for img in X_train]

Train Model

In [None]:
name = NAME_PREFIX + '_' + MASK + '_' + str(CYCLIC_LR)
print(name)

## Generators
train_generator = preproc.DataAugmentationGenerator(data = data_train, 
                                                classlabels=y_train,
                                                instancelabels=None,
                                                tile_shape = TILE_SHAPE, 
                                                padding= PADDING,
                                                batch_size = 4,
                                                n_classes=2,
                                                ignore=None,
                                                weights=None,
                                                element_size_um=None,
                                                rotation_range_deg=(0, 360),
                                                flip=False,
                                                deformation_grid=(150, 150),
                                                deformation_magnitude=(10, 10),
                                                value_minimum_range=(0, 0),
                                                value_maximum_range=(0.0, 1),
                                                value_slope_range=(1, 1),
                                                shuffle=True,
                                                foreground_dist_sigma_px=SIGMA_BAL,
                                                border_weight_sigma_px=SIGMA_SEP,
                                                border_weight_factor=LAMBDA,
                                                foreground_background_ratio=V_BAL
                                               )


model = unet.Unet2D(snapshot=None, 
                n_channels=1, 
                n_classes=2, 
                n_levels=4,
                batch_norm = BATCH_NORM,
                upsample=False,
                relu_alpha=0.1,
                n_features=64, name=name)

model.train(train_generator, 
        validation_generator=None, 
        n_epochs=EPOCHS, 
        snapshot_interval= SNAPSHOT_INTERVAL,
        snapshot_dir = CHECKPOINTS,
        snapshot_prefix=name,
        log_dir = LOGDIR,
        cyclic_lr= CYCLIC_LR,
        step_muliplier=int(np.round(9/len(CODER))))