In [0]:
##########################################
#  TODO                                  #
#  @author: Mahesh Madhusoodanan Pillai  #
#  @email: mahesh.pillai@databricks.com  #
##########################################

In [0]:
from py4j.java_gateway import java_import
gw = spark.sparkContext._gateway
java_import(gw.jvm, "dbsqlDialectClass")
gw.jvm.org.apache.spark.sql.jdbc.JdbcDialects.registerDialect(
  gw.jvm.dbsqlDialectClass())

In [0]:
def runSimpleHiveSQL(sql, jdbc_options):
  print(jdbc_options)
  # jdbc_options = json.loads(jdbc_options)
  hostname = jdbc_options["hostname"]
  port = jdbc_options["port"]
  auth = jdbc_options["auth"]
  conn = hive.connect(host=hostname, port=port, username='')
  
  if not(auth.lower() == 'none'):
    user = dbutils.secrets.get(
        scope=jdbc_options["user"]["secret_scope"], key=jdbc_options["user"]["key"]
    )
    password = dbutils.secrets.get(
        scope=jdbc_options["password"]["secret_scope"],
        key=jdbc_options["password"]["key"],
    )
    # password can be used only if the auth with auth='LDAP' or auth='CUSTOM' only
    conn = hive.connect(host=hostname, port=port, auth = auth, username='', password=password)

  cursor = conn.cursor()
  # query = f'DESC {table}'
  # print(f"Capturing Hive Schema for the table: {table}")
  cursor.execute(sql)
  results = cursor.fetchall()
  df = pd.DataFrame(results, columns=[col[0] for col in cursor.description])
  cursor.close()
  conn.close()
  return df

In [0]:
from pyspark.sql.functions import lit, to_timestamp, coalesce
from pyhive import hive
import pandas as pd


def captureHiveSchema(table, jdbc_options):
    print(table)
    splits = table.split(".")
    db_name, tbl_name = splits

   
    query = f'DESC {table}'
    print(f"Capturing Hive Schema for the table: {table}")
    df = runSimpleHiveSQL(query, jdbc_options)
    tbl_schema = spark.createDataFrame(df)
    # spark_df.display()

    print(f"{table}: {tbl_schema.show()}")
    tbl_schema.createOrReplaceTempView(f"tbl_schema_{tbl_name}")
    tbl_schema_with_row_num = spark.sql(
        f"""select '{db_name}' as db_name, '{tbl_name}' as table_name, original_order, col_name, data_type, comment from (SELECT
            *,
            ROW_NUMBER() OVER (
                PARTITION BY col_name
                ORDER BY
                original_order
            ) AS rn
            FROM
            (
                select
                row_number() over (
                    order by
                    a
                ) as original_order,
                col_name as col_name,
                data_type as data_type,
                `comment` as
                comment
                from
                (
                    select
                    monotonically_increasing_id() a,
                    *
                    from
                    tbl_schema_{tbl_name}
                )
                where
                -- to address partitioning and clustering columns
                data_type not in ('data_type', '')
            ))x where rn = 1"""
    )
    tbl_schema_with_row_num.show()
    
    return tbl_schema_with_row_num

In [0]:
def processHiveColNames(table, col_mapping, mismatch_exclude_fields, jdbc_options):
  cm = col_mapping
  db_name = table.split(".")[0]
  # mismatch_exclude_fields_compiled = mismatch_exclude_fields_string.format(**locals())
  # mismatch_exclude_fields = [field.strip() for field in mismatch_exclude_fields_compiled.split("|")]  
  

  # column_sql = f"SHOW COLUMNS FROM {table}"
  column_sql = f"(SELECT * FROM {table} where 1=0)a"
  # columns = runSimpleHiveSQL(column_sql, jdbc_options_json)
  columns = readHive(column_sql, db_name, jdbc_options, additional_options={}).columns
  print(f"insideHive {mismatch_exclude_fields}")
  # Remove entries against mismatch_exclude_fields
  columns = [col for col in columns if col not in mismatch_exclude_fields]
  print(columns)
  col_dict = {}
  # replace the column names with the mapped column names from the user
  col_dict = {col: cm.get(col, col) for col in columns}
  # Sorting dictionary by values (ascending order)

  sorted_col = dict(sorted(col_dict.items(), key=lambda item: item[1]))
  col_cast_list = ", ".join([f"COALESCE(CAST({key} AS STRING),'') as {key}" for key in sorted_col.keys()])
  col_list = ", ".join(sorted_col.keys())
  return col_list, col_cast_list


