# Multi-Tower Encoders with TFRS

In [1]:
# set variables
SEED = 41781897
PROJECT_ID = 'hybrid-vertex'
BQ_LOCATION='us-central1'

### Pip

In [18]:
# !pip install tensorflow-recommenders -q --user
# !pip install -U tensorflow-io==0.15.0 --user

### Import packages

In [4]:
import warnings
warnings.filterwarnings("ignore") #do this b/c there's an info-level bug that can safely be ignored

from tensorflow.python.framework import ops
from tensorflow.python.framework import dtypes
from tensorflow_io.bigquery import BigQueryClient
from tensorflow_io.bigquery import BigQueryReadSession

import json
import tensorflow as tf
# import tensorflow_recommenders as tfrs
import datetime
from tensorflow.python.lib.io import file_io
from tensorflow.train import BytesList, Feature, FeatureList, Int64List, FloatList
from tensorflow.train import SequenceExample, FeatureLists

import os

import numpy as np

## Prep Train Data

* Use tensorflow-io to import from BigQuery
* Creates a [TF Dataset](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#methods_2) object 
* Vocab files for `TextVectorization` and `StringLookup` layers pre-computed and saved to GCS


See [end-to-end guide BQ -> TF Dataset](https://www.tensorflow.org/io/tutorials/bigquery#load_census_data_in_tensorflow_dataset_using_bigquery_reader) for tips

In [12]:
# ! rm -rf data_prep
# ! mkdir data_prep

In [5]:
bq_2_tf_dict = {
    'name': {'mode': BigQueryClient.FieldMode.NULLABLE, 'output_type': dtypes.string},
    'collaborative': {'mode': BigQueryClient.FieldMode.NULLABLE, 'output_type': dtypes.string},
    'pid': {'mode': BigQueryClient.FieldMode.NULLABLE, 'output_type': dtypes.int64},
    # 'duration_ms_playlist': {'mode': BigQueryClient.FieldMode.NULLABLE, 'output_type': dtypes.int64},
    # 'pid_pos_id': {'mode':BigQueryClient.FieldMode.NULLABLE, 'output_type': dtypes.string},
    # candidate features
    'pos_can': {'mode':BigQueryClient.FieldMode.NULLABLE, 'output_type': dtypes.int64},
    'artist_name_can': {'mode':BigQueryClient.FieldMode.NULLABLE, 'output_type': dtypes.string},
    'track_uri_can': {'mode':BigQueryClient.FieldMode.NULLABLE, 'output_type': dtypes.string},
    'artist_uri_can': {'mode':BigQueryClient.FieldMode.NULLABLE, 'output_type': dtypes.string},
    'track_name_can': {'mode':BigQueryClient.FieldMode.NULLABLE, 'output_type': dtypes.string},
    'album_uri_can': {'mode':BigQueryClient.FieldMode.NULLABLE, 'output_type': dtypes.string},
    'duration_ms_can': {'mode':BigQueryClient.FieldMode.NULLABLE, 'output_type': dtypes.float64},
    'album_name_can': {'mode':BigQueryClient.FieldMode.NULLABLE, 'output_type': dtypes.string},
    'track_pop_can': {'mode':BigQueryClient.FieldMode.NULLABLE, 'output_type': dtypes.float64},
    'artist_pop_can': {'mode':BigQueryClient.FieldMode.NULLABLE, 'output_type': dtypes.float64},
    'artist_genres_can': {'mode': BigQueryClient.FieldMode.NULLABLE, 'output_type': dtypes.string},
    'artist_followers_can': {'mode':BigQueryClient.FieldMode.NULLABLE, 'output_type': dtypes.float64},
    # seed track features
    'pos_seed_track': {'mode':BigQueryClient.FieldMode.NULLABLE, 'output_type': dtypes.int64},
    'artist_name_seed_track': {'mode':BigQueryClient.FieldMode.NULLABLE, 'output_type': dtypes.string},
    'artist_uri_seed_track': {'mode':BigQueryClient.FieldMode.NULLABLE, 'output_type': dtypes.string},
    'track_name_seed_track': {'mode': BigQueryClient.FieldMode.NULLABLE, 'output_type': dtypes.string},
    'track_uri_seed_track': {'mode': BigQueryClient.FieldMode.NULLABLE, 'output_type': dtypes.string},
    'album_name_seed_track': {'mode': BigQueryClient.FieldMode.NULLABLE, 'output_type': dtypes.string},
    'album_uri_seed_track': {'mode': BigQueryClient.FieldMode.NULLABLE, 'output_type': dtypes.string},
    'duration_seed_track': {'mode': BigQueryClient.FieldMode.NULLABLE, 'output_type': dtypes.float64},
    'track_pop_seed_track': {'mode':BigQueryClient.FieldMode.NULLABLE, 'output_type': dtypes.float64},
    'artist_pop_seed_track': {'mode':BigQueryClient.FieldMode.NULLABLE, 'output_type': dtypes.float64},
    'artist_genres_seed_track': {'mode': BigQueryClient.FieldMode.NULLABLE, 'output_type': dtypes.string},
    'artist_followers_seed_track': {'mode':BigQueryClient.FieldMode.NULLABLE, 'output_type': dtypes.float64},
    ### playlist features
    'duration_ms_seed_pl': {'mode': BigQueryClient.FieldMode.NULLABLE, 'output_type': dtypes.float64},
    'n_songs_pl': {'mode': BigQueryClient.FieldMode.NULLABLE, 'output_type': dtypes.float64},
    'num_artists_pl': {'mode': BigQueryClient.FieldMode.NULLABLE, 'output_type': dtypes.float64},
    'num_albums_pl': {'mode': BigQueryClient.FieldMode.NULLABLE, 'output_type': dtypes.float64},
    'description_pl': {'mode': BigQueryClient.FieldMode.NULLABLE, 'output_type': dtypes.string},
    # 'pos_pl': {'mode': BigQueryClient.FieldMode.REPEATED, 'output_type': dtypes.int64},
    'artist_name_pl': {'mode': BigQueryClient.FieldMode.REPEATED, 'output_type': dtypes.string},
    'track_uri_pl': {'mode': BigQueryClient.FieldMode.REPEATED, 'output_type': dtypes.string},
    'track_name_pl': {'mode': BigQueryClient.FieldMode.REPEATED, 'output_type': dtypes.string},
    'duration_ms_songs_pl': {'mode': BigQueryClient.FieldMode.REPEATED, 'output_type': dtypes.float64},
    'album_name_pl': {'mode': BigQueryClient.FieldMode.REPEATED, 'output_type': dtypes.string},
    'artist_pop_pl': {'mode': BigQueryClient.FieldMode.REPEATED, 'output_type': dtypes.float64},
    'artists_followers_pl': {'mode': BigQueryClient.FieldMode.REPEATED, 'output_type': dtypes.float64},              
    'track_pop_pl': {'mode': BigQueryClient.FieldMode.REPEATED, 'output_type': dtypes.int64},
    'artist_genres_pl': {'mode': BigQueryClient.FieldMode.REPEATED, 'output_type': dtypes.string},
}

In [6]:
BQ_TABLE_TRAIN = 'train_flatten'
BQ_DATASET_TRAIN = 'spotify_train_3'
io_batch_size = 1

bq_client = BigQueryClient()

bqsession = bq_client.read_session(
    "projects/" + PROJECT_ID,
    PROJECT_ID, 
    f'{BQ_TABLE_TRAIN}', 
    f'{BQ_DATASET_TRAIN}',
    bq_2_tf_dict,
    requested_streams=2,
)

dataset = bqsession.parallel_read_rows()
dataset = dataset.prefetch(1).shuffle(io_batch_size*10).batch(io_batch_size)

2022-07-05 20:48:29.884955: W tensorflow/stream_executor/platform/default/dso_loader.cc:59] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64:/usr/local/nccl2/lib:/usr/local/cuda/extras/CUPTI/lib64
2022-07-05 20:48:29.884995: W tensorflow/stream_executor/cuda/cuda_driver.cc:312] failed call to cuInit: UNKNOWN ERROR (303)
2022-07-05 20:48:29.885017: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (spotify-tf-2-9-jt-dev): /proc/driver/nvidia/version does not exist
2022-07-05 20:48:29.885379: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN)to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-0

