In [None]:
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

<table align="left">

  <td>
    <a href="https://colab.research.google.com/github/GoogleCloudPlatform/vertex-ai-samples/blob/master/notebooks/official/matching_engine/two-tower-model-introduction.ipynb"">
      <img src="https://cloud.google.com/ml-engine/images/colab-logo-32px.png" alt="Colab logo"> Run in Colab
    </a>
  </td>
  <td>
    <a href="https://github.com/GoogleCloudPlatform/vertex-ai-samples/tree/master/notebooks/official/matching_engine/two-tower-model-introduction.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/github-logo-32px.png" alt="GitHub logo">
      View on GitHub
    </a>
  </td>
</table>

## Overview

This tutorial demonstrates how to use the Two-Tower built-in algorithm on the Vertex AI platform.

Two-tower models learn to represent two items of various types (such as user profiles, search queries, web documents, answer passages, or images) in the same vector space, so that similar or related items are close to each other. These two items are referred to as the query and candidate object, since when paired with a nearest neighbor search service such as Vertex Matching Engine, the two-tower model can retrieve candidate objects related to an input query object. These objects are encoded by a query and candidate encoder (the two "towers") respectively, which are trained on pairs of relevant items. This built-in algorithm exports trained query and candidate encoders as model artifacts, which can be deployed in Vertex Prediction for usage in a recommendation system.

### Dataset

