### Import Packages

In [1]:
# import argparse
import gcsfs
import numpy as np

import tensorflow as tf

from datetime import datetime

import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import SetupOptions

from pprint import pprint

In [2]:
# ! gsutil mb -l us-central1 gs://spotify-beam-v3

In [12]:
# setup
PROJECT_ID = 'hybrid-vertex'
BUCKET_NAME = 'spotify-beam-v3' # 'spotify-tfrecords-blog' # Set your Bucket name
REGION = 'us-central1' # Set the region for Dataflow jobs
VERSION = 'v3'

# storage
ROOT = f'gs://{BUCKET_NAME}/{VERSION}'

DATA_DIR = ROOT + '/data/' # Location to store data
STATS_DIR = ROOT +'/stats/' # Location to store stats 
STAGING_DIR = ROOT + '/job/staging/' # Dataflow staging directory on GCP
TEMP_DIR =  ROOT + '/job/temp/' # Dataflow temporary directory on GCP
TF_RECORD_DIR = ROOT + '/tf-records/'
CANDIDATE_DIR = ROOT + "/candidates/"

In [13]:
# estimate TF-Record shard count needed
# TF-Records
total_samples = 2262292  
samples_per_file = 300_000 
NUM_TF_RECORDS = total_samples // samples_per_file

if NUM_TF_RECORDS % total_samples:
    NUM_TF_RECORDS += 1
    
print("Number of Expected TFRecords: {}".format(NUM_TF_RECORDS)) # 5343

Number of Expected TFRecords: 8


#### Defining Custom DoFn’s

> `process()` function receives the yielded output value from the prior block as the input argument value

## Candidates

