In [9]:
import os
import time
import json

from absl import app
from absl import flags
from absl import logging

import numpy as np
import tensorflow as tf
import tensorflow_recommenders as tfrs

from google.cloud import aiplatform as vertex_ai
from google.cloud import bigquery
from google.cloud import storage

In [10]:
PREFIX = 'two-tower'
#PREFIX = 'css_retail'
DISPLAY_NAME = f'{PREFIX}-tensorboard'
PROJECT= 'babrams-recai-demo-final'
REGION='us-central1'

STAGING_BUCKET = """gs://{}_vertex_training""".format(PROJECT) #lowes-reccomendation-tensorboard-logs-us-central1 - this 
#TENSORBOARD = 'projects/258043323883/locations/us-central1/tensorboards/4236655796332527616' #note really can only get this after gcloud beta ai tensorboards create...
#VERTEX_SA = 'vertex-tb@lowes-reccomendation.iam.gserviceaccount.com'

FLAGS = flags.FLAGS
flags.DEFINE_float("LR", -1.01, "Learning Rate")
flags.DEFINE_integer("EMBEDDING_DIM", 15, "Embedding dimension")
flags.DEFINE_integer("MAX_TOKENS", 15, "Max embeddings for query and last_n products")
flags.DEFINE_integer("NUM_EPOCHS", 29, "Number of epochs")
flags.DEFINE_string("MODEL_DIR", 'model-dirs-lowes', "GCS Bucket to store the model artifact")
flags.DEFINE_bool("DROPOUT", False, "Use Dropout - T/F bool type")
flags.DEFINE_float("DROPOUT_RATE", -1.4, "Dropout rate only works with DROPOUT=True")
#flags.DEFINE_integer("N_PRODUCTS", 19999, "number of products considered for embedding")
flags.DEFINE_integer("BATCH_SIZE", 1023, "batch size")
flags.DEFINE_string("ARCH", '[127,64]', "deep architecture, expressed as a list of ints in string format - will be parsed into list")
flags.DEFINE_integer("SEED", 41781896, "random seed")
#flags.DEFINE_string("TF_RECORDS_DIR", "gs://tfrs-central-a", "source data in tfrecord format gcs location")

# initialize vertex sdk
vertex_ai.init(
    project=PROJECT,
    location=REGION,
    staging_bucket=STAGING_BUCKET
)

client = bigquery.Client()

In [None]:
#!pip install -r notebook_requirements.txt --user

In [15]:
product_catalog_sql = """
WITH inner_q AS (
    SELECT
        SAFE_CAST(id AS INT) AS productId,
        title,
        description,
        product_metadata.exact_price.original_price AS price,
        ARRAY_TO_STRING(cats.categories, ' ') AS categories
    FROM `babrams-recai-demo-final.css_retail.recommendation_ai_data` AS rad,
    UNNEST(category_hierarchies) AS cats
) SELECT
    productId,
    title,
    description,
    price,
    ARRAY_TO_STRING(ARRAY_AGG(categories), ",") AS categories
FROM inner_q
GROUP BY productId, title, description, price
"""

product_catalog = client.query(product_catalog_sql)
product_catalog_df = product_catalog.to_dataframe()
#product_catalog_df.describe()
#product_categoricals = ['productId', 'title', 'description', 'categories']
# product_catalog_df.dtypes

