# Build baseline tfrs model 

Look inside of `./two_tower_src/` for the source code and model code

In [1]:
PROJECT_ID = 'hybrid-vertex'  # <--- TODO: CHANGE THIS
LOCATION = 'us-central1' 

In [2]:
# !pip install tensorflow-recommenders==0.6.0 --user

In [3]:
import json
import tensorflow as tf
import tensorflow_recommenders as tfrs

from google.cloud import storage

import numpy as np
import pickle as pkl
from pprint import pprint

from two_tower_src import two_tower as tt

2022-09-26 00:00:55.675504: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-09-26 00:00:56.305055: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 38238 MB memory:  -> device: 0, name: NVIDIA A100-SXM4-40GB, pci bus id: 0000:00:04.0, compute capability: 8.0


## Create Dataset for local training and testing

### Playlist dataset

In [4]:
#get the feature dictionaries from the src module
cont_feats = tt.cont_feats
seq_feats = tt.seq_feats
candidate_features = tt.candidate_features

client = storage.Client()
# # gs://spotify-beam-v3/v3/candidates/*.tfrecords

BUCKET = 'spotify-beam-v3'
CANDIDATE_PREFIX = 'v3/train/'

valid_files = []
for blob in client.list_blobs(f"{BUCKET}", prefix=f'{CANDIDATE_PREFIX}', delimiter="/"):
    valid_files.append(blob.public_url.replace("https://storage.googleapis.com/", "gs://"))
    
raw_dataset = tf.data.TFRecordDataset(valid_files)

BUCKET = 'spotify-beam-v3'
CANDIDATE_PREFIX = 'v3/valid/'

valid_files = []
for blob in client.list_blobs(f"{BUCKET}", prefix=f'{CANDIDATE_PREFIX}', delimiter="/"):
    valid_files.append(blob.public_url.replace("https://storage.googleapis.com/", "gs://"))
    
raw_dataset_valid = tf.data.TFRecordDataset(valid_files)

In [5]:
#parse the features
parsed_dataset = raw_dataset.map(tt.parse_tfrecord)
parsed_dataset_valid = raw_dataset_valid.map(tt.parse_tfrecord)

#pad the ragged tensors to MAX_PLAYLIST_LENTGH
parsed_dataset_padded = parsed_dataset.map(tt.return_padded_tensors)   
parsed_dataset_padded_valid = parsed_dataset_valid.map(tt.return_padded_tensors)   

In [6]:
# for features in parsed_dataset_padded_valid.skip(3).take(1):
#     pprint(features)
#     print("_______________")

### Candidate Track dataset

In [7]:
for features in tt.parsed_candidate_dataset.take(1):
    pprint(features)
    print("_______________")

