# Models

In [24]:
import keras
import keras.backend as K
# from qkeras import QConv2D, QDense
# from qkeras.quantizers import quantized_bits, quantized_relu
import tensorflow as tf
from collections import OrderedDict

# q_relu = quantized_relu(9, 5)
# q_bits = quantized_bits(9, 5)

latent_dims  = [8,8,8,8,8]    # test for all same latent dims
multiplicity = [3,3,19,30] # sectors/roc, rocs/module, modules/layer, layers

def conv_encoder(output_dims):
    return keras.Sequential([
        keras.layers.Input(shape=(4,4,1)),
        keras.layers.Conv2D(filters=64, kernel_size=(2,2), strides=(1,1), padding='valid', activation='relu'),
        keras.layers.Conv2D(filters=64, kernel_size=(2,2), strides=(1,1), padding='valid', activation='relu'),
        keras.layers.Conv2D(filters=64, kernel_size=(2,2), strides=(1,1), padding='valid', activation='relu'),
        keras.layers.Flatten(),
        keras.layers.Dense(units=128, activation='relu'),
        keras.layers.Dense(units=256, activation='relu'),
        keras.layers.Dense(units=output_dims, activation='linear')

#         QConv2D(filters=64, kernel_size=(2,2), strides=(1,1), activation=q_relu, padding='valid', kernel_quantizer=q_bits, bias_quantizer=q_bits),
#         QConv2D(filters=64, kernel_size=(2,2), strides=(1,1), activation=q_relu, padding='valid', kernel_quantizer=q_bits, bias_quantizer=q_bits),
#         QConv2D(filters=64, kernel_size=(2,2), strides=(1,1), activation=q_relu, padding='valid', kernel_quantizer=q_bits, bias_quantizer=q_bits),
#         keras.layers.Flatten(),
#         QDense(units=128, activation=q_relu, kernel_quantizer=q_bits, bias_quantizer=q_bits),
#         QDense(units=256, activation=q_relu, kernel_quantizer=q_bits, bias_quantizer=q_bits),
#         QDense(units=output_dims, activation=q_bits, kernel_quantizer=q_bits, bias_quantizer=q_bits)
    ])

def conv_decoder(input_dims):
    return keras.Sequential([
        keras.layers.Input(shape=(input_dims)),
        keras.layers.Dense(units=256, activation='relu'),
        keras.layers.Dense(units=128, activation='relu'),
        keras.layers.Reshape(target_shape=(1,1,128)),
        keras.layers.Conv2DTranspose(filters=64, kernel_size=(3,3), strides=(1,1), padding='valid', activation='relu'),
        keras.layers.Conv2DTranspose(filters=64, kernel_size=(3,3), strides=(1,1), padding='valid', activation='relu'),
        keras.layers.Conv2D(filters=64, kernel_size=(2,2), strides=(1,1), padding='valid', activation='relu'),
        keras.layers.Conv2D(filters=1, kernel_size=(1,1), strides=(1,1), padding='valid', activation='relu')
    ])

def encoder(input_dims, output_dims, n):
    return keras.Sequential([
        keras.layers.Input(shape=(n*input_dims)),
        keras.layers.Dense(units=32, activation='relu'),
        keras.layers.Dense(units=64, activation='relu'),
        keras.layers.Dense(units=128, activation='relu'),
        keras.layers.Dense(units=256, activation='relu'),
        keras.layers.Dense(output_dims, activation='linear')

#         QDense(units=32, activation=q_relu, kernel_quantizer=q_bits, bias_quantizer=q_bits),
#         QDense(units=64, activation=q_relu, kernel_quantizer=q_bits, bias_quantizer=q_bits),
#         QDense(units=128, activation=q_relu, kernel_quantizer=q_bits, bias_quantizer=q_bits),
#         QDense(units=256, activation=q_relu, kernel_quantizer=q_bits, bias_quantizer=q_bits),
#         QDense(output_dims, activation=q_bits, kernel_quantizer=q_bits, bias_quantizer=q_bits)
    ])

def decoder(input_dims, output_dims, n):
    return keras.Sequential([
        keras.layers.Input(shape=input_dims),
        keras.layers.Dense(units=256, activation='linear'),
        keras.layers.Dense(units=128, activation='linear'),
        keras.layers.Dense(units=64, activation='linear'),
        keras.layers.Dense(units=32, activation='linear'),
        keras.layers.Dense(units=n*output_dims, activation='linear'),
    ])

class Reshaper(keras.models.Model):
    def __init__(self, input_shape, output_shape):
        model_in  = keras.layers.Input(shape=input_shape)
        model_out = K.concatenate((K.variable([-1], dtype='int32'), output_shape))
        shaper = keras.layers.Lambda(lambda x: K.reshape(x, model_out))(model_in)
        super(Reshaper, self).__init__(inputs=model_in, outputs=shaper)

