In [0]:
from lakefed_ingest.main import *

In [0]:
src_type = dbutils.widgets.get('src_type')
src_catalog = dbutils.widgets.get('src_catalog')
src_schema = dbutils.widgets.get('src_schema')
src_table = dbutils.widgets.get('src_table')
partition_col = dbutils.widgets.get('partition_col')
partition_size_mb = int(dbutils.widgets.get('partition_size_mb'))
tgt_catalog = dbutils.widgets.get('tgt_catalog')
tgt_schema = dbutils.widgets.get('tgt_schema')
tgt_table = dbutils.widgets.get('tgt_table')

In [0]:
# Get size of source table. Table size and the specified partition size
# are used to calculate the approximate size of each individual query.
table_size_mb = get_table_size(src_catalog, src_schema, src_table, src_type)

In [0]:
# Get the lower and upper bound values of the partition column
lower_bound, upper_bound = get_partition_boundaries(src_catalog, src_schema, src_table, partition_col)

print(f'Upper and lower bound: {lower_bound}, {upper_bound}')

In [0]:
# Calculate number of partitions. Minimum is 2.
num_partitions = int(table_size_mb / partition_size_mb)
num_partitions = max(num_partitions, 2)

print(f'Number of partitions: {num_partitions}')

In [0]:
# Generate partition list
partition_list = get_partition_list(
    partition_col,
    lower_bound,
    upper_bound,
    num_partitions
)

partitions_tbl = f'{tgt_catalog}.{tgt_schema}.{tgt_table}_partitions'

# Write partitions to table
partition_df = get_partition_df(partition_list, num_partitions, 1000)
partition_df.write.option("overwriteSchema", "true").mode("overwrite").saveAsTable(partitions_tbl)

In [0]:
# Get list of ids
partitions_qry = f"""\
    select
      sort_array(array_agg(distinct batch_id)) as batch_ids,
      count(where_clause) as cnt_partitions
    from {partitions_tbl}
"""

partitions_df = spark.sql(partitions_qry)
batch_id_list = partitions_df.collect()[0]['batch_ids']
cnt_partitions = partitions_df.collect()[0]['cnt_partitions']

print(f'Count of batches: {len(batch_id_list)}')
print(f'Count of partitions: {cnt_partitions}')

# Assign id list to job task value to make it available to a for each task.
dbutils.jobs.taskValues.set(key="batch_id_list", value=batch_id_list)