In [1]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

# tf.enable_eager_execution()
tf_config = tf.ConfigProto()
tf_config.gpu_options.allow_growth = True
sess = tf.Session(config=tf_config)
import sys
from pathlib import Path

import keras
import pandas as pd
from keras import backend as K
from keras import layers, models, optimizers, regularizers
from keras.initializers import Constant
from keras.callbacks import CSVLogger, LearningRateScheduler, ModelCheckpoint
from keras.datasets import mnist
from keras.layers import (
    Activation,
    BatchNormalization,
    Concatenate,
    Conv1D,
    Conv2D,
    Conv2DTranspose,
    Dense,
    Dropout,
    Flatten,
    Input,
    Lambda,
    Layer,
    MaxPooling2D,
    PReLU,
    Reshape,
    Softmax,
)
from keras.utils import to_categorical, multi_gpu_model

keras.__version__

Using TensorFlow backend.


'2.3.1'

In [2]:
## Change to your own path
path_photoz = "/home/bid13/code/encapZulate-1/src"

sys.path.insert(1, path_photoz)
path_photoz = Path(path_photoz)

In [3]:
import encapzulate
from encapzulate.base.deepCapsLayers import (
    CapsToScalars,
    CapsuleLayer,
    Conv2DCaps,
    ConvCapsuleLayer3D,
    ConvertToCaps,
    FlattenCaps,
    Mask_CID,
    squash
)
from encapzulate.base.loss import (
        margin_loss,
        quantile_loss,
        central_mse,
        central_bias,
    )
from encapzulate.data_loader.data_loader import load_data
from encapzulate.models.multi_gpu import MultiGPUModel
from encapzulate.utils import metrics
from encapzulate.utils.fileio import load_config, load_model
from encapzulate.utils.metrics import Metrics, bins_to_redshifts, probs_to_redshifts

In [4]:
config = load_config(path_photoz / "encapzulate" / "configs" / "morphCaps_4.yml")
config["run_name"] = 'test_nb'
config["input_shape"] = config["image_shape"]
config["run_name"] = "test"
config["epochs"] = 4
config["frac_train"] = 0.2
config["learning_rate"] = 0.001
# config["decay_rate"] = 0.97


config:
{   'bands': ('u', 'g', 'r', 'i', 'z'),
    'batch_size': 200,
    'checkpoint': None,
    'compile_on': 'cpu',
    'dataset': 'sdss_gz1_final_iter2',
    'decay_rate': 0.95,
    'dim_capsule': 16,
    'epochs': 75,
    'frac_dev': 0.1,
    'frac_train': 0.02,
    'image_scale': 10,
    'image_shape': (64, 64, 5),
    'img_augmentation': 1,
    'lam_recon': 0.005,
    'lam_redshift': 2,
    'learning_rate': 0.001,
    'logistic': True,
    'model_name': 'morphCapsDeep_2',
    'num_class': 2,
    'num_gpus': 2,
    'num_quantiles': False,
    'path_data': '/data/bid13/photoZ/data/pasquet2019',
    'path_results': None,
    'random_state': 200,
    'routings': 3,
    'run_name': 'paper1_regression_2perc_4',
    'timeline': False,
    'use_vals': False,
    'z_max': 0.4,
    'z_min': 0.0}



In [5]:
path_output = "/home/bid13/code/photozCapsNet/results"
path_output = Path(path_output)
path_results = (
    path_output / config["run_name"].split("_")[0] / config["run_name"] / "results"
)
path_logs = path_results / "logs"
path_weights = path_results / "weights"
path_logs.mkdir(parents=True, exist_ok=True)
path_weights.mkdir(parents=True, exist_ok=True)

In [6]:
(
    (x_train, y_train, vals_train, z_spec_train, cat_train),
    (x_dev, y_dev, vals_dev, z_spec_dev, cat_dev),
    (x_test, y_test, vals_test, z_spec_test, cat_test),
) = load_data(load_cat=True, **config)

