In [1]:
import numpy as np
from jarvis.train import datasets
from jarvis.train.client import Client
from jarvis.utils.general import tools as jtools
from jarvis.utils.display import imshow
from tensorflow.keras import Input, Model, models, layers, metrics
from tensorflow import losses, optimizers

In [2]:
from jarvis.utils.general import gpus
gpus.autoselect()

[ 2020-11-05 18:48:23 ] CUDA_VISIBLE_DEVICES automatically set to: 3           


In [3]:
paths = jtools.get_paths('xr/breast-fgt')
client = Client('/data/raw/xr_breast_fgt/data/ymls/client.yml')
gen_train, gen_valid = client.create_generators()

In [4]:
# --- Yield one example
# xs, ys = next(gen_train)

# for key, arr in xs.items():
#     print('xs key: {} | shape = {}'.format(key.ljust(8), arr.shape))
# for key, arr in ys.items():
#     print('ys key: {} | shape = {}'.format(key.ljust(8), arr.shape))

# imshow(xs['dat'][0])
# imshow(xs['dat'], figsize=(12, 12))

In [5]:
inputs = client.get_inputs(Input)

kwargs = {
    'kernel_size': (1, 3, 3),
    'padding': 'same'}
  #  'kernel_initializer': 'he_normal'}

conv = lambda x, filters, strides : layers.Conv3D(filters=filters, strides=strides, **kwargs)(x)
norm = lambda x : layers.BatchNormalization()(x)
relu = lambda x : layers.LeakyReLU()(x)
conv1 = lambda filters, x : relu(norm(conv(x, filters, strides=1)))
conv2 = lambda filters, x : relu(norm(conv(x, filters, strides=(1, 2, 2))))

# # T1
# l1 = conv2(48, conv1(48, conv1(48, inputs['dat'])))
# l2 = conv2(56, conv1(56, conv1(56, l1)))
# l3 = conv2(64, conv1(64, conv1(64, l2)))
# l4 = conv2(80, conv1(80, conv1(80, l3)))
# l5 = conv2(96, conv1(96, conv1(96, l4)))
# l6 = conv2(112, conv1(112, conv1(112, l5)))
# l7 = conv2(128, conv1(128, conv1(128, l6)))
# f0 = layers.Reshape((1, 1, 1, 2 * 2 * 128))(l7)
# trial = 1

# # # T2
# l1 = conv2(48, conv1(48, inputs['dat']))
# l2 = conv2(56, conv1(56, l1))
# l3 = conv2(64, conv1(64, l2))
# l4 = conv2(80, conv1(80, l3))
# l5 = conv2(96, conv1(96, l4))
# l6 = conv2(112, conv1(112, l5))
# l7 = conv2(128, conv1(128, l6))
# f0 = layers.Reshape((1, 1, 1, 2 * 2 * 128))(l7)
# trial = 2

# # T3
# l1 = conv2(48, conv1(48, inputs['dat']))
# l2 = conv2(56, conv1(56, l1))
# l3 = conv2(64, conv1(64, l2))
# l4 = conv2(80, conv1(80, l3))
# l5 = conv2(96, conv1(96, l4))
# l6 = conv2(112, conv1(112, l5))
# l7 = conv2(128, conv1(128, l6))
# l8 = conv2(256, conv1(256, l7))
# f0 = layers.Reshape((1, 1, 1, 1 * 1 * 256))(l8)
# trial = 3

# T4
l1 = conv2(16, conv1(16, conv1(16, inputs['dat'])))
l2 = conv2(36, conv1(36, conv1(36, l1)))
l3 = conv2(48, conv1(48, conv1(48, l2)))
l4 = conv2(64, conv1(64, conv1(64, l3)))
l5 = conv2(80, conv1(80, conv1(80, l4)))
l6 = conv2(112, conv1(112, conv1(112, l5)))
l7 = conv2(128, conv1(128, conv1(128, l6)))
f0 = layers.Reshape((1, 1, 1, 2 * 2 * 128))(l7)
trial = 4

logits = {}
logits['lbl'] = layers.Conv3D(filters=1, kernel_size=(1, 1, 1), activation='sigmoid', name='lbl')(f0)

