In [3]:
import random

import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.initializers import VarianceScaling
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import Callback
from sklearn.metrics import roc_auc_score

from mmoe import MMoE

In [4]:
# Simple callback to print out ROC-AUC
class ROCCallback(Callback):
    def __init__(self, training_data, validation_data, test_data):
        self.train_X = training_data[0]
        self.train_Y = training_data[1]
        self.validation_X = validation_data[0]
        self.validation_Y = validation_data[1]
        self.test_X = test_data[0]
        self.test_Y = test_data[1]

    def on_train_begin(self, logs={}):
        return

    def on_train_end(self, logs={}):
        return

    def on_epoch_begin(self, epoch, logs={}):
        return

    def on_epoch_end(self, epoch, logs={}):
        train_prediction = self.model.predict(self.train_X)
        validation_prediction = self.model.predict(self.validation_X)
        test_prediction = self.model.predict(self.test_X)

        # Iterate through each task and output their ROC-AUC across different datasets
        for index, output_name in enumerate(self.model.output_names):
            train_roc_auc = roc_auc_score(self.train_Y[index], train_prediction[index])
            validation_roc_auc = roc_auc_score(self.validation_Y[index], validation_prediction[index])
            test_roc_auc = roc_auc_score(self.test_Y[index], test_prediction[index])
            print(
                'ROC-AUC-{}-Train: {} ROC-AUC-{}-Validation: {} ROC-AUC-{}-Test: {}'.format(
                    output_name, round(train_roc_auc, 4),
                    output_name, round(validation_roc_auc, 4),
                    output_name, round(test_roc_auc, 4)
                )
            )

        return

    def on_batch_begin(self, batch, logs={}):
        return

    def on_batch_end(self, batch, logs={}):
        return

input_layer = Input(shape=(1,))

# Set up MMoE layer
mmoe_layers = MMoE(
    units=4,
    num_experts=8,
    num_tasks=2
)(input_layer)

output_info = [(2, 'click'), (2, 'deep_read')]
output_layers = []

# Build tower layer from MMoE layer
for index, task_layer in enumerate(mmoe_layers):
    tower_layer = Dense(
        units=8,
        activation='relu',
        kernel_initializer=VarianceScaling())(task_layer)
    output_layer = Dense(
        units=output_info[index][0],
        name=output_info[index][1],
        activation='softmax',
        kernel_initializer=VarianceScaling())(tower_layer)
    output_layers.append(output_layer)
    
model = Model(inputs=[input_layer], outputs=output_layers)
adam_optimizer = Adam()
model.compile(
    loss={'click': 'binary_crossentropy', 'deep_read': 'binary_crossentropy'},
    optimizer=adam_optimizer,
    metrics=['accuracy']
)

# Print out model architecture summary
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 1)]          0                                            
__________________________________________________________________________________________________
m_mo_e (MMoE)                   [(None, 4), (None, 4 96          input_1[0][0]                    
__________________________________________________________________________________________________
dense (Dense)                   (None, 8)            40          m_mo_e[0][0]                     
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 8)            40          m_mo_e[0][1]                     
______________________________________________________________________________________________

In [11]:
# 训练数据构建
train_data = {
    'user_status':[0,1,2]
}
train_frame = pd.DataFrame(train_data)

train_lable_data = {
    'click':[0,1,1],
    'deep_read':[0,0,1]
}
train_label_frame = pd.DataFrame(train_lable_data)


# 测试数据构建
test_data = {
    'user_status':[0,1,2]
}
test_frame = pd.DataFrame(test_data)

test_lable_data = {
    'click':[0,1,1],
    'deep_read':[0,0,1]
}
test_label_frame = pd.DataFrame(test_lable_data)


# 验证数据构建
validation_data = {
    'user_status':[0,1,2]
}
validation_frame = pd.DataFrame(validation_data)

validation_lable_data = {
    'click':[0,1,1],
    'deep_read':[0,0,1]
}
validation_label_frame = pd.DataFrame(validation_lable_data)


train_label_frame

# Train the model
model.fit(
    x=train_frame,
    y=train_label_frame,
    validation_data=(validation_frame, validation_label_frame),
    callbacks=[
        ROCCallback(
            training_data=(train_frame, train_label_frame),
            validation_data=(validation_frame, validation_label_frame),
            test_data=(test_frame, test_label_frame)
        )
    ],
    epochs=1
)

Unnamed: 0,click,deep_read
0,0,0
1,1,0
2,1,1
