<a href="https://colab.research.google.com/github/hkvision/bigdl-demo/blob/main/friesian_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Environment Preparation

In [1]:
# Install jdk8
!apt-get install openjdk-8-jdk-headless -qq > /dev/null
import os
# Set environment variable JAVA_HOME.
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"
!update-alternatives --set java /usr/lib/jvm/java-8-openjdk-amd64/jre/bin/java
!java -version

update-alternatives: using /usr/lib/jvm/java-8-openjdk-amd64/jre/bin/java to provide /usr/bin/java (java) in manual mode
openjdk version "1.8.0_342"
OpenJDK Runtime Environment (build 1.8.0_342-8u342-b07-0ubuntu1~18.04-b07)
OpenJDK 64-Bit Server VM (build 25.342-b07, mixed mode)


In [2]:
!pip install --pre --upgrade bigdl-friesian-spark3[train]
!pip install tensorflow

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting bigdl-friesian-spark3[train]
  Downloading bigdl_friesian_spark3-2.1.0b20220820-py3-none-manylinux1_x86_64.whl (114 kB)
[K     |████████████████████████████████| 114 kB 5.1 MB/s 
[?25hCollecting bigdl-orca-spark3==2.1.0b20220820
  Downloading bigdl_orca_spark3-2.1.0b20220820-py3-none-manylinux1_x86_64.whl (21.8 MB)
[K     |████████████████████████████████| 21.8 MB 1.5 MB/s 
Collecting bigdl-tf==0.14.0.dev1
  Downloading bigdl_tf-0.14.0.dev1-py3-none-manylinux2010_x86_64.whl (71.0 MB)
[K     |████████████████████████████████| 71.0 MB 365 bytes/s 
[?25hCollecting bigdl-math==0.14.0.dev1
  Downloading bigdl_math-0.14.0.dev1-py3-none-manylinux2010_x86_64.whl (35.4 MB)
[K     |████████████████████████████████| 35.4 MB 471 kB/s 
Collecting bigdl-dllib-spark3==2.1.0b20220820
  Downloading bigdl_dllib_spark3-2.1.0b20220820-py3-none-manylinux1_x86_64.whl (50.0 MB)
[K     |███████

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


# Generate random data for 2021 Twitter Recsys Challenge

In [3]:
import random
from pyspark.sql.types import StructType, StructField, StringType, LongType, BooleanType
from bigdl.orca import init_orca_context, stop_orca_context, OrcaContext
from bigdl.friesian.feature import FeatureTable

# To display terminal's stdout and stderr in the Jupyter notebook.
OrcaContext.log_output = True

sc = init_orca_context(cores=4, init_ray_on_spark=True)
spark = OrcaContext.get_spark_session()

In [4]:
id_list = ["0", "1", "2", "3", "4", "5", "6", "7", "8",
           "9", "A", "B", "C", "D", "E", "F", "G", "H",
           "I", "J", "K", "L", "M", "N", "O", "P", "Q",
           "R", "S", "T", "U", "V", "W", "X", "Y", "Z"]
media_list = ["Photo", "Video", "GIF"]
tweet_list = ["Retweet", "Quote", "TopLevel"]
language_list = ["".join(random.choices(id_list, k=32)) for _ in range(65)]

In [5]:
schema = StructType(
    [StructField("text_tokens", StringType(), True),
     StructField("hashtags", StringType(), True),
     StructField("tweet_id", StringType(), True),
     StructField("present_media", StringType(), True),
     StructField("present_links", StringType(), True),
     StructField("present_domains", StringType(), True),
     StructField("tweet_type", StringType(), True),
     StructField("language", StringType(), True),
     StructField("tweet_timestamp", LongType(), True),
     StructField("engaged_with_user_id", StringType(), True),
     StructField("engaged_with_user_follower_count", LongType(), True),
     StructField("engaged_with_user_following_count", LongType(), True),
     StructField("engaged_with_user_is_verified", BooleanType(), True),
     StructField("engaged_with_user_account_creation", LongType(), True),
     StructField("enaging_user_id", StringType(), True),
     StructField("enaging_user_follower_count", LongType(), True),
     StructField("enaging_user_following_count", LongType(), True),
     StructField("enaging_user_is_verified", BooleanType(), True),
     StructField("enaging_user_account_creation", LongType(), True),
     StructField("engagee_follows_engager", StringType(), True),
     StructField("reply_timestamp", LongType(), True),
     StructField("retweet_timestamp", LongType(), True),
     StructField("retweet_with_comment_timestamp", LongType(), True),
     StructField("like_timestamp", LongType(), True)])

In [6]:
def generate_record(random_seed):
    random.seed(random_seed)
    text_tokens = "\t".join([str(random.randint(1, 1000))
                            for i in range(random.randint(1, 10))])
    hashtags = "\t".join(["".join(random.choices(id_list, k=32))
                          for i in range(random.randint(0, 50))])
    tweet_id = "".join(random.choices(id_list, k=32))
    present_media = "\t".join(random.choices(
        media_list, k=random.randint(0, 9)))
    present_links = "\t".join(["".join(random.choices(id_list, k=32))
                               for i in range(random.randint(0, 10))])
    present_domains = "\t".join(["".join(random.choices(id_list, k=32))
                                for i in range(random.randint(0, 10))])
    tweet_type = random.choices(tweet_list)[0]
    language = random.choices(language_list)[0]
    tweet_timestamp = random.randint(946656000, 1609430400)
    engaged_with_user_id = "".join(random.choices(id_list, k=32))
    engaged_with_user_follower_count = random.randint(0, 10000)
    engaged_with_user_following_count = random.randint(0, 10000)
    engaged_with_user_is_verified = bool(random.getrandbits(1))
    engaged_with_user_account_creation = random.randint(946656000, 1609430400)
    enaging_user_id = "".join(random.choices(id_list, k=32))
    enaging_user_follower_count = random.randint(0, 10000)
    enaging_user_following_count = random.randint(0, 10000)
    enaging_user_is_verified = bool(random.getrandbits(1))
    enaging_user_account_creation = random.randint(946656000, 1609430400)
    engagee_follows_engager = bool(random.getrandbits(1))
    reply = bool(random.getrandbits(1))
    reply_timestamp = random.randint(946656000, 1609430400) if reply else None
    retweet = bool(random.getrandbits(1))
    retweet_timestamp = random.randint(
        946656000, 1609430400) if retweet else None
    comment = bool(random.getrandbits(1))
    retweet_with_comment_timestamp = random.randint(
        946656000, 1609430400) if comment else None
    like = bool(random.getrandbits(1))
    like_timestamp = random.randint(946656000, 1609430400) if like else None
    return (text_tokens, hashtags, tweet_id, present_media, present_links, present_domains,
            tweet_type, language, tweet_timestamp, engaged_with_user_id,
            engaged_with_user_follower_count, engaged_with_user_following_count,
            engaged_with_user_is_verified, engaged_with_user_account_creation,
            enaging_user_id, enaging_user_follower_count, enaging_user_following_count,
            enaging_user_is_verified, enaging_user_account_creation,
            engagee_follows_engager, reply_timestamp, retweet_timestamp,
            retweet_with_comment_timestamp, like_timestamp)

In [7]:
rdd = sc.parallelize(range(50000))
dummy_data_rdd = rdd.map(generate_record)
df = FeatureTable(spark.createDataFrame(dummy_data_rdd, schema))

Initializing orca context
Current pyspark location is : /usr/local/lib/python3.7/dist-packages/pyspark/__init__.py
Start to getOrCreate SparkContext
pyspark_submit_args is:  --driver-class-path /usr/local/lib/python3.7/dist-packages/bigdl/share/dllib/lib/bigdl-dllib-spark_3.1.2-2.1.0-SNAPSHOT-jar-with-dependencies.jar:/usr/local/lib/python3.7/dist-packages/bigdl/share/orca/lib/bigdl-orca-spark_3.1.2-2.1.0-SNAPSHOT-jar-with-dependencies.jar:/usr/local/lib/python3.7/dist-packages/bigdl/share/friesian/lib/bigdl-friesian-spark_3.1.2-2.1.0-SNAPSHOT-jar-with-dependencies.jar:/usr/local/lib/python3.7/dist-packages/bigdl/share/core/lib/all-2.1.0-20220728.053003-14.jar pyspark-shell 
Successfully got a SparkContext


2022-08-22 09:41:38,191	INFO services.py:1340 -- View the Ray dashboard at [1m[32mhttp://172.28.0.2:8265[39m[22m


{'node_ip_address': '172.28.0.2', 'raylet_ip_address': '172.28.0.2', 'redis_address': '172.28.0.2:6379', 'object_store_address': '/tmp/ray/session_2022-08-22_09-41-33_644736_58/sockets/plasma_store', 'raylet_socket_name': '/tmp/ray/session_2022-08-22_09-41-33_644736_58/sockets/raylet', 'webui_url': '172.28.0.2:8265', 'session_dir': '/tmp/ray/session_2022-08-22_09-41-33_644736_58', 'metrics_export_port': 49215, 'node_id': '3d8cef3a50caea4b9387379adda12a6b4a9fafe3eee1dc23813a898b'}


In [8]:
train_tbl, valid_tbl = df.random_split([0.8, 0.2])

train_size = train_tbl.size()
valid_size = valid_tbl.size()
print("Total number of train records: {}".format(train_size))
print("Total number of validation records: {}".format(valid_size))

Total number of train records: 40099
Total number of validation records: 9901


# Feature Engineering

In [9]:
bool_cols = [
    'engaged_with_user_is_verified',
    'enaging_user_is_verified'
]

count_cols = [
    'engaged_with_user_follower_count',
    'engaged_with_user_following_count',
    'enaging_user_follower_count',
    'enaging_user_following_count'
]

cat_cols = [
    'present_media',
    'tweet_type',
    'language'
]

In [10]:
media_map = {
    '': 0,
    'GIF': 1,
    'GIF_GIF': 2,
    'GIF_Photo': 3,
    'GIF_Video': 4,
    'Photo': 5,
    'Photo_GIF': 6,
    'Photo_Photo': 7,
    'Photo_Video': 8,
    'Video': 9,
    'Video_GIF': 10,
    'Video_Photo': 11,
    'Video_Video': 12
}

type_map = {
    'Quote': 0,
    'Retweet': 1,
    'TopLevel': 2,
}

In [11]:
def preprocess(tbl):
    tbl = tbl.fillna("", "present_media")
    tbl = tbl.cast(bool_cols + count_cols, "int")  # cast bool and long to int
    tbl = tbl.cut_bins(columns=count_cols,
                       bins=[1, 1e2, 1e3, 1e4, 1e5, 1e6, 1e7],
                       out_cols=count_cols)
    process_media = lambda x: '_'.join(x.split('\t')[:2])
    tbl = tbl.apply("present_media", "present_media", process_media, "string")
    tbl = tbl.encode_string("present_media", media_map)
    tbl = tbl.encode_string("tweet_type", type_map)

    return tbl


train_tbl = preprocess(train_tbl)
valid_tbl = preprocess(valid_tbl)

In [12]:
train_tbl, language_idx = train_tbl.category_encode("language")
valid_tbl = valid_tbl.encode_string("language", language_idx)
valid_tbl = valid_tbl.fillna(0, "language")

print("The number of languages: {}".format(language_idx.size()))

The number of languages: 65


In [13]:
def generate_features(tbl):
    cross_cols = [['present_media', 'language']]
    cross_dims = [600]
    tbl = tbl.cross_columns(cross_cols, cross_dims)  # The resulting cross column will have name "present_media_language"

    count_func = lambda x: str(x).count('\t') + 1 if x else 0
    tbl = tbl.apply("hashtags", "len_hashtags", count_func, "int") \
        .apply("present_domains", "len_domains", count_func, "int") \
        .apply("present_links", "len_links", count_func, "int")
    return tbl


train_tbl = generate_features(train_tbl)
valid_tbl = generate_features(valid_tbl)

In [14]:
len_cols = ['len_hashtags',
            'len_domains',
            'len_links']

train_tbl, min_max_dict = train_tbl.min_max_scale(len_cols)
valid_tbl = valid_tbl.transform_min_max_scale(len_cols, min_max_dict)

  1. The dashboard might not display correct information on this node.
  2. Metrics on this node won't be reported.
  3. runtime_env APIs won't work.
Check out the `dashboard_agent.log` to see the detailed failure messages.


In [15]:
timestamp_cols = [
    'reply_timestamp',
    'retweet_timestamp',
    'retweet_with_comment_timestamp',
    'like_timestamp'
]

In [16]:
def transform_label(tbl):
    tbl = tbl.cast(timestamp_cols, "int")
    tbl = tbl.fillna(0, timestamp_cols)
    gen_label = lambda x: 1 if max(x) > 0 else 0
    tbl = tbl.apply(in_col=timestamp_cols, out_col="label", func=gen_label, dtype="int")
    return tbl


train_tbl = transform_label(train_tbl)
valid_tbl = transform_label(valid_tbl)

In [17]:
train_tbl.select(bool_cols + cat_cols).show(5)

+-----------------------------+------------------------+-------------+----------+--------+
|engaged_with_user_is_verified|enaging_user_is_verified|present_media|tweet_type|language|
+-----------------------------+------------------------+-------------+----------+--------+
|                            0|                       1|           12|         1|       4|
|                            1|                       1|            8|         1|      59|
|                            1|                       0|            6|         0|      15|
|                            1|                       0|            3|         2|      36|
|                            0|                       1|            2|         2|      44|
+-----------------------------+------------------------+-------------+----------+--------+
only showing top 5 rows



In [18]:
train_tbl.select(count_cols).show(5)

+--------------------------------+---------------------------------+---------------------------+----------------------------+
|engaged_with_user_follower_count|engaged_with_user_following_count|enaging_user_follower_count|enaging_user_following_count|
+--------------------------------+---------------------------------+---------------------------+----------------------------+
|                               3|                                3|                          3|                           3|
|                               3|                                3|                          3|                           2|
|                               3|                                3|                          3|                           3|
|                               3|                                3|                          3|                           3|
|                               3|                                3|                          3|                      

In [19]:
train_tbl.select(len_cols + ["present_media_language", "label"]).show(5)

+------------+-----------+---------+----------------------+-----+
|len_hashtags|len_domains|len_links|present_media_language|label|
+------------+-----------+---------+----------------------+-----+
|        0.94|        0.9|      0.6|                   221|    1|
|        0.68|        0.6|      1.0|                   558|    0|
|        0.38|        0.5|      0.3|                   471|    1|
|        0.38|        0.3|      0.6|                   221|    1|
|        0.86|        0.9|      1.0|                   402|    1|
+------------+-----------+---------+----------------------+-----+
only showing top 5 rows



# Wide & Deep Model Training

In [27]:
import math
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping

from bigdl.orca.learn.tf2.estimator import Estimator

In [21]:
wide_cols = ['engaged_with_user_is_verified', 'enaging_user_is_verified']
wide_dims = [1, 1]
cross_cols = ['present_media_language']
cross_dims = [600]

embedding_cols = []
embedding_dims = []

cat_cols = ['present_media',
            'tweet_type',
            'language']
cat_dims = [12, 2, 66]
count_cols = ['engaged_with_user_follower_count',
              'engaged_with_user_following_count',
              'enaging_user_follower_count',
              'enaging_user_following_count']
count_dims = [7, 7, 7, 7]
indicator_cols = cat_cols + count_cols
indicator_dims = cat_dims + count_dims

continuous_cols = ['len_hashtags',
                   'len_domains',
                   'len_links']

column_info = { "wide_base_cols": wide_cols,
                "wide_base_dims": wide_dims,
                "wide_cross_cols": cross_cols,
                "wide_cross_dims": cross_dims,
                "indicator_cols": indicator_cols,
                "indicator_dims": indicator_dims,
                "continuous_cols": continuous_cols,
                "embed_cols": [],
                "embed_in_dims": [],
                "embed_out_dims": [],
                "label": "label"}

In [22]:
def build_model(column_info, hidden_units=[100, 50, 25]):
    """Build an estimator appropriate for the given model type."""
    wide_base_input_layers = []
    wide_base_layers = []
    for i in range(len(column_info["wide_base_cols"])):
        wide_base_input_layers.append(tf.keras.layers.Input(shape=[], dtype="int32"))
        wide_base_layers.append(tf.keras.backend.one_hot(wide_base_input_layers[i], column_info["wide_base_dims"][i] + 1))

    wide_cross_input_layers = []
    wide_cross_layers = []
    for i in range(len(column_info["wide_cross_cols"])):
        wide_cross_input_layers.append(tf.keras.layers.Input(shape=[], dtype="int32"))
        wide_cross_layers.append(tf.keras.backend.one_hot(wide_cross_input_layers[i], column_info["wide_cross_dims"][i]))

    indicator_input_layers = []
    indicator_layers = []
    for i in range(len(column_info["indicator_cols"])):
        indicator_input_layers.append(tf.keras.layers.Input(shape=[], dtype="int32"))
        indicator_layers.append(tf.keras.backend.one_hot(indicator_input_layers[i], column_info["indicator_dims"][i] + 1))

    embed_input_layers = []
    embed_layers = []
    for i in range(len(column_info["embed_in_dims"])):
        embed_input_layers.append(tf.keras.layers.Input(shape=[], dtype="int32"))
        iembed = tf.keras.layers.Embedding(column_info["embed_in_dims"][i] + 1,
                                           output_dim=column_info["embed_out_dims"][i])(embed_input_layers[i])
        flat_embed = tf.keras.layers.Flatten()(iembed)
        embed_layers.append(flat_embed)

    continuous_input_layers = []
    continuous_layers = []
    for i in range(len(column_info["continuous_cols"])):
        continuous_input_layers.append(tf.keras.layers.Input(shape=[]))
        continuous_layers.append(tf.keras.layers.Reshape(target_shape=(1,))(continuous_input_layers[i]))

    if len(wide_base_layers + wide_cross_layers) > 1:
        wide_input = tf.keras.layers.concatenate(wide_base_layers + wide_cross_layers, axis=1)
    else:
        wide_input = (wide_base_layers + wide_cross_layers)[0]
    wide_out = tf.keras.layers.Dense(1)(wide_input)
    if len(indicator_layers + embed_layers + continuous_layers) > 1:
        deep_concat = tf.keras.layers.concatenate(indicator_layers +
                                                  embed_layers +
                                                  continuous_layers, axis=1)
    else:
        deep_concat = (indicator_layers + embed_layers + continuous_layers)[0]
    linear = deep_concat
    for ilayer in range(0, len(hidden_units)):
        linear_mid = tf.keras.layers.Dense(hidden_units[ilayer])(linear)
        bn = tf.keras.layers.BatchNormalization()(linear_mid)
        relu = tf.keras.layers.ReLU()(bn)
        dropout = tf.keras.layers.Dropout(0.1)(relu)
        linear = dropout
    deep_out = tf.keras.layers.Dense(1)(linear)
    added = tf.keras.layers.add([wide_out, deep_out])
    out = tf.keras.layers.Activation("sigmoid")(added)
    model = tf.keras.models.Model(wide_base_input_layers +
                                  wide_cross_input_layers +
                                  indicator_input_layers +
                                  embed_input_layers +
                                  continuous_input_layers,
                                  out)

    return model

In [23]:
config = {
    "lr": 0.0001,
    "column_info": column_info,
    "inter_op_parallelism": 4,
    "intra_op_parallelism": 24
}
batch_size = 2560

In [24]:
def model_creator(config):
    model = build_model(column_info=config["column_info"],
                        hidden_units=[1024, 1024])
    optimizer = tf.keras.optimizers.Adam(config["lr"])
    model.compile(optimizer=optimizer,
                  loss='binary_crossentropy',
                  metrics=['binary_accuracy', 'binary_crossentropy', 'AUC', 'Precision', 'Recall'])
    return model

In [25]:
estimator = Estimator.from_keras(
    model_creator=model_creator,
    verbose=True,
    config=config,
    workers_per_node=2)

[2m[36m(Worker pid=479)[0m Instructions for updating:
[2m[36m(Worker pid=479)[0m use distribute.MultiWorkerMirroredStrategy instead
[2m[36m(Worker pid=478)[0m Instructions for updating:
[2m[36m(Worker pid=478)[0m use distribute.MultiWorkerMirroredStrategy instead
[2m[36m(Worker pid=479)[0m 2022-08-22 09:44:21.854823: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
[2m[36m(Worker pid=478)[0m 2022-08-22 09:44:21.855390: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected


In [29]:
steps_per_epoch = math.ceil(train_size / batch_size)
epochs = 5
val_steps = math.ceil(valid_size / batch_size)

callbacks = [EarlyStopping(monitor='val_auc', mode='max', verbose=1, patience=3)]

In [30]:
def label_cols(column_info):
    return [column_info["label"]]

def feature_cols(column_info):
    return column_info["wide_base_cols"] + column_info["wide_cross_cols"] +\
                  column_info["indicator_cols"] + column_info["embed_cols"] + column_info["continuous_cols"]

estimator.fit(data=train_tbl.df,
              epochs=epochs,
              batch_size=batch_size,
              steps_per_epoch=steps_per_epoch,
              validation_data=valid_tbl.df,
              validation_steps=val_steps,
              callbacks=callbacks,
              feature_cols=feature_cols(column_info),
              label_cols=label_cols(column_info))

[2m[36m(Worker pid=479)[0m Instructions for updating:
[2m[36m(Worker pid=479)[0m rename to distribute_datasets_from_function
[2m[36m(Worker pid=478)[0m Instructions for updating:
[2m[36m(Worker pid=478)[0m rename to distribute_datasets_from_function
[2m[36m(Worker pid=479)[0m 2022-08-22 09:48:07.325853: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.
[2m[36m(Worker pid=478)[0m 2022-08-22 09:48:07.325829: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.


[2m[36m(Worker pid=479)[0m Epoch 1/5
 1/16 [>.............................] - ETA: 2:17 - loss: 0.9040 - binary_accuracy: 0.3836 - binary_crossentropy: 0.9040 - auc: 0.5017 - precision: 0.9340 - recall: 0.3626
 2/16 [==>...........................] - ETA: 12s - loss: 0.8797 - binary_accuracy: 0.3885 - binary_crossentropy: 0.8797 - auc: 0.5188 - precision: 0.9381 - recall: 0.3685 
 3/16 [====>.........................] - ETA: 11s - loss: 0.8671 - binary_accuracy: 0.3970 - binary_crossentropy: 0.8671 - auc: 0.4947 - precision: 0.9329 - recall: 0.3820
[2m[36m(Worker pid=479)[0m Epoch 2/5
 1/16 [>.............................] - ETA: 12s - loss: 0.5297 - binary_accuracy: 0.7844 - binary_crossentropy: 0.5297 - auc: 0.5465 - precision: 0.9445 - recall: 0.8192
 2/16 [==>...........................] - ETA: 11s - loss: 0.5255 - binary_accuracy: 0.7936 - binary_crossentropy: 0.5255 - auc: 0.5212 - precision: 0.9428 - recall: 0.8311
 3/16 [====>.........................] - ETA: 10s - loss: 

[2m[36m(Worker pid=479)[0m 2022-08-22 09:48:47.417395: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.
[2m[36m(Worker pid=478)[0m 2022-08-22 09:48:47.440609: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.


[2m[36m(Worker pid=479)[0m Epoch 3/5
 1/16 [>.............................] - ETA: 12s - loss: 0.3615 - binary_accuracy: 0.9238 - binary_crossentropy: 0.3615 - auc: 0.5022 - precision: 0.9370 - recall: 0.9850
 2/16 [==>...........................] - ETA: 11s - loss: 0.3604 - binary_accuracy: 0.9221 - binary_crossentropy: 0.3604 - auc: 0.5078 - precision: 0.9352 - recall: 0.9850
 3/16 [====>.........................] - ETA: 10s - loss: 0.3556 - binary_accuracy: 0.9240 - binary_crossentropy: 0.3556 - auc: 0.5137 - precision: 0.9363 - recall: 0.9858
[2m[36m(Worker pid=479)[0m Epoch 4/5
 1/16 [>.............................] - ETA: 12s - loss: 0.2964 - binary_accuracy: 0.9316 - binary_crossentropy: 0.2964 - auc: 0.5399 - precision: 0.9320 - recall: 0.9996
 2/16 [==>...........................] - ETA: 12s - loss: 0.2920 - binary_accuracy: 0.9311 - binary_crossentropy: 0.2920 - auc: 0.5527 - precision: 0.9318 - recall: 0.9992
 3/16 [====>.........................] - ETA: 11s - loss: 0.

[2m[36m(Worker pid=479)[0m 2022-08-22 09:49:20.137071: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.
[2m[36m(Worker pid=478)[0m 2022-08-22 09:49:20.151022: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.


[2m[36m(Worker pid=479)[0m Epoch 5/5
 1/16 [>.............................] - ETA: 9s - loss: 0.2569 - binary_accuracy: 0.9332 - binary_crossentropy: 0.2569 - auc: 0.6120 - precision: 0.9336 - recall: 0.9996
 2/16 [==>...........................] - ETA: 9s - loss: 0.2517 - binary_accuracy: 0.9363 - binary_crossentropy: 0.2517 - auc: 0.6062 - precision: 0.9365 - recall: 0.9998
 3/16 [====>.........................] - ETA: 8s - loss: 0.2479 - binary_accuracy: 0.9387 - binary_crossentropy: 0.2479 - auc: 0.6009 - precision: 0.9388 - recall: 0.9999


[{'train_loss': 0.24382218718528748,
  'train_binary_accuracy': 0.9387451410293579,
  'train_binary_crossentropy': 0.24382218718528748,
  'train_auc': 0.587367057800293,
  'train_precision': 0.9388368129730225,
  'train_recall': 0.9998959898948669,
  'train_val_loss': 0.5580403208732605,
  'train_val_binary_accuracy': 0.9330078363418579,
  'train_val_binary_crossentropy': 0.5580403208732605,
  'train_val_auc': 0.5002649426460266,
  'train_val_precision': 0.9330078363418579,
  'train_val_recall': 1.0}]

In [31]:
model = estimator.get_model()
tf.saved_model.save(model, "recsys_wnd/")

stop_orca_context()

Stopping orca context