`apache_beam.io.gcp.bigquery.ReadFromBigQuery` ?
* [docs](https://beam.apache.org/releases/pydoc/2.24.0/apache_beam.io.gcp.bigquery.html?highlight=readfrombigquery)

In [38]:
def beam_pipe(args):
    
    BQ_TABLE = args['bq_source_table']
    CANDIDATE_SINK = args['candidate_sink']
    RUNNER = args['runner']
    NUM_TF_RECORDS = args['num_candidate_tfrecords']
    
    # Convert rows to tf-example
    _to_tf_example = candidates_to_tfexample(mode='candidates')
    
    # Write serialized example to tfrecords
    write_to_tf_record = beam.io.WriteToTFRecord(
        file_path_prefix = f'{CANDIDATE_SINK}/candidate-tracks', 
        file_name_suffix=".tfrecords",
        num_shards=NUM_TF_RECORDS
    )
    
    pipeline_options = beam.options.pipeline_options.GoogleCloudOptions(**args)
    pipeline_options.view_as(SetupOptions).save_main_session = True
    print(pipeline_options)

    with beam.Pipeline(RUNNER, options=pipeline_options) as pipeline:
        (pipeline 
         | "Read from BigQuery">> beam.io.Read(beam.io.BigQuerySource(query=args['source_query'], flatten_results=True))
         | 'Convert to tf Example' >> beam.ParDo(_to_tf_example)
         | 'Serialize to String' >> beam.Map(lambda example: example.SerializeToString(deterministic=True))
         | "Write as TFRecords to GCS" >> write_to_tf_record
        )

In [39]:
# setup
PROJECT_ID = 'hybrid-vertex'
BUCKET_NAME = 'spotify-beam-v3' # 'spotify-tfrecords-blog' # Set your Bucket name
REGION = 'us-central1' # Set the region for Dataflow jobs
VERSION = 'v3'

# Pipeline Params
TIMESTAMP = datetime.utcnow().strftime('%y%m%d-%H%M%S')
JOB_NAME = f'spotify-bq-tfrecords-{VERSION}-{TIMESTAMP}'
MAX_WORKERS = '20'
RUNNER = 'DataflowRunner'
NETWORK = 'ucaip-haystack-vpc-network'

# Source data
BQ_TABLE = 'candidates'
BQ_DATASET = 'mdp_eda_test'
TABLE_SPEC = f'{PROJECT_ID}:{BQ_DATASET}.{BQ_TABLE}' # need " : " between project and ds

# storage
ROOT = f'gs://{BUCKET_NAME}/{VERSION}'

DATA_DIR = ROOT + '/data/' # Location to store data
STATS_DIR = ROOT +'/stats/' # Location to store stats 
STAGING_DIR = ROOT + '/job/staging/' # Dataflow staging directory on GCP
TEMP_DIR =  ROOT + '/job/temp/' # Dataflow temporary directory on GCP
TF_RECORD_DIR = ROOT + '/tf-records/'
CANDIDATE_DIR = ROOT + "/candidates/"

QUERY = f"SELECT * FROM {PROJECT_ID}.{BQ_DATASET}.{BQ_TABLE}"

NUM_TF_RECORDS = 8

args = {
    'job_name': JOB_NAME,
    'runner': RUNNER,
    'source_query': QUERY,
    'bq_source_table': TABLE_SPEC,
    'network': NETWORK,
    'candidate_sink': CANDIDATE_DIR,
    'num_candidate_tfrecords': NUM_TF_RECORDS,
    'project': PROJECT_ID,
    'region': REGION,
    'staging_location': STAGING_DIR,
    'temp_location': TEMP_DIR,
    'save_main_session': True,
    'setup_file': 'beam-candidates/setup.py',
}
print("Pipeline args are set to:")
pprint(args)

Pipeline args are set to:
{'bq_source_table': 'hybrid-vertex:mdp_eda_test.candidates',
 'candidate_sink': 'gs://spotify-beam-v3/v3/candidates/',
 'job_name': 'spotify-bq-tfrecords-v3-220920-163929',
 'network': 'ucaip-haystack-vpc-network',
 'num_candidate_tfrecords': 8,
 'project': 'hybrid-vertex',
 'region': 'us-central1',
 'runner': 'DataflowRunner',
 'save_main_session': True,
 'setup_file': 'beam-candidates/setup.py',
 'source_query': 'SELECT * FROM hybrid-vertex.mdp_eda_test.candidates',
 'staging_location': 'gs://spotify-beam-v3/v3/job/staging/',
 'temp_location': 'gs://spotify-beam-v3/v3/job/temp/'}


/home/jupyter/spotify_mpd_two_tower


In [40]:
beam_pipe(args)

GoogleCloudOptions(create_from_snapshot=None, dataflow_endpoint=https://dataflow.googleapis.com, dataflow_kms_key=None, dataflow_service_options=None, enable_artifact_caching=False, enable_hot_key_logging=False, enable_streaming_engine=False, flexrs_goal=None, impersonate_service_account=None, job_name=spotify-bq-tfrecords-v3-220920-163929, labels=None, no_auth=False, project=hybrid-vertex, region=us-central1, service_account_email=None, staging_location=gs://spotify-beam-v3/v3/job/staging/, temp_location=gs://spotify-beam-v3/v3/job/temp/, template_location=None, transform_name_mapping=None, update=False)





ERROR:apache_beam.runners.dataflow.dataflow_runner:Console URL: https://console.cloud.google.com/dataflow/jobs/<RegionId>/2022-09-20_09_39_43-3352881796352466712?project=<ProjectId>


DataflowRuntimeException: Dataflow pipeline failed. State: FAILED, Error:
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/site-packages/apache_beam/internal/dill_pickler.py", line 285, in loads
    return dill.loads(s)
  File "/usr/local/lib/python3.7/site-packages/dill/_dill.py", line 275, in loads
    return load(file, ignore, **kwds)
  File "/usr/local/lib/python3.7/site-packages/dill/_dill.py", line 270, in load
    return Unpickler(file, ignore=ignore, **kwds).load()
  File "/usr/local/lib/python3.7/site-packages/dill/_dill.py", line 472, in load
    obj = StockUnpickler.load(self)
  File "/usr/local/lib/python3.7/site-packages/dill/_dill.py", line 462, in find_class
    return StockUnpickler.find_class(self, module, name)
AttributeError: Can't get attribute '_create_code' on <module 'dill._dill' from '/usr/local/lib/python3.7/site-packages/dill/_dill.py'>

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.7/site-packages/apache_beam/runners/worker/sdk_worker.py", line 284, in _execute
    response = task()
  File "/usr/local/lib/python3.7/site-packages/apache_beam/runners/worker/sdk_worker.py", line 357, in <lambda>
    lambda: self.create_worker().do_instruction(request), request)
  File "/usr/local/lib/python3.7/site-packages/apache_beam/runners/worker/sdk_worker.py", line 598, in do_instruction
    getattr(request, request_type), request.instruction_id)
  File "/usr/local/lib/python3.7/site-packages/apache_beam/runners/worker/sdk_worker.py", line 629, in process_bundle
    instruction_id, request.process_bundle_descriptor_id)
  File "/usr/local/lib/python3.7/site-packages/apache_beam/runners/worker/sdk_worker.py", line 462, in get
    self.data_channel_factory)
  File "/usr/local/lib/python3.7/site-packages/apache_beam/runners/worker/bundle_processor.py", line 871, in __init__
    self.ops = self.create_execution_tree(self.process_bundle_descriptor)
  File "/usr/local/lib/python3.7/site-packages/apache_beam/runners/worker/bundle_processor.py", line 928, in create_execution_tree
    descriptor.transforms, key=topological_height, reverse=True)])
  File "/usr/local/lib/python3.7/site-packages/apache_beam/runners/worker/bundle_processor.py", line 927, in <listcomp>
    get_operation(transform_id))) for transform_id in sorted(
  File "/usr/local/lib/python3.7/site-packages/apache_beam/runners/worker/bundle_processor.py", line 814, in wrapper
    result = cache[args] = func(*args)
  File "/usr/local/lib/python3.7/site-packages/apache_beam/runners/worker/bundle_processor.py", line 909, in get_operation
    pcoll_id in descriptor.transforms[transform_id].outputs.items()
  File "/usr/local/lib/python3.7/site-packages/apache_beam/runners/worker/bundle_processor.py", line 909, in <dictcomp>
    pcoll_id in descriptor.transforms[transform_id].outputs.items()
  File "/usr/local/lib/python3.7/site-packages/apache_beam/runners/worker/bundle_processor.py", line 907, in <listcomp>
    tag: [get_operation(op) for op in pcoll_consumers[pcoll_id]]
  File "/usr/local/lib/python3.7/site-packages/apache_beam/runners/worker/bundle_processor.py", line 814, in wrapper
    result = cache[args] = func(*args)
  File "/usr/local/lib/python3.7/site-packages/apache_beam/runners/worker/bundle_processor.py", line 909, in get_operation
    pcoll_id in descriptor.transforms[transform_id].outputs.items()
  File "/usr/local/lib/python3.7/site-packages/apache_beam/runners/worker/bundle_processor.py", line 909, in <dictcomp>
    pcoll_id in descriptor.transforms[transform_id].outputs.items()
  File "/usr/local/lib/python3.7/site-packages/apache_beam/runners/worker/bundle_processor.py", line 907, in <listcomp>
    tag: [get_operation(op) for op in pcoll_consumers[pcoll_id]]
  File "/usr/local/lib/python3.7/site-packages/apache_beam/runners/worker/bundle_processor.py", line 814, in wrapper
    result = cache[args] = func(*args)
  File "/usr/local/lib/python3.7/site-packages/apache_beam/runners/worker/bundle_processor.py", line 909, in get_operation
    pcoll_id in descriptor.transforms[transform_id].outputs.items()
  File "/usr/local/lib/python3.7/site-packages/apache_beam/runners/worker/bundle_processor.py", line 909, in <dictcomp>
    pcoll_id in descriptor.transforms[transform_id].outputs.items()
  File "/usr/local/lib/python3.7/site-packages/apache_beam/runners/worker/bundle_processor.py", line 907, in <listcomp>
    tag: [get_operation(op) for op in pcoll_consumers[pcoll_id]]
  File "/usr/local/lib/python3.7/site-packages/apache_beam/runners/worker/bundle_processor.py", line 814, in wrapper
    result = cache[args] = func(*args)
  File "/usr/local/lib/python3.7/site-packages/apache_beam/runners/worker/bundle_processor.py", line 912, in get_operation
    transform_id, transform_consumers)
  File "/usr/local/lib/python3.7/site-packages/apache_beam/runners/worker/bundle_processor.py", line 1206, in create_operation
    return creator(self, transform_id, transform_proto, payload, consumers)
  File "/usr/local/lib/python3.7/site-packages/apache_beam/runners/worker/bundle_processor.py", line 1560, in create_par_do
    parameter)
  File "/usr/local/lib/python3.7/site-packages/apache_beam/runners/worker/bundle_processor.py", line 1596, in _create_pardo_operation
    dofn_data = pickler.loads(serialized_fn)
  File "/usr/local/lib/python3.7/site-packages/apache_beam/internal/pickler.py", line 52, in loads
    encoded, enable_trace=enable_trace, use_zlib=use_zlib)
  File "/usr/local/lib/python3.7/site-packages/apache_beam/internal/dill_pickler.py", line 289, in loads
    return dill.loads(s)
  File "/usr/local/lib/python3.7/site-packages/dill/_dill.py", line 275, in loads
    return load(file, ignore, **kwds)
  File "/usr/local/lib/python3.7/site-packages/dill/_dill.py", line 270, in load
    return Unpickler(file, ignore=ignore, **kwds).load()
  File "/usr/local/lib/python3.7/site-packages/dill/_dill.py", line 472, in load
    obj = StockUnpickler.load(self)
  File "/usr/local/lib/python3.7/site-packages/dill/_dill.py", line 462, in find_class
    return StockUnpickler.find_class(self, module, name)