class ProductModel(tf.keras.Model):
    def __init__(self, layer_sizes, adapt_data):
        super().__init__()
        
        #preprocess stuff
        self.sku_count = np.unique(
            np.concatenate(
                list(
                    adapt_data.map(lambda x: x["productId"]).batch(1000)
                )
            )
        )
        #categorical: sku
        self.sku_lookup = tf.keras.layers.experimental.preprocessing.StringLookup(
            name="sku_monotic"
        )
        self.title_vectorizor = tf.keras.layers.TextVectorization(
            max_tokens=self.sku_count
            , name="title_vectorizor"
        )
        self.description_vectorizor = tf.keras.layers.TextVectorization(
            max_tokens=self.sku_count
            , name="description_vectorizor"
        )
        self.category_vectorizor = tf.keras.layers.TextVectorization(
            max_tokens=FLAGS.N_PRODUCTS
            , name="category_vectorizor"
        )
        self.price_normalization = tf.keras.layers.experimental.preprocessing.Normalization(axis=None)

        
        #adapt stuff
        self.category_vectorizor.adapt(adapt_data.map(lambda x: x['categories']))
        self.title_vectorizor.adapt(adapt_data.map(lambda x: x['title']))
        self.description_vectorizor.adapt(adapt_data.map(lambda x: x['description']))
        self.sku_lookup.adapt(adapt_data.map(lambda x: x['productId']))
        self.price_normalization.adapt(adapt_data.map(lambda x: x['price']))
        
        #embed stuff
        self.sku_embedding = tf.keras.Sequential(
            [
                self.sku_lookup,
                tf.keras.layers.Embedding(
                    self.sku_count+1,
                    FLAGS.EMBEDDING_DIM,
                    mask_zero=True,
                    name="sku_emb"
                ),
                tf.keras.layers.GlobalAveragePooling1D(
                    name="sku_flat"
                )
            ],
            name="sku_embedding"
        )
        self.title_embedding = tf.keras.Sequential(
            [
                self.title_vectorizor,
                tf.keras.layers.Embedding(
                    self.sku_count+1, 
                    FLAGS.EMBEDDING_DIM, 
                    mask_zero=True, 
                    name="title_emb"
                ),
                tf.keras.layers.GlobalAveragePooling1D(
                    name="title_flatten"
                )
            ], 
            name="title_embedding"
        )
        self.description_embedding = tf.keras.Sequential(
            [
                self.description_vectorizor,
                tf.keras.layers.Embedding(
                    self.sku_count+1, 
                    FLAGS.EMBEDDING_DIM, 
                    mask_zero=True, 
                    name="desc_emb"),
                tf.keras.layers.GlobalAveragePooling1D(
                    name="desc_flatten"
                )
            ], 
            name="description_embedding"
        )
        self.category_embedding = tf.keras.Sequential(
            [
                self.category_vectorizor,
                tf.keras.layers.Embedding(
                    self.category_vectorizer.vocab_size(), 
                    FLAGS.EMBEDDING_DIM, 
                    mask_zero=True, 
                    name="category_emb"),
                tf.keras.layers.GlobalAveragePooling1D(
                    name="category_flatten"
                )
            ], 
            name="category_embedding"
        )
        
        
        # Then construct the layers.
        self.dense_layers = tf.keras.Sequential(name="dense_layers_product")
        
        # Adding weight initialzier
        initializer = tf.keras.initializers.GlorotUniform(seed=FLAGS.SEED)
        # Use the ReLU activation for all but the last layer.
        for layer_size in layer_sizes[:-1]:
            self.dense_layers.add(tf.keras.layers.Dense(
                layer_size,
                activation="relu",
                kernel_initializer=initializer
            ))
            if FLAGS.DROPOUT:
                self.dense_layers.add(tf.keras.layers.Dropout(
                    FLAGS.DROPOUT_RATE
                ))
        # No activation for the last layer
        for layer_size in layer_sizes[-1:]:
            self.dense_layers.add(tf.keras.layers.Dense(
                layer_size,
                kernel_initializer=initializer
            ))
        ### ADDING L2 NORM AT THE END
        self.dense_layers.add(tf.keras.layers.Lambda(
            lambda x: tf.nn.l2_normalize(
                x,
                1,
                epsilon=1e-12,
                name="normalize_dense"
            )
        ))

    def call(self, data):
        all_embs = tf.concat(
            [
                tf.reshape(self.price_normalization(data["price"]), (-1, 1)),
                self.description_embedding(data['description']),
                self.sku_embedding(data['productId']),
                self.category_embedding(data['categories']),
                self.title_embedding(data['title'])
            ], axis=1)
        return self.dense_layers(all_embs)  #last plus for number continuous + 1 if you add other(s) 2048 for visual


In [14]:
  
customer_data_sql = """
SELECT
    id AS userId,
    age,
    gender,
    latitude,
    longitude,
    zip,
    traffic_source,
    TIMESTAMP_DIFF(CURRENT_TIMESTAMP(), created_at, DAY) AS customer_lifetime_days
FROM `babrams-recai-demo-final.css_retail.customers` AS customers
"""
customer_data = client.query(customer_data_sql)
customer_data_df = customer_data.to_dataframe()
#customer_data_df.describe()