{'album_name_can': <tf.Tensor: shape=(), dtype=string, numpy=b'The Sound of Everything Rmx'>,
 'album_uri_can': <tf.Tensor: shape=(), dtype=string, numpy=b'spotify:album:4a8tMD6qq6GUuUwNae38VI'>,
 'artist_followers_can': <tf.Tensor: shape=(), dtype=float32, numpy=277649.0>,
 'artist_genres_can': <tf.Tensor: shape=(), dtype=string, numpy=b"'downtempo', 'electronica', 'funk', 'latin alternative', 'nu jazz', 'nu-cumbia', 'trip hop', 'world'">,
 'artist_name_can': <tf.Tensor: shape=(), dtype=string, numpy=b'Quantic'>,
 'artist_pop_can': <tf.Tensor: shape=(), dtype=float32, numpy=64.0>,
 'artist_uri_can': <tf.Tensor: shape=(), dtype=string, numpy=b'spotify:artist:5ZMwoAjeDtLJ0XRwRTgaK8'>,
 'duration_ms_can': <tf.Tensor: shape=(), dtype=float32, numpy=267130.0>,
 'track_name_can': <tf.Tensor: shape=(), dtype=string, numpy=b'The Sound of Everything - Watch TV & Se\xc3\xb1orlobo Remix'>,
 'track_pop_can': <tf.Tensor: shape=(), dtype=float32, numpy=53.0>,
 'track_uri_can': <tf.Tensor: shape=(),

# Two-Tower Model Testing and Demonstration

## Playlist (query) Tower

### Test playlist tower

In [7]:
layer_sizes=[64,32]

test_playlist_model = tt.Playlist_Model(layer_sizes,tt.vocab_dict_load)

# test_playlist_model.pl_name_text_embedding.layers[0].adapt(parsed_dataset_padded.map(lambda x: x['name']).batch(1000))

# print("Adapts complete for name")

#### Some examples below of how to access an adapted vocabulary and set to a new key in the vocab_dict

In [8]:
# test_playlist_model.pl_name_text_embedding.layers[0].get_vocabulary()
# vocab_dict_load['name_voacb'] = test_playlist_model.pl_name_text_embedding.layers[0].get_vocabulary()

In [9]:
#test with the source batched datset
batched_dataset = parsed_dataset_padded.batch(2)
for x in batched_dataset.take(1):
    print(test_playlist_model(x))

tf.Tensor(
[[-0.24514134 -0.28178468 -0.14416333 -0.4173075  -0.20545025  0.03124828
  -0.01103428  0.17916968  0.12727895  0.20036177 -0.0101915  -0.01609189
   0.10526141  0.24144886  0.07728486 -0.32014552 -0.0897892  -0.06304988
   0.2856185  -0.12049487 -0.22993435  0.15289545  0.08797454 -0.20615126
  -0.0870543  -0.19611506 -0.1266518  -0.06482445 -0.12822834 -0.09181462
   0.13239121 -0.11096024]
 [-0.14029537  0.06033436  0.09287861 -0.28724906 -0.09712798  0.5612625
   0.20269902  0.06799392 -0.00701721 -0.10050222  0.19178475  0.0269598
  -0.04040827  0.10064789 -0.080539   -0.40841147 -0.05291994 -0.25351888
   0.12378392  0.08894099 -0.25947258  0.05428258 -0.23103    -0.13233204
  -0.0062121  -0.01586104  0.00975485  0.0135821  -0.01611765  0.13543643
  -0.1688953   0.05923532]], shape=(2, 32), dtype=float32)


2022-09-25 23:43:54.131216: I tensorflow/stream_executor/cuda/cuda_blas.cc:1786] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.


In [11]:
del test_playlist_model

In [27]:
test_playlist_model.summary(expand_nested=True)

Model: "playlist__model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 pl_name_emb_model (Sequenti  (None, 32)               2368928   
 al)                                                             
|¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯|
| pl_name_txt_vectorizer (Tex  (None, None)           0         |
| tVectorization)                                               |
|                                                               |
| pl_name_emb_layer (Embeddin  (None, None, 32)       2368928   |
| g)                                                            |
|                                                               |
| pl_name_pooling (GlobalAver  (None, 32)             0         |
| agePooling1D)                                                 |
¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯
 pl_collaborative_emb_model   (10, 32)             

## Track (candidate) Tower 

### Candidate model test
Note adapts are done on the unique candidate dataset to save time

In [13]:
layer_sizes=[64,32]

test_can_track_model = tt.Candidate_Track_Model(layer_sizes, tt.vocab_dict_load)

#adapts
# test_can_track_model.artist_name_can_text_embedding.layers[0].adapt(parsed_candidate_dataset.map(lambda x: x['artist_name_can']).batch(1000))
# test_can_track_model.track_name_can_text_embedding.layers[0].adapt(parsed_candidate_dataset.map(lambda x: x['track_name_can']).batch(1000))
# test_can_track_model.album_name_can_text_embedding.layers[0].adapt(parsed_candidate_dataset.map(lambda x: x['album_name_can']).batch(1000))
# test_can_track_model.artist_genres_can_text_embedding.layers[0].adapt(parsed_candidate_dataset.map(lambda x: x['artist_genres_can']).batch(1000))

# can_result = test_can_track_model([candidate_test_instance])

# print(f"Shape of can_result: {can_result.shape}")
# can_result

## Validate the output in a small batch 

In [15]:
#test with the source batched datset and candidate dataset
batched_dataset = tt.parsed_candidate_dataset.batch(2)
for x in batched_dataset.take(1):
    print(test_can_track_model(x))

