In [17]:
import keras
import tensorflow as tf
import numpy as np
import matplotlib as plt
import keras.layers as layers

In [18]:
class  CustomerTicketModel(keras.Model):
    def __init__(self, num_departments):
        super().__init__()
        self.concat_layer = layers.Concatenate()
        self.mixing_layer = layers.Dense(64, activation="relu")
        self.priority_scorer = layers.Dense(1, activation="sigmoid")
        self.department_classifier = layers.Dense(num_departments, activation="softmax")

    def call(self, inputs):
        title = inputs["title"]
        text_body = inputs["text_body"]
        tags = inputs["tags"]
        features = self.concat_layer([title, text_body, tags])
        features = self.mixing_layer(features)
        priority = self.priority_scorer(features)
        department = self.department_classifier(features)
        return priority, department
        

In [19]:

num_samples = 1280
vocabulary_size = 10000
num_tags = 10
num_departments = 4

title_data = np.random.randint(0, 2, size=(num_samples, vocabulary_size))
text_body_data = np.random.randint(0, 2, size=(num_samples, vocabulary_size))
tags_data = np.random.randint(0, 2, size=(num_samples, num_tags))
print(title_data.shape)
print(text_body_data.shape)
print(tags_data.shape)


priority_data = np.random.random(size=(num_samples, 1))
department_data = np.random.randint(0, 2, size=(num_samples, num_departments))


(1280, 10000)
(1280, 10000)
(1280, 10)


In [20]:
model = CustomerTicketModel(num_departments=4)
priority, department = model({"title": title_data,
                             "text_body": text_body_data,
                             "tags": tags_data})

In [None]:
model.compile(
    optimizer="rmsprop",
    loss=["mean_squared_error", "categorical_crossentropy"],
    metrics=[["mean_absolute_error"], ["accuracy"]]
)

tf_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs_sub_model")

model.fit({
    "title": title_data,
    "tags": tags_data,
    "text_body": text_body_data
    },
    [priority_data, department_data],
    epochs=1,
    callbacks=[tf_callback]
    )