# Multi-Tower Model

Create a DNN nominator based on a 3-tower architecture, with `playlist`, `seed_track`, and `candidate_track` towers (TODO: insert diagram)

* In this model, we have **two** kinds of query embeddings: `playlist` and `seed_track`
* Within `playlist`, we include the sequence of tracks in each playlist
* We propose an attention-based playlist model to summarize the track sequence

In [12]:
import pickle as pkl
from google.cloud import storage

BUCKET_NAME = 'spotify-v1'
FILE_PATH = 'vocabs/v1_string_vocabs'
FILE_NAME = 'string_vocabs_v1_20220705-202905.txt'
DESTINATION_FILE = 'downloaded_vocabs.txt'

client = storage.Client()

with open(f'{DESTINATION_FILE}', 'wb') as file_obj:
    client.download_blob_to_file(
        f'gs://{BUCKET_NAME}/{FILE_PATH}/{FILE_NAME}', file_obj)

    
with open(f'{DESTINATION_FILE}', 'rb') as pickle_file:
    vocab_dict_load = pkl.load(pickle_file)


In [14]:
# vocab_dict_load

## Playlist Tower
* TODO: Add sequence model for playlist_seed_sequence features

In [None]:
OUTPUT_DIM = 32
PROJECTION_DIM = 5
SEED = 1234