tf.Tensor(
[[-0.11201185 -0.206755   -0.18045759  0.07737618  0.03264336 -0.33545116
  -0.1591977   0.19581121 -0.06903844  0.01181059  0.4259455   0.15603209
  -0.18393841 -0.16382658 -0.18161613 -0.21947345  0.48698536 -0.02518937
   0.12081616 -0.4297281   0.2016943  -0.52300894  0.04528558  0.40296668
  -0.42080963  0.01119863  0.11597577  0.05094406  0.05474314  0.18036497
   0.07655917 -0.1995453 ]
 [ 0.08474436 -0.2626423  -0.15325345  0.02726585  0.07601295  0.00363033
   0.14329386  0.07862413 -0.00256308  0.21880352  0.32578433 -0.07587124
  -0.09890726  0.00391698  0.05119768 -0.20095554 -0.04316837 -0.05883677
   0.14658357 -0.17711747 -0.39837694 -0.44184572 -0.26477706  0.19551444
  -0.27569064 -0.00155214 -0.30036432  0.03524041 -0.19633287  0.24641578
   0.7029978   0.04300903]], shape=(2, 32), dtype=float32)


In [None]:
# test_can_track_model.summary(expand_nested=True)

## Combined Model

# Local Training

In [8]:
layer_sizes=[64,32]

model = tt.TheTwoTowers(layer_sizes)

model.compile(optimizer=tf.keras.optimizers.Adagrad(0.01))

In [9]:
## Quick look at the layers
print("Playlist (query) Tower:")

for i, l in enumerate(model.query_tower.layers):
    print(i, l.name)

Playlist (query) Tower:
0 pl_name_emb_model
1 pl_collaborative_emb_model
2 pl_track_uri_emb_model
3 n_songs_pl_emb_model
4 n_artists_pl_emb_model
5 n_albums_pl_emb_model
6 artist_name_pl_emb_model
7 track_uri_pl_emb_model
8 track_name_pl_emb_model
9 duration_ms_songs_pl_emb_model
10 album_name_pl_emb_model
11 artist_pop_pl_emb_model
12 artists_followers_pl_emb_model
13 track_pop_pl_emb_model
14 artist_genres_pl_emb_model
15 pl_cross_layer
16 pl_dense_layers


In [10]:
print("Track (candidate) Tower:")
for i, l in enumerate(model.candidate_tower.layers):
    print(i, l.name)

Track (candidate) Tower:
0 artist_name_can_emb_model
1 track_name_can_emb_model
2 album_name_can_emb_model
3 artist_uri_can_emb_model
4 track_uri_can_emb_model
5 album_uri_can_emb_model
6 normalization
7 normalization_1
8 normalization_2
9 normalization_3
10 artist_genres_can_emb_model
11 can_cross_layer
12 candidate_dense_layers


### Set the shuffle buffer and batch size for train and validate
Split occured in BQ via random digit - see BQ notebook for details

In [64]:
tf.random.set_seed(42)
shuffle_train = parsed_dataset_padded.shuffle(10_000, seed=42, reshuffle_each_iteration=False)

TRAIN = shuffle_train.batch(2048)
VALID = parsed_dataset_padded_valid.batch(2048)

In [66]:
%%timeit -n 1 -r 1


NUM_EPOCHS = 1
start_time = time.time()
layer_history = model.fit(
    TRAIN,
    validation_data=VALID,
    # validation_freq=5,
    epochs=NUM_EPOCHS,
    # callbacks=tensorboard_cb,
    # verbose=0
)

print(f"Training for {NUM_EPOCHS} epoch")
accuracy = layer_history.history["factorized_top_k/top_100_categorical_accuracy"][-1]
print(f"Top 100 categorical accuracy: {accuracy}")

   8230/Unknown - 5542s 673ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_5_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_10_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_50_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_100_categorical_accuracy: 0.0000e+00 - loss: 1990.2237 - regularization_loss: 0.0000e+00 - total_loss: 1990.2237

IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)



  10669/Unknown - 7236s 678ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_5_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_10_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_50_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_100_categorical_accuracy: 0.0000e+00 - loss: 1798.4900 - regularization_loss: 0.0000e+00 - total_loss: 1798.4900

IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)



  16400/Unknown - 11026s 672ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_5_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_10_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_50_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_100_categorical_accuracy: 0.0000e+00 - loss: 1512.1358 - regularization_loss: 0.0000e+00 - total_loss: 1512.1358

IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)



  21107/Unknown - 14224s 674ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_5_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_10_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_50_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_100_categorical_accuracy: 0.0000e+00 - loss: 1361.3310 - regularization_loss: 0.0000e+00 - total_loss: 1361.3310

NotFoundError: Graph execution error:

Detected at node 'IteratorGetNext' defined at (most recent call last):
    File "/opt/conda/lib/python3.7/runpy.py", line 193, in _run_module_as_main
      "__main__", mod_spec)
    File "/opt/conda/lib/python3.7/runpy.py", line 85, in _run_code
      exec(code, run_globals)
    File "/opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py", line 17, in <module>
      app.launch_new_instance()
    File "/opt/conda/lib/python3.7/site-packages/traitlets/config/application.py", line 978, in launch_instance
      app.start()
    File "/opt/conda/lib/python3.7/site-packages/ipykernel/kernelapp.py", line 712, in start
      self.io_loop.start()
    File "/opt/conda/lib/python3.7/site-packages/tornado/platform/asyncio.py", line 215, in start
      self.asyncio_loop.run_forever()
    File "/opt/conda/lib/python3.7/asyncio/base_events.py", line 541, in run_forever
      self._run_once()
    File "/opt/conda/lib/python3.7/asyncio/base_events.py", line 1786, in _run_once
      handle._run()
    File "/opt/conda/lib/python3.7/asyncio/events.py", line 88, in _run
      self._context.run(self._callback, *self._args)
    File "/opt/conda/lib/python3.7/site-packages/ipykernel/kernelbase.py", line 510, in dispatch_queue
      await self.process_one()
    File "/opt/conda/lib/python3.7/site-packages/ipykernel/kernelbase.py", line 499, in process_one
      await dispatch(*args)
    File "/opt/conda/lib/python3.7/site-packages/ipykernel/kernelbase.py", line 406, in dispatch_shell
      await result
    File "/opt/conda/lib/python3.7/site-packages/ipykernel/kernelbase.py", line 730, in execute_request
      reply_content = await reply_content
    File "/opt/conda/lib/python3.7/site-packages/ipykernel/ipkernel.py", line 387, in do_execute
      cell_id=cell_id,
    File "/opt/conda/lib/python3.7/site-packages/ipykernel/zmqshell.py", line 528, in run_cell
      return super().run_cell(*args, **kwargs)
    File "/opt/conda/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 2975, in run_cell
      raw_cell, store_history, silent, shell_futures, cell_id
    File "/opt/conda/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3029, in _run_cell
      return runner(coro)
    File "/opt/conda/lib/python3.7/site-packages/IPython/core/async_helpers.py", line 78, in _pseudo_sync_runner
      coro.send(None)
    File "/opt/conda/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3257, in run_cell_async
      interactivity=interactivity, compiler=compiler, result=result)
    File "/opt/conda/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3472, in run_ast_nodes
      if (await self.run_code(code, result,  async_=asy)):
    File "/opt/conda/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3552, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "/tmp/ipykernel_28514/124030209.py", line 1, in <module>
      get_ipython().run_cell_magic('timeit', '', '\nNUM_EPOCHS = 1\nlayer_history = model.fit(\n    TRAIN,\n    validation_data=VALID,\n    # validation_freq=5,\n    epochs=NUM_EPOCHS,\n    steps_per_epoch=30,\n    # callbacks=tensorboard_cb,\n    # verbose=1\n)\n\nprint(f"Training for {NUM_EPOCHS} epoch")\naccuracy = layer_history.history["factorized_top_k/top_100_categorical_accuracy"][-1]\nprint(f"Top 100 categorical accuracy: {accuracy}")\n')
    File "/opt/conda/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 2472, in run_cell_magic
      result = fn(*args, **kwargs)
    File "/opt/conda/lib/python3.7/site-packages/decorator.py", line 232, in fun
      return caller(func, *(extras + args), **kw)
    File "/opt/conda/lib/python3.7/site-packages/IPython/core/magic.py", line 187, in <lambda>
      call = lambda f, *a, **k: f(*a, **k)
    File "/opt/conda/lib/python3.7/site-packages/IPython/core/magics/execution.py", line 1180, in timeit
      time_number = timer.timeit(number)
    File "/opt/conda/lib/python3.7/site-packages/IPython/core/magics/execution.py", line 169, in timeit
      timing = self.inner(it, self.timer)
    File "<magic-timeit>", line 7, in inner
    File "/opt/conda/lib/python3.7/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/opt/conda/lib/python3.7/site-packages/keras/engine/training.py", line 1431, in fit
      _use_cached_eval_dataset=True)
    File "/opt/conda/lib/python3.7/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/opt/conda/lib/python3.7/site-packages/keras/engine/training.py", line 1716, in evaluate
      tmp_logs = self.test_function(iterator)
    File "/opt/conda/lib/python3.7/site-packages/keras/engine/training.py", line 1525, in test_function
      return step_function(self, iterator)
    File "/opt/conda/lib/python3.7/site-packages/keras/engine/training.py", line 1513, in step_function
      data = next(iterator)