In [7]:
def CapsNet(input_shape, num_class, routings, dim_capsule, **kwargs):
    # assemble encoder
    x = Input(shape=input_shape)
    l = x

    l = Conv2D(
        128,
        (3, 3),
        strides=(1, 1),
        activation="relu",
        padding="same",
        kernel_initializer="he_normal",
    )(
        l
    )  # common conv layer
    l = BatchNormalization()(l)
    l = ConvertToCaps()(l)

    l = Conv2DCaps(
        32, 4, kernel_size=(3, 3), strides=(2, 2), r_num=1, b_alphas=[1, 1, 1]
    )(l)
    l_skip = Conv2DCaps(
        32, 4, kernel_size=(3, 3), strides=(1, 1), r_num=1, b_alphas=[1, 1, 1]
    )(l)
    l = Conv2DCaps(
        32, 4, kernel_size=(3, 3), strides=(1, 1), r_num=1, b_alphas=[1, 1, 1]
    )(l)
    l = Conv2DCaps(
        32, 4, kernel_size=(3, 3), strides=(1, 1), r_num=1, b_alphas=[1, 1, 1]
    )(l)
    l = layers.Add()([l, l_skip])

    l = Conv2DCaps(
        32, 8, kernel_size=(3, 3), strides=(2, 2), r_num=1, b_alphas=[1, 1, 1]
    )(l)
    l_skip = Conv2DCaps(
        32, 8, kernel_size=(3, 3), strides=(1, 1), r_num=1, b_alphas=[1, 1, 1]
    )(l)
    l = Conv2DCaps(
        32, 8, kernel_size=(3, 3), strides=(1, 1), r_num=1, b_alphas=[1, 1, 1]
    )(l)
    l = Conv2DCaps(
        32, 8, kernel_size=(3, 3), strides=(1, 1), r_num=1, b_alphas=[1, 1, 1]
    )(l)
    l = layers.Add()([l, l_skip])

    l = Conv2DCaps(
        32, 8, kernel_size=(3, 3), strides=(2, 2), r_num=1, b_alphas=[1, 1, 1]
    )(l)
    l_skip = Conv2DCaps(
        32, 8, kernel_size=(3, 3), strides=(1, 1), r_num=1, b_alphas=[1, 1, 1]
    )(l)
    l = Conv2DCaps(
        32, 8, kernel_size=(3, 3), strides=(1, 1), r_num=1, b_alphas=[1, 1, 1]
    )(l)
    l = Conv2DCaps(
        32, 8, kernel_size=(3, 3), strides=(1, 1), r_num=1, b_alphas=[1, 1, 1]
    )(l)
    l = layers.Add()([l, l_skip])
    l1 = l

    l = Conv2DCaps(
        32, 8, kernel_size=(3, 3), strides=(2, 2), r_num=1, b_alphas=[1, 1, 1]
    )(l)
    l_skip = ConvCapsuleLayer3D(
        kernel_size=3,
        num_capsule=32,
        num_atoms=8,
        strides=1,
        padding="same",
        routings=3,
    )(l)
    l = Conv2DCaps(
        32, 8, kernel_size=(3, 3), strides=(1, 1), r_num=1, b_alphas=[1, 1, 1]
    )(l)
    l = Conv2DCaps(
        32, 8, kernel_size=(3, 3), strides=(1, 1), r_num=1, b_alphas=[1, 1, 1]
    )(l)
    l = layers.Add()([l, l_skip])
    l2 = l

    la = FlattenCaps()(l2)
    lb = FlattenCaps()(l1)
    l = layers.Concatenate(axis=-2)([la, lb])

    #     l = Dropout(0.4)(l)
    digits_caps = CapsuleLayer(
        num_capsule=num_class,
        dim_capsule=dim_capsule,
        routings=routings,
        channels=0,
        name="digit_caps",
    )(l)

    l = CapsToScalars(name="capsnet")(digits_caps)
    # l = Softmax()(l)

    m_capsnet = models.Model(inputs=x, outputs=l, name="capsnet_model")

    y = Input(shape=(num_class,))

    masked_by_y = Mask_CID()([digits_caps, y])
    masked = Mask_CID()(digits_caps)

    # Redshift Network
    # digits_caps_flat = Flatten()(digits_caps)
    # val_input = Input(shape=(6,))
    #     collect_layer = Concatenate()([digits_caps_flat, val_input, m_capsnet.output])
    # collect_layer = Concatenate()([digits_caps_flat, m_capsnet.output])
    #     redshift_input = Input(shape=(num_class * dim_capsule + 8,))
    redshift_input = Input(shape=(dim_capsule,))
    #     r = LayerNormalization()(redshift_input)
    #     r = BatchNormalization(momentum=0.9)(redshift_input)
    r = Dense(128, kernel_initializer="he_normal")(redshift_input)
    r = PReLU()(r)
    r = Dense(64, kernel_initializer="he_normal")(r)
    r = PReLU()(r)
    r = Dense(32, kernel_initializer="he_normal")(r)
    r = PReLU()(r)
    r = Dense(16, kernel_initializer="he_normal")(r)
    r = PReLU()(r)
    redshift_out = Dense(1)(r)
    #     redshift_out = Dense(1)(redshift_input)
    redshift = models.Model(redshift_input, redshift_out, name="redshift_model")

    # Decoder Network
    decoder_input = Input(shape=(dim_capsule,))
    d = Dense(
        np.prod(input_shape),
        kernel_initializer="he_normal",
    )(decoder_input)
    d = PReLU()(d)
    d = Reshape(input_shape)(d)

    d = Conv2DTranspose(
        64,
        (3, 3),
        padding="same",
        kernel_initializer="he_normal",
    )(d)
    d = PReLU()(d)
    d = Conv2DTranspose(
        32,
        (3, 3),
        padding="same",
        kernel_initializer="he_normal",
    )(d)
    d = PReLU()(d)
    d = Conv2DTranspose(
        16,
        (3, 3),
        padding="same",
        kernel_initializer="he_normal",
    )(d)
    d = PReLU()(d)
    d = Conv2DTranspose(8, (3, 3), padding="same", kernel_initializer="he_normal")(d)
    d = PReLU()(d)
    d = Conv2DTranspose(
        input_shape[-1],
        (3, 3),
        padding="same",
        activation="tanh",
        kernel_initializer="he_normal",
    )(d)
    decoder_output = Reshape(target_shape=input_shape, name="out_recon")(d)

    decoder = models.Model(decoder_input, decoder_output, name="decoder_model")
    train_model = models.Model(
        [
            x,
            y,
        ],
        [m_capsnet.output, decoder(masked_by_y), redshift(masked_by_y)],
    )

    eval_model = models.Model(
        [
            x,
        ],
        [
            masked,
            digits_caps,
            m_capsnet.output,
            decoder(masked),
            redshift(masked),
        ],
    )

    manipulate_model = models.Model(
        [
            x,
        ],
        [masked, m_capsnet.output, decoder(masked), redshift(masked)],
    )
    train_model.summary()

    return train_model, eval_model, manipulate_model, decoder, redshift