class UserModel(tf.keras.Model):
    def __init__(self, layer_sizes, adapt_data):
        super().__init__()
        
        #preprocess stuff
        self.user_lookup = tf.keras.layers.experimental.preprocessing.StringLookup()
        self.max_age = adapt_data.map(
            lambda x: x['age']
        ).reduce(
            tf.cast(0, tf.int64),
            tf.maximum
        ).numpy().max()
        self.min_age = adapt_data.map(
            lambda x: x['age']
        ).reduce(
            np.int64(1e9),
            tf.minimum
        ).numpy().min()
        self.age_buckets = np.linspace(self.min_age, self.max_age, num=20)
        self.max_lifetime = adapt_data.map(
            lambda x: x['customer_lifetime_days']
        ).reduce(
            tf.cast(0, tf.int64),
            tf.maximum
        ).numpy().max()
        self.min_lifetime = adapt_data.map(
            lambda x: x['customer_lifetime_days']
        ).reduce(
            np.int64(1e9),
            tf.minimum
        ).numpy().min()
        self.lifetime_buckets = np.linspace(self.min_lifetime, self.max_lifetime, num=100)
        self.traffic_source_lookup = tf.keras.layers.experimental.preprocessing.StringLookup()
        self.zip_lookup = tf.keras.layers.experimental.preprocessing.StringLookup()
        self.gender_lookup = tf.keras.layers.experimental.preprocessing.StringLookup()
        
        #adapt stuff
        self.user_lookup.adapt(adapt_data.map(lambda x: x['productId']))
        self.zip_lookup.adapt(adapt_data.map(lambda x: x['zip']))
        self.gender_lookup.adapt(adapt_data.map(lambda x: x['gender']))
         
        
        #embed stuff
        self.user_embedding = tf.keras.Sequential([
            self.user_lookup,
            tf.keras.layers.Embedding(
                self.user_lookup.vocab_size() + 1,
                FLAGS.EMBEDDING_DIM,
                mask_zero=True,
                name="user_emb"
            ),
            tf.keras.layers.GlobalAveragePooling1D(
                name="user_flat"
            )
        ], name="user_embedding")
        self.age_embedding = tf.keras.Sequential([
            tf.keras.layers.experimental.preprocessing.Discretization(
                self.age_buckets.tolist()
            ),
            tf.keras.layers.Embedding(
                len(self.age_buckets) + 1,
                32
            )
        ], name="age_embedding")
        self.lifetime_embedding = tf.keras.Sequential([
            tf.keras.layers.experimental.preprocessing.Discretization(
                self.lifetime_buckets.tolist(),
                name="lifetime_disc"
            ),
            tf.keras.layers.Embedding(
                len(self.lifetime_buckets) + 1,
                32
            )
        ], name="customer_lifetime_embedding")
        self.traffic_source_embedding = tf.keras.Sequential([
            self.traffic_source_lookup,
            tf.keras.layers.Embedding(
                self.traffic_source_lookup.vocab_size() + 1,
                FLAGS.EMBEDDING_DIM,
                mask_zero=True,
                name="traffic_source_emb"
            ),
            tf.keras.layers.GlobalAveragePooling1D(
                name="traffic_source_flat"
            )
        ], name="traffic_source_embedding")
        self.zip_embedding = tf.keras.Sequential([
            self.zip_lookup,
            tf.keras.layers.Embedding(
                self.zip_lookup.vocab_size() + 1,
                FLAGS.EMBEDDING_DIM,
                mask_zero=True,
                name="zip_emb"),
            tf.keras.layers.GlobalAveragePooling1D(name="zip_flat")
        ], name="zip_embedding")
        self.gender_embedding = tf.keras.Sequential([
            self.genderlookup,
            tf.keras.layers.Embedding(
                self.gender_lookup.vocab_size() + 1,
                FLAGS.EMBEDDING_DIM,
                mask_zero=True,
                name="gender_emb"
            ),
            tf.keras.layers.GlobalAveragePooling1D(
                name="gender_flat"
            )
        ], name="gender_embedding")
        
        # Then construct the layers.
        self.dense_layers = tf.keras.Sequential(name="dense_layers_product")
        
        # Adding weight initialzier
        initializer = tf.keras.initializers.GlorotUniform(seed=FLAGS.SEED)
        # Use the ReLU activation for all but the last layer.
        for layer_size in layer_sizes[:-1]:
            self.dense_layers.add(tf.keras.layers.Dense(
                layer_size,
                activation="relu",
                kernel_initializer=initializer
            ))
            if FLAGS.DROPOUT:
                self.dense_layers.add(tf.keras.layers.Dropout(
                    FLAGS.DROPOUT_RATE
                ))
            # No activation for the last layer
        for layer_size in layer_sizes[-1:]:
            self.dense_layers.add(tf.keras.layers.Dense(
                layer_size,
                kernel_initializer=initializer
            ))
        ### ADDING L2 NORM AT THE END
        self.dense_layers.add(tf.keras.layers.Lambda(
            lambda x: tf.nn.l2_normalize(
                x,
                1,
                epsilon=1e-12,
                name="normalize_dense"
            )
        ))
        

    def call(self, data):
        all_embs = tf.concat(
            [
                self.user_embedding(data['userId']),
                self.age_embedding(data['age']),
                self.lifetime_embedding(data['customer_lifetime_days']),
                self.traffic_source_embedding(data['traffic_source']),
                self.zip_embedding(data['zip']),
                self.gender_embedding(data['gender'])
            ], axis=1)
        return self.dense_layers(all_embs)  #last plus for number continuous + 1 if you add other(s) 2048 for visual



