# PaccMann modeling sandbox

In [None]:
from __future__ import annotations

import tensorflow as tf

from tensorflow import keras

from cdrpy.data import datasets
from cdrpy.splits import Split
from cdrpy.layers import util as layer_utils

In [None]:
smiles_length = 100
smiles_vocab = None
smiles_embedding_size = 100
filter = []

x = keras.Input(shape=(smiles_length,), name="input_smiles")
x = layer_utils.Embedding()(x)
x = layer_utils.ExpandDim(axis=3)(x)


In [None]:
input_zeros = keras.Input(shape=(1,), name="input_zeros")
pad = layer_utils.Embedding(smiles_vocab, smiles_embedding_size)(input_zeros)
pad = layer_utils.ExpandDim(axis=3)(pad)

convolved_smiles = []
for index, (filter_size, kernel_size) in enumerate(zip(params["filter"], params["kernels"])):
    smiles_pad = keras.layers.concatenate([pad]*(kernel_size[0] // 2) + [smiles_expand] + [pad]*(kernel_size[0] // 2), axis=1)

    conv_smiles = keras.layers.Conv2D(filters=filter_size, kernel_size=kernel_size, activation=tf.nn.relu) (smiles_pad)
    conv_smiles = SqueezeLayer(axis=2) (conv_smiles)
    conv_smiles = keras.layers.Dropout(rate=params["dropout"]) (conv_smiles)
    convolved_smiles.append(keras.layers.BatchNormalization() (conv_smiles))

convolved_smiles.insert(0, embedding_smiles)

input_genes = keras.Input(shape=(params["genes_number"],), name="input_genes")
encoded_genes = [DenseAttentionLayer(params["genes_number"])(input_genes) for i in range(len(params["multiheads"]))]

encoding_coefficients = [ContextualAttentionLayer(
                            attention_size=params["smiles_attention_size"], 
                            hidden_size=convolved_smiles[layer].shape[2],
                            num_genes=params["genes_number"]) ([encoded_genes[layer], convolved_smiles[layer]])
                        for layer in range(len(convolved_smiles)) for _ in range(params["multiheads"][layer])]

encoding = keras.layers.concatenate(encoding_coefficients, axis=1)
encoding = keras.layers.Reshape((params["smiles_embedding_size"] * params["multiheads"][0] + sum([a*b for a,b in zip(params["multiheads"][1:], params["filter"])]),)) (encoding)

x = keras.layers.BatchNormalization() (encoding) 

for index, size in enumerate(params["stacked_dense_hidden_sizes"]):
    x = keras.layers.Dense(size, activation=None) (x)
    x = keras.layers.BatchNormalization() (x) 
    x = keras.layers.ReLU() (x)
    x = keras.layers.Dropout(rate=params["dropout"]) (x) 

output = keras.layers.Dense(1) (x)

model = keras.Model(inputs=[input_smiles, input_zeros, input_genes], outputs=output)
opt = tf.keras.optimizers.Adam()#learning_rate=learning_rate)
loss = tf.keras.losses.MeanSquaredError()
model.compile(optimizer = opt, loss = loss, 
            metrics=["mse"])