AttributeError: Can't get attribute '_create_code' on <module 'dill._dill' from '/usr/local/lib/python3.7/site-packages/dill/_dill.py'>


## Full Train Data

In [27]:
# setup
PROJECT_ID = 'hybrid-vertex'
BUCKET_NAME = 'spotify-beam-v1' # 'spotify-tfrecords-blog' # Set your Bucket name
REGION = 'us-central1' # Set the region for Dataflow jobs
VERSION = 'v3'

# storage
ROOT = f'gs://{BUCKET_NAME}/{VERSION}'

DATA_DIR = ROOT + '/data/' # Location to store data
STATS_DIR = ROOT +'/stats/' # Location to store stats 
STAGING_DIR = ROOT + '/job/staging/' # Dataflow staging directory on GCP
TEMP_DIR =  ROOT + '/job/temp/' # Dataflow temporary directory on GCP
TF_RECORD_DIR = ROOT + '/tf-records/'
CANDIDATE_DIR = ROOT + "/candidates/"

# estimate TF-Record shard count needed
# TF-Records
total_samples = 65_346_428  
samples_per_file = 12_800 
NUM_TF_RECORDS = total_samples // samples_per_file

if NUM_TF_RECORDS % total_samples:
    NUM_TF_RECORDS += 1
    