class TPG(keras.models.Model):
    def __init__(self):
        super(TPG, self).__init__()
        self.encoders = OrderedDict({
            'ROC' : conv_encoder(output_dims=latent_dims[0]),
            'MOD' : encoder(input_dims=latent_dims[0], output_dims=latent_dims[1], n=multiplicity[0]),
            'TRI' : encoder(input_dims=latent_dims[1], output_dims=latent_dims[2], n=multiplicity[1]),
            'ST1' : encoder(input_dims=latent_dims[2], output_dims=latent_dims[3], n=multiplicity[2]),
            'ST2' : encoder(input_dims=latent_dims[3], output_dims=latent_dims[4], n=multiplicity[3])
        })
        self.decoders = OrderedDict({
            'ST2' : decoder(input_dims=latent_dims[4], output_dims=latent_dims[3], n=multiplicity[3]),
            'ST1' : decoder(input_dims=latent_dims[3], output_dims=latent_dims[2], n=multiplicity[2]),
            'TRI' : decoder(input_dims=latent_dims[2], output_dims=latent_dims[1], n=multiplicity[1]),
            'MOD' : decoder(input_dims=latent_dims[1], output_dims=latent_dims[0], n=multiplicity[0]),
            'ROC' : conv_decoder(input_dims=latent_dims[0])
        })
        self.encoder_reshapers = OrderedDict({
            'ROC' : Reshaper([30,19,3,3, 4,4,1], [4,4,1]),
            'MOD' : Reshaper([latent_dims[0]], [3 *latent_dims[0]]),
            'TRI' : Reshaper([latent_dims[1]], [3 *latent_dims[1]]),
            'ST1' : Reshaper([latent_dims[2]], [19*latent_dims[2]]),
            'ST2' : Reshaper([latent_dims[3]], [30*latent_dims[3]])
        })
        self.decoder_reshapers = OrderedDict({
            'ST2' : Reshaper([30*latent_dims[3]], [latent_dims[3]]),
            'ST1' : Reshaper([19*latent_dims[2]], [latent_dims[2]]),
            'TRI' : Reshaper([3 *latent_dims[1]], [latent_dims[1]]),
            'MOD' : Reshaper([3 *latent_dims[0]], [latent_dims[0]]),
            'ROC' : Reshaper([4,4,1], [30,19,3,3, 4,4,1])
        })
        self.depth = 5  # depth of autoencoding
    
    def encode(self, x, training=True):
        for level, encoder in list(self.encoders.items())[:self.depth]:
            x = self.encoders[level](self.encoder_reshapers[level](x), training=training)
        return x
    
    def decode(self, x, training=True):
        for level, decoder in list(self.decoders.items())[-self.depth:]:
            x = self.decoder_reshapers[level](self.decoders[level](x), training=training)
        return x
    
    def train_step(self, data):
        x, y = data
        with tf.GradientTape() as tape:
            y_pred = self(x)
            loss = self.compiled_loss(y, y_pred)
        grads = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
        self.compiled_metrics.update_state(y, y_pred)
        return {m.name: m.result() for m in self.metrics}
    
    def test_step(self, data):
        x, y = data
        y_pred = self(x, training=False)
        self.compiled_loss(y, y_pred)
        self.compiled_metrics.update_state(y, y_pred)
        return {m.name: m.result() for m in self.metrics}
    
    def call(self, x, training=True):
        z = self.encode(x, training=training)
        y = self.decode(z, training=training)
        return y
    
    def set_trainable(self, d):
        if type(d) != dict:
            raise TypeError("set_trainable expected input type dict.")
        else:
            for k, v in d.items():
                self.encoders[k].trainable = int(v)
                self.decoders[k].trainable = int(v)
        self.depth = [idx+1 for idx, item in enumerate(d.values()) if item != 0][-1]
        self.compile(loss=self.loss, optimizer=self.optimizer)  # recompile model AFTER setting trainable attributes.
        return {k : (self.encoders[k].trainable, self.decoders[k].trainable) for k in d.keys()}

In [68]:
import pickle
import numpy as np
import matplotlib.pyplot as plt

data = pickle.load(open('/home/don/Desktop/electrons_config1_1_to_500GeV_4Tesla.pickle', 'rb'))
x_train = data['digi'][:800].reshape(800,30,19,3,3,4,4,1)
x_test = data['digi'][800:].reshape(200,30,19,3,3,4,4,1)

print(x_test.shape)

In [70]:
import tensorflow as tf
import keras
tf.config.optimizer.set_experimental_options({'layout_optimizer':0}) # prevent reshaping from NWHC to NCWH for GPU usage optimization

opt = keras.optimizers.Adam(1e-4)
trainable_dict = {'ROC':1, 'MOD':1, 'TRI':1, 'ST1':1, 'ST2':1}

# with tf.device('/GPU:0'):  # too much for my poor GPU :(
with tf.device('/CPU:0'):
    model = TPG()
    model.compile(loss='mse', optimizer=opt)
    model.set_trainable(trainable_dict)
    model.fit(x_train, x_train, validation_data=(x_test, x_test), epochs=4, batch_size=2)

Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4