In [0]:
def generateHiveSqls(table, primary_keys_string, table_mapping, jdbc_options, sql_override, data_load_filter):
  # pk_columns = table_mapping.tgt_primary_keys
  # primary_keys_string = pk_columns.replace("|", ",")
  # print (primary_keys_string)
  quick_validation = table_mapping.quick_validation
  col_mapping = table_mapping.col_mapping
  mismatch_exclude_fields_string = table_mapping.mismatch_exclude_fields

  # quick_validation = True


  
  load_filter = data_load_filter if (not data_load_filter is None) else "1=1"
  read_sql = sql_override if (not sql_override is None) else f"select * from {table}"

  # read_sql_compiled = f"({read_sql.format(**locals())})a"
  read_sql_compiled = f"(SELECT hash({primary_keys_string}) as pk_hash_mmp, a.* FROM  ({read_sql.format(**locals())})a where {load_filter})b"
  boundary_sql = f"(SELECT min(hash({primary_keys_string})) lower_bound, max(hash({primary_keys_string})) upper_bound  FROM {read_sql_compiled})c"

  # if quick_validation:
  #   processColNames(table, col_mapping, primary_keys_string, mismatch_exclude_fields_string, jdbc_options)

   
  print(f"{read_sql}\n{read_sql_compiled}")
  return boundary_sql, read_sql_compiled

In [0]:
def readHive(query, db_name, jdbc_options, additional_options={}):
    # jdbc_options = json.loads(jdbc_options_json)
    hostname = jdbc_options["hostname"]
    port = jdbc_options["port"]
    auth = jdbc_options["auth"]

    url = f"jdbc:hive2://{hostname}:{port}/{db_name};hive.exec.dynamic.partition.mode=nonstrict;hive.tez.container.size=8192;hive.tez.java.opts=-Xmx6g;tez.runtime.io.sort.mb=4096;tez.runtime.unordered.output.buffer.size-mb=1024;hive.exec.max.dynamic.partitions=10000;hive.exec.max.dynamic.partitions.pernode=500;"
    
    user = "" if (auth.lower() == 'none') else dbutils.secrets.get(
        scope=jdbc_options["user"]["secret_scope"], key=jdbc_options["user"]["key"]
    )
    password = "" if (auth.lower() == 'none') else dbutils.secrets.get(
        scope=jdbc_options["password"]["secret_scope"],
        key=jdbc_options["password"]["key"],
    )

    reader_options = {
        "url": url,
        "auth": auth,
        "user": user,
        "password": password,
        # "fetchsize": "10000",
        # "compression": "snappy",
    }

    final_reader_options = {**reader_options, **additional_options}
    print(f'jdbc_query:{query}')
    df = (
        spark.read.format("jdbc")
        # .option("driver", "org.apache.hive.jdbc.HiveDriver") #default
        .option("driver", "com.amazon.hive.jdbc.HS2Driver") #EMR
        # .option("driver", "org.apache.hive.jdbc.HiveDriver") #FromEMRCluster
        # .option("driver", "com.cloudera.hive.jdbc4.HS2Driver")  #Cloudera
        .option("dbtable", query)
        .options(**final_reader_options)
        .load())
    return df