print("Number of Expected TFRecords: {}".format(NUM_TF_RECORDS)) # 5343

Number of Expected TFRecords: 5106


In [9]:
class TrainTfSeqExampleDoFn(beam.DoFn):
    """
    Convert training sample into TFExample
    """
    def __init__(self, task):
        """
        Initialization
        """
        self.task = task

    @staticmethod
    def _bytes_feature(value):
        """
        Get byte features
        """
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

    @staticmethod
    def _int64_feature(value):
        """
        Get int64 feature
        """
        return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
    
    @staticmethod
    def _string_array(value):
        """
        Returns a bytes_list from a string / byte.
        """
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[v.encode('utf-8') for v in value]))
    
    @staticmethod
    def _float_feature(value):
        """Returns a float_list from a float / double."""
        return tf.train.Feature(float_list=tf.train.FloatList(value=[float(v) for v in value]))
    
    def process(self, data):
        """
        Convert BQ row to tf-example
        """
    
        # ===============================
        # Ragged Features - Query
        # ===============================
        ragged_key_list = [
            'track_name_pl',
            'artist_name_pl',
            'album_name_pl',
            # 'track_uri_pl',
            'duration_ms_songs_pl',
            'artist_pop_pl',
            'artists_followers_pl',
            'track_pop_pl',
            'artist_genres_pl',
        ]

        ragged_dict = {}

        for _ in ragged_key_list:
            ragged_dict[_] = []

        for x in data['track_name_pl']:
            ragged_dict['track_name_pl'].append(x.encode('utf8'))

        for x in data['artist_name_pl']:
            ragged_dict['artist_name_pl'].append(x.encode('utf8'))

        for x in data['album_name_pl']:
            ragged_dict['album_name_pl'].append(x.encode('utf8'))

        # for x in data['track_uri_pl']:
        #     ragged_dict['track_uri_pl'].append(x.encode('utf8'))

        for x in data['duration_ms_songs_pl']:
            ragged_dict['duration_ms_songs_pl'].append(x)

        for x in data['artist_pop_pl']:
            ragged_dict['artist_pop_pl'].append(x)

        for x in data['artists_followers_pl']:
            ragged_dict['artists_followers_pl'].append(x)

        for x in data['track_pop_pl']:
            ragged_dict['track_pop_pl'].append(x)

        for x in data['artist_genres_pl']:
            ragged_dict['artist_genres_pl'].append(x.encode('utf8'))

        # Set List Types
        # Bytes
        track_name_pl = BytesList(value=ragged_dict['track_name_pl'])
        artist_name_pl = BytesList(value=ragged_dict['artist_name_pl'])
        album_name_pl = BytesList(value=ragged_dict['album_name_pl'])
        # track_uri_pl = BytesList(value=ragged_dict['track_uri_pl'])
        artist_genres_pl = BytesList(value=ragged_dict['artist_genres_pl'])

        # Float List
        duration_ms_songs_pl = FloatList(value=ragged_dict['duration_ms_songs_pl'])
        artist_pop_pl = FloatList(value=ragged_dict['artist_pop_pl'])
        artists_followers_pl = FloatList(value=ragged_dict['artists_followers_pl'])
        track_pop_pl = FloatList(value=ragged_dict['track_pop_pl'])

        # Set FeatureLists
        # Bytes
        track_name_pl = FeatureList(feature=[Feature(bytes_list=track_name_pl)])
        artist_name_pl = FeatureList(feature=[Feature(bytes_list=artist_name_pl)])
        album_name_pl = FeatureList(feature=[Feature(bytes_list=album_name_pl)])
        # track_uri_pl = FeatureList(feature=[Feature(bytes_list=track_uri_pl)])
        artist_genres_pl = FeatureList(feature=[Feature(bytes_list=artist_genres_pl)])

        # Float Lists
        duration_ms_songs_pl = FeatureList(feature=[Feature(float_list=duration_ms_songs_pl)])
        artist_pop_pl = FeatureList(feature=[Feature(float_list=artist_pop_pl)])
        artists_followers_pl = FeatureList(feature=[Feature(float_list=artists_followers_pl)])
        track_pop_pl = FeatureList(feature=[Feature(float_list=track_pop_pl)])
        
        # ===============================
        # Create Context Features
        # ===============================
        context_features = {
            # playlist - context features
            "name": _string_array(data['name']),
            'collaborative' : _string_array(data['collaborative']),
            # 'duration_ms_seed_pl' : _float_feature(data['duration_ms_seed_pl']),
            'n_songs_pl' : _float_feature(data['n_songs_pl']),
            'num_artists_pl' : _float_feature(data['num_artists_pl']),
            'num_albums_pl' : _float_feature(data['num_albums_pl']),
            'description_pl' : _string_array(data['description_pl']),

            # seed track - context features
            'track_name_seed_track' : _string_array(data['track_name_seed_track']),
            'artist_name_seed_track' : _string_array(data['artist_name_seed_track']),
            'album_name_seed_track' : _string_array(data['album_name_seed_track']),
            # 'track_uri_seed_track' : _string_array(data['track_uri_seed_track']),
            # 'artist_uri_seed_track' : _string_array(data['artist_uri_seed_track']),
            # 'album_uri_seed_track' : _string_array(data['album_uri_seed_track']),
            'duration_seed_track' : _float_feature(data['duration_seed_track']),
            'track_pop_seed_track' : _float_feature(data['track_pop_seed_track']),
            'artist_pop_seed_track' : _float_feature(data['artist_pop_seed_track']),
            'artist_genres_seed_track' : _string_array(data['artist_genres_seed_track']),
            'artist_followers_seed_track' : _float_feature(data['artist_followers_seed_track']),

            #candidate features
            "track_name_can": _string_array(data['track_name_can']), 
            "artist_name_can": _string_array(data['artist_name_can']),
            "album_name_can": _string_array(data['album_name_can']),
            "track_uri_can": _string_array(data['track_uri_can']),
            # "artist_uri_can": _string_array(data['artist_uri_can']),
            # "album_uri_can": _string_array(data['album_uri_can']),
            "duration_ms_can": _float_feature(data['duration_ms_can']),
            "track_pop_can": _float_feature(data['track_pop_can']), 
            "artist_pop_can": _float_feature(data['artist_pop_can']),
            "artist_genres_can": _string_array(data['artist_genres_can']),
            "artist_followers_can": _float_feature(data['artist_followers_can']),
        }
        
        # ===============================
        # Create Sequence
        # ===============================
        seq = SequenceExample(
            context=tf.train.Features(
                feature=context_features
            ),
            feature_lists=FeatureLists(
                feature_list={
                    "track_name_pl": track_name_pl,
                    "artist_name_pl": artist_name_pl,
                    "album_name_pl": album_name_pl,
                    # "track_uri_pl": track_uri_pl,
                    "duration_ms_songs_pl": duration_ms_songs_pl,
                    "artist_pop_pl": artist_pop_pl,
                    "artists_followers_pl": artists_followers_pl,
                    "track_pop_pl": track_pop_pl,
                    "artist_genres_pl": artist_genres_pl
                }
            )
        )

        yield seq