class Playlist_Model(tf.keras.Model):
    def __init__(self, layer_sizes, vocab_dict):
        super().__init__()
        
        # ========================================
        # non-sequence playlist features
        # ========================================
        
        # Feature: playlist name
        self.pl_name_text_embedding = tf.keras.Sequential(
            [
                tf.keras.layers.TextVectorization(
                    max_tokens=len(vocab_dict["name"]),
                    vocabulary=vocab_dict['name'], 
                    name="pl_name_txt_vectorizer", 
                    ngrams=2
                ),
                tf.keras.layers.Embedding(
                    input_dim=len(vocab_dict["name"]),
                    output_dim=OUTPUT_DIM,
                    mask_zero=False,
                    name="pl_name_emb_layer",
                ),
                tf.keras.layers.GlobalAveragePooling1D(name="pl_name_pooling"),
            ], name="pl_name_emb_model"
        )
        
        # Feature: collaborative
        collaborative_vocab = np.array([b'false', b'true'])
        
        self.pl_collaborative_embedding = tf.keras.Sequential(
            [
                tf.keras.layers.StringLookup(
                    vocabulary=collaborative_vocab, 
                    mask_token=None, 
                    name="pl_collaborative_lookup", 
                    output_mode='int'
                ),
                tf.keras.layers.Embedding(
                    input_dim=len(collaborative_vocab),
                    output_dim=OUTPUT_DIM,
                    mask_zero=False,
                    name="pl_collaborative_emb_layer",
                ),
            ], name="pl_collaborative_emb_model"
        )

        # Feature: pid
        self.pl_pid_embedding = tf.keras.Sequential(
            [
                tf.keras.layers.StringLookup(
                    vocabulary=vocab_dict['unique_pids'], 
                    mask_token=None, 
                    name="pl_pid_lookup", 
                    output_mode='int'
                ),
                tf.keras.layers.Embedding(
                    input_dim=len(vocab_dict['unique_pids']),
                    output_dim=OUTPUT_DIM,
                    mask_zero=False,
                    name="pl_pid_emb_layer",
                ),
            ], name="pl_pid_emb_model"
        )
        
        # Feature: description_pl
        self.pl_description_text_embedding = tf.keras.Sequential(
            [
                tf.keras.layers.TextVectorization(
                    max_tokens=len(vocab_dict["description_pl"]),
                    vocabulary=vocab_dict['description_pl'], 
                    name="description_pl_vectorizer", 
                    ngrams=2,
                ),
                tf.keras.layers.Embedding(
                    input_dim=len(vocab_dict["description_pl"]),
                    output_dim=OUTPUT_DIM,
                    mask_zero=False,
                    name="description_pl_emb_layer",
                ),
                tf.keras.layers.GlobalAveragePooling1D(name="description_pl_pooling"),
            ], name="pl_description_emb_model"
        )
        
        # Feature: duration_ms_seed_pl
        duration_ms_seed_pl_buckets = np.linspace(
            vocab_dict['min_duration_ms_seed_pl'], 
            vocab_dict['max_duration_ms_seed_pl'], 
            num=1000
        )
        self.duration_ms_seed_pl_embedding = tf.keras.Sequential(
            [
                tf.keras.layers.Discretization(duration_ms_seed_pl_buckets.tolist()),
                tf.keras.layers.Embedding(
                    input_dim=len(duration_ms_seed_pl_buckets) + 1, 
                    output_dim=OUTPUT_DIM, 
                    name="duration_ms_seed_pl_emb_layer",
                )
            ], name="duration_ms_seed_pl_emb_model"
        )
        # self.duration_ms_seed_pl_normalization = tf.keras.layers.Normalization(
        #     axis=None
        # )
        
        # Feature: n_songs_pl
        n_songs_pl_buckets = np.linspace(
            vocab_dict['min_n_songs_pl'], 
            vocab_dict['max_n_songs_pl'], 
            num=100
        )
        self.n_songs_pl_embedding = tf.keras.Sequential(
            [
                tf.keras.layers.Discretization(n_songs_pl_buckets.tolist()),
                tf.keras.layers.Embedding(
                    input_dim=len(n_songs_pl_buckets) + 1, 
                    output_dim=OUTPUT_DIM, 
                    name="n_songs_pl_emb_layer",
                )
            ], name="n_songs_pl_emb_model"
        )
        # self.n_songs_pl_normalization = tf.keras.layers.Normalization(
        #     axis=None
        # )
        
        # Feature: num_artists_pl
        n_artists_pl_buckets = np.linspace(
            vocab_dict['min_n_artists_pl'], 
            vocab_dict['max_n_artists_pl'], 
            num=100
        )
        self.n_artists_pl_embedding = tf.keras.Sequential(
            [
                tf.keras.layers.Discretization(n_artists_pl_buckets.tolist()),
                tf.keras.layers.Embedding(
                    input_dim=len(n_artists_pl_buckets) + 1, 
                    output_dim=OUTPUT_DIM, 
                    name="n_artists_pl_emb_layer",
                )
            ], name="n_artists_pl_emb_model"
        )
        # self.n_artists_pl_normalization = tf.keras.layers.Normalization(
        #     axis=None
        # )
        
        # Feature: num_albums_pl
        n_albums_pl_buckets = np.linspace(
            vocab_dict['min_n_albums_pl'], 
            vocab_dict['max_n_albums_pl'],
            num=100
        )
        self.n_albums_pl_embedding = tf.keras.Sequential(
            [
                tf.keras.layers.Discretization(n_albums_pl_buckets.tolist()),
                tf.keras.layers.Embedding(
                    input_dim=len(n_albums_pl_buckets) + 1, 
                    output_dim=OUTPUT_DIM, 
                    name="n_albums_pl_emb_layer",
                )
            ], name="n_albums_pl_emb_model"
        )
        # self.n_albums_pl_normalization = tf.keras.layers.Normalization(
        #     axis=None
        # )

        # ========================================
        # sequence playlist features
        # ========================================
        
        # Feature: pos_pl
        # self.pos_pl_embedding = XXX
        
        # Feature: artist_name_pl
        self.artist_name_pl_embedding = tf.keras.Sequential(
            [
                tf.keras.layers.StringLookup(
                    vocabulary=vocab_dict['artist_name_pl'], mask_token=None),
                tf.keras.layers.Embedding(
                    input_dim=len(vocab_dict['artist_name_pl']) + 1, 
                    OUTPUT_DIM,
                    name="artist_name_pl_emb_layer",
                ),
                tf.keras.layers.GRU(OUTPUT_DIM),
            ], name="artist_name_pl_emb_model"
        )
        
        # Feature: track_uri_pl
        self.track_uri_pl_embedding = tf.keras.Sequential(
            [
                tf.keras.layers.Hashing(num_bins=200_000),
                tf.keras.layers.Embedding(
                    input_dim=len(vocab_dict['track_uri_pl']) + 1, 
                    OUTPUT_DIM,
                    name="track_uri_pl_emb_layer",
                ),
                tf.keras.layers.GRU(OUTPUT_DIM),
            ], name="track_uri_pl_emb_model"
        )
        
        # Feature: track_name_pl
        self.track_name_pl_embedding = tf.keras.Sequential(
            [
                tf.keras.layers.StringLookup(
                    vocabulary=vocab_dict['track_name_pl'], mask_token=None),
                tf.keras.layers.Embedding(
                    input_dim=len(vocab_dict['track_name_pl']) + 1, 
                    OUTPUT_DIM,
                    name="track_name_pl_emb_layer",
                ),
                tf.keras.layers.GRU(OUTPUT_DIM),
            ], name="track_name_pl_emb_model"
        )
        
        # Feature: duration_ms_songs_pl
        duration_ms_songs_pl_buckets = np.linspace(
            vocab_dict['min_duration_ms_songs_pl'], 
            vocab_dict['max_duration_ms_songs_pl'], 
            num=1000
        )
        self.duration_ms_songs_pl_embedding = tf.keras.Sequential(
            [
                tf.keras.layers.Discretization(duration_ms_songs_pl_buckets.tolist()),
                tf.keras.layers.Embedding(
                    input_dim=len(duration_ms_songs_pl_buckets) + 1, 
                    OUTPUT_DIM,
                    name="duration_ms_songs_pl_emb_layer",
                ),
                tf.keras.layers.GRU(OUTPUT_DIM),
            ], name="duration_ms_songs_pl_emb_model"
        )
        
        # Feature: album_name_pl
        self.album_name_pl_embedding = tf.keras.Sequential(
            [
                tf.keras.layers.StringLookup(
                    vocabulary=vocab_dict['album_name_pl'], mask_token=None),
                tf.keras.layers.Embedding(
                    input_dim=len(vocab_dict['album_name_pl']) + 1, 
                    OUTPUT_DIM,
                    name="album_name_pl_emb_layer",
                ),
                tf.keras.layers.GRU(OUTPUT_DIM),
            ], name="album_name_pl_emb_model"
        )
        
        # Feature: artist_pop_pl
        artist_pop_pl_buckets = np.linspace(
            vocab_dict['min_artist_pop'], 
            vocab_dict['max_artist_pop'], 
            num=10
        )
        self.artist_pop_pl_embedding = tf.keras.Sequential(
            [
                tf.keras.layers.Discretization(artist_pop_pl_buckets.tolist()),
                tf.keras.layers.Embedding(
                    input_dim=len(artist_pop_pl_buckets) + 1, 
                    OUTPUT_DIM,
                    name="artist_pop_pl_emb_layer",
                ),
                tf.keras.layers.GRU(OUTPUT_DIM),
            ], name="artist_pop_pl_emb_model"
        )
        
        # Feature: artists_followers_pl
        artists_followers_pl_buckets = np.linspace(
            vocab_dict['min_artist_followers'], 
            vocab_dict['max_artist_followers'], 
            num=10
        )
        self.artists_followers_pl_embedding = tf.keras.Sequential(
            [
                tf.keras.layers.Discretization(artists_followers_pl_buckets.tolist()),
                tf.keras.layers.Embedding(
                    input_dim=len(artists_followers_pl_buckets) + 1, 
                    OUTPUT_DIM,
                    name="artists_followers_pl_emb_layer",
                ),
                tf.keras.layers.GRU(OUTPUT_DIM),
            ], name="artists_followers_pl_emb_model"
        )
        
        # Feature: track_pop_pl
        track_pop_pl_buckets = np.linspace(
            vocab_dict['min_track_pop'], 
            vocab_dict['max_track_pop'], 
            num=10
        )
        self.track_pop_pl_embedding = tf.keras.Sequential(
            [
                tf.keras.layers.Discretization(track_pop_pl_buckets.tolist()),
                tf.keras.layers.Embedding(
                    input_dim=len(track_pop_pl_buckets) + 1, 
                    OUTPUT_DIM,
                    name="track_pop_pl_emb_layer",
                ),
                tf.keras.layers.GRU(OUTPUT_DIM),
            ], name="track_pop_pl_emb_model"
        )
        
        # Feature: artist_genres_pl
        self.artist_genres_pl_embedding = tf.keras.Sequential(
            [
                tf.keras.layers.StringLookup(
                    vocabulary=vocab_dict['artist_genres_pl'], mask_token=None),
                tf.keras.layers.Embedding(
                    input_dim=len(vocab_dict['artist_genres_pl']) + 1, 
                    OUTPUT_DIM,
                    name="artist_genres_pl_emb_layer",
                ),
                tf.keras.layers.GRU(OUTPUT_DIM),
            ], name="artist_genres_pl_emb_model"
        )


        # ========================================
        # Cross Layers
        # ========================================
        # Cross Layers
        if USE_CROSS_LAYER:
            self._cross_layer = tfrs.layers.dcn.Cross(
                projection_dim=PROJECTION_DIM,
                kernel_initializer="glorot_uniform", 
                name="pl_cross_layer"
            )
        else:
            self._cross_layer = None
            
        # Dense Layer
        self.dense_layers = tf.keras.Sequential(name="pl_dense_layers")
        initializer = tf.keras.initializers.GlorotUniform(seed=SEED)
        
        # Use the ReLU activation for all but the last layer.
        for layer_size in layer_sizes[:-1]:
            self.dense_layers.add(
                tf.keras.layers.Dense(
                    layer_size, 
                    activation="relu", 
                    kernel_initializer=initializer,
                )
            )
            if DROPOUT:
                self.dense_layers.add(
                    tf.keras.layers.Dropout(
                        DROPOUT_RATE
                    )
                )
                
        # No activation for the last layer
        for layer_size in layer_sizes[-1:]:
            self.dense_layers.add(
                tf.keras.layers.Dense(
                    layer_size, 
                    kernel_initializer=initializer
                )
            )
        ### ADDING L2 NORM AT THE END
        self.dense_layers.add(
            tf.keras.layers.Lambda(
                lambda x: tf.nn.l2_normalize(
                    x, 1, epsilon=1e-12, name="normalize_dense"
                )
            )
        )
    
    def call(self, data):
        '''
        The call method defines what happens when
        the model is called
        '''
        
        all_embs = tf.concat(
            [
                self.pl_name_text_embedding(data['name']),
                self.pl_collaborative_embedding(data['collaborative']),
                self.pl_pid_embedding(data["pid"]),
                # data["pid_pos_id"],
                self.pl_description_text_embedding(data['description_pl']),
                self.duration_ms_seed_pl_embedding(data["duration_ms_seed_pl"]),
                self.n_songs_pl_embedding(data["n_songs_pl"]),
                self.n_artists_pl_embedding(data["num_artists_pl"]),
                self.n_albums_pl_embedding(data["num_albums_pl"]),
                # sequence features
                # data["pos_pl"],
                self.artist_name_pl_embedding(data["artist_name_pl"]),
                self.track_uri_pl_embedding(data["track_uri_pl"]),
                self.track_name_pl_embedding(data["track_name_pl"]),
                self.duration_ms_songs_pl_embedding(data["duration_ms_songs_pl"]),
                self.album_name_pl_embedding(data["album_name_pl"]),
                self.artist_pop_pl_embedding(data["artist_pop_pl"]),
                self.artists_followers_pl_embedding(data["artists_followers_pl"]),
                self.track_pop_pl_embedding(data["track_pop_pl"]),
                self.artist_genres_pl_embedding(data["artist_genres_pl"]),
            ], axis=1)
        # Build Cross Network
        if self._cross_layer is not None:
            cross_embs = self._cross_layer(all_embs)
            return self.dense_layers(cross_embs)
        else:
            return self.dense_layers(all_embs)

