# Example of how to train a DPN

## Imports

In [None]:
import os
import sys

from imgaug import augmenters as iaa
from keras.callbacks import EarlyStopping
import numpy as np

# Add root directory to the python search path, if it is not already in there.
root_dir = os.path.abspath(os.path.join(".."))
if root_dir not in sys.path:
    sys.path.append(root_dir)
    
# Import the DeepParticleNet modules.
from dpn.config import Config
from dpn.dataset import Dataset
from dpn.model import Model

from external.CLR.clr_callback import CyclicLR

## Setup config

In [None]:
class MpacConfig(Config):
    # General
    NAME = "example"
    COMMENT = ""

    # Hardware
    GPU_COUNT = 1
    IMAGES_PER_GPU = 1

    # Dataset
    DATASET_PATH = os.path.join(root_dir, "datasets")

    DATASET_SUBSET_TRAIN = "train"
    DATASET_SUBSET_VAL = "val"    

    NUMBER_OF_SAMPLES_TRAIN = 400
    NUMBER_OF_SAMPLES_VAL = 100

    MEAN_PIXEL = [60.2, 60.2, 60.2]

    # Model
    BACKBONE = "resnet50"
    USE_PRETRAINED_WEIGHTS = "coco"

    # Augmentation
    # http://imgaug.readthedocs.io/en/latest/source/augmenters.html
    AUGMENTATION = iaa.SomeOf((0, 2), [
        iaa.Fliplr(0.5),
        iaa.Flipud(0.5),
        iaa.OneOf([iaa.Affine(rotate=90),
                   iaa.Affine(rotate=180),
                   iaa.Affine(rotate=270)]),
        iaa.Multiply((0.8, 1.5)),
        iaa.GaussianBlur(sigma=(0.0, 5.0))
    ])

    # Custom callbacks.
    CYCLIC_LEARNING_RATE = CyclicLR(base_lr=0.0005, 
                                    max_lr=0.0037, 
                                    step_size=2 * np.ceil(NUMBER_OF_SAMPLES_TRAIN/(GPU_COUNT*IMAGES_PER_GPU)),
                                    mode="triangular")

    EARLY_STOPPING = EarlyStopping(monitor="val_loss",
                                   min_delta=0,
                                   patience=20,
                                   verbose=0,
                                   mode="auto")

    CUSTOM_CALLBACKS = [CYCLIC_LEARNING_RATE, EARLY_STOPPING]

# Create a config object.
config = MpacConfig()

## Setup datasets

In [None]:
# Inherit from the Dataset class.
class MpacDataset(Dataset):
    MONOCLASS = "sphere"  # The dataset only has one class.

# Training dataset.
dataset_train = MpacDataset(config=config, dataset_name="training")
# Validation dataset
dataset_val = MpacDataset(config=config, dataset_name="validation")

## Setup model

In [None]:
# Directory to save logs and model checkpoints.
logging_dir = os.path.join(root_dir, "logs")

# Create a model object.
model = Model(
    mode="training",
    config=config,
    model_dir=logging_dir)

## Training

In [None]:
history = model.train(dataset_train, dataset_val, save_best_only=True)