In [36]:
def beam_pipe(args):
    
    BQ_TABLE = args['bq_source_table']
    CANDIDATE_SINK = args['candidate_sink']
    RUNNER = args['runner']
    NUM_TF_RECORDS = args['num_candidate_tfrecords']
    
    # Convert rows to tf-example
    _to_tf_example = TrainTfSeqExampleDoFn(task="train")
    
    # Write serialized example to tfrecords
    write_to_tf_record = beam.io.WriteToTFRecord(
        file_path_prefix = f'{TF_RECORD_DIR}/{args["folder"]}', 
        file_name_suffix=".tfrecords",
        num_shards=NUM_TF_RECORDS
    )
    
    pipeline_options = beam.options.pipeline_options.GoogleCloudOptions(**args)
    pipeline_options.view_as(SetupOptions).save_main_session = True
    print(pipeline_options)

    with beam.Pipeline(RUNNER, options=pipeline_options) as pipeline:
        (pipeline 
         | "Read from BigQuery">> beam.io.Read(beam.io.BigQuerySource(query=args['source_query'], flatten_results=True))
         | 'Convert to tf Example' >> beam.ParDo(_to_tf_example)
         | 'Serialize to String' >> beam.Map(lambda example: example.SerializeToString(deterministic=True))
         | "Write as TFRecords to GCS" >> write_to_tf_record
        )

