# Training Quantification


Notebook to train a model using given parameters, apply that model to a validation dataset, and then export a variety of statistics about predictions on that data.

## Parameters

In [3]:
# Required
data_dir = None
model_dir = None
export_dir = None
train_image_ids = None
val_image_ids = None

# Optional
class_names = ','.join(['Cell', 'StNum', 'AptNum', 'CellClump', 'Marker', 'Chamber'])
seed = 1211
n_epochs = 10

In [4]:
assert data_dir is not None, 'Must provided "data_dir"'
assert model_dir is not None, 'Must provided "model_dir"'
assert export_dir is not None, 'Must provided "export_dir"'
assert train_image_ids is not None, 'Must provided training image ids'
assert val_image_ids is not None, 'Must provided validation image ids'

train_image_ids = train_image_ids.split(',')
val_image_ids = train_image_ids.split(',')
class_names = class_names.split(',')
n_epochs = int(n_epochs)
seed = int(seed)

AssertionError: Must provided "data_dir"

## Initialization

In [2]:
%run ../config.py
%matplotlib inline

import os
import os.path as osp
import warnings
import numpy as np
import pandas as pd
import papermill as pm
import matplotlib.pyplot as plt
from mrcnn import visualize as mrcnn_viz
from mrcnn import model as mrcnn_model_lib
from cvutils.mrcnn import model as mrcnn_model
from cvutils.mrcnn.session import init_keras_session
from celldom.dataset import CelldomDataset
init_keras_session()

In [3]:
# Training dataset
train_image_paths = [osp.join(DATA_DIR, img) for img in train_image_ids]
dataset_train = CelldomDataset()
dataset_train.initialize(train_image_paths, class_names)
dataset_train.prepare()

# # Validation dataset
dataset_val = CelldomDataset()
val_image_paths = [osp.join(DATA_DIR, img) for img in val_image_ids]
dataset_val.initialize(val_image_paths, class_names)
dataset_val.prepare()

['/lab/data/celldom/data/dataset01/BF_ST_001_APT_000Day0.jpg',
 '/lab/data/celldom/data/dataset01/BF_ST_001_APT_000Day1.jpg',
 '/lab/data/celldom/data/dataset01/BF_ST_001_APT_000Day2.jpg',
 '/lab/data/celldom/data/dataset01/BF_ST_001_APT_000Day3.jpg',
 '/lab/data/celldom/data/dataset01/BF_ST_001_APT_001Day0.jpg']

In [7]:
train_config = CelldomTrainingConfig()
train_config.display()


Configurations:
BACKBONE                       resnet50
BACKBONE_STRIDES               [4, 8, 16, 32, 64]
BATCH_SIZE                     2
BBOX_STD_DEV                   [0.1 0.1 0.2 0.2]
DETECTION_MAX_INSTANCES        250
DETECTION_MIN_CONFIDENCE       0.7
DETECTION_NMS_THRESHOLD        0.3
GPU_COUNT                      1
GRADIENT_CLIP_NORM             5.0
IMAGES_PER_GPU                 2
IMAGE_MAX_DIM                  384
IMAGE_META_SIZE                19
IMAGE_MIN_DIM                  384
IMAGE_MIN_SCALE                0
IMAGE_RESIZE_MODE              square
IMAGE_SHAPE                    [384 384   3]
LEARNING_MOMENTUM              0.9
LEARNING_RATE                  0.001
LOSS_WEIGHTS                   {'rpn_bbox_loss': 1.0, 'rpn_class_loss': 1.0, 'mrcnn_class_loss': 1.0, 'mrcnn_bbox_loss': 1.0, 'mrcnn_mask_loss': 1.0}
MASK_POOL_SIZE                 14
MASK_SHAPE                     [28, 28]
MAX_GT_INSTANCES               250
MEAN_PIXEL                     [123.7 116.8 103.9]
MIN

## Training

