# Two - Tower Retreival Model

### Key resources:
* Many pages [here](https://www.tensorflow.org/recommenders/examples/deep_recommenders) include great techniques to build custom TFRS Models

### Goals:
* Show how to model off of most data types 
  * (String, Existing Embeddings (vectors), 
  * Floats (Normalized), 
  * Categorical with vocab, 
  * High Dim Categorical (Embed)
* Leverage class templates to create custom 2 Tower Models quick/easy

## SPOTIFY Create the tensorflow.io interface for the event and product table in Bigquery
Best practices from Google are in this blog post

In [1]:
# set variables
DROPOUT = False
DROPOUT_RATE = 0.2
EMBEDDING_DIM = 64
MAX_TOKENS = 100_000
BATCH_SIZE = 256
ARCH = [128, 64]
NUM_EPOCHS = 1
SEED = 41781897
PROJECT_ID = 'jtotten-project'
DROP_FIELDS = ['pid', 'track_uri', 'artist_uri', 'album_uri']
N_RECORDS_PER_TFRECORD_FILE = 300 * 50 #100ish mb  
TF_RECORDS_DIR = 'gs://spotify-tfrecords'

#### Quick counts on training data



#### Quick counts on the training records for track

In [2]:
%%bigquery TOTAL_PLAYLISTS
select count(1) from jtotten-project.spotify_mpd.playlists_track_string

Query complete after 0.00s: 100%|██████████| 1/1 [00:00<00:00, 1687.17query/s]
Downloading: 100%|██████████| 1/1 [00:00<00:00,  1.14rows/s]


In [3]:
TOTAL_PLAYLISTS = TOTAL_PLAYLISTS.values[0][0]
TOTAL_PLAYLISTS

1032000

#### Same with playlist

#### Quick counts (this time playlists) on the training records for track

In [4]:
%%bigquery TOTAL_TRACKS
select count(1) from jtotten-project.spotify_mpd.track_audio

Query complete after 0.00s: 100%|██████████| 1/1 [00:00<00:00, 1629.49query/s]
Downloading: 100%|██████████| 1/1 [00:00<00:00,  1.20rows/s]


In [5]:
TOTAL_TRACKS = TOTAL_TRACKS.values[0][0]
TOTAL_TRACKS

2261490

### Set the tf.io pipelines function from bigquery

[Great blog post here on it](https://towardsdatascience.com/how-to-read-bigquery-data-from-tensorflow-2-0-efficiently-9234b69165c8)

In [6]:
import tensorflow as tf
from tensorflow.python.framework import dtypes
from tensorflow_io.bigquery import BigQueryClient
from tensorflow_io.bigquery import BigQueryReadSession
import warnings
warnings.filterwarnings("ignore") #do this b/c there's an info-level bug that can safely be ignored
import json
import tensorflow as tf
import tensorflow_recommenders as tfrs
import datetime
from tensorflow.python.lib.io import file_io
from tensorflow.train import BytesList, Feature, FeatureList, Int64List
from tensorflow.train import SequenceExample, FeatureLists



def bq_to_tfdata(client, row_restriction, table_id, col_names, col_types, dataset, batch_size=BATCH_SIZE):
    TABLE_ID = table_id
    COL_NAMES = col_names
    COL_TYPES = col_types
    DATASET = dataset
    bqsession = client.read_session(
        "projects/" + PROJECT_ID,
        PROJECT_ID, TABLE_ID, DATASET,
        COL_NAMES, COL_TYPES,
        requested_streams=2,
        row_restriction=row_restriction)
    dataset = bqsession.parallel_read_rows()
    return dataset.prefetch(1).shuffle(batch_size*10).batch(batch_size)

## Get the song metadata

To get a pipeline working we need the metadata for the table along with the table information. The following functions are helpers that give us the metadata into the proper types for `tf`


For each table id, programatically get
* Column names
* Column types

In [7]:
%%bigquery schema
SELECT * FROM jtotten-project.spotify_mpd.INFORMATION_SCHEMA.TABLES
where table_name in ('track_audio', 'playlists_track_string');

Query complete after 0.00s: 100%|██████████| 1/1 [00:00<00:00, 924.87query/s]                          
Downloading: 100%|██████████| 2/2 [00:00<00:00,  3.14rows/s]


In [8]:
schema # we will get the fields out of the ddl field

Unnamed: 0,table_catalog,table_schema,table_name,table_type,is_insertable_into,is_typed,creation_time,base_table_catalog,base_table_schema,base_table_name,snapshot_time_ms,ddl
0,jtotten-project,spotify_mpd,track_audio,BASE TABLE,YES,NO,2022-04-06 17:46:25.801000+00:00,,,,NaT,CREATE TABLE `jtotten-project.spotify_mpd.trac...
1,jtotten-project,spotify_mpd,playlists_track_string,BASE TABLE,YES,NO,2022-04-22 22:50:46.601000+00:00,,,,NaT,CREATE TABLE `jtotten-project.spotify_mpd.play...


## Helper functions to pull metadata from ddl statements

From the DDL we are going to get the types for use in a  to create a `BigQueryReadSession` from `tensorflow_io.bigquery` 

In [9]:
# Function to convert string type representation to tf data types

def conv_dtype_to_tf(dtype_str):
    if dtype_str == 'FLOAT64':
        return dtypes.float64
    elif dtype_str == 'INT64':
        return dtypes.int64
    else: 
        return dtypes.string
        
def get_metadata_from_ddl(ddl, drop_field=None):
    fields = []
    types = []
    ddl = ddl.values[0]
    for line in ddl.splitlines():
        if line[:1] == ' ': #only pull indented lines for the fields
            # drop the comma
            line = line.replace(',','')
            space_delim = line.split(' ')
            if space_delim[2] in drop_field:
                pass
            else:
                fields.append(space_delim[2])
                types.append(conv_dtype_to_tf(space_delim[3]))
    return fields, types


track_audio_fields, track_audio_types = get_metadata_from_ddl(schema.ddl[schema.table_name == 'track_audio'], DROP_FIELDS)
playlist_fields, playlist_types = get_metadata_from_ddl(schema.ddl[schema.table_name == 'playlists_track_string'], DROP_FIELDS) 

In [10]:
# Quick check on data
for a, b in zip(playlist_fields, playlist_types):
    print(a +" : " + str(b))

name : <dtype: 'string'>
collaborative : <dtype: 'string'>
modified_at : <dtype: 'int64'>
num_tracks : <dtype: 'int64'>
num_albums : <dtype: 'int64'>
num_followers : <dtype: 'int64'>
tracks : <dtype: 'string'>
num_edits : <dtype: 'int64'>
duration_ms : <dtype: 'int64'>
num_artists : <dtype: 'int64'>
description : <dtype: 'string'>


In [11]:
# Quick check on data
for a, b in zip(track_audio_fields, track_audio_types):
    print(a +" : " + str(b))
    
DROP_TRACK_AUDIO_FIELDS = ['pid', 'track_uri', 'artist_uri', 'album_uri']

artist_name : <dtype: 'string'>
track_name : <dtype: 'string'>
album_name : <dtype: 'string'>
name : <dtype: 'string'>
danceability : <dtype: 'float64'>
energy : <dtype: 'float64'>
key : <dtype: 'float64'>
loudness : <dtype: 'float64'>
mode : <dtype: 'float64'>
speechiness : <dtype: 'float64'>
acousticness : <dtype: 'float64'>
instrumentalness : <dtype: 'float64'>
liveness : <dtype: 'float64'>
valence : <dtype: 'float64'>
tempo : <dtype: 'float64'>
type : <dtype: 'string'>
id : <dtype: 'string'>
uri : <dtype: 'string'>
track_href : <dtype: 'string'>
analysis_url : <dtype: 'string'>
time_signature : <dtype: 'float64'>
artist_pop : <dtype: 'int64'>
track_pop : <dtype: 'string'>
genres : <dtype: 'string'>
duration_ms : <dtype: 'int64'>


### Now the helper functions are set. Below tf.data pipelines are created from bigquery

In [12]:
track_train_pipeline = bq_to_tfdata(BigQueryClient(), row_restriction=None, table_id = 'track_audio'
                                    , col_names=track_audio_fields, col_types=track_audio_types, dataset='spotify_mpd', batch_size=1) #we will change to BATCH_SIZE after we test 

2022-04-25 00:29:01.284035: W tensorflow_io/core/kernels/audio_video_mp3_kernels.cc:271] libmp3lame.so.0 or lame functions are not available
2022-04-25 00:29:01.284392: I tensorflow_io/core/kernels/cpu_check.cc:128] Your CPU supports instructions that this TensorFlow IO binary was not compiled to use: AVX2 FMA
2022-04-25 00:29:02.680576: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:939] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-04-25 00:29:02.681395: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:939] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-04-25 00:29:02.691637: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:939] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-04-25 00:29:02.692462: 

