In [2]:
from pyspark.sql.types import *
from pyspark.sql.functions import *
from pyspark.sql import Window

#transv_pr_gr_psea_path = 'data/sas_402/transv_pr_gr_psea.parquet'
#transv_polhistory_psea_path = 'data/sas_401/transv_polhistory_psea.parquet'
#riskpf_path = '/group/axa_malaysia/data/adm_riskpf'
#acc_yrm = 201707

def pillar1(transv_pr_gr_psea_path, transv_polhistory_psea_path, riskpf_path, acc_yrm, output_folder='data/sas_404'):
    
    transv_pr_gr_psea = spark.read.parquet(transv_pr_gr_psea_path)
    transv_polhistory_psea = spark.read.parquet(transv_polhistory_psea_path)
    riskpf = spark.read.parquet(riskpf_path)

    # Summarize GWP records from PR_GR_PSEA
    premium = transv_pr_gr_psea\
    .filter((col('yrm') <= acc_yrm) & ((col('gwp') !=0) | (col('cwp') !=0)))\
    .groupBy('chdrnum','tranno','rskno','yrm','agentid','batcbrn','trantype')\
    .sum('gwp','cwp').withColumnRenamed('sum(gwp)','gwp').withColumnRenamed('sum(cwp)','cwp')

    # Merge Premium Summary with Pol history, by Tranno - This is to get the Zrenno
    polhist = transv_polhistory_psea[['chdrnum','tranno','zrenno']]
    polhist.cache()

    # There are two agentIDs in premium 2 - merge on them so as to kill
    premium2 = premium.join(polhist,on=['chdrnum','tranno'],how='left')\
    .withColumn('all_gwp_nb',when(col('zrenno')==0,(col('gwp')-col('cwp'))).otherwise(lit(0)))\
    .withColumn('all_gwp_re',when(col('zrenno')>0,(col('gwp')-col('cwp'))).otherwise(lit(0)))\
    .withColumn('nbrisk_nb',when(((col('trantype').isin('NB','RN')) & (col('zrenno')==0)),lit(1)).otherwise(lit(0)))\
    .withColumn('nbrisk_re',when(((col('trantype').isin('NB','RN')) & (col('zrenno')>0)),lit(1)).otherwise(lit(0)))

    premium3 = premium2.groupBy('chdrnum','tranno','rskno','zrenno','yrm','agentid','batcbrn')\
    .sum('all_gwp_nb','nbrisk_nb','all_gwp_re','nbrisk_re')\
    .withColumnRenamed('sum(all_gwp_nb)','all_gwp_nb').withColumnRenamed('sum(all_gwp_re)','all_gwp_re')\
    .withColumnRenamed('sum(nbrisk_nb)','nbrisk_nb').withColumnRenamed('sum(nbrisk_re)','nbrisk_re')
    premium3.cache()

    # Merge by Tranno with Riskpf, and output the errors if mismatch (Cancellations & Reinstatements)
    riskpf = riskpf[['chdrno','tranno','rskno','datime','recformat']]\
    .orderBy('chdrno','tranno','rskno','datime',ascending = [1,1,1,0]).drop('datime')\
    .dropDuplicates(['chdrno','tranno','rskno'])\
    .withColumnRenamed('chdrno','chdrnum')

    # Have to define udf to handle dynamic substring function
    _substring_udf = udf(lambda x: x[0:len(x)-3])

    riskpf2 = riskpf.join(polhist,on=['chdrnum','tranno'],how='inner')\
    .withColumn('rsktabl', _substring_udf(col('recformat')))\
    .drop('recformat')
    riskpf2.cache()

    # There are two zrennos in error and p1 - unsure which one to take? Set the merge as requiring both of them.
    p1 = premium3.join(riskpf2, on= ['chdrnum','tranno','rskno','zrenno'], how='inner')
    error = premium3.join(riskpf2.withColumn('InB',lit(1)), on= ['chdrnum','tranno','rskno','zrenno'],how='left')\
    .filter(isnull(col('InB'))).drop('InB','rsktabl')

    # Amend the errors, by attributing the latest tranno in the same POI, just before the missing tranno
    riskpf3 = riskpf2.withColumnRenamed('tranno','tranno_risk')

    cond = [error['chdrnum']==riskpf3['chdrnum'],
            error['rskno']==riskpf3['rskno'],
            error['zrenno']==riskpf3['zrenno'],
            error['tranno']>=riskpf3['tranno_risk']]

    # agentid is in both dataframes - set the merge as requiring both of them to be equal as well.
    error2 = error.join(riskpf3, cond, how='left')\
    .select([error[xx] for xx in error.columns] + [riskpf3['tranno_risk'],riskpf3['rsktabl']])\
    .sort('chdrnum','rskno','tranno','tranno_risk',ascending=[1,1,1,0]).dropDuplicates(['chdrnum','rskno','tranno'])

    # Perform the InA then tranno_risk = tranno by adding a column to p1 first.
    p1 = p1.withColumn('tranno_risk',col('tranno'))
    transv_p1 = p1.unionAll(error2.select(p1.columns)).sort('chdrnum','rskno','tranno')

    transv_p1.write.parquet('{}/transv_p1.parquet'.format(output_folder))

In [3]:
transv_pr_gr_psea_path = 'data/sas_402/transv_pr_gr_psea.parquet'
transv_polhistory_psea_path = 'data/sas_401/transv_polhistory_psea.parquet'
riskpf_path = '/group/axa_malaysia/data/adm_riskpf'
acc_yrm = 201707
pillar1(transv_pr_gr_psea_path,transv_polhistory_psea_path,riskpf_path,acc_yrm)

In [3]:
spark.read.parquet('data/sas_404/transv_p1.parquet').count()

15055393