## Seed Track Tower

In [None]:
class Seed_Track_Model(tf.keras.Model):
    def __init__(self, layer_sizes, vocab_dict):
        super().__init__()
        
        # ========================================
        # seed track features
        # ========================================
        
        # Feature: pos_seed_track
        
        # Feature: artist_name_seed_track
        self.artist_name_seed_track_text_embedding = tf.keras.Sequential(
            [
                tf.keras.layers.TextVectorization(
                    max_tokens=len(vocab_dict["artist_name_seed_track"]),
                    vocabulary=vocab_dict["artist_name_seed_track"],
                    name="artist_name_seed_track_txt_vectorizer",
                    ngrams=2,
                ),
                tf.keras.layers.Embedding(
                    input_dim=(vocab_dict["artist_name_seed_track"]),
                    output_dim=OUTPUT_DIM,
                    mask_zero=False,
                    name="artist_name_seed_track_emb_layer",
                ),
                tf.keras.layers.GlobalAveragePooling1D(name="artist_name_seed_track_pooling"),
            ], name="artist_name_seed_track_emb_model"
        )
        
        # Feature: artist_uri_seed_track
        # self.artist_uri_seed_track_hasher = tf.keras.layers.Hashing(num_bins=200_000)
        # self.artist_uri_seed_track_embedding = tf.keras.layers.CategoryEncoding(num_tokens=64, output_mode="int")
        self.artist_uri_seed_track_embedding = tf.keras.Sequential(
            [
                tf.keras.layers.StringLookup(
                    vocabulary=vocab_dict['artist_uri_seed_track'], mask_token=None),
                tf.keras.layers.Embedding(
                    input_dim=len(vocab_dict['artist_uri_seed_track']) + 1, 
                    OUTPUT_DIM,
                    name="artist_uri_seed_track_emb_layer",
                ),
            ], name="artist_uri_seed_track_emb_model"
        )
        
        # Feature: track_name_seed_track
        self.track_name_seed_track_text_embedding = tf.keras.Sequential(
            [
                tf.keras.layers.TextVectorization(
                    max_tokens=len(vocab_dict["track_name_seed_track"]),
                    vocabulary=vocab_dict["track_name_seed_track"],
                    name="track_name_seed_track_txt_vectorizer",
                    ngrams=2,
                ),
                tf.keras.layers.Embedding(
                    input_dim=(vocab_dict["track_name_seed_track"]),
                    output_dim=OUTPUT_DIM,
                    mask_zero=False,
                    name="track_name_seed_track_emb_layer",
                ),
                tf.keras.layers.GlobalAveragePooling1D(name="track_name_seed_track_pooling"),
            ], name="track_name_seed_track_emb_model"
        )
        # Feature: track_uri_seed_track
        self.track_uri_seed_track_embedding = tf.keras.Sequential(
            [
                tf.keras.layers.StringLookup(
                    vocabulary=vocab_dict['track_uri_seed_track'], mask_token=None),
                tf.keras.layers.Embedding(
                    input_dim=len(vocab_dict['track_uri_seed_track']) + 1, 
                    OUTPUT_DIM,
                    name="track_uri_seed_track_emb_layer",
                ),
            ], name="track_uri_seed_track_emb_model"
        )
        
        # Feature: album_name_seed_track
        self.album_name_seed_track_text_embedding = tf.keras.Sequential(
            [
                tf.keras.layers.TextVectorization(
                    max_tokens=len(vocab_dict["album_name_seed_track"]),
                    vocabulary=vocab_dict["album_name_seed_track"],
                    name="album_name_seed_track_txt_vectorizer",
                    ngrams=2,
                ),
                tf.keras.layers.Embedding(
                    input_dim=(vocab_dict["album_name_seed_track"]),
                    output_dim=OUTPUT_DIM,
                    mask_zero=False,
                    name="album_name_seed_track_emb_layer",
                ),
                tf.keras.layers.GlobalAveragePooling1D(name="album_name_seed_track_pooling"),
            ], name="album_name_seed_track_emb_model"
        )
        # Feature: album_uri_seed_track
        self.album_uri_seed_track_embedding = tf.keras.Sequential(
            [
                tf.keras.layers.StringLookup(
                    vocabulary=vocab_dict['album_uri_seed_track'], mask_token=None),
                tf.keras.layers.Embedding(
                    input_dim=len(vocab_dict['album_uri_seed_track']) + 1, 
                    OUTPUT_DIM,
                    name="album_uri_seed_track_emb_layer",
                ),
            ], name="album_uri_seed_track_emb_model"
        )
        # Feature: duration_seed_track
        self.duration_seed_track_normalized = tf.keras.layers.Normalization(axis=None)
        
        # Feature: track_pop_seed_track
        self.track_pop_seed_track_normalized = tf.keras.layers.Normalization(axis=None)
        
        # Feature: artist_pop_seed_track
        self.artist_pop_seed_track_normalized = tf.keras.layers.Normalization(axis=None)
        
        # Feature: artist_genres_seed_track
        self.artist_genres_seed_track_text_embedding = tf.keras.Sequential(
            [
                tf.keras.layers.TextVectorization(
                    max_tokens=len(vocab_dict["artist_genres_seed_track"]),
                    vocabulary=vocab_dict["artist_genres_seed_track"],
                    name="artist_genres_seed_track_txt_vectorizer",
                    ngrams=2,
                ),
                tf.keras.layers.Embedding(
                    input_dim=(vocab_dict["artist_genres_seed_track"]),
                    output_dim=OUTPUT_DIM,
                    mask_zero=False,
                    name="artist_genres_seed_track_emb_layer",
                ),
                tf.keras.layers.GlobalAveragePooling1D(name="artist_genres_seed_track_pooling"),
            ], name="artist_genres_seed_track_emb_model"
        )
        
        # Feature: artist_followers_seed_track
        self.artist_followers_seed_track_normalized = tf.keras.layers.Normalization(axis=None)
        
        # ========================================
        # seed track cross_layers
        # ========================================
        
        # Cross Layers
        if USE_CROSS_LAYER:
            self._cross_layer = tfrs.layers.dcn.Cross(
                projection_dim=PROJECTION_DIM,
                kernel_initializer="glorot_uniform", 
                name="seed_track_cross_layer"
            )
        else:
            self._cross_layer = None
            
        # Dense Layer
        self.dense_layers = tf.keras.Sequential(name="seed_track_dense_layers")
        initializer = tf.keras.initializers.GlorotUniform(seed=SEED)
        
        # Use the ReLU activation for all but the last layer.
        for layer_size in layer_sizes[:-1]:
            self.dense_layers.add(
                tf.keras.layers.Dense(
                    layer_size, 
                    activation="relu", 
                    kernel_initializer=initializer,
                )
            )
            if DROPOUT:
                self.dense_layers.add(
                    tf.keras.layers.Dropout(
                        DROPOUT_RATE
                    )
                )
                
        # No activation for the last layer
        for layer_size in layer_sizes[-1:]:
            self.dense_layers.add(
                tf.keras.layers.Dense(
                    layer_size, 
                    kernel_initializer=initializer
                )
            )
        ### ADDING L2 NORM AT THE END
        self.dense_layers.add(
            tf.keras.layers.Lambda(
                lambda x: tf.nn.l2_normalize(
                    x, 1, epsilon=1e-12, name="normalize_dense"
                )
            )
        )
    
    def call(self, data):
        
        all_embs = tf.concat(
            [
                data['pos_seed_track'],
                self.artist_name_seed_track_text_embedding(data['artist_name_seed_track']),
                self.artist_uri_seed_track_embedding(data['artist_uri_seed_track']),
                self.track_name_seed_track_text_embedding(data['track_name_seed_track']),
                self.track_uri_seed_track_embedding(data["track_uri_seed_track"]),
                self.album_name_seed_track_text_embedding(data["album_name_seed_track"]),
                self.album_uri_seed_track_embedding(data["album_uri_seed_track"]),
                tf.reshape(self.duration_seed_track_normalized(data["duration_seed_track"]), (-1, 1))
                tf.reshape(self.duration_seed_track_normalized(data["track_pop_seed_track"]), (-1, 1))
                tf.reshape(self.duration_seed_track_normalized(data["artist_pop_seed_track"]), (-1, 1))
                self.artist_genres_seed_track_text_embedding(data["artist_genres_seed_track"]),
                tf.reshape(self.artist_followers_seed_track_normalized(data["artist_followers_seed_track"]), (-1, 1))
            ], axis=1)
        # Build Cross Network
        if self._cross_layer is not None:
            cross_embs = self._cross_layer(all_embs)
            return self.dense_layers(cross_embs)
        else:
            return self.dense_layers(all_embs)