Node: 'IteratorGetNext'
Detected at node 'IteratorGetNext' defined at (most recent call last):
    File "/opt/conda/lib/python3.7/runpy.py", line 193, in _run_module_as_main
      "__main__", mod_spec)
    File "/opt/conda/lib/python3.7/runpy.py", line 85, in _run_code
      exec(code, run_globals)
    File "/opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py", line 17, in <module>
      app.launch_new_instance()
    File "/opt/conda/lib/python3.7/site-packages/traitlets/config/application.py", line 978, in launch_instance
      app.start()
    File "/opt/conda/lib/python3.7/site-packages/ipykernel/kernelapp.py", line 712, in start
      self.io_loop.start()
    File "/opt/conda/lib/python3.7/site-packages/tornado/platform/asyncio.py", line 215, in start
      self.asyncio_loop.run_forever()
    File "/opt/conda/lib/python3.7/asyncio/base_events.py", line 541, in run_forever
      self._run_once()
    File "/opt/conda/lib/python3.7/asyncio/base_events.py", line 1786, in _run_once
      handle._run()
    File "/opt/conda/lib/python3.7/asyncio/events.py", line 88, in _run
      self._context.run(self._callback, *self._args)
    File "/opt/conda/lib/python3.7/site-packages/ipykernel/kernelbase.py", line 510, in dispatch_queue
      await self.process_one()
    File "/opt/conda/lib/python3.7/site-packages/ipykernel/kernelbase.py", line 499, in process_one
      await dispatch(*args)
    File "/opt/conda/lib/python3.7/site-packages/ipykernel/kernelbase.py", line 406, in dispatch_shell
      await result
    File "/opt/conda/lib/python3.7/site-packages/ipykernel/kernelbase.py", line 730, in execute_request
      reply_content = await reply_content
    File "/opt/conda/lib/python3.7/site-packages/ipykernel/ipkernel.py", line 387, in do_execute
      cell_id=cell_id,
    File "/opt/conda/lib/python3.7/site-packages/ipykernel/zmqshell.py", line 528, in run_cell
      return super().run_cell(*args, **kwargs)
    File "/opt/conda/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 2975, in run_cell
      raw_cell, store_history, silent, shell_futures, cell_id
    File "/opt/conda/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3029, in _run_cell
      return runner(coro)
    File "/opt/conda/lib/python3.7/site-packages/IPython/core/async_helpers.py", line 78, in _pseudo_sync_runner
      coro.send(None)
    File "/opt/conda/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3257, in run_cell_async
      interactivity=interactivity, compiler=compiler, result=result)
    File "/opt/conda/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3472, in run_ast_nodes
      if (await self.run_code(code, result,  async_=asy)):
    File "/opt/conda/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3552, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "/tmp/ipykernel_28514/124030209.py", line 1, in <module>
      get_ipython().run_cell_magic('timeit', '', '\nNUM_EPOCHS = 1\nlayer_history = model.fit(\n    TRAIN,\n    validation_data=VALID,\n    # validation_freq=5,\n    epochs=NUM_EPOCHS,\n    steps_per_epoch=30,\n    # callbacks=tensorboard_cb,\n    # verbose=1\n)\n\nprint(f"Training for {NUM_EPOCHS} epoch")\naccuracy = layer_history.history["factorized_top_k/top_100_categorical_accuracy"][-1]\nprint(f"Top 100 categorical accuracy: {accuracy}")\n')
    File "/opt/conda/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 2472, in run_cell_magic
      result = fn(*args, **kwargs)
    File "/opt/conda/lib/python3.7/site-packages/decorator.py", line 232, in fun
      return caller(func, *(extras + args), **kw)
    File "/opt/conda/lib/python3.7/site-packages/IPython/core/magic.py", line 187, in <lambda>
      call = lambda f, *a, **k: f(*a, **k)
    File "/opt/conda/lib/python3.7/site-packages/IPython/core/magics/execution.py", line 1180, in timeit
      time_number = timer.timeit(number)
    File "/opt/conda/lib/python3.7/site-packages/IPython/core/magics/execution.py", line 169, in timeit
      timing = self.inner(it, self.timer)
    File "<magic-timeit>", line 7, in inner
    File "/opt/conda/lib/python3.7/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/opt/conda/lib/python3.7/site-packages/keras/engine/training.py", line 1431, in fit
      _use_cached_eval_dataset=True)
    File "/opt/conda/lib/python3.7/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/opt/conda/lib/python3.7/site-packages/keras/engine/training.py", line 1716, in evaluate
      tmp_logs = self.test_function(iterator)
    File "/opt/conda/lib/python3.7/site-packages/keras/engine/training.py", line 1525, in test_function
      return step_function(self, iterator)
    File "/opt/conda/lib/python3.7/site-packages/keras/engine/training.py", line 1513, in step_function
      data = next(iterator)
