# Model training

After parsing raw image data, we expectingly have the following directory structure, where data arrays as .NPYs are saved in subdirectories named after the class labels (e.g. `Class Positive`, `Class Negative` etc.).

    /data/parsed/
        Experiment 001/
            Day 1/
                Sample A/
                    Replicate 1/
                        Class A/
                            A__32e88e1ac3a8f44bf8f77371155553b9.npy
                            A__3dc56a0c446942aa0da170acfa922091.npy
                        Class B/
                            B__8068ef7dcddd89da4ca9740bd2ccb31e.npy
        Experiment 002/
            Day 1/
                Sample A/
                    Replicate 1/
                        Class A/
                            A__8348deaa70dfc95c46bd02984d28b873.npy
                        Class B/
                            B__c1ecbca7bd98c01c1d3293b64cd6739a.npy
                            B__c56cfb8e7e7121dd822e47c67d07e2d4.npy

        ...
                
The data can be used to train a model to classify image data as one of each class. The `deepometry.utils.load` function selects images to use for training the model and generate the labels for the training images.

Suppose there is a large imbalance between the number of samples per class in each experiment. Undersampling across classes balances the data seen by the model during training. Additionally, `class_weights` are introduced to work together with undersampling to improve prediction accuracy in underrepresented classes. The `deepometry.utils.load` function performs undersampling across classes (per-experiment) with `sample=True`.

# User's settings

In [None]:
input_dir = '/Data/STEP1_Parsing'
output_dir = '/Data/STEP2_Trained_model'
class_option = 'class'

# Some hyperparameter
n_samples = None # sub-sampling for over-representing classes
validation_split=0.2
batch_size=32
epochs=10

# Executable

In [None]:
%matplotlib inline

import glob
import os

import keras
import matplotlib.pyplot
import numpy
import pandas
import pkg_resources
import tensorflow

import deepometry.model
import deepometry.utils

In [None]:
# build session running on GPU 1
configuration = tensorflow.ConfigProto()
configuration.gpu_options.allow_growth = True
# configuration.gpu_options.visible_device_list = "0"
session = tensorflow.Session(config = configuration)

# apply session
keras.backend.set_session(session)

In [None]:
all_subdirs = [x[0] for x in os.walk(input_dir)]
possible_labels = sorted(list(set([os.path.basename(i) for i in all_subdirs])))
labels_of_interest = [i for i in possible_labels if class_option.lower() in i.lower()]

pathnames_of_interest = deepometry.utils.collect_pathnames(input_dir, labels_of_interest, n_samples=n_samples)

In [None]:
x, y, _ = deepometry.utils._load(pathnames_of_interest, labels_of_interest)

units = len(list(set(labels_of_interest)))

The training and target data (`x` and `y`, respectively) is next passed to the model for training. The model is confiured to withhold 20% of the training data for validation. Use `validation_split` to adjust the size of the partition.

The model will iterate over the training data at most 512 times, specified by `epochs`. Training will terminate early if the validation loss fails to improve for 20 epochs. Training and validation data is provided to the model in batches of 32 samples. Use `batch_size` to configure the number of samples. A smaller `batch_size` requires less memory.

In [None]:
model = deepometry.model.Model(shape=x.shape[1:], units=units)

model.compile()

print('Model training... Please wait!')

model.fit(
    x,
    y,
    balance_train=False,
    class_weight=None,
    validation_split=validation_split,
    batch_size=batch_size,
    epochs=epochs,
    verbose=1
)

Visualize training accuracy and loss

In [None]:
csv = pandas.read_csv(pkg_resources.resource_filename("deepometry", "data/training.csv"))

In [None]:
_, (ax0, ax1) = matplotlib.pyplot.subplots(ncols=2, figsize=(16, 4))

ax0.plot(csv["acc"], c="r")
ax0.plot(csv["val_acc"], c="b")

ax1.plot(csv["loss"][0:], c="r")
ax1.plot(csv["val_loss"][0:], c="b");