In [8]:
model = mrcnn_model.get_model(mode="training", config=train_config, model_dir=model_dir, init_with='coco')

In [9]:
# Ignore these warnings for now as they seem to be irrelevant so far
warnings.filterwarnings(
    'ignore', category=UserWarning,
    message='Converting sparse IndexedSlices to a dense Tensor of unknown shape'
)
warnings.filterwarnings(
    'ignore', category=UserWarning,
    message='Using a generator with `use_multiprocessing=True` and multiple workers may duplicate your data'
)

model.train(
    dataset_train, dataset_val, 
    learning_rate=train_config.LEARNING_RATE, 
    epochs=n_epochs, 
    layers='heads'
)


Starting at epoch 0. LR=0.001

Checkpoint Path: models/domain01/celldom20180527T1559/mask_rcnn_celldom_{epoch:04d}.h5
Selecting layers to train
fpn_c5p5               (Conv2D)
fpn_c4p4               (Conv2D)
fpn_c3p3               (Conv2D)
fpn_c2p2               (Conv2D)
fpn_p5                 (Conv2D)
fpn_p2                 (Conv2D)
fpn_p3                 (Conv2D)
fpn_p4                 (Conv2D)
In model:  rpn_model
    rpn_conv_shared        (Conv2D)
    rpn_class_raw          (Conv2D)
    rpn_bbox_pred          (Conv2D)
mrcnn_mask_conv1       (TimeDistributed)
mrcnn_mask_bn1         (TimeDistributed)
mrcnn_mask_conv2       (TimeDistributed)
mrcnn_mask_bn2         (TimeDistributed)
mrcnn_class_conv1      (TimeDistributed)
mrcnn_class_bn1        (TimeDistributed)
mrcnn_mask_conv3       (TimeDistributed)
mrcnn_mask_bn3         (TimeDistributed)
mrcnn_class_conv2      (TimeDistributed)
mrcnn_class_bn2        (TimeDistributed)
mrcnn_mask_conv4       (TimeDistributed)
mrcnn_mask_bn4     

## Quantification

In [8]:
inference_config = CelldomInferenceConfig()
model = mrcnn_model.get_model('inference', inference_config, model_dir, init_with='last')

In [68]:
from cvutils.mrcnn import inference as mrcnn_inference
from celldom import inference as celldom_inference

pred_gen = mrcnn_inference.prediction_generator(model, dataset_val)
analysis_fns = celldom_inference.get_default_analysis_fns()
df = pd.DataFrame([celldom_inference.analyze_prediction(p, analysis_fns) for p in pred_gen])