## Candidate Track Tower

In [None]:
class Candidate_Track_Model(tf.keras.Model):
    def __init__(self, layer_sizes, vocab_dict):
        super().__init__()
        
        # ========================================
        # seed track features
        # ========================================
        
        # Feature: pos_can
        
        # Feature: artist_name_can
        self.artist_name_can_text_embedding = tf.keras.Sequential(
            [
                tf.keras.layers.TextVectorization(
                    max_tokens=len(vocab_dict["artist_name_can"]),
                    vocabulary=vocab_dict["artist_name_can"],
                    name="artist_name_can_txt_vectorizer",
                    ngrams=2,
                ),
                tf.keras.layers.Embedding(
                    input_dim=(vocab_dict["artist_name_can"]),
                    output_dim=OUTPUT_DIM,
                    mask_zero=False,
                    name="artist_name_can_emb_layer",
                ),
                tf.keras.layers.GlobalAveragePooling1D(name="artist_name_can_pooling"),
            ], name="artist_name_can_emb_model"
        )
        
        # Feature: track_uri_can
        self.track_uri_can_embedding = tf.keras.Sequential(
            [
                tf.keras.layers.StringLookup(
                    vocabulary=vocab_dict['track_uri_can'], mask_token=None),
                tf.keras.layers.Embedding(
                    input_dim=len(vocab_dict['track_uri_can']) + 1, 
                    OUTPUT_DIM,
                    name="track_uri_can_emb_layer",
                ),
            ], name="track_uri_can_emb_model"
        )
        
        # Feature: artist_uri_can
        
        # Feature: track_name_can
        
        # Feature: album_uri_can
        
        # Feature: duration_ms_can
        
        # Feature: album_name_can
        
        # Feature: track_pop_can
        
        # Feature: artist_pop_can
        
        # Feature: artist_genres_can
        
        # Feature: artist_followers_can