Node: 'IteratorGetNext'
2 root error(s) found.
  (0) NOT_FOUND:  Error executing an HTTP request: HTTP response code 404 with body '<?xml version='1.0' encoding='UTF-8'?><Error><Code>NoSuchKey</Code><Message>The specified key does not exist.</Message><Details>No such object: spotify-beam-v3/v3/valid/-01793-of-05106.tfrecords</Details></Error>'
	 when reading gs://spotify-beam-v3/v3/valid/-01793-of-05106.tfrecords
	 [[{{node IteratorGetNext}}]]
	 [[candidate__track__model/album_name_can_emb_model/hashing_9/hash/_22]]
  (1) NOT_FOUND:  Error executing an HTTP request: HTTP response code 404 with body '<?xml version='1.0' encoding='UTF-8'?><Error><Code>NoSuchKey</Code><Message>The specified key does not exist.</Message><Details>No such object: spotify-beam-v3/v3/valid/-01793-of-05106.tfrecords</Details></Error>'
	 when reading gs://spotify-beam-v3/v3/valid/-01793-of-05106.tfrecords
	 [[{{node IteratorGetNext}}]]
0 successful operations.
0 derived errors ignored. [Op:__inference_test_function_79310]

### Testing on finding the optimal batch size

In [51]:
import time

def time_steps(model, train_data, n_steps=10) -> float:
    
    NUM_EPOCHS = 1
    start_time = time.time()
    layer_history = model.fit(
        train_data,
        # validation_data=VALID,
        # validation_freq=5,
        epochs=NUM_EPOCHS,
        steps_per_epoch=n_steps,
        # callbacks=tensorboard_cb,
        verbose=0
    )
    return(time.time() - start_time)

In [52]:
# time a step to find optimal time / record for gpu
def get_run_times_by_batch_size(model, batch_size, n_steps=10):
    train = shuffle_train.batch(batch_size)
    run_times = [time_steps(model, train, n_steps) for _ in range(10)] #do it over 10 times
    avg_time = np.average(run_times)
    ms_per_record = avg_time/n_steps/batch_size * 1000
    return(ms_per_record)