Unnamed: 0,counts,image_id,image_info,scores_aptnum,scores_cell,scores_cellclump,scores_cellunion,scores_chamber,scores_stnum,stats_cell,stats_cellclump,stats_cellunion,stats_chamber
0,"{'pred': [28, 1, 1, 1, 1, 1], 'true': [34, 1, ...",0,/lab/data/celldom/data/dataset01/BF_ST_001_APT...,"{'pred_status': 'Valid', 'dice': 0.85014049328...","{'pred_status': 'Valid', 'dice': 0.90818495880...","{'pred_status': 'Valid', 'dice': 0.92190548108...","{'pred_status': 'Valid', 'dice': 0.93618937347...","{'pred_status': 'Valid', 'dice': 0.96056108463...","{'pred_status': 'Valid', 'dice': 0.82280788993...","{'pred': [28.0, 388.67857142857144, 477.209795...","{'pred': [1.0, 7859.0, nan, 7859.0, 7859.0, 78...","{'pred': [1.0, 7514.0, nan, 7514.0, 7514.0, 75...","{'pred': [1.0, 39310.0, nan, 39310.0, 39310.0,..."
1,"{'pred': [74, 1, 1, 1, 1], 'true': [105, 1, 1,...",1,/lab/data/celldom/data/dataset01/BF_ST_001_APT...,"{'pred_status': 'Valid', 'dice': 0.79703703703...","{'pred_status': 'Valid', 'dice': 0.88430854069...","{'pred_status': 'Empty', 'dice': nan, 'true_st...","{'pred_status': 'Valid', 'dice': 0.88632875920...","{'pred_status': 'Valid', 'dice': 0.95103313654...","{'pred_status': 'Valid', 'dice': 0.77904328018...","{'pred': [74.0, 340.52702702702703, 268.837080...","{'pred': [0.0, nan, nan, nan, nan, nan, nan, n...","{'pred': [1.0, 20289.0, nan, 20289.0, 20289.0,...","{'pred': [1.0, 41668.0, nan, 41668.0, 41668.0,..."
2,"{'pred': [6, 1, 1, 1, 1], 'true': [4, 2, 1, 1,...",2,/lab/data/celldom/data/dataset01/BF_ST_001_APT...,"{'pred_status': 'Valid', 'dice': 0.88285004142...","{'pred_status': 'Valid', 'dice': 0.70649689846...","{'pred_status': 'Empty', 'dice': nan, 'true_st...","{'pred_status': 'Valid', 'dice': nan, 'true_st...","{'pred_status': 'Valid', 'dice': 0.94012183392...","{'pred_status': 'Valid', 'dice': 0.90926493108...","{'pred': [6.0, 306.1666666666667, 46.884610125...","{'pred': [0.0, nan, nan, nan, nan, nan, nan, n...","{'pred': [1.0, 1815.0, nan, 1815.0, 1815.0, 18...","{'pred': [1.0, 36663.0, nan, 36663.0, 36663.0,..."
3,"{'pred': [1, 1, 1, 1, 1], 'true': [1, 1, 1, 1,...",3,/lab/data/celldom/data/dataset01/BF_ST_001_APT...,"{'pred_status': 'Valid', 'dice': 0.82600732600...","{'pred_status': 'Valid', 'dice': 0.94562647754...","{'pred_status': 'Empty', 'dice': nan, 'true_st...","{'pred_status': 'Valid', 'dice': 0.94562647754...","{'pred_status': 'Valid', 'dice': 0.92945220835...","{'pred_status': 'Valid', 'dice': 0.90671579299...","{'pred': [1.0, 208.0, nan, 208.0, 208.0, 208.0...","{'pred': [0.0, nan, nan, nan, nan, nan, nan, n...","{'pred': [1.0, 208.0, nan, 208.0, 208.0, 208.0...","{'pred': [1.0, 37360.0, nan, 37360.0, 37360.0,..."
4,"{'pred': [1, 1, 1, 1, 1], 'true': [1, 1, 1, 1,...",4,/lab/data/celldom/data/dataset01/BF_ST_001_APT...,"{'pred_status': 'Valid', 'dice': 0.86308286308...","{'pred_status': 'Valid', 'dice': 0.88391038696...","{'pred_status': 'Empty', 'dice': nan, 'true_st...","{'pred_status': 'Valid', 'dice': 0.88391038696...","{'pred_status': 'Valid', 'dice': 0.92573175507...","{'pred_status': 'Valid', 'dice': 0.88046166529...","{'pred': [1.0, 223.0, nan, 223.0, 223.0, 223.0...","{'pred': [0.0, nan, nan, nan, nan, nan, nan, n...","{'pred': [1.0, 223.0, nan, 223.0, 223.0, 223.0...","{'pred': [1.0, 36640.0, nan, 36640.0, 36640.0,..."


In [None]:
df.info()

In [None]:
df.head()

In [70]:
export_path = osp.join(export_dir, 'stats.pkl')
pm.record('stats_path', export_path)
df.to_pickle(export_path)

{'pred': Cell         28
 Chamber       1
 AptNum        1
 CellClump     1
 Marker        1
 StNum         1
 dtype: int64, 'true': Cell         34
 Chamber       1
 AptNum        1
 CellClump     1
 Marker        1
 StNum         1
 dtype: int64}