## Combined Model

In [None]:
# context_features_ = {
#     # playlist features
#     "name": tf.io.FixedLenFeature(dtype=tf.string, shape=(1)),
#     "collaborative": tf.io.FixedLenFeature(dtype=tf.string, shape=(1)),
#     "pid": tf.io.FixedLenFeature(dtype=tf.int64, shape=(1)),
#     "duration_ms_seed_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(1)),
#     "n_songs_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(1)),
#     "num_artists_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(1)),
#     "num_albums_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(1)),
#     "description_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(1)),

#     # candidate features
#     "pos_can": tf.io.FixedLenFeature(dtype=tf.int64, shape=(1)),
#     "artist_name_can": tf.io.FixedLenFeature(dtype=tf.string, shape=(1)),
#     "track_uri_can": tf.io.FixedLenFeature(dtype=tf.string, shape=(1)),
#     "artist_uri_can": tf.io.FixedLenFeature(dtype=tf.string, shape=(1)),
#     "track_name_can": tf.io.FixedLenFeature(dtype=tf.string, shape=(1)),
#     "album_uri_can": tf.io.FixedLenFeature(dtype=tf.string, shape=(1)),
#     "duration_ms_can": tf.io.FixedLenFeature(dtype=tf.float32, shape=(1)),
#     "album_name_can": tf.io.FixedLenFeature(dtype=tf.string, shape=(1)),
#     "track_pop_can": tf.io.FixedLenFeature(dtype=tf.float32, shape=(1)),
#     "artist_pop_can": tf.io.FixedLenFeature(dtype=tf.float32, shape=(1)),
#     "artist_genres_can": tf.io.FixedLenFeature(dtype=tf.float32, shape=(1)),
#     "artist_followers_can": tf.io.FixedLenFeature(dtype=tf.float32, shape=(1)),

