In [20]:
import tensorflow as tf
import tensorflow_datasets as tfds
from galaxies_datasets import datasets

from sklearn.metrics import confusion_matrix, jaccard_score
from keras import layers
from keras.callbacks import CSVLogger, EarlyStopping, ModelCheckpoint

In [33]:
RUN_FROM = 'local'
NUM_EPOCHS = 7
SIZE = 64  # Size of resized images and masks (in pixels). You may have to change batch sizes.

MASK = 'spiral_mask'
TRAIN_WITH = 'only'  # "all" uses all the images in the training dataset. "only" for spiraled galaxies only

MIN_VOTE = 3  # this parameter defines the minimum amount of votes that the most voted pixel of a mask must have in order to be considered a spiral arm (barred) galaxy.

THRESHOLD = 6  # threshold defines the minimum amount of votes that a pixel must have to be clasified as a spiral arm (bar).
PATIENCE = 3   # You can choose to stop training after 'patience' amount of epochs without improvement in the loss.

In [15]:
def resize(input_image, input_mask):
    input_image = tf.image.resize(input_image, (SIZE, SIZE), method="nearest")
    input_mask = tf.image.resize(input_mask, (SIZE, SIZE), method="nearest")

    return input_image, input_mask


def augment(input_image, input_mask):
    if tf.random.uniform(()) > 0.5:
        input_image = tf.image.flip_left_right(input_image)
        input_mask = tf.image.flip_left_right(input_mask)
    if tf.random.uniform(()) > 0.5:
        input_image = tf.image.flip_up_down(input_image)
        input_mask = tf.image.flip_up_down(input_mask)

    return input_image, input_mask


def normalize(input_image):
    input_image = tf.cast(input_image, tf.float32) / 255.0

    return input_image


def binary_mask(input_mask):
    th = THRESHOLD
    input_mask = tf.where(input_mask<th, tf.zeros_like(input_mask), tf.ones_like(input_mask))

    return input_mask


def load_image_train(datapoint):
    input_image = datapoint['image']
    input_mask = datapoint[MASK]
    input_image, input_mask = resize(input_image, input_mask)
    input_image, input_mask = augment(input_image, input_mask)
    input_image = normalize(input_image)
    input_mask = binary_mask(input_mask)

    return input_image, input_mask


def load_image_test(datapoint):
    input_image = datapoint['image']
    input_mask = datapoint[MASK]
    input_image, input_mask = resize(input_image, input_mask)
    input_image = normalize(input_image)
    input_mask = binary_mask(input_mask)

    return input_image, input_mask

In [31]:
ds, info = tfds.load('galaxy_zoo3d', split=['train[2:3420]', 'train[3666:6999]'], with_info=True)
ds_train, ds_test = ds[0], ds[1]

if TRAIN_WITH == 'all':
    BUFFER_SIZE, BATCH_SIZE = 1000, 64
    TRAIN_LENGTH, VAL_SIZE, TEST_SIZE = 22360, 4992, 2461
elif TRAIN_WITH == 'only':
    BUFFER_SIZE, BATCH_SIZE = 300, 16
    if MASK == 'spiral_mask':
        ds_train = ds_train.filter(lambda x: tf.reduce_max(x['spiral_mask']) >= MIN_VOTE)
        ds_test = ds_test.filter(lambda x: tf.reduce_max(x['spiral_mask']) >= MIN_VOTE)
        TRAIN_LENGTH, VAL_SIZE, TEST_SIZE = 4883, 1088, 551
    elif MASK == 'bar_mask':
        ds_train = ds_train.filter(lambda x: tf.reduce_max(x['bar_mask']) >= MIN_VOTE)
        ds_test = ds_test.filter(lambda x: tf.reduce_max(x['bar_mask']) >= MIN_VOTE)
        TRAIN_LENGTH, VAL_SIZE, TEST_SIZE = 3783, 832, 421

In [16]:
train_dataset = ds_train.map(load_image_train, num_parallel_calls=tf.data.AUTOTUNE)
test_dataset = ds_test.map(load_image_test, num_parallel_calls=tf.data.AUTOTUNE)

train_batches = train_dataset.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
train_batches = train_batches.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
validation_batches = test_dataset.take(VAL_SIZE).batch(BATCH_SIZE)
test_batches = test_dataset.skip(VAL_SIZE).take(TEST_SIZE).batch(BATCH_SIZE)

In [17]:
def double_conv_block(x, n_filters):

    x = layers.Conv2D(n_filters, 3, padding = "same", activation = "relu", kernel_initializer = "he_normal")(x)
    x = layers.Conv2D(n_filters, 3, padding = "same", activation = "relu", kernel_initializer = "he_normal")(x)

    return x


def downsample_block(x, n_filters):

    f = double_conv_block(x, n_filters)
    p = layers.MaxPool2D(2)(f)
    p = layers.Dropout(0.3)(p)

    return f, p


def upsample_block(x, conv_features, n_filters):

    x = layers.Conv2DTranspose(n_filters, 3, 2, padding="same")(x)
    x = layers.concatenate([x, conv_features])
    x = layers.Dropout(0.3)(x)
    x = double_conv_block(x, n_filters)

    return x

In [24]:
def build_unet_model():

    inputs = layers.Input(shape=(SIZE, SIZE, 3))

    f1, p1 = downsample_block(inputs, SIZE / 2)
    f2, p2 = downsample_block(p1, SIZE)
    f3, p3 = downsample_block(p2, SIZE * 2)
    f4, p4 = downsample_block(p3, SIZE * 4)

    bottleneck = double_conv_block(p4, SIZE * 8)

    u6 = upsample_block(bottleneck, f4, SIZE * 4)
    u7 = upsample_block(u6, f3, SIZE * 2)
    u8 = upsample_block(u7, f2, SIZE)
    u9 = upsample_block(u8, f1, SIZE / 2)

    outputs = layers.Conv2D(2, 1, padding="same", activation="softmax")(u9)

    unet_model = tf.keras.Model(inputs, outputs, name="U-Net")

    return unet_model

In [34]:
path = '/home/jose/git-repos/imgalaxy/imgalaxy/resources/models'
filename = 'poc'

csv_log = CSVLogger(f'{path}{filename}/{filename}.csv', append=True)
early_stop = EarlyStopping(monitor='val_loss', patience=PATIENCE, mode='min')
mcp_save_best = ModelCheckpoint(
    filepath=f'{path}/{filename}/{filename}_best.h5', monitor='val_accuracy', mode='max', save_best_only=True
)
mcp_save_last = ModelCheckpoint(filepath=f'{path}{filename}/{filename}_last.h5')

In [36]:
unet_model = build_unet_model()

unet_model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss="sparse_categorical_crossentropy",
    metrics="accuracy"
)

STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE

VAL_SUBSPLITS = 5
TEST_LENGTH = VAL_SIZE + TEST_SIZE
VALIDATION_STEPS = TEST_LENGTH // BATCH_SIZE // VAL_SUBSPLITS

print(STEPS_PER_EPOCH * NUM_EPOCHS)
model_history = unet_model.fit(
    train_batches,
    epochs=NUM_EPOCHS,
    steps_per_epoch=STEPS_PER_EPOCH,
    validation_steps=VALIDATION_STEPS,
    validation_data=validation_batches,
    #callbacks=[csv_log, early_stop, mcp_save_best, mcp_save_last]
)

2135
Epoch 1/7








Epoch 2/7
Epoch 3/7
Epoch 4/7
Epoch 5/7
Epoch 6/7
Epoch 7/7