In [13]:
### Validate we are getting records

for line in track_train_pipeline.take(1):
    print(line) #should come out based on batch size

2022-04-25 00:29:04.506932: E tensorflow/core/framework/dataset.cc:577] UNIMPLEMENTED: Cannot compute input sources for dataset of type IO>BigQueryDataset, because the dataset does not implement `InputDatasets`.
2022-04-25 00:29:04.506991: E tensorflow/core/framework/dataset.cc:581] UNIMPLEMENTED: Cannot merge options for dataset of type IO>BigQueryDataset, because the dataset does not implement `InputDatasets`.
2022-04-25 00:29:04.507811: E tensorflow/core/framework/dataset.cc:577] UNIMPLEMENTED: Cannot compute input sources for dataset of type IO>BigQueryDataset, because the dataset does not implement `InputDatasets`.
2022-04-25 00:29:04.507854: E tensorflow/core/framework/dataset.cc:581] UNIMPLEMENTED: Cannot merge options for dataset of type IO>BigQueryDataset, because the dataset does not implement `InputDatasets`.


OrderedDict([('acousticness', <tf.Tensor: shape=(1,), dtype=float64, numpy=array([0.00612])>), ('album_name', <tf.Tensor: shape=(1,), dtype=string, numpy=array([b'Kitsun\xc3\xa9 Maison Compilation 6'], dtype=object)>), ('analysis_url', <tf.Tensor: shape=(1,), dtype=string, numpy=
array([b'https://api.spotify.com/v1/audio-analysis/34C6BtgHosjtPjzsCLhB7Z'],
      dtype=object)>), ('artist_name', <tf.Tensor: shape=(1,), dtype=string, numpy=array([b'Ted & Francis'], dtype=object)>), ('artist_pop', <tf.Tensor: shape=(1,), dtype=int64, numpy=array([0])>), ('danceability', <tf.Tensor: shape=(1,), dtype=float64, numpy=array([0.512])>), ('duration_ms', <tf.Tensor: shape=(1,), dtype=int64, numpy=array([324040])>), ('energy', <tf.Tensor: shape=(1,), dtype=float64, numpy=array([0.782])>), ('genres', <tf.Tensor: shape=(1,), dtype=string, numpy=array([b'unknown'], dtype=object)>), ('id', <tf.Tensor: shape=(1,), dtype=string, numpy=array([b'34C6BtgHosjtPjzsCLhB7Z'], dtype=object)>), ('instrumentalnes

### For the song audio data, we are set and will use this pipeline in training - there's no need to pre-process as there are no nested elements

In [14]:
playlist_types[-1] = dtypes.string #try manually setting the dtype for the tracks nested column

In [15]:
## Validate playlist data
playlist_train_pipeline = bq_to_tfdata(BigQueryClient(), row_restriction=None, table_id = 'playlists_track_string'
                                    , col_names=playlist_fields
                                       , col_types=playlist_types
                                       , dataset='spotify_mpd', batch_size=1)
for line in playlist_train_pipeline.take(1):
    print(line) #should come out based on batch size

2022-04-25 00:29:05.211213: E tensorflow/core/framework/dataset.cc:577] UNIMPLEMENTED: Cannot compute input sources for dataset of type IO>BigQueryDataset, because the dataset does not implement `InputDatasets`.
2022-04-25 00:29:05.211265: E tensorflow/core/framework/dataset.cc:581] UNIMPLEMENTED: Cannot merge options for dataset of type IO>BigQueryDataset, because the dataset does not implement `InputDatasets`.
2022-04-25 00:29:05.211700: E tensorflow/core/framework/dataset.cc:577] UNIMPLEMENTED: Cannot compute input sources for dataset of type IO>BigQueryDataset, because the dataset does not implement `InputDatasets`.
2022-04-25 00:29:05.211743: E tensorflow/core/framework/dataset.cc:581] UNIMPLEMENTED: Cannot merge options for dataset of type IO>BigQueryDataset, because the dataset does not implement `InputDatasets`.