In [37]:
# setup
PROJECT_ID = 'hybrid-vertex'
BUCKET_NAME = 'spotify-beam-v3' # 'spotify-tfrecords-blog' # Set your Bucket name
REGION = 'us-central1' # Set the region for Dataflow jobs
VERSION = 'v3'

# Pipeline Params
TIMESTAMP = datetime.utcnow().strftime('%y%m%d-%H%M%S')
JOB_NAME = f'spotify-bq-tfrecords-{VERSION}-{TIMESTAMP}'
MAX_WORKERS = '20'
RUNNER = 'DataflowRunner'
NETWORK = 'ucaip-haystack-vpc-network'

# Source data
BQ_TABLE = 'train_flatten'
BQ_DATASET = 'mdp_eda_test'
TABLE_SPEC = f'{PROJECT_ID}:{BQ_DATASET}.{BQ_TABLE}' # need " : " between project and ds

# storage
ROOT = f'gs://{BUCKET_NAME}/{VERSION}'

DATA_DIR = ROOT + '/data/' # Location to store data
STATS_DIR = ROOT +'/stats/' # Location to store stats 
STAGING_DIR = ROOT + '/job/staging/' # Dataflow staging directory on GCP
TEMP_DIR =  ROOT + '/job/temp/' # Dataflow temporary directory on GCP
TF_RECORD_DIR = ROOT + '/tf-records/'
CANDIDATE_DIR = ROOT + "/candidates/"