#     # seed track features
#     "pos_seed_track": tf.io.FixedLenFeature(dtype=tf.int64, shape=(1)),
#     "artist_name_seed_track": tf.io.FixedLenFeature(dtype=tf.string, shape=(1)),
#     "artist_uri_seed_track": tf.io.FixedLenFeature(dtype=tf.string, shape=(1)),
#     "track_name_seed_track": tf.io.FixedLenFeature(dtype=tf.string, shape=(1)),
#     "track_uri_seed_track": tf.io.FixedLenFeature(dtype=tf.string, shape=(1)),
#     "album_name_seed_track": tf.io.FixedLenFeature(dtype=tf.string, shape=(1)),
#     "album_uri_seed_track": tf.io.FixedLenFeature(dtype=tf.string, shape=(1)),
#     "duration_seed_track": tf.io.FixedLenFeature(dtype=tf.float32, shape=(1)),
#     "track_pop_seed_track": tf.io.FixedLenFeature(dtype=tf.float32, shape=(1)),
#     "artist_pop_seed_track": tf.io.FixedLenFeature(dtype=tf.float32, shape=(1)),
#     "artist_genres_seed_track": tf.io.FixedLenFeature(dtype=tf.string, shape=(1)),
#     "artist_followers_seed_track": tf.io.FixedLenFeature(dtype=tf.float32, shape=(1)),
# }