In [6]:
purchase_data_sql = """
WITH inner_q AS (
    SELECT
        SAFE_CAST(userInfo.userId AS INT) AS userId,
        SAFE_CAST(eventTime AS TIMESTAMP) AS eventTime,
        productEventDetail.cartId,
        productEventDetail.purchaseTransaction.revenue,
        products.id AS productId,
        products.quantity,
        products.displayPrice AS price
    FROM `babrams-recai-demo-final.css_retail.purchase_complete` AS purchase,
    UNNEST(productEventDetail.productDetails) AS products
) SELECT
    inner_q.* EXCEPT (eventTime),
    UNIX_MILLIS(eventTime) AS eventTime,
    EXTRACT(HOUR FROM eventTime) AS hour,
    EXTRACT(DAY FROM eventTime) AS day,
    EXTRACT(MONTH FROM eventTime) AS month,
    EXTRACT(DAYOFWEEK FROM eventTime) AS dow
FROM inner_q;
"""

purchase_data = client.query(purchase_data_sql)
purchase_data_df = purchase_data.to_dataframe()
#purchase_categoricals = ['userId', 'cartId', 'productId']
# purchase_data_df.dtypes

class EventModel(tf.keras.Model):
    def __init__(self, layer_sizes, adapt_data):
        super().__init__()

        ### preprocess stuff
        self.month_vocab = tf.constant(
            ["%02d" % i for i in range(1,13)],
            name="month_vocab"
        )
        self.day_vocab = tf.constant(
            ["%02d" % i for i in range(1,32)],
            name='day_vocab'
        )
        self.dow_vocab = tf.constant(
            ["%02d" % i for i in range(1,8)],
            name="dow_vocab"
        )
        self.hour_vocab = tf.constant(
            ["%02d" % i for i in range(0,24)],
            name="hour_vocab"
        )
        self.eventtime_normalization = tf.keras.layers.experimental.preprocessing.Normalization(axis=None)
        self.price_normalization = tf.keras.layers.experimental.preprocessing.Normalization(axis=None)


        ### adapt stuff
        self.eventtime_normalization.adapt(adapt_data.map(lambda x: x['eventTime']).batch(1024))
        self.price_normalization.adapt(adapt_data.map(lambda x: x['price']))

        ### embed stuff
        self.month_embedding = tf.keras.Sequential([
            tf.keras.layers.StringLookup(
                vocabulary=self.month_vocab,
                mask_token=None,
                name="month_lookup",
                output_mode='count')
        ], name="month_emb")
        self.hour_embedding = tf.keras.Sequential([
            tf.keras.layers.StringLookup(
                vocabulary=self.hour_vocab,
                mask_token=None,
                name="hour_lookup",
                output_mode='count')
        ], name="hour_emb")
        self.day_embedding = tf.keras.Sequential([
            tf.keras.layers.StringLookup(
                vocabulary=self.day_vocab,
                mask_token=None,
                name="day_lookup",
                output_mode='count')
        ], name="day_emb")
        self.dow_embedding = tf.keras.Sequential([
            tf.keras.layers.StringLookup(
                vocabulary=self.dow_vocab,
                mask_token=None,
                name="dow_lookup",
                output_mode='count')
        ], name="dow_emb")


        # Then construct the layers.
        self.dense_layers = tf.keras.Sequential(name="dense_layers_query")

        initializer = tf.keras.initializers.GlorotUniform(seed=FLAGS.SEED)
        # Use the ReLU activation for all but the last layer.
        for layer_size in layer_sizes[:-1]:
            self.dense_layers.add(tf.keras.layers.Dense(
                layer_size,
                activation="relu",
                kernel_initializer=initializer
            ))
            if FLAGS.DROPOUT:
                self.dense_layers.add(tf.keras.layers.Dropout(
                    FLAGS.DROPOUT_RATE
                ))
        # No activation for the last layer
        for layer_size in layer_sizes[-1:]:
            self.dense_layers.add(tf.keras.layers.Dense(
                layer_size,
                kernel_initializer=initializer
            ))
        ### ADDING L2 NORM AT THE END
        self.dense_layers.add(tf.keras.layers.Lambda(
            lambda x: tf.nn.l2_normalize(
                x,
                1,
                epsilon=1e-12,
                name="normalize_dense"
            )
        ))


    def call(self, data):
        all_embs = tf.concat(
            [
                self.month_embedding(data['month']),
                self.dow_embedding(data['dow']),
                self.day_embedding(data['day']),
                self.hour_embedding(data['hour']),
                tf.reshape(self.eventtime_normalization(data['eventTime']), (-1,1)),
                tf.reshape(self.price_normalization(data['price']), (-1,1)),
                #self.query_embedding(data['query']),
                #self.last_viewed_embedding(data['last_viewed'])
            ], axis=1)
        return self.dense_layers(all_embs)