In [6]:
model = Model(inputs=inputs, outputs=logits)
model.compile(
    optimizer=optimizers.Adam(learning_rate=5e-4),
    loss={'lbl': losses.MeanAbsoluteError()}, metrics={'lbl':losses.MeanSquaredError()},
    experimental_run_tf_function=False)

# ***mean absolute error (MAE)
# mean squared error (MSE)
# Huber loss

In [7]:
model.fit(
    x=gen_train, 
    steps_per_epoch=500, 
    epochs=4,
    validation_data=gen_valid,
    validation_steps=500,
    validation_freq=2,
    use_multiprocessing=True)


Epoch 1/4
Epoch 2/4
Epoch 1/4
Epoch 3/4
Epoch 4/4
Epoch 1/4


<tensorflow.python.keras.callbacks.History at 0x7f0f38568470>

In [13]:
#11052020 - (T4) - trained on 4 epochs, lr: 5e-4
#train again on lower lr: 5e-5
model.compile(
    optimizer=optimizers.Adam(learning_rate=5e-5),
    loss={'lbl': losses.MeanAbsoluteError()}, metrics={'lbl':losses.MeanSquaredError()},
    experimental_run_tf_function=False)

model.fit(
    x=gen_train, 
    steps_per_epoch=500, 
    epochs=2,
    validation_data=gen_valid,
    validation_steps=500,
    validation_freq=2,
    use_multiprocessing=True)

Epoch 1/2
Epoch 2/2
Epoch 1/2


<tensorflow.python.keras.callbacks.History at 0x7f0e107d48d0>

In [14]:
model.save('model_trial_{}.hdf5'.format(trial))

In [None]:
#DATE       MODEL         EPOCH    LR      RESULTS
#11042020 - (SU OG)       4        5e-5    loss: 0.0494 - mean_squared_error: 0.0060 - val_loss: 0.0643 - val_mean_squared_error: 0.0092
#11052020 - (T1)          4        5e-5    loss: 0.0436 - mean_squared_error: 0.0047 - val_loss: 0.0590 - val_mean_squared_error: 0.0082
#11052020 - (T2)          4        5e-5    loss: 0.0417 - mean_squared_error: 0.0043 - val_loss: 0.0613 - val_mean_squared_error: 0.0089
#11052020 - (T3)          4        5e-5    loss: 0.0506 - mean_squared_error: 0.0061 - val_loss: 0.0657 - val_mean_squared_error: 0.0099
#11052020 - (T3)          6        5e-5    loss: 0.0455 - mean_squared_error: 0.0049 - val_loss: 0.0626 - val_mean_squared_error: 0.0090
#11052020 - (T2)          6        5e-5    loss: 0.0344 - mean_squared_error: 0.0028 - val_loss: 0.0635 - val_mean_squared_error: 0.0100
#11052020 - (T1)          6        5e-5    loss: 0.0362 - mean_squared_error: 0.0029 - val_loss: 0.0667 - val_mean_squared_error: 0.0097
#11052020 - (T4)          6        5e-4    loss: 0.0435 - mean_squared_error: 0.0050 - val_loss: 0.0692 - val_mean_squared_error: 0.0093

#best so far
#11052020 - (T4)          4        5e-4    loss: 0.0498 - mean_squared_error: 0.0063 - val_loss: 0.0473 - val_mean_squared_error: 0.0061
#    cont - (T4)          2        5e-5    loss: 0.0332 - mean_squared_error: 0.0029 - val_loss: 0.0422 - val_mean_squared_error: 0.0054

In [9]:
# test_train, test_valid = client.create_generators(test=True)
# # xs, ys = next(test_valid)
# xs, ys = next(gen_valid)
# logits = model.predict(xs['dat'])

In [10]:
# def mse(a, b):
#     mse = ((a - b)**2).mean()
#     return mse

In [11]:
# #evaluation
# model = models.load_model('model_11042020.hdf5', compile=False)

# losses = []

# for x, y in gen_valid:
#     # --- Predict Percentage
#     logits = model.predict(x['dat'])
#     if type(logits) is dict:
#         logits = logits['lbl']
        
#     pred = logits
#     trues = y['lbl']
    
#     loss = mse(trues, pred)
# #     print('.', end='')
#     losses.append(loss)

# losses = np.array(losses)

In [12]:
# print(np.mean(losses))