# Spotify TFRS - Environment Setup

### TODO:
* fix pip installs with versions
* update env variables 
* should we use something like the `run_bq_query` function?

## Install additional packages

#### Install the following packages required to execute this notebook.

In [None]:
import os

# The Vertex AI Workbench Notebook product has specific requirements
IS_WORKBENCH_NOTEBOOK = os.getenv("DL_ANACONDA_HOME")
IS_USER_MANAGED_WORKBENCH_NOTEBOOK = os.path.exists(
    "/opt/deeplearning/metadata/env_version"
)

# Vertex AI Notebook requires dependencies to be installed with '--user'
USER_FLAG = ""
if IS_WORKBENCH_NOTEBOOK:
    USER_FLAG = "--user"

# !pip install --upgrade --no-warn-conflicts '{USER_FLAG}' -q \
#     google-cloud-pubsub==2.13.6 \
#     google-api-core==2.8.2 \
#     google-apitools==0.5.32 \
#     plotly==5.10.0 \
#     itables==1.2.0 \
#     xgboost==1.6.2 \
#     apache_beam==2.40.0 \
#     plotly==5.10.0 \
#     google-cloud-pipeline-components \
#     kfp

In [None]:
# ! pip3 install --upgrade google-cloud-aiplatform -q \
#                          google-cloud-pipeline-components -q \
#                          google-cloud-logging -q \
#                          pyarrow -q \
#                          google-cloud-storage $USER_FLAG \
#                          kfp $USER_FLAG -q
# !pip install --upgrade 'apache-beam[gcp]' --user



# ! pip3 install jsonobject

After install, retart notebook kernel..

In [None]:
# Automatically restart kernel after installs
import os

if not os.getenv("IS_TESTING"):
    import IPython

    app = IPython.Application.instance()
    app.kernel.do_shutdown(True)

## Create a Google Cloud Storage bucket and save the config data.
Next, we will create a Google Cloud Storage bucket and will save the config data in this bucket. After the cell operation finishes, you can navigate to Google Cloud Storage to see the GCS bucket.

In [None]:
import random
import string
from typing import Union

# Generate unique ID to help w/ unique naming of certain pieces
ID = "".join(random.choices(string.ascii_lowercase + string.digits, k=5))

GCP_PROJECTS = !gcloud config get-value project
PROJECT_ID = GCP_PROJECTS[0]

# GCS locations
DATA_SOURCE_BUCKET = 'spotify-million-playlist-dataset'              # where MPD gzip stored
BUCKET_NAME = f"{PROJECT_ID}-tfrs-retrieval"                         # store repo artifacts: models, pipelines, indexes
REGION = "us-central1"
BQ_DATASET = "mdp_eda_test"                                          # BQ destination from gzip, BQ source elsewhere
BQ_TABLE_TRAIN = 'train'                                             # train table
BQ_TABLE_VALID = 'valid'                                             # valid table
BQ_TABLE_CANDIDATES = 'candidates'                                   # candidates table
VPC_NETWORK = 'ucaip-haystack-vpc-network'                           # VPC network (required to interact with Matching Engine)
VERTEX_SA = '934903580331-compute@developer.gserviceaccount.com'

# MAX_PLAYLIST_LENGTH = 5 # 15

In [None]:
config = f"""
DATA_SOURCE_BUCKET   = \"{DATA_SOURCE_BUCKET}\" 
BUCKET_NAME          = \"{BUCKET_NAME}\"
PROJECT              = \"{PROJECT_ID}\"
REGION               = \"{REGION}\"
ID                   = \"{ID}\"
BQ_DATASET           = \"{BQ_DATASET}\"
BQ_TABLE_TRAIN       = \"{BQ_TABLE_TRAIN}\"
BQ_TABLE_VALID       = \"{BQ_TABLE_VALID}\"
BQ_TABLE_CANDIDATES  = \"{BQ_TABLE_CANDIDATES}\"
VPC_NETWORK          = \"{VPC_NETWORK}\"
VERTEX_SA            = \"{VERTEX_SA}\"
"""

!gsutil mb -l {REGION} gs://{BUCKET_NAME}

!echo '{config}' | gsutil cp - gs://{BUCKET_NAME}/config/notebook_env.py

## Inspect train_job dict

In [3]:
TRAIN_JOB_DICT_PICKLE = 'gs://jt-tfrs-central-v2/8m-tfrs-v1-jtv15/run-20230125-172025/train_job_dict.pkl'