QUERY = f"SELECT * FROM {PROJECT_ID}.{BQ_DATASET}.{BQ_TABLE}"

NUM_TF_RECORDS = 8


args = {
    'job_name': JOB_NAME,
    'runner': RUNNER,
    'source_query': QUERY,
    'bq_source_table': TABLE_SPEC,
    'network': NETWORK,
    'candidate_sink': CANDIDATE_DIR,
    'num_candidate_tfrecords': NUM_TF_RECORDS,
    'project': PROJECT_ID,
    'region': REGION,
    'staging_location': STAGING_DIR,
    'temp_location': TEMP_DIR,
    'save_main_session': True,
    'setup_file': 'beam-training/setup.py',
    'folder': 'train'
}
print("Pipeline args are set to:")
pprint(args)

Pipeline args are set to:
{'bq_source_table': 'hybrid-vertex:mdp_eda_test.train_flatten',
 'candidate_sink': 'gs://spotify-beam-v3/v3/candidates/',
 'folder': 'train',
 'job_name': 'spotify-bq-tfrecords-v3-220920-163847',
 'network': 'ucaip-haystack-vpc-network',
 'num_candidate_tfrecords': 8,
 'project': 'hybrid-vertex',
 'region': 'us-central1',
 'runner': 'DataflowRunner',
 'save_main_session': True,
 'setup_file': 'beam-training/setup.py',
 'source_query': 'SELECT * FROM hybrid-vertex.mdp_eda_test.train_flatten',
 'staging_location': 'gs://spotify-beam-v3/v3/job/staging/',
 'temp_location': 'gs://spotify-beam-v3/v3/job/temp/'}


In [34]:
beam_pipe(args)

GoogleCloudOptions(create_from_snapshot=None, dataflow_endpoint=https://dataflow.googleapis.com, dataflow_kms_key=None, dataflow_service_options=None, enable_artifact_caching=False, enable_hot_key_logging=False, enable_streaming_engine=False, flexrs_goal=None, impersonate_service_account=None, job_name=spotify-bq-tfrecords-v3-220920-163539, labels=None, no_auth=False, project=hybrid-vertex, region=us-central1, service_account_email=None, staging_location=gs://spotify-beam-v3/v3/job/staging/, temp_location=gs://spotify-beam-v3/v3/job/temp/, template_location=None, transform_name_mapping=None, update=False)







KeyboardInterrupt: 

In [None]:
BQ_TABLE = 'train_flatten_valid'

QUERY = f"SELECT * FROM {PROJECT_ID}.{BQ_DATASET}.{BQ_TABLE}"

args.update({'source_query': QUERY, 'folder': 'valid'}
            print("Pipeline args are set to:")
pprint(args)

In [None]:
beam_pipe(args)