# sequence_features_ = {
#     "pos_pl": tf.io.RaggedFeature(tf.int64),
#     "artist_name_pl": tf.io.RaggedFeature(tf.string),
#     "track_uri_pl": tf.io.RaggedFeature(tf.string),
#     "track_name_pl": tf.io.RaggedFeature(tf.string),
#     "duration_ms_songs_pl": tf.io.RaggedFeature(tf.float32),
#     "album_name_pl": tf.io.RaggedFeature(tf.string),
#     "artist_pop_pl": tf.io.RaggedFeature(tf.float32),
#     "artists_followers_pl": tf.io.RaggedFeature(tf.float32),
#     "track_pop_pl": tf.io.RaggedFeature(tf.float32),
#     "artist_genres_pl": tf.io.RaggedFeature(tf.string),
# }

#### References
* [multi-head_and torso: Dugtrio code](https://source.corp.google.com/piper///depot/google3/learning/brain/models/recommendations/dugtrio/torso/multi_head_torso.py)
* [CPU-optimized sparse implementation of attention](https://source.corp.google.com/piper///depot/google3/ads/ubaq/brain_embeddings/tensorflow/layers.py;l=1234;rcl=202730546?q=layers.py%20f:brain_embeddings&dr=)

#### Notes
* `tf.feature_column.sequence_categorical_column_with_identity` - "Pass this to embedding_column or indicator_column to convert sequence categorical data into dense representation for input to sequence NN, such as RN" [docs](https://www.tensorflow.org/api_docs/python/tf/feature_column/sequence_categorical_column_with_identity)
* Implementing a sequential model with Two Tower [docs](https://www.tensorflow.org/recommenders/examples/sequential_retrieval#implementing_a_sequential_model)