In [4]:
!gsutil cp $TRAIN_JOB_DICT_PICKLE .

Copying gs://jt-tfrs-central-v2/8m-tfrs-v1-jtv15/run-20230125-172025/train_job_dict.pkl...
/ [1 files][  1.9 KiB/  1.9 KiB]                                                
Operation completed over 1 objects/1.9 KiB.                                      


In [6]:
import pickle as pkl

filehandler = open('train_job_dict.pkl', 'rb')
train_job_dict = pkl.load(filehandler)
filehandler.close()

train_job_dict

{'name': 'projects/934903580331/locations/us-central1/customJobs/3241873888052772864',
 'displayName': 'train-sp-2tower-tfrs-jtv15-v1-80gb',
 'jobSpec': {'workerPoolSpecs': [{'machineSpec': {'machineType': 'a2-highgpu-1g',
     'acceleratorType': 'NVIDIA_TESLA_A100',
     'acceleratorCount': 1},
    'replicaCount': '1',
    'diskSpec': {'bootDiskType': 'pd-ssd', 'bootDiskSizeGb': 100},
    'containerSpec': {'imageUri': 'gcr.io/hybrid-vertex/sp-2tower-tfrs-jtv15-v1-80gb-training',
     'command': ['python', 'two_tower_jt/task.py'],
     'args': ['--project=hybrid-vertex',
      '--train_output_gcs_bucket=jt-tfrs-central-v2',
      '--train_dir=spotify-data-regimes',
      '--train_dir_prefix=jtv14-8m/train_v14',
      '--valid_dir=spotify-data-regimes',
      '--valid_dir_prefix=jtv14-8m/valid_v14',
      '--candidate_file_dir=spotify-data-regimes',
      '--candidate_files_prefix=jtv14-8m/candidates',
      '--experiment_name=8m-tfrs-v1-jtv15',
      '--experiment_run=run-20230125-1720

In [8]:
train_job_dict['jobSpec']['workerPoolSpecs']

[{'machineSpec': {'machineType': 'a2-highgpu-1g',
   'acceleratorType': 'NVIDIA_TESLA_A100',
   'acceleratorCount': 1},
  'replicaCount': '1',
  'diskSpec': {'bootDiskType': 'pd-ssd', 'bootDiskSizeGb': 100},
  'containerSpec': {'imageUri': 'gcr.io/hybrid-vertex/sp-2tower-tfrs-jtv15-v1-80gb-training',
   'command': ['python', 'two_tower_jt/task.py'],
   'args': ['--project=hybrid-vertex',
    '--train_output_gcs_bucket=jt-tfrs-central-v2',
    '--train_dir=spotify-data-regimes',
    '--train_dir_prefix=jtv14-8m/train_v14',
    '--valid_dir=spotify-data-regimes',
    '--valid_dir_prefix=jtv14-8m/valid_v14',
    '--candidate_file_dir=spotify-data-regimes',
    '--candidate_files_prefix=jtv14-8m/candidates',
    '--experiment_name=8m-tfrs-v1-jtv15',
    '--experiment_run=run-20230125-172025',
    '--num_epochs=5',
    '--batch_size=8192',
    '--embedding_dim=128',
    '--projection_dim=50',
    '--layer_sizes=[512,128]',
    '--learning_rate=0.01',
    '--valid_frequency=35',
    '--valid

## Check data in BigQuery

In [None]:
import pandas as pd
from google.cloud import bigquery

# Wrapper to use BigQuery client to run query/job, return job ID or result as DF
def run_bq_query(sql: str) -> Union[str, pd.DataFrame]:
    """
    Run a BigQuery query and return the job ID or result as a DataFrame
    Args:
        sql: SQL query, as a string, to execute in BigQuery
    Returns:
        df: DataFrame of results from query,  or error, if any
    """

    bq_client = bigquery.Client()

    # Try dry run before executing query to catch any errors
    job_config = bigquery.QueryJobConfig(dry_run=True, use_query_cache=False)
    bq_client.query(sql, job_config=job_config)

    # If dry run succeeds without errors, proceed to run query
    job_config = bigquery.QueryJobConfig()
    client_result = bq_client.query(sql, job_config=job_config)

    job_id = client_result.job_id

    # Wait for query/job to finish running. then get & return data frame
    df = client_result.result().to_arrow().to_pandas()
    print(f"Finished job_id: {job_id}")
    return df