In [None]:
import sys
from concurrent import futures
import numpy as np
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from datetime import datetime,timedelta
import copy
import json
import pytz
import hmac
from hashlib import sha256 
from pyspark.sql.functions import col
from pyspark.sql.functions import lit
from pyspark.sql.types import StringType
from pyspark.sql import types as T
spark = SparkSession\
    .builder\
    .appName("Python Demo")\
    .config("hive.metastore.client.factory.class", "com.amazonaws.glue.catalog.metastore.AWSGlueDataCatalogHiveClientFactory") \
    .config("spark.driver.maxResultSize", "4g") \
    .enableHiveSupport()\
    .getOrCreate()
spark.conf.set("hive.exec.dynamic.partition.mode","nonstrict");
spark.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic")
spark.conf.set("spark.sql.shuffle.partitions",20)
spark.conf.set("spark.files.overwrite","true")
#     ,Apache Arrow：一个跨平台的在内存中以列式存储的数据层，用来加速大数据分析速度。其可以一次性传入更大块的数据，pyspark中已经有载入该模块，需要打开该设置：
spark.conf.set("spark.sql.execution.arrow.enabled", "true")
#     创建表时自动删除已存在的目录
dbmap={
    "banda":{
        "increment_database":"banda_stream_etl",
        "database":"banda-etl-s3",
        "realTime_path":"banda"
    }
}
nowdate=(datetime.now()).strftime('%Y-%m-%d')
yesterday=(datetime.now()+ timedelta(-1)).strftime('%Y-%m-%d')
sqltemp="""with {update_snapshot_table} as (
  select *
    from (
      SELECT * ,
         row_number() OVER (PARTITION BY id ORDER BY date_format(etldate,'yyyy-MM-dd HH:mm:ss.SSS') DESC, if(etlindex is NULL, 0, etlindex) desc) row_num
         FROM `{increment_database}`.`{tableNm}`
         where date(year || '-' || month || '-' || day) <=date('{nowdate}') and date(year || '-' || month || '-' || day) >= date('{yesterday}')
      )
    where row_num = 1
  )
select  {select_sql} 
from `{database}`.`{tableNm}` l
left join {update_snapshot_table}
on l.id = {update_snapshot_table}.id
where if({update_snapshot_table}.kind is not null,{update_snapshot_table}.kind,'') <> 'delete'
union
select {update_snapshot_sql} from (
  select *
  from `{increment_database}`.`{tableNm}`
  where
    date(year || '-' || month || '-' || day) <=date('{nowdate}') and date(year || '-' || month || '-' || day) >= date('{yesterday}')
    and kind = 'insert'
  ) new
left join {update_snapshot_table}
on new.id = {update_snapshot_table}.id where if({update_snapshot_table}.kind is not null,{update_snapshot_table}.kind,'') <> 'delete' """
def get_secret_obj():
    df=spark.read.text("s3://rupiahplus-configs/etl/data_secrt/col.json").collect()
    keymap=''
    for  row in df:
        keymap=keymap+row['value']
    json_content=json.loads(keymap)
    return json_content
colmap= get_secret_obj()   
hmac_key=colmap["hmac_key"]
mobileType=colmap["mobileType"]
def getRelmobile(colNm,str):
    if colNm in mobileType:
        relmobile = str.strip().replace("+", "").replace("-", "").replace(" ", "").replace('"', '')
        if (relmobile.startswith("62")):
            relmobile = relmobile.replace("62", "0",1)
        if relmobile.startswith("0") == False:
            relmobile = "0" + relmobile;
        return relmobile
    return str
# hmac_256加密
def hmac_sha256(key,colNm,value):
    if value!=None and value!='':
#         处理手机号
        rtn=getRelmobile(colNm,value)
        h = hmac.new(key.encode('utf-8'),digestmod=sha256)
        h.update(rtn.encode('utf-8'))
        h_str = h.hexdigest()
        return h_str
spark.udf.register("hmac_sha256",hmac_sha256,T.StringType())
def getTableColum(b,update_snapshot_table,dbtype,tableNm):
    colum=""
    snapshot_colum=""
    sqlstrtemp="if({update_snapshot_table}.id is null, {colNm}     , {snapshot_colNm})  {ascolNm} ,"
    sqlstr_secret_temp=" if({update_snapshot_table}.id is null, {colNm}     , hmac_sha256('{hmac_key}','{ascolNm}',{snapshot_colNm})) {ascolNm}_x , "
    for index in range(len(b)):
        if(index==1 or (index>2 and index<len(b)-7)):
#             print(index,b[index])
#             在year前面加上加密列
            if(index==len(b)-8)and  colmap.get(dbtype)!=None and colmap[dbtype]!=None and colmap[dbtype].get(tableNm)!=None:
                for secret_col in colmap[dbtype][tableNm]:
                    sqlstr=copy.copy(sqlstr_secret_temp)
                    colum=colum+sqlstr.format(colNm=setDef("string",secret_col,False),snapshot_colNm=setDef("string",secret_col,True),ascolNm=secret_col,update_snapshot_table=update_snapshot_table,hmac_key=hmac_key)
                    snapshot_colum=snapshot_colum+setsecret_SnapshotDef(secret_col).format(update_snapshot_table=update_snapshot_table,hmac_key=hmac_key)+" ,"
            sqlstr=copy.copy(sqlstrtemp)
            colum=colum+sqlstr.format(colNm=isX(dbtype,tableNm,setDef(b[index]["data_type"],b[index]["col_name"],False)),snapshot_colNm=setDef(b[index]["data_type"],b[index]["col_name"],True),ascolNm=b[index]["col_name"],update_snapshot_table=update_snapshot_table)
            snapshot_colum=snapshot_colum+setSnapshotDef(b[index]["data_type"],b[index]["col_name"]).format(update_snapshot_table=update_snapshot_table)+" ,"
    return colum[0:len(colum)-2],snapshot_colum[0:len(snapshot_colum)-2]