In [0]:
def captureHiveTableHash(table, primary_keys_string, mismatch_exclude_fields, sql_override, data_load_filter, table_mapping, jdbc_options):

  db_name = table.split(".")[0]
  # pk_columns = table_mapping["tgt_primary_keys"]
  # primary_keys_string = pk_columns.replace("|", ",")
  print (primary_keys_string)
  col_mapping = table_mapping.col_mapping
  # mismatch_exclude_fields_string = table_mapping.mismatch_exclude_fields

  load_filter = data_load_filter if (not data_load_filter is None) else "1=1"
  read_sql = sql_override if (not sql_override is None) else f"select * from {table}"

  # read_sql_compiled = f"({read_sql.format(**locals())})a"
  read_sql_compiled = f"(SELECT a.* FROM ({read_sql.format(**locals())})a where {load_filter})b"

  col_list, col_cast_list = processHiveColNames(read_sql_compiled, col_mapping, mismatch_exclude_fields, jdbc_options)

  #one with concatenated values for debugging
  # sql = f"""(SELECT concat_ws(":",{primary_keys_string}) as p_keys, sha2(concat_ws(":",{col_list}),256) as row_hash, concat_ws(":",{col_list}) as val from (SELECT {col_cast_list} from {read_sql_compiled})a)b"""

  sql = f"""(SELECT concat_ws(":",{primary_keys_string}) as p_keys, sha2(concat_ws(":",{col_list}),256) as row_hash from (SELECT {col_cast_list} from {read_sql_compiled})a)b"""
  print(sql)

  df = readHive(sql, db_name, jdbc_options)
  df_renamed = df.select([col(f"`{c}`").alias(re.sub('^b\.', '', c)) for c in df.columns])

  df_renamed.show()
  return df_renamed


In [0]:
from pyspark.sql.types import StringType    
from pyspark.sql.functions import col

import re


def captureHiveTable(table, primary_keys_string, sql_override, data_load_filter, table_mapping, jdbc_options):

    db_name = table.split(".")[0]

    # jdbc_options = json.loads(jdbc_options_json)
    src_cast_to_string = table_mapping.src_cast_to_string


    boundary_sql, read_sql_compiled = generateHiveSqls(table, primary_keys_string, table_mapping, jdbc_options, sql_override, data_load_filter)
    print (f"Retrieveing the boundaries")

    df_boundary = readHive(boundary_sql, db_name, jdbc_options).collect()[0]

    lowerbound, upperbound = df_boundary["c.lower_bound"], df_boundary["c.upper_bound"]
    print(f"Capturing Hive Contents for the table: {table}")

    #for cases where the hash validation returns no anomalies resulting in no data. To prevent null poitnter exception from the partition conditions.
    print(f'{lowerbound},{upperbound}')
    if (not(lowerbound is None) and not(upperbound is None)):
        partition_options ={
            "partitionColumn":"b.pk_hash_mmp",
            "numPartitions":5,
            "lowerBound":lowerbound,
            "upperBound":upperbound,
            "fetchSize":1000000
            }
        final_reader_options = {**partition_options, **jdbc_options["additional_options"]}
    else:
        final_reader_options = jdbc_options["additional_options"]

    df = readHive(read_sql_compiled, db_name, jdbc_options, final_reader_options)
    
    # Remove the "b." prefix from the column names 
    df_renamed = df.select([col(f"`{c}`").alias(re.sub('^b\.', '', c)) for c in df.columns])
    # Drop the jdbc partition column
    df_renamed_dropped = df_renamed.drop("pk_hash_mmp")

    # df = spark.read.table(table).filter(load_filter)
    to_str = df_renamed_dropped.columns
    if src_cast_to_string:
    #Convert all fields to String
        for cols in to_str:
            df_renamed_dropped = df_renamed_dropped.select([df_renamed_dropped[c].cast(StringType()).alias(c) for c in to_str])
    return df_renamed_dropped