In [1]:
from __future__ import print_function, unicode_literals, absolute_import, division
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from tifffile import imread
from csbdeep.utils import axes_dict, plot_some, plot_history
from csbdeep.utils.tf import limit_gpu_memory
from csbdeep.io import load_training_data
from csbdeep.models import Config, ProjectionCARE, ProjectionConfig

import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [None]:
import os
import time

TriggerName = '/home/sancere/NextonDisk_1/TimeTrigger/TTGenData2'
TimeCount = 0
TimeThreshold = 3600*12
while os.path.exists(TriggerName) == False and TimeCount < TimeThreshold :
    time.sleep(60*5)
    TimeCount = TimeCount + 60*5

In [None]:
BaseDir = '/run/media/sancere/DATA/Lucas_NextonCreated_npz/'
ModelName = 'MariaTraining_projection_40x_bin2_Rfp'

load_path = BaseDir + ModelName + '.npz'

(X,Y), (X_val,Y_val), axes = load_training_data(load_path, validation_split=0.1, verbose=True)
c = axes_dict(axes)['C']
n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

In [None]:
plt.figure(figsize=(12,5))
plot_some(X_val[:5],Y_val[:5])
plt.suptitle('5 example validation patches (top row: source, bottom row: target)');

In [None]:
config = ProjectionConfig(axes, n_channel_in, n_channel_out, unet_n_depth=4,train_epochs= 40,train_steps_per_epoch = 100, train_batch_size = 100, train_reduce_lr={'patience': 5, 'factor': 0.5})
print(config)
vars(config)

In [None]:
model = ProjectionCARE(config=config, name = ModelName, basedir = BaseDir)
model.load_weights(BaseDir + ModelName + '/' + 'weights_best.h5')

In [None]:
history = model.train(X,Y, validation_data=(X_val,Y_val))

In [None]:
print(sorted(list(history.history.keys())))
plt.figure(figsize=(16,5))
plot_history(history,['loss','val_loss'],['mse','val_mse','mae','val_mae']);

In [None]:
plt.figure(figsize=(12,7))
_P = model.keras_model.predict(X_val[:20])
if config.probabilistic:
    _P = _P[...,:(_P.shape[-1]//2)]
plot_some(X_val[:20],Y_val[:20],_P,pmax=99.5)
plt.suptitle('5 example validation patches\n'      
             'top row: input (source),  '          
             'middle row: target (ground truth),  '
             'bottom row: predicted from source');

In [None]:
model.export_TF()

In [None]:
from csbdeep.utils import Path

TriggerName = '/home/sancere/NextonDisk_1/TimeTrigger/TT_Training'
Path(TriggerName).mkdir(exist_ok = True)