def isX(type,table,column):
    columnList=column.split(".")
    if len(columnList)>0 and  colmap.get(type)!=None and colmap[type].get(table)!=None and columnList[len(columnList)-1] in colmap[type][table]:
        return column+"_x"
    return column
def setDef(type,table_col,is_snapshot):
    table_col_x=copy.copy(table_col)
    if(type[:7] == 'decimal'):
        if is_snapshot:
            return "ifnull({update_snapshot_table}."+table_col_x+",0)  "
        else:
            return "ifnull(l."+table_col_x+",0) "
    else:
        if is_snapshot:
            return "{update_snapshot_table}."+table_col_x
        else:
            return "l."+table_col_x
def setSnapshotDef(type,table_col):
    table_col_x=copy.copy(table_col)
    if(type[:7] == 'decimal'):
        return "ifnull({update_snapshot_table}."+table_col_x+",0) "+table_col
    else:
        return "{update_snapshot_table}."+table_col_x
def setsecret_SnapshotDef(table_col):
    table_col_x=copy.copy(table_col)
#     print("hmac_sha256('{hmac_key}','"+table_col_x+"', {update_snapshot_table}."+table_col_x+") "+ table_col_x+"_x")
    return "hmac_sha256('{hmac_key}','"+table_col_x+"', {update_snapshot_table}."+table_col_x+") "+ table_col_x+"_x"
#     return "sha2({update_snapshot_table}."+table_col+",256) " +table_col+"_x"
def execute(tablerow,dbtype):
#     print(tablerow,dbtype)
    tableName=tablerow["tableName"]
    if tableName=='t_auto_review_loan':
        update_snapshot_table="update_snapshot_"+dbtype+"_"+tableName
        print("`"+dbmap[dbtype]["database"]+"`."+tableName)
        print("`"+dbmap[dbtype]["increment_database"]+"`."+tableName)
        spark.catalog.refreshTable("`"+dbmap[dbtype]["database"]+"`."+tableName)
        spark.catalog.refreshTable("`"+dbmap[dbtype]["increment_database"]+"`."+tableName)
        col=spark.sql("desc `"+dbmap[dbtype]["increment_database"]+"`."+tableName).collect()
        #     #真正的列,增量的列
        real_col,snapshot_col=getTableColum(col,update_snapshot_table,dbtype,tableName)
        sql=copy.copy(sqltemp)
        real_sql=sql.format(increment_database=dbmap[dbtype]["increment_database"],tableNm=tableName,nowdate=nowdate,select_sql=real_col.format(update_snapshot_table=update_snapshot_table),update_snapshot_sql=snapshot_col.format(update_snapshot_table=update_snapshot_table),database=dbmap[dbtype]["database"],yesterday=yesterday,update_snapshot_table=update_snapshot_table,hmac_key=hmac_key)
        print(real_sql)
                df=spark.sql(real_sql).drop("etlindex")
        if colmap.get(dbtype)!=None and colmap[dbtype].get(tableName)!=None :
            for i in colmap[dbtype][tableName]:
                df=df.withColumn("col_x",F.col(i+"_x")).withColumn(i+"_x",F.col(i)).withColumn(i,F.col("col_x"))
        if colmap["extraColumn"]!=None and colmap["extraColumn"].get(dbtype)!=None and colmap["extraColumn"][dbtype].get(tableName)!=None and len(colmap["extraColumn"][dbtype][tableName])>0:
            colmap["extraColumn"][dbtype][tableName].insert(0,"*")
            df=df.selectExpr(colmap["extraColumn"][dbtype][tableName])
            df=df.withColumn("year_new",F.col("year")).drop("year").withColumnRenamed("year_new","year")
        df.drop("col_x").write.option("path", "s3://rupiahplus-data-warehouse/stream/etl_s3_temp/"+dbtype+"_"+tableName).mode("overwrite").partitionBy("year").format("orc").saveAsTable("etl_s3_temp."+dbtype+"_"+tableName);
        tablepath='s3://rupiahplus-data-warehouse/aliyun/'+dbmap[dbtype]["realTime_path"]+'/'+tableName+"/"
        print(spark.sql("select  * from  "+"etl_s3_temp."+dbtype+"_"+tableName).count())
        spark.sql("select  * from  "+"etl_s3_temp."+dbtype+"_"+tableName).write.mode("overwrite").partitionBy("year").orc(tablepath)
if __name__ == "builtins":
    print("start-----",datetime.now())
    for dbtype in dbmap:
        databasesql="show tables in "+dbmap[dbtype]["increment_database"]
        tables=spark.sql(databasesql)
        tablelist=tables.collect();
    #         for row in tablelist:
        executor=None
        with futures.ThreadPoolExecutor(max_workers=10) as executor:
            futures_result=futures.wait([executor.submit(execute, table,dbtype) for table in tablelist])
            for  future in futures_result[0]:
                print(future.result())
                print("end-----",datetime.now())