In [16]:
import tensorflow as tf

from models import MultiGateMixtureOfExperts, MultiTaskBCE

TRAIN_SIZE = 4096
NUM_TASKS = 2
NUM_CATEGORICAL = 5
NUM_CONTINUOUS = 10
NUM_EMBEDDING = 100

sparse_inputs = tf.random.categorical(
    tf.random.uniform((TRAIN_SIZE, NUM_EMBEDDING), minval=0, maxval=1),
    NUM_CATEGORICAL,
    dtype=tf.int32,
)
dense_inputs = tf.random.uniform((TRAIN_SIZE, NUM_CONTINUOUS), minval=-1.0, maxval=1.0, dtype=tf.float32)
labels = tf.cast(tf.random.uniform(shape=(TRAIN_SIZE, NUM_TASKS)) > 0.5, tf.float32)

print(f"Sparse inputs: {sparse_inputs.shape}")
print(f"Dense inputs: {dense_inputs.shape}")
print(f"Labels: {labels.shape}")

model = MultiGateMixtureOfExperts(
    num_tasks=NUM_TASKS,
    num_emb=NUM_EMBEDDING,
    dim_emb=4,
    num_experts=2,
    num_hidden_expert=2,
    dim_hidden_expert=32,
    dropout_expert=0.0,
    gate_function="softmax",
    num_hidden_tasks=2,
    dim_hidden_tasks=32,
    dim_out_tasks=1,
    dropout_tasks=0.0,
)

model.compile(optimizer="adam", loss=MultiTaskBCE(NUM_TASKS))
history = model.fit(x=(sparse_inputs, dense_inputs), y=labels, batch_size=512, validation_split=0.1, epochs=5)

Sparse inputs: (4096, 5)
Dense inputs: (4096, 10)
Labels: (4096, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