In [61]:
batch_sizes = [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]

run_time_data = {}

for b in batch_sizes:
    avg_ms_per_rec = get_run_times_by_batch_size(model, b, 3)
    run_time_data[b] = avg_ms_per_rec
    print(b, avg_ms_per_rec)

### Running on a100 40gb
# pprint(run_time_data)

16 39.10556336243947
32 19.96664231022199
64 10.527420168121656
128 5.5340541526675215
256 3.3056994279225664
512 1.9949027802795174
1024 1.49439366068691
2048 1.1579723019773762
4096 1.112759350022922
8192 1.0598300247996424


In [62]:
batch_sizes = [ 8192*2, 8192*4]

# run_time_data = {}

for b in batch_sizes:
    avg_ms_per_rec = get_run_times_by_batch_size(model, b, 3)
    run_time_data[b] = avg_ms_per_rec
    print(b, avg_ms_per_rec)

### Running on a100 40gb
# pprint(run_time_data)

16384 0.9729407339667281
32768 0.9885529014961016


### End section on finding optional batch size

## Loading SavedModels

In [526]:
candidate_tower_uri = 'gs://spotify-tfrs-dir/v2/run-20220920-210334/candidate_tower'
loaded_candidate_model = tf.saved_model.load(candidate_tower_uri)

### Candidate Model

In [527]:
print(list(loaded_candidate_model.signatures.keys()))

['serving_default']


In [528]:
infer = loaded_candidate_model.signatures["serving_default"]
print(infer.structured_outputs)

{'output_1': TensorSpec(shape=(None, 32), dtype=tf.float32, name='output_1')}


In [530]:
loaded_candidate_model.signatures

_SignatureMap({'serving_default': <ConcreteFunction signature_wrapper(*, album_name_can, album_uri_can, artist_followers_can, artist_genres_can, artist_name_can, artist_pop_can, artist_uri_can, duration_ms_can, track_name_can, track_pop_can, track_uri_can) at 0x7F1A85BA2C10>})


In [531]:
predict2 = loaded_candidate_model.signatures['serving_default']
predict2.output_shapes

{'output_1': TensorShape([None, 32])}

In [533]:
parsed_candidate_dataset

<MapDataset element_spec={'album_name_can': TensorSpec(shape=(), dtype=tf.string, name=None), 'album_uri_can': TensorSpec(shape=(), dtype=tf.string, name=None), 'artist_followers_can': TensorSpec(shape=(), dtype=tf.float32, name=None), 'artist_genres_can': TensorSpec(shape=(), dtype=tf.string, name=None), 'artist_name_can': TensorSpec(shape=(), dtype=tf.string, name=None), 'artist_pop_can': TensorSpec(shape=(), dtype=tf.float32, name=None), 'artist_uri_can': TensorSpec(shape=(), dtype=tf.string, name=None), 'duration_ms_can': TensorSpec(shape=(), dtype=tf.float32, name=None), 'track_name_can': TensorSpec(shape=(), dtype=tf.string, name=None), 'track_pop_can': TensorSpec(shape=(), dtype=tf.float32, name=None), 'track_uri_can': TensorSpec(shape=(), dtype=tf.string, name=None)}>

In [534]:
embs_iter = parsed_candidate_dataset.batch(1).map(
    lambda data: predict2(
        artist_name_can = data["artist_name_can"],
        track_name_can = data['track_name_can'],
        album_name_can = data['album_name_can'],
        track_uri_can = data['track_uri_can'],
        artist_uri_can = data['artist_uri_can'],
        album_uri_can = data['album_uri_can'],
        duration_ms_can = data['duration_ms_can'],
        track_pop_can = data['track_pop_can'],
        artist_pop_can = data['artist_pop_can'],
        artist_followers_can = data['artist_followers_can'],
        artist_genres_can = data['artist_genres_can']
    )
)

embs = []
for emb in embs_iter:
    embs.append(emb)

In [540]:
print(f"Length of embs: {len(embs)}")
embs[0]

Length of embs: 166827


