In [1]:
from __future__ import absolute_import, division, print_function, unicode_literals

# Disable eager execution (otherwise there are problems with prob. conv layers)
from tensorflow.python.framework.ops import disable_eager_execution, enable_eager_execution
enable_eager_execution()

import os
import sys
sys.path.insert(0, '/mnt/home/raheppt1/projects/age_prediction')
import yaml
import numpy as np
import datetime
import argparse
from pathlib import Path

# tensorflow-gpu 2.1.0
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.optimizers import *

from dataset import AgeData
from models.models3d import age_regression_models
from misc.utils import init_gpu

In [2]:
print(tf.__version__)

2.0.0


In [30]:
image_size = [100, 120, 100]
image_spacing = [1.5, 1.5, 1.5]
batch_size = 16    
shuffle_buffer_size = 32

In [61]:
# Load training and validation data.
age_data = AgeData(image_size,
                   image_spacing,
                   shuffle_training_images=True,
                   save_debug_images=False)
dataset_train = age_data.dataset_train()
train_samples = dataset_train.num_entries()
dataset_val = age_data.dataset_val()
val_samples = dataset_val.num_entries()

# Define training and validation datasets from generators.
def train_gen():
    data = dataset_train
    i = 0
    while i < data.num_entries():
        sample = data.get_next()
        # DHWC tensor format
        image = sample['generators']['image'].transpose([1, 2, 3, 0])
        image = image.astype('float32')
        age = sample['generators']['age']
        yield image, age
        i += 1

def val_gen():
    data = dataset_val
    i = 0
    while i < data.num_entries():
        sample = data.get_next()
        image = sample['generators']['image'].transpose([1, 2, 3, 0])
        image = image.astype('float32')
        age = sample['generators']['age']
        yield image, age
        i += 1

ds_train = tf.data.Dataset.from_generator(train_gen,
                                          output_types=(tf.float32, tf.float32),
                                          output_shapes=(tf.TensorShape((None, None, None, None)),
                                                         tf.TensorShape((1,))))
# todo repeat ???? shuffle foreach epoch?
ds_train = ds_train.repeat().batch(batch_size=batch_size)
ds_train_a = ds_train.shard(num_shards=2, index=0) 
ds_train_b = ds_train.shard(num_shards=2, index=1) 

ds_val = tf.data.Dataset.from_generator(val_gen,
                                        output_types=(tf.float32, tf.float32),
                                        output_shapes=(tf.TensorShape((None, None, None, None)),
                                                       tf.TensorShape((1,))))
ds_val = ds_val.repeat().shuffle(buffer_size=shuffle_buffer_size).batch(batch_size=batch_size)
ds_val_a = ds_val.shard(num_shards=2, index=0) 
ds_val_b = ds_val.shard(num_shards=2, index=1) 

loaded 416 ids
loaded 47 ids


In [56]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Conv3D, MaxPooling3D, Dropout, ReLU

In [None]:
def create_encoder(input_shape):
    model = Sequential()
    model.add(Conv3D(8, kernel_size=(3, 3, 3),  padding='same', input_shape=input_shape))
    model.add(ReLU())
    model.add(MaxPooling3D(pool_size=(2, 2, 2)))
    model.add(Conv3D(16, kernel_size=(3, 3, 3), padding='same'))
    model.add(ReLU())
    model.add(MaxPooling3D(pool_size=(2, 2, 2)))
    model.add(Conv3D(32, kernel_size=(3, 3, 3), padding='same'))
    model.add(ReLU())
    model.add(MaxPooling3D(pool_size=(2, 2, 2)))
    model.add(Conv3D(64, kernel_size=(3, 3, 3), padding='same'))
    model.add(ReLU())
    model.add(MaxPooling3D(pool_size=(2, 2, 2)))
    model.add(Conv3D(128, kernel_size=(3, 3, 3), padding='same'))
    model.add(ReLU())
    model.add(MaxPooling3D(pool_size=(2, 2, 2)))
    model.add(Flatten())
    return model

def create_common_dense_layers(dropout,
                               lambda_l2,
                               outputs):
    model = Sequential()
    if dropout:
        model.add(Dropout(0.5))
    model.add(Dense(1024, activation='relu',
                    kernel_regularizer=tf.keras.regularizers.l2(lambda_l2)))
    if dropout:
        model.add(Dropout(0.5))
    model.add(Dense(512, activation='relu',
                    kernel_regularizer=tf.keras.regularizers.l2(lambda_l2)))
    if dropout:
        model.add(Dropout(0.5))
    model.add(Dense(outputs, activation='linear'))
    return model

class Siamese(tf.keras.Model):
  def __init__(self, 
               input_shape,
               dropout=False,
               lambda_l2=0.0,
               outputs=1):
    super(Siamese, self).__init__()
    self.encoder = create_encoder(input_shape)
    self.common_dense = create_common_dense_layers(dropout, lambda_l2, outputs)

  def call(self, x_a, x_b):
    z_a = self.encoder(x_a)
    z_b = self.encoder(x_b)
    z = tf.concat([z_a, z_b], axis=1)
    y = self.common_dense(z)
    return y


siamese_model = Siamese(image_size + [1])

In [66]:
i = 0
for sample_a, sample_b in tf.data.Dataset.zip((ds_train_a, ds_train_b)):
    images_a = sample_a[0]
    images_b = sample_b[0]
    labels_a = sample_a[1]
    labels_b = sample_b[1]
    print(labels_a-labels_b)
    print(labels_b)
    i=i+1
    print(i)
    y = siamese_model(images_a, images_b)
    print(y)
    if i == 10:
        break

tf.Tensor(
[[-17.17454  ]
 [  8.454481 ]
 [-19.956192 ]
 [-14.666664 ]
 [  3.526352 ]
 [  6.652977 ]
 [-26.861053 ]
 [ 12.388775 ]
 [  4.416155 ]
 [  7.5154037]
 [ -7.008896 ]
 [ -1.7111549]
 [ 13.971252 ]
 [ 17.831623 ]
 [ -6.140999 ]
 [ 11.121151 ]], shape=(16, 1), dtype=float32)
tf.Tensor(
[[46.431213]
 [29.779604]
 [58.99247 ]
 [70.24777 ]
 [58.568104]
 [54.21492 ]
 [76.78029 ]
 [38.902122]
 [28.933607]
 [64.75017 ]
 [53.305954]
 [60.84052 ]
 [34.283367]
 [21.566051]
 [54.19302 ]
 [58.568104]], shape=(16, 1), dtype=float32)
1
tf.Tensor(
[[-0.01092665]
 [-0.01960653]
 [-0.00202544]
 [-0.01143807]
 [-0.01995041]
 [-0.01205333]
 [ 0.0108896 ]
 [-0.034601  ]
 [-0.00402762]
 [-0.01382279]
 [-0.01654757]
 [-0.03232229]
 [-0.02265441]
 [-0.01112285]
 [-0.02781778]
 [ 0.00093283]
 [-0.01609883]
 [-0.00944746]
 [-0.01273163]
 [-0.00381302]
 [-0.00639603]
 [-0.00703232]
 [-0.01882287]
 [-0.00832454]
 [-0.01213564]
 [-0.03545582]
 [-0.01957696]
 [-0.00746209]
 [-0.0180815 ]
 [-0.01520977]
 [ 

KeyboardInterrupt: 

In [None]:
# Run Siamese
# ADNI/IXI Selection - Test vs. validate
# Motion Blurring Test 
# Swicht test / Validate
# Select run 
# Test results test on ADNI (covariate shift)
# Test MAP / Alleatoric - visualize
# scale callback - shuffle data