In [None]:
class TheTwoTowers(tfrs.models.Model):
    def __init__(self, layer_sizes, query_adapt_data, cat_adapt_data):
        super().__init__()
        self.query_model = UserModel(layer_sizes, query_adapt_data)
        self.candidate_model = ProductModel(layer_sizes, cat_adapt_data)
        self.task = tfrs.tasks.Retrieval(
            metrics=tfrs.metrics.FactorizedTopK(
                candidates=cat_adapt_data.batch(128).cache().map(self.candidate_model),
            )
        )

    def compute_loss(self, data, training=False):
        query_embeddings = self.query_model(data)
        product_embeddings = self.candidate_model(data)

        return self.task(
            query_embeddings,
            product_embeddings,
            compute_metrics=not training
        )#### turn off metrics to save time on training


In [None]:
def _is_chief(task_type, task_id):
    '''Helper function. Determines if machine is chief.'''

    return task_type == 'chief'


def _get_temp_dir(dirpath, task_id):
    '''Helper function. Gets temporary directory for saving model.'''

    base_dirpath = 'workertemp_' + str(task_id)
    temp_dir = os.path.join(dirpath, base_dirpath)
    tf.io.gfile.makedirs(temp_dir)
    return temp_dir


def write_filepath(filepath, task_type, task_id):
    '''Helper function. Gets filepath to save model.'''

    dirpath = os.path.dirname(filepath)
    base = os.path.basename(filepath)
    if not _is_chief(task_type, task_id):
        dirpath = _get_temp_dir(dirpath, task_id)
    return os.path.join(dirpath, base)

In [None]:
def runner(argv):
    strategy = tf.distribute.MultiWorkerMirroredStrategy()