This tutorial uses the `movielens_100k sample dataset` in the public bucket `gs://cloud-samples-data/vertex-ai/matching-engine/two-tower`, which was generated from the [MovieLens movie rating dataset](https://grouplens.org/datasets/movielens/100k/). For simplicity, the data for this tutorial only includes the user id feature for users, and the movie id and movie title features for movies. In this example, the user is the query object and the movie is the candidate object, and each training example in the dataset contains a user and a movie they rated (we only include positive ratings in the dataset). The two-tower model will embed the user and the movie in the same embedding space, so that given a user, the model will recommend movies it thinks the user will like.

### Objective

In this notebook, you will learn how to run the two-tower model.
The tutorial covers the following steps:
1. **Setup**: Importing the required libraries and setting your global variables.
2. **Configure parameters**: Setting the appropriate parameter values for the training job.
3. **Train on Vertex Training**: Submitting a training job.
4. **Deploy on Vertex Prediction**: Importing and deploying the trained model to a callable endpoint.
5. **Predict**: Calling the deployed endpoint using online or batch prediction.
6. **Hyperparameter tuning**: Running a hyperparameter tuning job.
7. **Cleaning up**: Deleting resources created by this tutorial.


### Costs 


This tutorial uses billable components of Google Cloud:

* Vertex AI
* Cloud Storage


Learn about [Vertex AI
pricing](https://cloud.google.com/vertex-ai/pricing) and [Cloud Storage
pricing](https://cloud.google.com/storage/pricing), and use the [Pricing
Calculator](https://cloud.google.com/products/calculator/)
to generate a cost estimate based on your projected usage.

### Set up your local development environment

**If you are using Colab or Google Cloud Notebooks**, your environment already meets
all the requirements to run this notebook. You can skip this step.

**Otherwise**, make sure your environment meets this notebook's requirements.
You need the following:

* The Google Cloud SDK
* Git
* Python 3
* virtualenv
* Jupyter notebook running in a virtual environment with Python 3

The Google Cloud guide to [Setting up a Python development
environment](https://cloud.google.com/python/setup) and the [Jupyter
installation guide](https://jupyter.org/install) provide detailed instructions
for meeting these requirements. The following steps provide a condensed set of
instructions:

1. [Install and initialize the Cloud SDK.](https://cloud.google.com/sdk/docs/)

1. [Install Python 3.](https://cloud.google.com/python/setup#installing_python)

1. [Install
   virtualenv](https://cloud.google.com/python/setup#installing_and_using_virtualenv)
   and create a virtual environment that uses Python 3. Activate the virtual environment.

1. To install Jupyter, run `pip3 install jupyter` on the
command-line in a terminal shell.

1. To launch Jupyter, run `jupyter notebook` on the command-line in a terminal shell.

1. Open this notebook in the Jupyter Notebook Dashboard.

### Install additional packages


In [None]:
import os

# The Google Cloud Notebook product has specific requirements
IS_GOOGLE_CLOUD_NOTEBOOK = os.path.exists("/opt/deeplearning/metadata/env_version")

# Google Cloud Notebook requires dependencies to be installed with '--user'
USER_FLAG = ""
if IS_GOOGLE_CLOUD_NOTEBOOK:
    USER_FLAG = "--user"

In [None]:
! pip3 install {USER_FLAG} --upgrade tensorflow
! pip3 install {USER_FLAG} --upgrade google-cloud-aiplatform tensorboard-plugin-profile
! gcloud components update --quiet

### Restart the kernel

After you install the additional packages, you need to restart the notebook kernel so it can find the packages.

In [None]:
# Automatically restart kernel after installs
import os

if not os.getenv("IS_TESTING"):
    # Automatically restart kernel after installs
    import IPython

    app = IPython.Application.instance()
    app.kernel.do_shutdown(True)

## Before you begin

### Set up your Google Cloud project

**The following steps are required, regardless of your notebook environment.**

1. [Select or create a Google Cloud project](https://console.cloud.google.com/cloud-resource-manager). When you first create an account, you get a $300 free credit towards your compute/storage costs.

1. [Make sure that billing is enabled for your project](https://cloud.google.com/billing/docs/how-to/modify-project).

1. [Enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com).

1. If you are running this notebook locally, you will need to install the [Cloud SDK](https://cloud.google.com/sdk).

1. Enter your project ID in the cell below. Then run the cell to make sure the
Cloud SDK uses the right project for all the commands in this notebook.

**Note**: Jupyter runs lines prefixed with `!` as shell commands, and it interpolates Python variables prefixed with `$` into these commands.

#### Set your project ID

**If you do not know your project ID**, you may be able to get your project ID using `gcloud`.

In [1]:
import os

PROJECT_ID = 'hybrid-vertex'

# Get your Google Cloud project ID from gcloud
if not os.getenv("IS_TESTING"):
    shell_output = ! gcloud config list --format 'value(core.project)' 2>/dev/null
    PROJECT_ID = shell_output[0]
    print("Project ID: ", PROJECT_ID)

Project ID:  hybrid-vertex


Otherwise, set your project ID here.

In [2]:
if PROJECT_ID == "" or PROJECT_ID is None:
    PROJECT_ID = "[your-project-id]"  # @param {type:"string"}

! gcloud config set project {PROJECT_ID}

Updated property [core/project].


#### Timestamp

If you are in a live tutorial session, you might be using a shared test account or project. To avoid name collisions between users on resources created, you create a timestamp for each instance session, and append it onto the name of resources you create in this tutorial.

In [4]:
from datetime import datetime

TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")

### Authenticate your Google Cloud account

**If you are using Google Cloud Notebooks**, your environment is already
authenticated. Skip this step.

**If you are using Colab**, run the cell below and follow the instructions
when prompted to authenticate your account via oAuth.

**Otherwise**, follow these steps:

1. In the Cloud Console, go to the [**Create service account key**
   page](https://console.cloud.google.com/apis/credentials/serviceaccountkey).

2. Click **Create service account**.

3. In the **Service account name** field, enter a name, and
   click **Create**.

4. In the **Grant this service account access to project** section, click the **Role** drop-down list. Type "Vertex AI"
into the filter box, and select
   **Vertex AI Administrator**. Type "Storage Object Admin" into the filter box, and select **Storage Object Admin**.

5. Click *Create*. A JSON file that contains your key downloads to your
local environment.

6. Enter the path to your service account key as the
`GOOGLE_APPLICATION_CREDENTIALS` variable in the cell below and run the cell.

In [None]:
import os
import sys

# If you are running this notebook in Colab, run this cell and follow the
# instructions to authenticate your GCP account. This provides access to your
# Cloud Storage bucket and lets you submit training jobs and prediction
# requests.

# The Google Cloud Notebook product has specific requirements
IS_GOOGLE_CLOUD_NOTEBOOK = os.path.exists("/opt/deeplearning/metadata/env_version")

# If on Google Cloud Notebooks, then don't execute this code
if not IS_GOOGLE_CLOUD_NOTEBOOK:
    if "google.colab" in sys.modules:
        from google.colab import auth as google_auth

        google_auth.authenticate_user()

    # If you are running this notebook locally, replace the string below with the
    # path to your service account key and run this cell to authenticate your GCP
    # account.
    elif not os.getenv("IS_TESTING"):
        %env GOOGLE_APPLICATION_CREDENTIALS ''

### Create a Cloud Storage bucket

**The following steps are required, regardless of your notebook environment.**

Before you submit a training job for the two-tower model, you need to upload your training data and schema to Cloud Storage. Vertex AI trains the model using this input data. In this tutorial, the Two-Tower built-in algorithm also saves the trained model that results from your job in the same bucket. Using this model artifact, you can then create Vertex AI model and endpoint resources in order to serve online predictions.

Set the name of your Cloud Storage bucket below. It must be unique across all
Cloud Storage buckets.

You may also change the `REGION` variable, which is used for operations
throughout the rest of this notebook. Make sure to [choose a region where Vertex AI services are
available](https://cloud.google.com/vertex-ai/docs/general/locations#available_regions). You may
not use a Multi-Regional Storage bucket for training with Vertex AI.

In [8]:
BUCKET_NAME = "gs://spotify-builtin-2t"  # @param {type:"string"}
REGION = "us-central1"  # @param {type:"string"}

In [None]:
if BUCKET_NAME == "" or BUCKET_NAME is None or BUCKET_NAME == "gs://[your-bucket-name]":
    BUCKET_NAME = "gs://" + PROJECT_ID + "aip-" + TIMESTAMP

**Only if your bucket doesn't already exist**: Run the following cell to create your Cloud Storage bucket.

In [None]:
! gsutil mb -l $REGION $BUCKET_NAME

Finally, validate access to your Cloud Storage bucket by examining its contents:

In [None]:
! gsutil ls -al $BUCKET_NAME

### Import libraries and define constants

In [5]:
import os
import re
import time

from google.cloud import aiplatform

%load_ext tensorboard

## Configure parameters

The following table shows parameters that are common to all Vertex Training jobs created using the `gcloud ai custom-jobs create` command. See the [official documentation](https://cloud.google.com/sdk/gcloud/reference/ai/custom-jobs/create) for all the possible arguments.

| Parameter | Data type | Description | Required |
|--|--|--|--|
| `display-name` | string | Name of the job. | Yes |
| `worker-pool-spec` | string | Comma-separated list of arguments specifying a worker pool configuration (see below). | Yes |
| `region` | string | Region to submit the job to. | No |

The `worker-pool-spec` flag can be specified multiple times, one for each worker pool. The following table shows the arguments used to specify a worker pool.

| Parameter | Data type | Description | Required |
|--|--|--|--|
| `machine-type` | string | Machine type for the pool. See the [official documentation](https://cloud.google.com/vertex-ai/docs/training/configure-compute) for supported machines. | Yes |
| `replica-count` | int | The number of replicas of the machine in the pool. | No |
| `container-image-uri` | string | Docker image to run on each worker. | No |

The following table shows the parameters for the two-tower model training job:

| Parameter | Data type | Description | Required |
|--|--|--|--|
| `training_data_path` | string | Cloud Storage pattern where training data is stored. | Yes |
| `input_schema_path` | string | Cloud Storage path where the JSON input schema is stored. | Yes |
| `input_file_format` | string | The file format of input. Currently supports `jsonl` and `tfrecord`. | No - default is `jsonl`. |
| `job_dir` | string | Cloud Storage directory where the model output files will be stored. | Yes |
| `eval_data_path` | string | Cloud Storage pattern where eval data is stored. | No |
| `candidate_data_path` | string | Cloud Storage pattern where candidate data is stored. Only used for top_k_categorical_accuracy metrics. If not set, it's generated from training/eval data. | No |
| `train_batch_size` | int | Batch size for training. | No - Default is 100. |
| `eval_batch_size` | int | Batch size for evaluation. | No - Default is 100. |
| `eval_split` | float | Split fraction to use for the evaluation dataset, if `eval_data_path` is not provided. | No - Default is 0.2 |
| `optimizer` | string | Training optimizer. Lowercase string name of any TF2.3 Keras optimizer is supported ('sgd', 'nadam', 'ftrl', etc.). See [TensorFlow documentation](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers). | No - Default is 'adagrad'. |
| `learning_rate` | float | Learning rate for training. | No - Default is the default learning rate of the specified optimizer. |
| `momentum` | float | Momentum for optimizer, if specified. | No - Default is the default momentum value for the specified optimizer. |
| `metrics` | string | Metrics used to evaluate the model. Can be either `auc`, `top_k_categorical_accuracy` or `precision_at_1`. | No - Default is `auc`. |
| `num_epochs` | int | Number of epochs for training. | No - Default is 10. |
| `num_hidden_layers` | int | Number of hidden layers. | No |
| `num_nodes_hidden_layer{index}` | int | Num of nodes in hidden layer {index}. The range of index is 1 to 20. | No |
| `output_dim` | int | The output embedding dimension for each encoder tower of the two-tower model. | No - Default is 64. |
| `training_steps_per_epoch` | int | Number of steps per epoch to run the training for.  Only needed if you are using more than 1 machine or using a master machine with more than 1 gpu. | No - Default is None. |
| `eval_steps_per_epoch` | int | Number of steps per epoch to run the evaluation for.  Only needed if you are using more than 1 machine or using a master machine with more than 1 gpu. | No - Default is None. |
| `gpu_memory_alloc` | int | Amount of memory allocated per GPU (in MB). | No - Default is no limit. |

In [16]:

# Change to your data and schema paths. These are paths to the movielens_100k
# sample data.
TRAINING_DATA_PATH = f"{BUCKET_NAME}/train_data/*"
INPUT_SCHEMA_PATH = f"{BUCKET_NAME}/schema.json"

# URI of the two-tower training Docker image.
LEARNER_IMAGE_URI = "us-docker.pkg.dev/vertex-ai-restricted/builtin-algorithm/two-tower"

# Change to your output location.
OUTPUT_DIR = f"{BUCKET_NAME}/experiment-{TIMESTAMP}/output"

TRAIN_BATCH_SIZE = 100  # Batch size for training.
NUM_EPOCHS = 3  # Number of epochs for training.

print(f"Training data path: {TRAINING_DATA_PATH}")
print(f"Input schema path: {INPUT_SCHEMA_PATH}")
print(f"Output directory: {OUTPUT_DIR}")
print(f"Train batch size: {TRAIN_BATCH_SIZE}")
print(f"Number of epochs: {NUM_EPOCHS}")

Training data path: gs://spotify-builtin-2t/train_data/*
Input schema path: gs://spotify-builtin-2t/schema.json
Output directory: gs://spotify-builtin-2t/experiment-20220629210427/output
Train batch size: 100
Number of epochs: 3


In [32]:
from pprint import pprint
output = ! gsutil cat $BUCKET_NAME/schema.json
print(output[0])

{"query": {"name": {"feature_type": "Text"}, "collaborative": {"feature_type": "Text"}, "duration_ms_playlist": {"feature_type": "Numeric"}, "artist_name_seed_track": {"feature_type": "Text"}, "artist_uri_seed_track": {"feature_type": "Id", "config": {"num_buckets": 1000}}, "track_name_seed_track": {"feature_type": "Text"}, "track_uri_seed_track": {"feature_type": "Id", "config": {"num_buckets": 10000}}, "album_name_seed_track": {"feature_type": "Text"}, "album_uri_seed_track": {"feature_type": "Id", "config": {"num_buckets": 1000}}, "duration_seed_track": {"feature_type": "Numeric"}, "track_pop_seed_track": {"feature_type": "Numeric"}, "artist_pop_seed_track": {"feature_type": "Numeric"}, "artist_followers_seed_track": {"feature_type": "Numeric"}, "duration_ms_seed_pl": {"feature_type": "Numeric"}, "n_songs_pl": {"feature_type": "Numeric"}, "num_artists_pl": {"feature_type": "Numeric"}, "num_albums_pl": {"feature_type": "Numeric"}, "artist_genres_seed_track": {"feature_type": "Text"},

## Train on Vertex Training

Submit the two-tower training job to Vertex Training. The following command uses a single CPU machine for training. When using single node training, `training_steps_per_epoch` and `eval_steps_per_epoch` do not need to be set.

In [11]:
DATASET_NAME = "spotify"
learning_job_name = f"two_tower_cpu_{DATASET_NAME}_{TIMESTAMP}"

CREATION_LOG = ! gcloud ai custom-jobs create \
  --display-name={learning_job_name} \
  --worker-pool-spec=machine-type=n1-standard-8,replica-count=1,container-image-uri={LEARNER_IMAGE_URI} \
  --region={REGION} \
  --args=--training_data_path={TRAINING_DATA_PATH} \
  --args=--input_schema_path={INPUT_SCHEMA_PATH} \
  --args=--job-dir={OUTPUT_DIR} \
  --args=--train_batch_size={TRAIN_BATCH_SIZE} \
  --args=--num_epochs={NUM_EPOCHS}

print(CREATION_LOG)

['Using endpoint [https://us-central1-aiplatform.googleapis.com/]', 'CustomJob [projects/934903580331/locations/us-central1/customJobs/4890078239411666944] is submitted successfully.', '', 'Your job is still active. You may view the status of your job with the command', '', '  $ gcloud ai custom-jobs describe projects/934903580331/locations/us-central1/customJobs/4890078239411666944', '', 'or continue streaming the logs with the command', '', '  $ gcloud ai custom-jobs stream-logs projects/934903580331/locations/us-central1/customJobs/4890078239411666944']


If you want to train using GPUs, you need to write configuration to a YAML file:

In [None]:
      # - --training_steps_per_epoch=1500
      # - --eval_steps_per_epoch=1500

### Example Record
```
{
   "candidate":{
      "artist_name_can":[
         "The Head Assembly"
      ],
      "track_uri_can":[
         "spotify:track:5Q2LDZsjtscyce03gyG08Q"
      ],
      "track_name_can":[
         "Tickle My... - Julissa Veloz Club Mix"
      ],
      "duration_ms_can":[
         354573
      ],
      "album_name_can":[
         "Tickle My..."
      ],
      "track_pop_can":[
         "0"
      ],
      "artist_pop_can":[
         0
      ],
      "artist_followers_can":[
         0
      ],
      "artist_genres_can":[
         ""
      ]
   },
   "query":{
      "id_pl":[
         "986484-124"
      ],
      "name":[
         "New Music"
      ],
      "collaborative":[
         "false"
      ],
      "artist_name_seed_track":[
         "Jason Aldean"
      ],
      "track_name_seed_track":[
         "Burnin' It Down"
      ],
      "album_name_seed_track":[
         "Old Boots, New Dirt"
      ],
      "duration_seed_track":[
         219146
      ],
      "track_pop_seed_track":[
         "50"
      ],
      "artist_pop_seed_track":[
         78
      ],
      "artist_followers_seed_track":[
         5139127
      ],
      "duration_ms_seed_pl":[
         29578370
      ],
      "n_songs_pl":[
         "124"
      ],
      "num_artists_pl":[
         "85"
      ],
      "num_albums_pl":[
         "99"
      ],
      "artist_genres_seed_track":[
         "'contemporary country', 'country', 'country road'"
      ],
      "description_pl":[
         ""
      ],
      "artist_name_pl":[
         "Tove Lo",
         "Ellie Goulding",
         "Kanye West",
         "Becky G",
         "Coldplay",
         "MAGIC!",
         "OneRepublic",
         "Alex Clare",
         "Sampha",
         "Hackman",
         "Ellie Goulding",
         "OneRepublic",
         "Jennifer Lopez",
         "Jason Derulo",
         "Sam Smith",
         "Sia",
         "Sia",
         "Christina Perri",
         "Jhene Aiko",
         "Jhene Aiko",
         "Jason Derulo",
         "Frank Ocean",
         "Miley Cyrus",
         "Miley Cyrus",
         "Tito \"El Bambino\"",
         "Demi Lovato",
         "Selena Gomez \u0026 The Scene",
         "Usher",
         "Usher",
         "Maroon 5",
         "DEV",
         "Manufactured Superstars",
         "Tiësto",
         "Diplo",
         "Steve Aoki",
         "Steve Aoki",
         "Issues",
         "Chiodos",
         "Clean Bandit",
         "You Me At Six",
         "NERO",
         "NERO",
         "NERO",
         "NERO",
         "Farruko",
         "Juanes",
         "Prince Royce",
         "Paramore",
         "Kiesza",
         "Snootie Wild",
         "Hoobastank",
         "Zedd",
         "Future",
         "Kat Dahlia",
         "Naughty Boy",
         "Nico \u0026 Vinz",
         "Disclosure",
         "Ariana Grande",
         "Jessie Ware",
         "Miley Cyrus",
         "Amy Winehouse",
         "Gemini",
         "Vance Joy",
         "Ed Sheeran",
         "Frank Ocean",
         "Tonight Alive",
         "Ariana Grande",
         "Ty Dolla $ign",
         "The Weeknd",
         "The Weeknd",
         "The Weeknd",
         "The Weeknd",
         "Jeremih",
         "Breathe Carolina",
         "Sam Smith",
         "Rita Ora",
         "Muse",
         "The Used",
         "Austin Mahone",
         "Seven Lions",
         "Seven Lions",
         "Seven Lions",
         "Seven Lions",
         "Skrillex",
         "AJR",
         "Niykee Heaton",
         "Chris Brown",
         "Skrillex",
         "Tonight Alive",
         "Shelco Garcia \u0026 Teenwolf",
         "Gareth Emery",
         "Lights",
         "Charli XCX",
         "Pharrell Williams",
         "Trinidad James",
         "Tonight Alive",
         "Tonight Alive",
         "Milkman",
         "Maroon 5",
         "Iggy Azalea",
         "The Chainsmokers",
         "Diplo",
         "Diplo",
         "Yellow Claw",
         "Steve Aoki",
         "Diplo",
         "Diplo",
         "Diplo",
         "Diplo",
         "Diplo",
         "Steve Aoki",
         "Diplo",
         "Diplo",
         "K CAMP",
         "K CAMP",
         "Jessie J",
         "Drake",
         "PARTYNEXTDOOR",
         "Tiësto",
         "Beyoncé",
         "John Legend",
         "Nicki Minaj",
         "Echosmith",
         "Jason Aldean"
      ],
      "track_name_pl":[
         "Habits (Stay High) - Hippie Sabotage Remix",
         "Beating Heart",
         "Lost In The World",
         "Shower",
         "Magic",
         "Don't Kill the Magic",
         "Something I Need",
         "Relax My Beloved",
         "Too Much",
         "Forgotten Notes",
         "Tessellate - Bonus Track",
         "Love Runs Out",
         "First Love",
         "Stupid Love",
         "Money On My Mind",
         "Chandelier",
         "Elastic Heart - From \"The Hunger Games: Catching Fire\"/Soundtrack",
         "human",
         "To Love \u0026 Die",
         "The Worst",
         "Trumpets",
         "Novacane",
         "Someone Else",
         "Drive",
         "El Gran Perdedor",
         "Really Don't Care",
         "I Won't Apologize",
         "Trading Places",
         "Good Kisser",
         "Maps",
         "Naked",
         "Like Satellites - Extended Mix",
         "Wasted",
         "Express Yourself - feat. Nicky Da B",
         "Rage the Night Away",
         "Freak - feat. Steve Bays",
         "Never Lose Your Flames",
         "Hey Zeus! the Dungeon",
         "Rather Be (feat. Jess Glynne)",
         "Stay With Me",
         "Me And You",
         "Guilt",
         "My Eyes",
         "Satisfy",
         "Passion Whine - Remastered Version",
         "La Luz",
         "Soy el Mismo",
         "Ain't It Fun",
         "Hideaway",
         "Yayo",
         "The Letter",
         "Find You",
         "I Won",
         "Crazy",
         "La La La",
         "Am I Wrong",
         "Latch",
         "Problem",
         "Running",
         "Maybe You're Right",
         "Back To Black",
         "Blue",
         "Riptide",
         "Don't",
         "Lost",
         "Little Lion Man",
         "Break Free",
         "Or Nah (feat. The Weeknd, Wiz Khalifa and DJ Mustard) - Remix",
         "Adaptation",
         "Pretty",
         "Belong To The World",
         "Wicked Games",
         "Don't Tell 'Em",
         "Blackout",
         "Stay With Me",
         "I Will Never Let You Down",
         "Madness",
         "The Bird And The Worm",
         "Mmm Yeah (feat. Pitbull)",
         "Don’t Leave",
         "Worlds Apart",
         "Strangers",
         "Keep It Close",
         "Recess",
         "I'm Ready",
         "Bad Intentions",
         "New Flame",
         "Dirty Vibe",
         "The Edge (From the Motion Picture \"The Amazing Spider-Man 2\")",
         "That's My Jam - Original Mix",
         "Concrete Angel",
         "Up We Go",
         "Boom Clap",
         "Come Get It Bae",
         "Female$ Welcomed",
         "Lonely Girl",
         "Say Please",
         "Somebody Find Me (feat. Kait Weston)",
         "It Was Always You",
         "Black Widow",
         "Kanye",
         "Revolution - feat. Faustix \u0026 Imanos and Kai",
         "6th Gear - feat. Kstylis",
         "Techno - feat. Wacka Flocka Flame",
         "Freak - feat. Steve Bays",
         "Boy Oh Boy",
         "Biggie Bounce - feat. Angger Dimas \u0026 Travis Porte",
         "Express Yourself - feat. Nicky Da B",
         "Revolution (Danny Diggz Remix) - feat. Faustix \u0026 Imanos and Kai",
         "Boy Oh Boy - Thugli Remix",
         "Freak (Rickyxsan Remix) - feat. Steve Bays",
         "Biggie Bounce (Tony Romera Remix) - feat. Angger Dimas \u0026 Travis Porter",
         "Express Yourself (Party Favor Extended Remix) - feat. Nicky Da B",
         "Money Baby",
         "Cut Her Off",
         "Bang Bang",
         "From Time",
         "Wus Good / Curious",
         "Who Wants To Be Alone",
         "XO",
         "You \u0026 I (Nobody in the World)",
         "Anaconda",
         "Cool Kids",
         "Burnin' It Down"
      ],
      "duration_ms_songs_pl":[
         258933,
         212125,
         256586,
         206166,
         285014,
         217226,
         240080,
         211986,
         177795,
         268511,
         236960,
         224853,
         215786,
         214493,
         192670,
         216120,
         257986,
         250706,
         203453,
         254493,
         217306,
         302346,
         288333,
         255213,
         182896,
         201600,
         186546,
         268240,
         249626,
         189959,
         236933,
         340157,
         190013,
         277566,
         285186,
         281250,
         221279,
         264160,
         227833,
         194973,
         247520,
         284120,
         289666,
         284516,
         213200,
         176986,
         223053,
         296520,
         251986,
         220240,
         234853,
         204346,
         239733,
         211320,
         222200,
         247520,
         255631,
         193893,
         268746,
         213386,
         240440,
         316266,
         204280,
         219840,
         234093,
         233146,
         214840,
         242983,
         283933,
         375400,
         307173,
         323746,
         266840,
         210200,
         172723,
         203466,
         281040,
         225040,
         231624,
         362040,
         376242,
         204240,
         312506,
         237680,
         227311,
         198167,
         244133,
         206680,
         180600,
         270000,
         236733,
         171293,
         169866,
         201933,
         190080,
         191133,
         194013,
         225584,
         239919,
         209423,
         229946,
         263716,
         211900,
         208012,
         281250,
         174398,
         224769,
         277566,
         193000,
         197635,
         311999,
         270000,
         195728,
         219120,
         243693,
         199320,
         322160,
         212880,
         276226,
         215946,
         252653,
         260240,
         237626,
         219146
      ],
      "album_name_pl":[
         "Queen Of The Clouds",
         "Halcyon Days",
         "My Beautiful Dark Twisted Fantasy",
         "Shower",
         "Ghost Stories",
         "Don't Kill the Magic",
         "Native",
         "The Lateness Of The Hour",
         "Too Much / Happens",
         "Forgotten Notes",
         "Halcyon Days",
         "Native",
         "A.K.A.",
         "Tattoos",
         "In The Lonely Hour",
         "1000 Forms Of Fear",
         "Elastic Heart",
         "head or heart",
         "Souled Out",
         "Sail Out",
         "Tattoos",
         "Novacane",
         "Bangerz (Deluxe Version)",
         "Bangerz (Deluxe Version)",
         "El Gran Perdedor",
         "Demi",
         "Kiss \u0026 Tell",
         "Here I Stand (Deluxe Version)",
         "Hard II Love",
         "V",
         "The Night The Sun Came Up",
         "Like Satellites [Remixes]",
         "A Town Called Paradise",
         "Random White Dude Be Everywhere",
         "Neon Future I",
         "Random White Dude Be Everywhere",
         "Issues",
         "Illuminaudio",
         "I Cry When I Laugh",
         "Hold Me Down",
         "Welcome Reality",
         "Welcome Reality",
         "Welcome Reality",
         "Satisfy",
         "Farruko Presents Los Menores",
         "Loco De Amor",
         "Soy el Mismo",
         "Paramore",
         "Sound Of A Woman",
         "Yayo",
         "FOR(N)EVER",
         "Clarity",
         "Honest",
         "My Garden",
         "Hotel Cabana",
         "Black Star Elephant",
         "Settle",
         "My Everything",
         "Devotion",
         "Bangerz (Deluxe Version)",
         "Back To Black",
         "Blue",
         "Dream Your Life Away",
         "x",
         "channel ORANGE",
         "Punk Goes Pop, Vol. 4",
         "My Everything",
         "Or Nah (feat. The Weeknd, Wiz Khalifa and DJ Mustard)",
         "Kiss Land",
         "Kiss Land",
         "Kiss Land",
         "Trilogy",
         "Late Nights: The Album",
         "Hell Is What You Make It: Reloaded",
         "In The Lonely Hour",
         "I Will Never Let You Down",
         "The 2nd Law",
         "Lies For The Liars",
         "The Secret",
         "Worlds Apart",
         "Worlds Apart",
         "Strangers",
         "Worlds Apart",
         "Recess",
         "Living Room",
         "Bad Intentions",
         "X (Deluxe Version)",
         "Recess",
         "The Edge (From the Motion Picture \"The Amazing Spider-Man 2\")",
         "That's My Jam",
         "Concrete Angel",
         "Little Machines",
         "SUCKER",
         "G I R L",
         "Don't Be S.A.F.E.",
         "The Other Side",
         "The Other Side",
         "Reboot",
         "V",
         "The New Classic",
         "Kanye",
         "Random White Dude Be Everywhere",
         "Random White Dude Be Everywhere",
         "Random White Dude Be Everywhere",
         "Random White Dude Be Everywhere",
         "Random White Dude Be Everywhere",
         "Random White Dude Be Everywhere",
         "Random White Dude Be Everywhere",
         "Random White Dude Be Everywhere",
         "Random White Dude Be Everywhere",
         "Random White Dude Be Everywhere",
         "Random White Dude Be Everywhere",
         "Random White Dude Be Everywhere",
         "Money Baby",
         "Cut Her Off",
         "My Everything",
         "Nothing Was The Same",
         "PARTYNEXTDOOR",
         "The Best of Nelly Furtado",
         "BEYONCÉ [Platinum Edition]",
         "Love In The Future",
         "The Pinkprint",
         "Talking Dreams",
         "Old Boots, New Dirt"
      ],
      "artist_pop_pl":[
         79,
         81,
         96,
         86,
         92,
         70,
         86,
         59,
         64,
         32,
         81,
         86,
         81,
         85,
         86,
         88,
         88,
         75,
         82,
         82,
         85,
         87,
         86,
         86,
         71,
         84,
         72,
         83,
         83,
         89,
         66,
         25,
         87,
         84,
         78,
         78,
         55,
         52,
         80,
         65,
         60,
         60,
         60,
         60,
         90,
         78,
         80,
         80,
         60,
         43,
         69,
         80,
         92,
         59,
         65,
         67,
         77,
         93,
         66,
         86,
         78,
         39,
         79,
         95,
         87,
         55,
         93,
         87,
         97,
         97,
         97,
         97,
         80,
         61,
         86,
         79,
         79,
         66,
         65,
         65,
         65,
         65,
         65,
         81,
         78,
         60,
         90,
         81,
         55,
         35,
         63,
         63,
         82,
         83,
         62,
         55,
         55,
         40,
         89,
         75,
         85,
         84,
         84,
         69,
         78,
         84,
         84,
         84,
         84,
         84,
         78,
         84,
         84,
         68,
         68,
         76,
         98,
         80,
         87,
         87,
         81,
         90,
         63,
         78
      ],
      "artists_followers_pl":[
         2979355,
         10616880,
         16780345,
         11124521,
         34753939,
         1252851,
         13351301,
         642317,
         545398,
         9565,
         10616880,
         13351301,
         11559186,
         10917261,
         19080875,
         22715182,
         22715182,
         3243327,
         5334113,
         5334113,
         10917261,
         9142911,
         17680286,
         17680286,
         2987540,
         21918370,
         7426513,
         9537420,
         9537420,
         35660772,
         255226,
         4430,
         6259946,
         2457277,
         3638728,
         3638728,
         320870,
         263185,
         4717503,
         687147,
         814804,
         814804,
         814804,
         814804,
         12543291,
         3704387,
         6825828,
         6781282,
         283160,
         186274,
         2040593,
         5710363,
         11270410,
         284539,
         317269,
         484909,
         2006723,
         77395215,
         650384,
         17680286,
         7770093,
         48393,
         2571148,
         94437255,
         9142911,
         316592,
         77395215,
         4168033,
         42984888,
         42984888,
         42984888,
         42984888,
         5611842,
         339954,
         19080875,
         7312799,
         6919697,
         985769,
         1943504,
         419066,
         419066,
         419066,
         419066,
         7984163,
         2151505,
         455793,
         15618321,
         7984163,
         316592,
         3706,
         281099,
         278431,
         2589435,
         3864604,
         280347,
         316592,
         316592,
         16966,
         35660772,
         5928632,
         18906144,
         2457277,
         2457277,
         994597,
         3638728,
         2457277,
         2457277,
         2457277,
         2457277,
         2457277,
         3638728,
         2457277,
         2457277,
         1434353,
         1434353,
         9737642,
         62021773,
         3688081,
         6259946,
         30713126,
         6131889,
         24503334,
         576379,
         5139127
      ],
      "track_pop_pl":[
         "69",
         "55",
         "69",
         "82",
         "76",
         "46",
         "0",
         "35",
         "0",
         "0",
         "44",
         "68",
         "0",
         "51",
         "0",
         "82",
         "0",
         "67",
         "0",
         "75",
         "71",
         "79",
         "58",
         "55",
         "0",
         "74",
         "44",
         "0",
         "0",
         "86",
         "46",
         "0",
         "0",
         "19",
         "49",
         "28",
         "47",
         "31",
         "54",
         "51",
         "0",
         "0",
         "0",
         "0",
         "0",
         "56",
         "60",
         "77",
         "69",
         "28",
         "0",
         "58",
         "64",
         "49",
         "75",
         "81",
         "1",
         "0",
         "32",
         "51",
         "0",
         "47",
         "85",
         "54",
         "82",
         "0",
         "0",
         "82",
         "57",
         "59",
         "55",
         "76",
         "79",
         "60",
         "0",
         "70",
         "72",
         "64",
         "71",
         "44",
         "25",
         "0",
         "34",
         "60",
         "66",
         "55",
         "0",
         "55",
         "48",
         "26",
         "50",
         "45",
         "54",
         "0",
         "45",
         "52",
         "31",
         "27",
         "61",
         "70",
         "62",
         "20",
         "2",
         "4",
         "28",
         "41",
         "1",
         "19",
         "1",
         "1",
         "0",
         "0",
         "0",
         "0",
         "0",
         "0",
         "0",
         "70",
         "0",
         "67",
         "0",
         "71",
         "75",
         "50"
      ],
      "artist_genres_pl":[
         "'art pop', 'dance pop', 'electropop', 'metropopolis', 'pop', 'swedish electropop', 'swedish pop', 'swedish synthpop'",
         "'dance pop', 'edm', 'electropop', 'indietronica', 'metropopolis', 'pop', 'uk pop'",
         "'chicago rap', 'rap'",
         "'dance pop', 'latin', 'latin pop', 'latin viral pop', 'pop', 'rap latina', 'reggaeton', 'trap latino'",
         "'permanent wave', 'pop'",
         "'reggae fusion'",
         "'dance pop', 'piano rock', 'pop', 'pop rock'",
         "'modern alternative rock'",
         "'alternative r\u0026b', 'indie soul'",
         "'future garage', 'uk bass'",
         "'dance pop', 'edm', 'electropop', 'indietronica', 'metropopolis', 'pop', 'uk pop'",
         "'dance pop', 'piano rock', 'pop', 'pop rock'",
         "'dance pop', 'pop', 'pop rap', 'urban contemporary'",
         "'dance pop', 'pop', 'pop rap', 'post-teen pop'",
         "'dance pop', 'pop', 'uk pop'",
         "'australian dance', 'australian pop', 'pop'",
         "'australian dance', 'australian pop', 'pop'",
         "'dance pop', 'neo mellow', 'pop', 'pop rock', 'post-teen pop', 'viral pop'",
         "'alternative r\u0026b', 'dance pop', 'pop', 'r\u0026b', 'urban contemporary'",
         "'alternative r\u0026b', 'dance pop', 'pop', 'r\u0026b', 'urban contemporary'",
         "'dance pop', 'pop', 'pop rap', 'post-teen pop'",
         "'alternative r\u0026b', 'hip hop', 'lgbtq+ hip hop', 'neo soul', 'pop'",
         "'dance pop', 'pop', 'post-teen pop'",
         "'dance pop', 'pop', 'post-teen pop'",
         "'latin', 'latin hip hop', 'reggaeton', 'trap latino'",
         "'dance pop', 'pop', 'post-teen pop'",
         "'dance pop', 'electropop', 'pop', 'post-teen pop', 'viral pop'",
         "'atl hip hop', 'contemporary r\u0026b', 'dance pop', 'pop', 'r\u0026b', 'south carolina hip hop', 'urban contemporary'",
         "'atl hip hop', 'contemporary r\u0026b', 'dance pop', 'pop', 'r\u0026b', 'south carolina hip hop', 'urban contemporary'",
         "'pop', 'pop rock'",
         "'dance pop', 'electropop'",
         "",
         "'big room', 'brostep', 'dance pop', 'dutch edm', 'edm', 'house', 'pop', 'pop dance', 'slap house', 'trance', 'tropical house'",
         "'dance pop', 'edm', 'electro house', 'house', 'moombahton', 'ninja', 'pop', 'pop dance', 'tropical house'",
         "'dance pop', 'edm', 'electro house', 'pop dance', 'pop rap', 'tropical house'",
         "'dance pop', 'edm', 'electro house', 'pop dance', 'pop rap', 'tropical house'",
         "'metalcore', 'nu-metalcore', 'trancecore'",
         "'metalcore', 'pop punk', 'post-hardcore', 'progressive post-hardcore', 'screamo'",
         "'dance pop', 'edm', 'pop', 'pop dance', 'tropical house', 'uk dance', 'uk funky'",
         "'modern alternative rock', 'modern rock', 'neon pop punk', 'pop emo', 'pop punk', 'rock'",
         "'drum and bass', 'edm', 'electro house', 'melodic dubstep'",
         "'drum and bass', 'edm', 'electro house', 'melodic dubstep'",
         "'drum and bass', 'edm', 'electro house', 'melodic dubstep'",
         "'drum and bass', 'edm', 'electro house', 'melodic dubstep'",
         "'latin', 'latin hip hop', 'reggaeton', 'trap latino'",
         "'colombian pop', 'latin', 'latin pop', 'mexican pop', 'rock en espanol', 'tropical'",
         "'bachata', 'latin', 'latin hip hop', 'latin pop'",
         "'candy pop', 'pixie', 'pop emo', 'pop punk'",
         "'electropop', 'pop edm'",
         "'memphis hip hop', 'trap'",
         "'alternative metal', 'funk metal', 'nu metal', 'pop rock', 'post-grunge', 'rap rock', 'rock'",
         "'complextro', 'dance pop', 'edm', 'electro house', 'electropop', 'german techno', 'pop', 'pop dance', 'pop rap', 'tropical house'",
         "'atl hip hop', 'hip hop', 'pop rap', 'rap', 'southern hip hop', 'trap'",
         "'hip pop', 'miami hip hop'",
         "'tropical house', 'uk contemporary r\u0026b'",
         "'pop rap'",
         "'edm', 'house', 'pop', 'uk dance'",
         "'dance pop', 'pop'",
         "'art pop', 'british soul', 'dance pop', 'electropop', 'neo soul', 'pop', 'pop soul', 'tropical house'",
         "'dance pop', 'pop', 'post-teen pop'",
         "'british soul', 'indie r\u0026b', 'neo soul'",
         "'uk dance'",
         "'folk-pop', 'modern rock', 'pop', 'pop rock'",
         "'pop', 'uk pop'",
         "'alternative r\u0026b', 'hip hop', 'lgbtq+ hip hop', 'neo soul', 'pop'",
         "'candy pop', 'neon pop punk', 'pixie', 'pop emo', 'pop punk'",
         "'dance pop', 'pop'",
         "'hip hop', 'pop', 'pop rap', 'r\u0026b', 'southern hip hop', 'trap', 'trap soul'",
         "'canadian contemporary r\u0026b', 'canadian pop', 'pop'",
         "'canadian contemporary r\u0026b', 'canadian pop', 'pop'",
         "'canadian contemporary r\u0026b', 'canadian pop', 'pop'",
         "'canadian contemporary r\u0026b', 'canadian pop', 'pop'",
         "'chicago rap', 'dance pop', 'pop', 'pop rap', 'r\u0026b', 'southern hip hop', 'trap', 'urban contemporary'",
         "'electropowerpop', 'neon pop punk', 'pop punk', 'screamo', 'trancecore'",
         "'dance pop', 'pop', 'uk pop'",
         "'dance pop', 'edm', 'electropop', 'pop', 'pop rap', 'post-teen pop', 'tropical house', 'uk pop'",
         "'modern rock', 'permanent wave', 'rock'",
         "'post-hardcore', 'rock', 'screamo'",
         "'dance pop', 'pop', 'post-teen pop', 'viral pop'",
         "'dubstep', 'edm', 'electro house', 'future bass', 'melodic dubstep', 'pop dance', 'pop edm', 'progressive trance'",
         "'dubstep', 'edm', 'electro house', 'future bass', 'melodic dubstep', 'pop dance', 'pop edm', 'progressive trance'",
         "'dubstep', 'edm', 'electro house', 'future bass', 'melodic dubstep', 'pop dance', 'pop edm', 'progressive trance'",
         "'dubstep', 'edm', 'electro house', 'future bass', 'melodic dubstep', 'pop dance', 'pop edm', 'progressive trance'",
         "'brostep', 'complextro', 'edm', 'electro', 'pop rap'",
         "'modern rock'",
         "'alt z', 'pop'",
         "'dance pop', 'pop', 'pop rap', 'r\u0026b', 'rap'",
         "'brostep', 'complextro', 'edm', 'electro', 'pop rap'",
         "'candy pop', 'neon pop punk', 'pixie', 'pop emo', 'pop punk'",
         "",
         "'edm', 'pop dance', 'progressive electro house', 'progressive house', 'progressive trance', 'trance', 'uplifting trance'",
         "'canadian folk', 'canadian pop', 'electropop', 'indie poptimism'",
         "'art pop', 'candy pop', 'dance pop', 'electropop', 'metropopolis', 'pop', 'uk pop'",
         "'pop rap'",
         "'atl hip hop', 'rap', 'southern hip hop', 'trap'",
         "'candy pop', 'neon pop punk', 'pixie', 'pop emo', 'pop punk'",
         "'candy pop', 'neon pop punk', 'pixie', 'pop emo', 'pop punk'",
         "",
         "'pop', 'pop rock'",
         "'australian hip hop', 'dance pop', 'hip pop', 'pop', 'pop rap', 'post-teen pop'",
         "'dance pop', 'edm', 'electropop', 'pop', 'pop dance', 'tropical house'",
         "'dance pop', 'edm', 'electro house', 'house', 'moombahton', 'ninja', 'pop', 'pop dance', 'tropical house'",
         "'dance pop', 'edm', 'electro house', 'house', 'moombahton', 'ninja', 'pop', 'pop dance', 'tropical house'",
         "'bass trap', 'edm', 'electro house', 'electronic trap', 'pop dance'",
         "'dance pop', 'edm', 'electro house', 'pop dance', 'pop rap', 'tropical house'",
         "'dance pop', 'edm', 'electro house', 'house', 'moombahton', 'ninja', 'pop', 'pop dance', 'tropical house'",
         "'dance pop', 'edm', 'electro house', 'house', 'moombahton', 'ninja', 'pop', 'pop dance', 'tropical house'",
         "'dance pop', 'edm', 'electro house', 'house', 'moombahton', 'ninja', 'pop', 'pop dance', 'tropical house'",
         "'dance pop', 'edm', 'electro house', 'house', 'moombahton', 'ninja', 'pop', 'pop dance', 'tropical house'",
         "'dance pop', 'edm', 'electro house', 'house', 'moombahton', 'ninja', 'pop', 'pop dance', 'tropical house'",
         "'dance pop', 'edm', 'electro house', 'pop dance', 'pop rap', 'tropical house'",
         "'dance pop', 'edm', 'electro house', 'house', 'moombahton', 'ninja', 'pop', 'pop dance', 'tropical house'",
         "'dance pop', 'edm', 'electro house', 'house', 'moombahton', 'ninja', 'pop', 'pop dance', 'tropical house'",
         "'atl hip hop', 'hip hop', 'melodic rap', 'pop r\u0026b', 'pop rap', 'r\u0026b', 'rap', 'southern hip hop', 'trap'",
         "'atl hip hop', 'hip hop', 'melodic rap', 'pop r\u0026b', 'pop rap', 'r\u0026b', 'rap', 'southern hip hop', 'trap'",
         "'dance pop', 'pop', 'pop rap', 'post-teen pop'",
         "'canadian hip hop', 'canadian pop', 'hip hop', 'rap', 'toronto rap'",
         "'pop', 'pop rap', 'r\u0026b', 'rap', 'toronto rap', 'trap', 'urban contemporary'",
         "'big room', 'brostep', 'dance pop', 'dutch edm', 'edm', 'house', 'pop', 'pop dance', 'slap house', 'trance', 'tropical house'",
         "'dance pop', 'pop', 'r\u0026b'",
         "'neo soul', 'pop', 'pop soul', 'r\u0026b', 'urban contemporary'",
         "'dance pop', 'hip pop', 'pop', 'pop rap', 'queens hip hop'",
         "'dance pop', 'indie poptimism', 'pop rock'",
         "'contemporary country', 'country', 'country road'"
      ]
   }
}
```

In [35]:
from datetime import datetime

TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")

learning_job_name = f"two_tower_gpu_{DATASET_NAME}_{TIMESTAMP}"

config = f"""workerPoolSpecs:
  -
    machineSpec:
      machineType: n1-highmem-8
      acceleratorType: NVIDIA_TESLA_K80
      acceleratorCount: 1
    replicaCount: 1
    containerSpec:
      imageUri: {LEARNER_IMAGE_URI}
      args:
      - --training_data_path={TRAINING_DATA_PATH}
      - --input_schema_path={INPUT_SCHEMA_PATH}
      - --job-dir={OUTPUT_DIR}
"""

!echo $'{config}' > ./config.yaml

CREATION_LOG = ! gcloud ai custom-jobs create \
  --display-name={learning_job_name} \
  --region={REGION} \
  --config=config.yaml

print(CREATION_LOG)

['Using endpoint [https://us-central1-aiplatform.googleapis.com/]', 'CustomJob [projects/934903580331/locations/us-central1/customJobs/624465704850030592] is submitted successfully.', '', 'Your job is still active. You may view the status of your job with the command', '', '  $ gcloud ai custom-jobs describe projects/934903580331/locations/us-central1/customJobs/624465704850030592', '', 'or continue streaming the logs with the command', '', '  $ gcloud ai custom-jobs stream-logs projects/934903580331/locations/us-central1/customJobs/624465704850030592']


If you want to use TFRecord input file format, you can try the following command:

In [None]:
TRAINING_DATA_PATH = f"gs://cloud-samples-data/vertex-ai/matching-engine/two-tower/{DATASET_NAME}/tfrecord/*"

learning_job_name = f"two_tower_cpu_tfrecord_{DATASET_NAME}_{TIMESTAMP}"

CREATION_LOG = ! gcloud ai custom-jobs create \
  --display-name={learning_job_name} \
  --worker-pool-spec=machine-type=n1-standard-8,replica-count=1,container-image-uri={LEARNER_IMAGE_URI} \
  --region={REGION} \
  --args=--training_data_path={TRAINING_DATA_PATH} \
  --args=--input_schema_path={INPUT_SCHEMA_PATH} \
  --args=--job-dir={OUTPUT_DIR} \
  --args=--train_batch_size={TRAIN_BATCH_SIZE} \
  --args=--num_epochs={NUM_EPOCHS} \
  --args=--input_file_format=tfrecord

print(CREATION_LOG)

After the job is submitted successfully, you can view its details and logs:

In [14]:
JOB_ID = re.search(r"(?<=/customJobs/)\d+", CREATION_LOG[1]).group(0)
print(JOB_ID)

8188402016507133952


In [31]:
# View the job's configuration and state.
STATE = "state: JOB_STATE_PENDING"

while STATE not in ["state: JOB_STATE_SUCCEEDED", "state: JOB_STATE_FAILED"]:
    DESCRIPTION = ! gcloud ai custom-jobs describe {JOB_ID} --region={REGION}
    STATE = DESCRIPTION[-2]
    print(STATE)
    time.sleep(60)

state: JOB_STATE_FAILED


When the training starts, you can view the logs in TensorBoard. Colab users can use the TensorBoard widget below:

In [None]:
TENSORBOARD_DIR = os.path.join(OUTPUT_DIR, "tensorboard")
%tensorboard --logdir {TENSORBOARD_DIR}

For Google CLoud Notebooks users, the TensorBoard widget above won't work. We recommend you to launch TensorBoard through the Cloud Shell.

1. In your Cloud Shell, launch Tensorboard on port 8080:

    ```
    export TENSORBOARD_DIR=gs://xxxxx/tensorboard
    tensorboard --logdir=${TENSORBOARD_DIR} --port=8080 --load_fast=false
    ```

2. Click the "Web Preview" button at the top-right of the Cloud Shell window (looks like an eye in a rectangle). 

3. Select "Preview on port 8080". This should launch the TensorBoard webpage in a new tab in your browser.

After the job finishes successfully, you can view the output directory:

In [None]:
! gsutil ls {OUTPUT_DIR}

## Deploy on Vertex Prediction

### Import the model

Our training job will export two TF SavedModels under `gs://<job_dir>/query_model` and `gs://<job_dir>/candidate_model`. These exported models can be used for online or batch prediction in Vertex Prediction. First, import the query (or candidate) model:

In [None]:
# The following imports the query (user) encoder model.
MODEL_TYPE = "query"
# Use the following instead to import the candidate (movie) encoder model.
# MODEL_TYPE = 'candidate'

DISPLAY_NAME = f"{DATASET_NAME}_{MODEL_TYPE}"  # The display name of the model.
MODEL_NAME = f"{MODEL_TYPE}_model"  # Used by the deployment container.

In [None]:
aiplatform.init(
    project=PROJECT_ID,
    location=REGION,
    staging_bucket=BUCKET_NAME,
)

model = aiplatform.Model.upload(
    display_name=DISPLAY_NAME,
    artifact_uri=OUTPUT_DIR,
    serving_container_image_uri="us-central1-docker.pkg.dev/cloud-ml-algos/two-tower/deploy",
    serving_container_health_route=f"/v1/models/{MODEL_NAME}",
    serving_container_predict_route=f"/v1/models/{MODEL_NAME}:predict",
    serving_container_environment_variables={
        "MODEL_BASE_PATH": "$(AIP_STORAGE_URI)",
        "MODEL_NAME": MODEL_NAME,
    },
)

### Deploy the model

After importing the model, you must deploy it to an endpoint so that you can get online predictions. More information about this process can be found in the [official documentation](https://cloud.google.com/vertex-ai/docs/predictions/deploy-model-api).

In [None]:
! gcloud ai models list --region={REGION} --filter={DISPLAY_NAME}

Create a model endpoint:

In [None]:
endpoint = aiplatform.Endpoint.create(display_name=DATASET_NAME)

Deploy model to the endpoint

In [None]:
model.deploy(
    endpoint=endpoint,
    machine_type="n1-standard-4",
    traffic_split={"0": 100},
    deployed_model_display_name=DISPLAY_NAME,
)

## Predict

Now that you have deployed the query/candidate encoder model on Vertex Prediction, you can call the model to calculate embeddings for live data. There are two methods of getting predictions, online and batch, which are shown below.

### Online prediction

[Online prediction](https://cloud.google.com/vertex-ai/docs/predictions/online-predictions-custom-models) is used to synchronously query a model on a small batch of instances with minimal latency. The following function calls the deployed Vertex Prediction model endpoint using Vertex SDK for Python:

The input data you want predictions on should be provided as a stringified JSON in the `data` field. Note that you should also provide a unique `key` field (of type str) for each input instance so that you can associate each output embedding with its corresponding input.

In [None]:
# Input items for the query model:
input_items = [
    {"data": '{"user_id": ["1"]}', "key": "key1"},
    {"data": '{"user_id": ["2"]}', "key": "key2"},
]

# Input items for the candidate model:
# input_items = [{
#     'data' : '{"movie_id": ["1"], "movie_title": ["fake title"]}',
#     'key': 'key1'
# }]

encodings = endpoint.predict(input_items)
print(f"Number of encodings: {len(encodings.predictions)}")
print(encodings.predictions[0]["encoding"])

You can also do online prediction using the gcloud CLI, as shown below:

In [None]:
import json
request = json.dumps({"instances": input_items})
with open("request.json", "w") as writer:
    writer.write(f"{request}\n")

ENDPOINT_ID = endpoint.resource_name

! gcloud ai endpoints predict {ENDPOINT_ID} \
  --region={REGION} \
  --json-request=request.json

### Batch prediction

[Batch prediction](https://cloud.google.com/vertex-ai/docs/predictions/batch-predictions) is used to asynchronously make predictions on a batch of input data.  This is recommended if you have a large input size and do not need an immediate response, such as getting embeddings for candidate objects in order to create an index for a nearest neighbor search service such as [Vertex Matching Engine](https://cloud.google.com/vertex-ai/docs/matching-engine/overview).

The input data needs to be on Cloud Storage and in JSONL format. You can use the sample query object file provided below. Like with online prediction, it's recommended to have the `key` field so that you can associate each output embedding with its corresponding input.

In [None]:
QUERY_SAMPLE_PATH = f"gs://cloud-samples-data/vertex-ai/matching-engine/two-tower/{DATASET_NAME}/query_sample.jsonl"

! gsutil cat {QUERY_SAMPLE_PATH}

The following function calls the deployed Vertex Prediction model using the sample query object input file. Note that it uses the model resource directly and doesn't require a deployed endpoint. Once you start the job, you can track its status on the [Cloud Console](https://console.cloud.google.com/vertex-ai/batch-predictions).

In [None]:
model.batch_predict(
    job_display_name=f"batch_predict_{DISPLAY_NAME}",
    gcs_source=[QUERY_SAMPLE_PATH],
    gcs_destination_prefix=OUTPUT_DIR,
    machine_type="n1-standard-4",
    starting_replica_count=1,
)

## Hyperparameter tuning

After successfully training your model, deploying it, and calling it to make predictions, you may want to optimize the hyperparameters used during training to improve your model's accuracy and performance. See the Vertex AI documentation for an [overview of hyperparameter tuning](https://cloud.google.com/vertex-ai/docs/training/hyperparameter-tuning-overview) and [how to use it in your Vertex Training jobs](https://cloud.google.com/vertex-ai/docs/training/using-hyperparameter-tuning).

For this example, the following command runs a Vertex AI hyperparameter tuning job with 8 trials that attempts to maximize the validation AUC metric. The hyperparameters it optimizes are the number of hidden layers, the size of the hidden layers, and the learning rate.

In [None]:
PARALLEL_TRIAL_COUNT = 4
MAX_TRIAL_COUNT = 8
METRIC = "val_auc"
hyper_tune_job_name = f"hyper_tune_{DATASET_NAME}_{TIMESTAMP}"

config = json.dumps(
    {
        "displayName": hyper_tune_job_name,
        "studySpec": {
            "metrics": [{"metricId": METRIC, "goal": "MAXIMIZE"}],
            "parameters": [
                {
                    "parameterId": "num_hidden_layers",
                    "scaleType": "UNIT_LINEAR_SCALE",
                    "integerValueSpec": {"minValue": 0, "maxValue": 2},
                    "conditionalParameterSpecs": [
                        {
                            "parameterSpec": {
                                "parameterId": "num_nodes_hidden_layer1",
                                "scaleType": "UNIT_LOG_SCALE",
                                "integerValueSpec": {"minValue": 1, "maxValue": 128},
                            },
                            "parentIntValues": {"values": [1, 2]},
                        },
                        {
                            "parameterSpec": {
                                "parameterId": "num_nodes_hidden_layer2",
                                "scaleType": "UNIT_LOG_SCALE",
                                "integerValueSpec": {"minValue": 1, "maxValue": 128},
                            },
                            "parentIntValues": {"values": [2]},
                        },
                    ],
                },
                {
                    "parameterId": "learning_rate",
                    "scaleType": "UNIT_LOG_SCALE",
                    "doubleValueSpec": {"minValue": 0.0001, "maxValue": 1.0},
                },
            ],
            "algorithm": "ALGORITHM_UNSPECIFIED",
        },
        "maxTrialCount": MAX_TRIAL_COUNT,
        "parallelTrialCount": PARALLEL_TRIAL_COUNT,
        "maxFailedTrialCount": 3,
        "trialJobSpec": {
            "workerPoolSpecs": [
                {
                    "machineSpec": {
                        "machineType": "n1-standard-4",
                    },
                    "replicaCount": 1,
                    "containerSpec": {
                        "imageUri": LEARNER_IMAGE_URI,
                        "args": [
                            f"--training_data_path={TRAINING_DATA_PATH}",
                            f"--input_schema_path={INPUT_SCHEMA_PATH}",
                            f"--job-dir={OUTPUT_DIR}",
                        ],
                    },
                }
            ]
        },
    }
)


! curl -X POST -H "Authorization: Bearer "$(gcloud auth print-access-token) \
 -H "Content-Type: application/json; charset=utf-8"  \
 -d '{config}' https://us-central1-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{REGION}/hyperparameterTuningJobs

## Cleaning up

To clean up all Google Cloud resources used in this project, you can [delete the Google Cloud
project](https://cloud.google.com/resource-manager/docs/creating-managing-projects#shutting_down_projects) you used for the tutorial.

Otherwise, you can delete the individual resources you created in this tutorial:

In [None]:
# Delete endpoint resource
endpoint.delete(force=True)

# Delete model resource
model.delete()

# Delete Cloud Storage objects that were created
! gsutil -m rm -r $OUTPUT_DIR