OrderedDict([('collaborative', <tf.Tensor: shape=(1,), dtype=string, numpy=array([b'false'], dtype=object)>), ('description', <tf.Tensor: shape=(1,), dtype=string, numpy=array([b''], dtype=object)>), ('duration_ms', <tf.Tensor: shape=(1,), dtype=int64, numpy=array([27942453])>), ('modified_at', <tf.Tensor: shape=(1,), dtype=int64, numpy=array([1462320000])>), ('name', <tf.Tensor: shape=(1,), dtype=string, numpy=array([b'SS'], dtype=object)>), ('num_albums', <tf.Tensor: shape=(1,), dtype=int64, numpy=array([94])>), ('num_artists', <tf.Tensor: shape=(1,), dtype=int64, numpy=array([80])>), ('num_edits', <tf.Tensor: shape=(1,), dtype=int64, numpy=array([4])>), ('num_followers', <tf.Tensor: shape=(1,), dtype=int64, numpy=array([2])>), ('num_tracks', <tf.Tensor: shape=(1,), dtype=int64, numpy=array([110])>), ('tracks', <tf.Tensor: shape=(1,), dtype=string, numpy=
array([b'[{\'pos\': 0, \'artist_name\': \'John Grant\', \'track_uri\': \'spotify:track:1fCSSXalgNdB6h8Y6Hf9f2\', \'artist_uri\': \

## In pulling one record it looks like it's properly parsing a tf record

# do some data wranglging on the text data
# tf.train.Example(features=tf.train.Features(feature=feature))
for _ in playlist_train_pipeline.map(lambda x: tf.io.parse_sequence_example(tf.io.serialize_tensor(x['tracks'][0]), sequence_features=feature_description, context_features=context_features, name='tracks')).take(1):
    tensor = _
    print(_)

Since the data is stored in a text dictionary we will eagerly execute, grab the values and do a string `eval`
`eval("{'pos': 0, 'artist_name': 'King Crimson', 'track_uri': 'spotify:track:173gp7NIXqk0MEo8K7Av4a', 'artist_uri': 'spotify:artist:7M1FPw29m5FbicYzS2xdpi', 'track_name': '21st Century Schizoid Man', 'album_uri': 'spotify:album:0ga8Q4tTXaFf9q3LvT8hrC', 'duration_ms': 657517, 'album_name': 'Radical Action To Unseat the Hold of Monkey Mind (Live)'}")`

## This funcion parses the playlist data and breaks down the nested fields to be conformant with `SequenceExample`
The 'flat' features come along as `context_features` in `SequenceExample`
There is one more helper function to parse the example and write it to the destination `gs://` path

In [16]:
@tf.function
def get_tensor_from_tracks(tensor):
    key_list = ['pos', 'artist_name', 'track_uri', 'artist_uri', 'track_name', 'album_uri', 'duration_ms', 'album_name']
    y = {}
    
    
    tracks = tensor["tracks"][0]
    tracks = tracks.numpy()

    tracks = eval(tracks)

    for _ in key_list:
        y[_] = []

    for track in tracks:
        y['pos'].append(track['pos'])
        y['artist_name'].append(track['artist_name'].encode('utf8'))
        y['artist_uri'].append(track['artist_uri'].encode('utf8'))
        y['track_name'].append(track['track_name'].encode('utf8'))
        y['album_uri'].append(track['album_uri'].encode('utf8'))
        y['duration_ms'].append(track['duration_ms'])
        y['album_name'].append(track['album_name'].encode('utf8'))
        y['track_uri'].append(track['track_uri'].encode('utf8'))
        


    # set list types
    pos = Int64List(value=y['pos'])
    artist_name = BytesList(value=y['artist_name'])
    artist_uri = BytesList(value=y['artist_uri'])
    track_name = BytesList(value=y['track_name'])
    album_uri = BytesList(value=y['album_uri'])
    duration_ms = Int64List(value=y['duration_ms'])
    album_name = BytesList(value=y['album_name'])
    track_uri = BytesList(value=y['track_uri'])

    
    sample_dict = {
    "name" : Feature(bytes_list=BytesList(value=tensor['name'].numpy())),
    "collaborative" : Feature(bytes_list=BytesList(value=tensor['collaborative'].numpy())),
    "modified_at" : Feature(int64_list=Int64List(value=tensor['modified_at'].numpy())),
    "num_tracks" : Feature(int64_list=Int64List(value=tensor['num_tracks'].numpy())),
    "num_albums" : Feature(int64_list=Int64List(value=tensor['num_albums'].numpy())),
    "num_followers" : Feature(int64_list=Int64List(value=tensor['num_followers'].numpy())),
    "num_edits" : Feature(int64_list=Int64List(value=tensor['num_edits'].numpy())),
    "duration_ms" : Feature(int64_list=Int64List(value=tensor['duration_ms'].numpy())),
    "num_artists" : Feature(int64_list=Int64List(value=tensor['num_artists'].numpy())),
    "description" : Feature(bytes_list=BytesList(value=tensor['description'].numpy()))
    }

    # combine feature list

    pos = FeatureList(feature=[Feature(int64_list=pos)]) 
    artist_name = FeatureList(feature=[Feature(bytes_list=artist_name)])
    artist_uri = FeatureList(feature=[Feature(bytes_list=artist_uri)])
    track_name = FeatureList(feature=[Feature(bytes_list=track_name)])
    album_uri = FeatureList(feature=[Feature(bytes_list=album_uri)])
    duration_ms = FeatureList(feature=[Feature(int64_list=duration_ms)])
    album_name = FeatureList(feature=[Feature(bytes_list=album_name)])
    track_uri = FeatureList(feature=[Feature(bytes_list=track_uri)])
            

    #create the sequence
    seq = SequenceExample(context=tf.train.Features(feature=sample_dict),
                          feature_lists=FeatureLists(feature_list={
                               "pos": pos,
                               "artist_name": artist_name,
                               "track_name": track_name,
                               "album_uri": album_uri,
                               "duration_ms": duration_ms,
                               "album_name": album_name,
                               "track_uri": track_uri,
                              "artist_uri": artist_uri
    }))
    
    return seq


def write_a_tfrec(lns, n_records_per_file, file_counter, subfolder):
    #next write to a tfrecord
    with tf.io.TFRecordWriter(
        TF_RECORDS_DIR + "/" + subfolder +"/file_%.2i-%i.tfrec" % (n_records_per_file, file_counter)
    ) as writer:
        for example in lns:
            writer.write(example.SerializeToString())

## Now iterate over the pipeline
Creating files with batches of `N_RECORDS_PER_TFRECORD_FILE`

This takes about 30 minutes on a 64 vCPUs, 57.6 GB RAM 

In [17]:
tf.config.run_functions_eagerly(True)
from tqdm import tqdm
# using datetime module

# ct stores current time
ct = str(datetime.datetime.now()).replace(" ",":")
print(f"Timestamp for folder expected: {ct}")
records = []
file_count = 0

for i, line in enumerate(tqdm(playlist_train_pipeline, total=TOTAL_PLAYLISTS)):
    sequence_example = get_tensor_from_tracks(line) #should come out based on batch size
    if (i % N_RECORDS_PER_TFRECORD_FILE == 0 or i == TOTAL_PLAYLISTS-1) and i is not 0: #write-a-file and reset the batch (+1 to avoid modulus reset)
        records.append(sequence_example)
        write_a_tfrec(records, n_records_per_file=N_RECORDS_PER_TFRECORD_FILE, subfolder=ct, file_counter = file_count)
        file_count+=1
        records = []
    else:
        records.append(sequence_example)

Timestamp for folder expected: 2022-04-25:00:29:05.543360


  0%|          | 0/1032000 [00:00<?, ?it/s]2022-04-25 00:29:05.558387: E tensorflow/core/framework/dataset.cc:577] UNIMPLEMENTED: Cannot compute input sources for dataset of type IO>BigQueryDataset, because the dataset does not implement `InputDatasets`.
2022-04-25 00:29:05.558437: E tensorflow/core/framework/dataset.cc:581] UNIMPLEMENTED: Cannot merge options for dataset of type IO>BigQueryDataset, because the dataset does not implement `InputDatasets`.
2022-04-25 00:29:05.558729: E tensorflow/core/framework/dataset.cc:577] UNIMPLEMENTED: Cannot compute input sources for dataset of type IO>BigQueryDataset, because the dataset does not implement `InputDatasets`.
2022-04-25 00:29:05.558768: E tensorflow/core/framework/dataset.cc:581] UNIMPLEMENTED: Cannot merge options for dataset of type IO>BigQueryDataset, because the dataset does not implement `InputDatasets`.
100%|██████████| 1032000/1032000 [45:37<00:00, 376.98it/s] 


### Parse records ensure this worked

In [18]:
#move the files around

!gsutil cp gs://spotify-tfrecords/$ct/* gs://spotify-tfrecords
!gsutil rm gs://spotify-tfrecords/$ct/*

Copying gs://spotify-tfrecords/2022-04-25:00:29:05.543360/file_15000-0.tfrec...
Copying gs://spotify-tfrecords/2022-04-25:00:29:05.543360/file_15000-1.tfrec... 
Copying gs://spotify-tfrecords/2022-04-25:00:29:05.543360/file_15000-10.tfrec...
Copying gs://spotify-tfrecords/2022-04-25:00:29:05.543360/file_15000-11.tfrec...
/ [4 files][566.4 MiB/566.4 MiB]                                                
==> NOTE: You are performing a sequence of gsutil operations that may
run significantly faster if you instead use gsutil -m cp ... Please
see the -m section under "gsutil help options" for further information
about when gsutil -m can be advantageous.

Copying gs://spotify-tfrecords/2022-04-25:00:29:05.543360/file_15000-12.tfrec...
Copying gs://spotify-tfrecords/2022-04-25:00:29:05.543360/file_15000-13.tfrec...
Copying gs://spotify-tfrecords/2022-04-25:00:29:05.543360/file_15000-14.tfrec...
Copying gs://spotify-tfrecords/2022-04-25:00:29:05.543360/file_15000-15.tfrec...
Copying gs://spotify

In [19]:
import tensorflow as tf

sequence_features = {'pos': tf.io.RaggedFeature(tf.int64), 
                     'artist_name':  tf.io.RaggedFeature(tf.string), 
                     'track_uri':  tf.io.RaggedFeature(tf.string), 
                     'artist_uri': tf.io.RaggedFeature(tf.string), 
                     'track_name': tf.io.RaggedFeature(tf.string), 
                     'album_uri': tf.io.RaggedFeature(tf.string),
                     'duration_ms': tf.io.RaggedFeature(tf.int64), 
                     'album_name': tf.io.RaggedFeature(tf.string)
                    }
context_features = {"name" : tf.io.FixedLenFeature(dtype=tf.string, shape=(1)),
                    "collaborative" : tf.io.FixedLenFeature(dtype=tf.string, shape=(1)),
                    "modified_at" : tf.io.FixedLenFeature(dtype=tf.int64, shape=(1)),
                    "num_tracks" : tf.io.FixedLenFeature(dtype=tf.int64, shape=(1)),
                    "num_albums" : tf.io.FixedLenFeature(dtype=tf.int64, shape=(1)),
                    "num_followers" :tf.io.FixedLenFeature(dtype=tf.int64, shape=(1)),
                    "num_edits" :tf.io.FixedLenFeature(dtype=tf.int64, shape=(1)),
                    "duration_ms" : tf.io.FixedLenFeature(dtype=tf.int64, shape=(1)),
                    "num_artists" : tf.io.FixedLenFeature(dtype=tf.int64, shape=(1)),
                    "description" : tf.io.FixedLenFeature(dtype=tf.string, shape=(1))
                   }

def parse_tfrecord_fn(example):
    example = tf.io.parse_single_sequence_example(example, sequence_features=sequence_features, context_features=context_features)
    return example

### parse tfrecord dataset

In [20]:
from google.cloud import storage

client = storage.Client()
files = []
for blob in client.list_blobs('spotify-tfrecords'):
    files.append(blob.public_url.replace("https://storage.googleapis.com/", "gs://"))

In [22]:
raw_dataset = tf.data.TFRecordDataset(files)

tf_record_pipeline = raw_dataset.map(parse_tfrecord_fn)

for _ in tf_record_pipeline.take(1):
    print(_)

({'collaborative': <tf.Tensor: shape=(1,), dtype=string, numpy=array([b'true'], dtype=object)>, 'description': <tf.Tensor: shape=(1,), dtype=string, numpy=array([b''], dtype=object)>, 'duration_ms': <tf.Tensor: shape=(1,), dtype=int64, numpy=array([3506582])>, 'modified_at': <tf.Tensor: shape=(1,), dtype=int64, numpy=array([1425772800])>, 'name': <tf.Tensor: shape=(1,), dtype=string, numpy=array([b'wedding songs'], dtype=object)>, 'num_albums': <tf.Tensor: shape=(1,), dtype=int64, numpy=array([16])>, 'num_artists': <tf.Tensor: shape=(1,), dtype=int64, numpy=array([15])>, 'num_edits': <tf.Tensor: shape=(1,), dtype=int64, numpy=array([6])>, 'num_followers': <tf.Tensor: shape=(1,), dtype=int64, numpy=array([1])>, 'num_tracks': <tf.Tensor: shape=(1,), dtype=int64, numpy=array([16])>}, {'album_name': <tf.RaggedTensor [[b'In Between Dreams', b'Love In The Future', b'Overexposed Track By Track', b'Adventures In Real Time', b'because the internet', b"I'm Wide Awake, It's Morning", b'Oh, What A

# Model Draft Stuff

In [None]:
# %%writefile -a vertex_train/trainer/task.py

class PlaylistsModel(tf.keras.Model):
    def __init__(self, layer_sizes, adapt_data):
        super().__init__()
        
        #start with lookups on low cardnality categorical items
        colab_vocab = tf.constant(['true','false'], name='colab_vocab', dtype='string')
        
        self.colab = tf.keras.Sequential([
            tf.keras.layers.StringLookup(
                vocabulary=colab_vocab, mask_token=None, name="colab_lookup", output_mode='count')
        ], name="colab")
        
        #create text vectorizors to be fed to an embedding layer
        self.artist_vectorizor = tf.keras.layers.TextVectorization(
            max_tokens=MAX_TOKENS, name="artist_tv", ngrams=2)
        
        self.album_vectorizor = tf.keras.layers.TextVectorization(
            max_tokens=MAX_TOKENS, name="album_tv", ngrams=2)
        
        self.description_vectorizor = tf.keras.layers.TextVectorization(
            max_tokens=MAX_TOKENS, name="album_tv", ngrams=2)
        
        self.query_embedding = tf.keras.Sequential([
            self.album_vectorizor,
            tf.keras.layers.Embedding(MAX_TOKENS+1, EMBEDDING_DIM , mask_zero=True, name="album_emb"),
            tf.keras.layers.GlobalAveragePooling1D()
        ], name="album_embedding_model")
        
        self.artist_embedding = tf.keras.Sequential([
            self.artist_vectorizor,
            tf.keras.layers.Embedding(MAX_TOKENS+1, EMBEDDING_DIM , mask_zero=True, name="artist_emb"),
            tf.keras.layers.GlobalAveragePooling1D()
        ], name="artist_embedding")
        
        ###############
        ### adapt stuff
        ###############
        
        self.artist_vectorizor.adapt(adapt_data.map(lambda x: x['artist_name']))
        self.album_vectorizor.adapt(adapt_data.map(lambda x: x['album_name'])) 
        
        # Then construct the layers.
        self.dense_layers = tf.keras.Sequential(name="dense_layers_query")
        
        initializer = tf.keras.initializers.GlorotUniform(seed=SEED)
        # Use the ReLU activation for all but the last layer.
        for layer_size in layer_sizes[:-1]:
            self.dense_layers.add(tf.keras.layers.Dense(layer_size, activation="relu", kernel_initializer=initializer))
            if DROPOUT:
                self.dense_layers.add(tf.keras.layers.Dropout(DROPOUT_RATE))
        # No activation for the last layer
        for layer_size in layer_sizes[-1:]:
            self.dense_layers.add(tf.keras.layers.Dense(layer_size, kernel_initializer=initializer))
        ### ADDING L2 NORM AT THE END
        self.dense_layers.add(tf.keras.layers.Lambda(lambda x: tf.nn.l2_normalize(x, 1, epsilon=1e-12, name="normalize_dense")))


    def call(self, data):    
        all_embs = tf.concat(
                [
                    self.album_embedding(data['album_name']),
                    self.artist_embedding(data['artist_name']),
                    self.colab(data['collaborative']),
                    self.description_embedding(data['description'])
                ], axis=1)
        return self.dense_layers(all_embs)

## Use the example output to think of how you process your features

```
OrderedDict([('album_name', <tf.Tensor: shape=(1,), dtype=string, numpy=array([b'The Helm'], dtype=object)>), ('artist_name', <tf.Tensor: shape=(1,), dtype=string, numpy=array([b'Carrot Green'], dtype=object)>), ('collaborative', <tf.Tensor: shape=(1,), dtype=string, numpy=array([b'false'], dtype=object)>), ('description', <tf.Tensor: shape=(1,), dtype=string, numpy=array([b''], dtype=object)>), ('duration_ms', <tf.Tensor: shape=(1,), dtype=string, numpy=array([b'358500'], dtype=object)>), ('modified_at', <tf.Tensor: shape=(1,), dtype=int64, numpy=array([1505692800])>), ('name', <tf.Tensor: shape=(1,), dtype=string, numpy=array([b'FeSTa'], dtype=object)>), ('num_albums', <tf.Tensor: shape=(1,), dtype=int64, numpy=array([82])>), ('num_artists', <tf.Tensor: shape=(1,), dtype=int64, numpy=array([66])>), ('num_edits', <tf.Tensor: shape=(1,), dtype=int64, numpy=array([48])>), ('num_followers', <tf.Tensor: shape=(1,), dtype=int64, numpy=array([1])>), ('num_tracks', <tf.Tensor: shape=(1,), dtype=int64, numpy=array([85])>), ('pos', <tf.Tensor: shape=(1,), dtype=string, numpy=array([b'45'], dtype=object)>), ('track_name', <tf.Tensor: shape=(1,), dtype=string, numpy=array([b'The Helm - Carrot Green Remix'], dtype=object)>)])
```

In [None]:
#### Tests

In [None]:
# data[0].keys()#originally got the values from this
feature_description = {'pos': tf.io.RaggedFeature(tf.int64), 
                     'artist_name':  tf.io.RaggedFeature(tf.string), 
                     'track_uri':  tf.io.RaggedFeature(tf.string), 
                     'artist_uri': tf.io.RaggedFeature(tf.string), 
                     'track_name': tf.io.RaggedFeature(tf.string), 
                     'album_uri': tf.io.RaggedFeature(tf.string),
                     'duration_ms': tf.io.RaggedFeature(tf.int64), 
                     'album_name': tf.io.RaggedFeature(tf.string)
                    }
context_features = {"name" : tf.io.FixedLenFeature(dtype=tf.string, shape=(1)),
                    "collaborative" : tf.io.FixedLenFeature(dtype=tf.string, shape=(1)),
                    "modified_at" : tf.io.FixedLenFeature(dtype=tf.int64, shape=(1)),
                    "num_tracks" : tf.io.FixedLenFeature(dtype=tf.int64, shape=(1)),
                    "num_albums" : tf.io.FixedLenFeature(dtype=tf.int64, shape=(1)),
                    "num_followers" :tf.io.FixedLenFeature(dtype=tf.int64, shape=(1)),
                    "num_edits" :tf.io.FixedLenFeature(dtype=tf.int64, shape=(1)),
                    "duration_ms" : tf.io.FixedLenFeature(dtype=tf.int64, shape=(1)),
                    "num_artists" : tf.io.FixedLenFeature(dtype=tf.int64, shape=(1)),
                    "description" : tf.io.FixedLenFeature(dtype=tf.string, shape=(1))
                   }