def main(argv):

    logging.set_verbosity(logging.INFO)

    ### SET STRATEGY
    #     gpus = tf.config.list_logical_devices('GPU')
    strategy = tf.distribute.MultiWorkerMirroredStrategy()

    ## get product, query data files
    client = storage.Client()
    files = []
    for blob in client.list_blobs('tfrs-tf-records'):
        files.append(blob.public_url.replace("https://storage.googleapis.com/", "gs://"))


    files_cat = []
    for blob in client.list_blobs('prod-catalog-central'):
        files_cat.append(blob.public_url.replace("https://storage.googleapis.com/", "gs://"))

    files = []
    for blob in client.list_blobs('tfrs-tf-records'):
        files.append(blob.public_url.replace("https://storage.googleapis.com/", "gs://"))


    files_cat = []
    for blob in client.list_blobs('prod-catalog-central'):
        files_cat.append(blob.public_url.replace("https://storage.googleapis.com/", "gs://"))

    ## establish the pipelines
    # Set dev dataset CHANGE THIS LATER TO THE WHOLE DIR
    raw_dataset = tf.data.TFRecordDataset(files) #local machine training wheels - using smaller data set for starters
    cat_dataset = tf.data.TFRecordDataset(files_cat)


    #See `pipeline-opts.ipynb` for more info on tuning options
    parsed_dataset = raw_dataset.map(
        parse_tfrecord_fn,
        num_parallel_calls=tf.data.AUTOTUNE
    )

    # Set AutoShardPolicy
    options = tf.data.Options()
    options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
    parsed_dataset = parsed_dataset.with_options(options)


    # Doing another pipeline for the adapts to get startup to run much faster

    parsed_dataset_adapt = raw_dataset.map(
        parse_tfrecord_fn,
        num_parallel_calls=tf.data.AUTOTUNE
    )

    parsed_dataset_adapt = parsed_dataset_adapt.batch(FLAGS.BATCH_SIZE)

    # parsed_dataset_adapt = parsed_dataset_adapt.batch(BATCH_SIZE)
    # loading de-duplicated product catalog

    parsed_dataset_candidates = cat_dataset.map(
        parse_tfrecord_catalog,
        num_parallel_calls=tf.data.AUTOTUNE
    )

    parsed_dataset_candidates = parsed_dataset_candidates

    logging.info('Setting model adapts and compiling the model')
    # Wrap variable creation within strategy scope
    with strategy.scope():
        model = TheTwoTowers( get_arch_from_string(FLAGS.ARCH), query_adapt_data=parsed_dataset_adapt, cat_adapt_data=parsed_dataset_candidates)
        model.compile(optimizer=tf.keras.optimizers.Adagrad(FLAGS.LR))
    logging.info('Adapts finish - training next')
    tf.random.set_seed(FLAGS.SEED)

    #tensorboard_cb = tf.keras.callbacks.TensorBoard(JOB_DIR, histogram_freq=0)

    tensorboard_cb = tf.keras.callbacks.TensorBoard(
        log_dir=os.environ['AIP_TENSORBOARD_LOG_DIR'], #this sends to the log dir for vertex tensorboard
        histogram_freq=0)

    shuffled = parsed_dataset.shuffle(200_000, seed=FLAGS.SEED, reshuffle_each_iteration=False)
    # shuffled = shuffled.cache()
    parsed_dataset_candidates = parsed_dataset_candidates.shuffle(200_000, seed=FLAGS.SEED, reshuffle_each_iteration=False).prefetch(tf.data.AUTOTUNE)

    # set up the global batch size
    global_batch_size = FLAGS.BATCH_SIZE * strategy.num_replicas_in_sync

    # train_records
    train = shuffled.take(3_000_000)
    test = shuffled.skip(3_000_000).take(200_000)

    cached_train = train.batch(global_batch_size)
    cached_test = test.batch(global_batch_size * 2).cache()
    logging.info('Training starting')
    layer_history = model.fit(
        cached_train,
        validation_data=cached_test,
        validation_freq=5,
        callbacks=[tensorboard_cb],
        epochs=FLAGS.NUM_EPOCHS,
        verbose=0)

    # Determine type and task of the machine from the strategy cluster resolver
    task_type, task_id = (strategy.cluster_resolver.task_type,
                          strategy.cluster_resolver.task_id)
    write_model_path = write_filepath(FLAGS.MODEL_DIR, task_type, task_id)# NEW
    logging.info('Getting evaluation metrics')
    index = tfrs.layers.factorized_top_k.BruteForce(model.query_model)
    val_metrics = model.evaluate(cached_test, return_dict=True) #check performance
    logging.info('Validation metrics below:')
    logging.info(print(val_metrics))
    # Create the index for the scann lookup - this is the query model lookup to the closest n products
    logging.info('Creating trained index')
    index = tfrs.layers.factorized_top_k.BruteForce(model.query_model)
    logging.info(f'Saving model to {FLAGS.MODEL_DIR}')
    tf.saved_model.save(
        index,
        write_model_path
    )
    print('model saved!')
    logging.info('All done - model saved')

if __name__ == "__main__":
    app.run(main)