In [8]:
train_model, eval_model, manipulate_model, decoder, redshift_model = CapsNet(**config)
parallel_train_model = MultiGPUModel(train_model, gpus=2)
train_model = parallel_train_model

Instructions for updating:
If using Keras pass *_constraint arguments to layers.
Tensor("conv_capsule_layer3d_1/stack:0", shape=(5,), dtype=int32)
Instructions for updating:
dim is deprecated, use axis instead
Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 64, 64, 5)    0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 64, 64, 128)  5888        input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 64, 64, 128)  512         conv2d_1[0][0]                   
________________________________________________________________________________

In [9]:
 compile_kwargs = {
                "optimizer": optimizers.Adam(lr=config["learning_rate"]),
                "loss": [margin_loss, "mse", "mse"],
                "loss_weights": [
                    1.0,
                    config["lam_recon"] * np.prod(config["input_shape"]),
                    config["lam_redshift"],
                ],
                "metrics": {
                    "capsnet": "accuracy",
                    "redshift_model": [central_mse(**config), central_bias(**config)],
                },
            }

train_model.compile(**compile_kwargs)

In [14]:
# Train model on dataset
lr_decay = LearningRateScheduler(
    schedule=lambda epoch: config["learning_rate"] * (config["decay_rate"] ** epoch)
)
log = CSVLogger(str(path_logs / "log.csv"))
cp = ModelCheckpoint(
    filepath=str(path_weights / "weights-{epoch:02d}.h5"),
    save_best_only=False,
    save_weights_only=True,
    verbose=1,
    mode="max",
)

# not doing img augmentation here
train_model.fit(
            [x_train, y_train],
            [y_train, x_train, z_spec_train],
            batch_size=config["batch_size"],
            epochs=config["epochs"],
            initial_epoch=0,
            validation_data=[[x_dev, y_dev], [y_dev, x_dev, z_spec_dev]],
            callbacks=[log, cp, lr_decay],
        )

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where

Train on 103305 samples, validate on 51653 samples
Epoch 1/4
  1400/103305 [..............................] - ETA: 16:13 - loss: 4.6049 - capsnet_loss: 0.3021 - decoder_model_loss: 0.0167 - redshift_model_loss: 1.2951 - capsnet_accuracy: 0.4971 - redshift_model_central_mse_metric: 0.0076 - redshift_model_central_bias_metric: -0.0574

KeyboardInterrupt: 