{'output_1': <tf.Tensor: shape=(1, 32), dtype=float32, numpy=
 array([[ 0.32745814,  0.15835902, -0.62232316, -0.18282643,  0.05425396,
          0.08093273, -0.36878368, -0.6777717 , -0.28926852, -0.44578883,
          0.03769037,  0.2013096 ,  0.01650199, -0.38547963, -0.2734376 ,
          0.16123655,  0.12652676, -0.09455037,  0.3709236 , -0.04336764,
         -0.20310198,  0.21418957, -0.44912496, -0.16798183, -0.21492694,
          0.04033846,  0.14083184,  0.25597924,  0.20490994,  0.18837534,
         -0.15261272,  0.06471566]], dtype=float32)>}

In [541]:
cleaned_embs = [x['output_1'].numpy()[0] for x in embs] #clean up the output

print(f"Length of cleaned_embs: {len(cleaned_embs)}")
cleaned_embs[0]

Length of cleaned_embs: 166827


array([ 0.32745814,  0.15835902, -0.62232316, -0.18282643,  0.05425396,
        0.08093273, -0.36878368, -0.6777717 , -0.28926852, -0.44578883,
        0.03769037,  0.2013096 ,  0.01650199, -0.38547963, -0.2734376 ,
        0.16123655,  0.12652676, -0.09455037,  0.3709236 , -0.04336764,
       -0.20310198,  0.21418957, -0.44912496, -0.16798183, -0.21492694,
        0.04033846,  0.14083184,  0.25597924,  0.20490994,  0.18837534,
       -0.15261272,  0.06471566], dtype=float32)

In [553]:
# clean product IDs
track_uris = [x['track_uri_can'].numpy() for x in parsed_candidate_dataset]
track_uris[0]

b'spotify:track:6KhJeYLg1AimCQjH6ii1Al'

In [554]:
track_uris[0]

b'spotify:track:6KhJeYLg1AimCQjH6ii1Al'

In [557]:
track_uris_cleaned = [str(z).replace("b'","").replace("'","") for z in track_uris]
track_uris_cleaned[0]

'spotify:track:6KhJeYLg1AimCQjH6ii1Al'

In [558]:
print(f"Length of track_uris: {len(track_uris)}")
print(f"Length of track_uris_cleaned: {len(track_uris_cleaned)}")

Length of track_uris: 166827
Length of track_uris_cleaned: 166827


### Write Index Json File

In [559]:
VERSION = 'local-v1'
TIMESTAMP = '092022'

embeddings_index_filename = f'candidate_embeddings_{VERSION}_{TIMESTAMP}.json'

with open(f'{embeddings_index_filename}', 'w') as f:
    for prod, emb in zip(track_uris_cleaned, cleaned_embs):
        f.write('{"id":"' + str(prod) + '",')
        f.write('"embedding":[' + ",".join(str(x) for x in list(emb)) + "]}")
        f.write("\n")

In [564]:
# import json

# with open('candidate_embeddings_local-v1_092022.json', 'r') as f:
#     data = json.load(f)

### Query Model

In [565]:
query_tower_uri = 'gs://spotify-tfrs-dir/v2/run-20220920-210334/query_tower'
loaded_query_model = tf.saved_model.load(query_tower_uri)

In [566]:
loaded_query_model.signatures

_SignatureMap({'serving_default': <ConcreteFunction signature_wrapper(*, album_name_can, album_name_pl, album_name_seed_track, album_uri_can, album_uri_seed_track, artist_followers_can, artist_followers_seed_track, artist_genres_can, artist_genres_pl, artist_genres_seed_track, artist_name_can, artist_name_pl, artist_name_seed_track, artist_pop_can, artist_pop_pl, artist_pop_seed_track, artist_uri_can, artist_uri_seed_track, artists_followers_pl, collaborative, description_pl, duration_ms_can, duration_ms_seed_pl, duration_ms_songs_pl, duration_seed_track, n_songs_pl, name, num_albums_pl, num_artists_pl, pid, pos_seed_track, track_name_can, track_name_pl, track_name_seed_track, track_pop_can, track_pop_pl, track_pop_seed_track, track_uri_can, track_uri_pl, track_uri_seed_track) at 0x7F1935582D50>})