# Training

In [None]:
from argparse import ArgumentParser, Namespace
from glob import glob
from pathlib import Path
from typing import Tuple

from csbdeep.utils import normalize
from csbdeep.data import shuffle_inplace
import matplotlib.pyplot as plt
import numpy as np
from tifffile import imread
from tqdm import tqdm


from iterative_biofilm_annotation.unet.utils import crop

In [None]:
def train(
    basedir: Path,
    modelname: str,
    dataset: str,
    patch_size: Tuple[int],
    epochs: int,
    steps: int,
    ) -> None:
    return

In [None]:
patch_size = (48, 96, 96)
modelname = 'care_bcm3d_target2_v2'
epochs = 100
steps = 100
basedir = Path('models')


In [None]:
    # load and crop out central patch (for simplicity)
    X_train = [crop(imread(x), patch_size) for x in sorted(glob(f'training_data/patches-semimanual-raw-64x128x128/train/images/*.tif'))]
    Y_train = [crop(imread(y), patch_size) for y in sorted(glob(f'training_data/patches-semimanual-raw-64x128x128/train/target_bacm3d_2/*.tif'))]

    # load and crop out central patch (for simplicity)
    X_valid = [crop(imread(x), patch_size) for x in sorted(glob(f'training_data/patches-semimanual-raw-64x128x128/valid/images/*.tif'))]
    Y_valid = [crop(imread(y), patch_size) for y in sorted(glob(f'training_data/patches-semimanual-raw-64x128x128/valid/target_bacm3d_2/*.tif'))]

    # normalize input image
    X_train = [normalize(x,1,99.8) for x in tqdm(X_train)]

    # normalize input image
    X_valid = [normalize(x,1,99.8) for x in tqdm(X_valid)]

    # convert to numpy arrays
    X_train, Y_train = np.expand_dims(np.stack(X_train),-1), np.expand_dims(np.stack(Y_train), -1)

    # convert to numpy arrays
    X_valid, Y_valid = np.expand_dims(np.stack(X_valid),-1), np.expand_dims(np.stack(Y_valid), -1)

In [None]:
    X_train.shape, X_valid.shape

# Use CARE

In [None]:
from __future__ import print_function, unicode_literals, absolute_import, division
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from tifffile import imread
from csbdeep.utils import axes_dict, plot_some, plot_history
from csbdeep.utils.tf import limit_gpu_memory
from csbdeep.models import Config, CARE

In [None]:
plt.figure(figsize=(12,5))
plot_some(X_valid[:5],Y_valid[:5])
plt.suptitle('5 example validation patches (top row: source, bottom row: target)');

In [None]:
#limit_gpu_memory(fraction=1/2)

In [None]:
axes = 'SZYXC'
config = Config(axes, n_channel_in=1, n_channel_out=1, train_steps_per_epoch=100)
model = CARE(config, modelname, basedir=basedir)
history = model.train(X_train,Y_train, validation_data=(X_valid,Y_valid))

In [None]:
print(sorted(list(history.history.keys())))
plt.figure(figsize=(16,5))
plot_history(history,['loss','val_loss'],['mse','val_mse','mae','val_mae']);