In [26]:
import avro.schema
from pyspark.sql.types import *
import pyspark.sql.functions as F
import pyspark.sql.window as window
from pyspark.sql import *

import pickle
import os.path
from datetime import datetime, timedelta, date
from dateutil.relativedelta import relativedelta

import pandas as pd
import numpy as np
import json
import re
import math
import h5py

import matplotlib
matplotlib.use('Agg') # Must be before importing matplotlib.pyplot or pylab!
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
from matplotlib.dates import DateFormatter

In [2]:
import tempfile
import sys
import os

#RTFReport for python3
sys.path.insert(0, '/nfsshare/RFORT/AuditLogExtracts/pgm//04_standard_reports_output/PyRTF')
sys.path.insert(0, '/nfsshare/RFORT/AuditLogExtracts/pgm//04_standard_reports_output/')
from RTFReport import RTFReport

#Need to set environment variables in order to make pyspark work in Python3.6
spark_home=os.environ['SPARK_HOME']
print(os.environ['PYTHONPATH'])
#os.environ['LD_LIBRARY_PATH']="/usr/lib64/:/usr/lib/:/nfsshare/apps1/dsp_engine/dsp_engine_20160707/lib/"#+os.environ['LD_LIBRARY_PATH']
os.environ['PYTHONPATH']= spark_home+"/python/lib/py4j-0.10.4-src.zip"#+os.environ['PYTHONPATH']
os.environ['PYSPARK_PYTHON']="/data/util/python3.6/bin/python3.6"
os.environ['PYSPARK_DRIVER_PYTHON']="/data/util/python3.6/bin/python3.6"

:/data/util/python3.6/site-packages:/usr/lib/python3.6/site-packages/


In [3]:
super_dir='/home/lijli06/Work/DFM_SPARK'
org_ruleset_hash_path=super_dir+"/extn_hashes/special_org_ruleset_hash.pickle"
with open(org_ruleset_hash_path,'rb') as handle:
    org_ruleset_hashdict=pickle.load(handle)
  
def sal_process_orgname(s):
    #This function is for ruleset mapping
    if s not in org_ruleset_hashdict:
        return s
    else:
        return org_ruleset_hashdict[s] 

sal_process_orgname_udf=F.udf(sal_process_orgname, StringType())

In [4]:
class GoLiveReport:
    def __init__ (self, name, bins=None, description=None,
            decimals=2, dateType=False, tableColPercent=None,
            colNames=["bin lower bound", "Count", "Percent", "Cumulative", "Cum_Pct", "PctLeft"],
            sortByFreq=False, sortHighToLow=False,            
            plot=True, plotCounts=False, plotCum=False, plotInverseCum=False, 
            plotPoints=False, plotXLog=False, plotYLog=False, 
            plotXRange=None, plotYRange=None,
            plotXLabel=None, plotYLabel=None
            ):
        """
            Description: A super class for generating a golive report. All the metrics appearing
                        in the report are subclasses to this class. This class has the method to
                        load the job configuration file, validate the given date range, load data
                        from HDFS for the given date range and orgs, bin field values in general,
                        and output rtf report
            Purpose: Define a high-level class to include general methods and attributes
            Parameters:
                name: string; used for the title of generated table and plots
                bins: None (default), list or iterable; the bins for binning a field
                description: None (default) or list of strings; the description for a table
                decimals: integer (default: 2); how many decimals for float in the table display
                dateType: False (default) or True; if True, the bins are date type and it requires special care
                colNames: list of strings (optional); column headers for the table
                tableColPercent: None (default) or list of integers; the percentage of width for each column of the table
                sortByFreq: False (default) or True (reserve for future use)
                sortHightToLow: False (default) or True; if True, sort the bins from high to low
                plot: True (default) or False; if True, include a plot in addition to the generated table
                plotCounts: False (default) or True; if True, plot the raw transaction counts
                plotCum: False (default) or True; if True, plot the cumulative percentage. When both plotCounts
                        and plotCum are False, plot the percentage
                plotInverseCum: False (default) or True; if an additional plot is made for inverse cumulative percentage
                plotPoints: False (default) or True; if True, use a dot as the marker; or else, plain line plots
                plotXLog: False (default) or True; if True, set the xaxis to log scale otherwise linear scale
                plotYLog: False (default) or True; if True, set the yaxis to log scale otherwise linear scale
                plotXRange: None (default), list of numbers, or iterable; use to determine range of xaxis and the xticks
                plotYRange: None (default), list of numbers, or iterable; use to determine range of yaxis and the yticks
                plotXLabel: None (default) or string; label for xaxis
                plotYLabel: None (default) or string; label for yaxis
        """
        
        self.value_dist = {}
        self.nCalls = None
        self.name = name
        self.bin_lower_bound = bins 
        
        self.colNames = colNames
        self.tableColPercent = tableColPercent

        self.description = description
        self.decimals = decimals

        self.sortByFreq = sortByFreq
        self.sortHighToLow = sortHighToLow
        self.plot = plot
        self.plotPoints = plotPoints
        self.plotCum = plotCum
        self.plotInverseCum = plotInverseCum
        self.plotXLog = plotXLog
        self.plotXRange = plotXRange
        self.plotYLog = plotYLog
        self.plotYRange = plotYRange
        if plotXLabel == None:
            self.plotXLabel = name
        else:
            self.plotXLabel = plotXLabel
        self.plotYLabel = plotYLabel
        
        self.plotCounts = plotCounts
        self.dateType = dateType
        
        self.extra_dates = None
        
    def load_job_config(self, file_path):
        """
            Description: read the json file for the job/task configurations
            Purpose: parse the json file given by the file_path to get the date range, 
                    data log, issuer groups, and issuer banks.
            Parameter:
                file_path: the path to the json file that has the job/task configuration
        """
        if not os.path.exists(file_path):
            print('The provided job configuration file does not exist!')
        tasks = json.load(open(file_path))
        start_date=tasks['Job 1']['StartDate']
        end_date=tasks['Job 1']['EndDate']
        self.start_date = datetime.strptime(start_date, '%Y-%m-%d').date()
        self.end_date = datetime.strptime(end_date, '%Y-%m-%d').date()
        
        log_label=tasks['Job 1']['Log']
        if log_label != 'sysauditlog': #for golive report, only use sysauditlog
            self.log_label = 'sysauditlog'
        else:
            self.log_label = log_label
            
        self.issuer_groups=tasks['Job 1']['BankGroups'] #e.g. ['NW', 'HSBC']
        self.issuer_sets=tasks['Job 1']['BankSets'] #e.g. [['NW-DEBIT', 'NW-CREDIT'], ['HSBCDEBIT', 'HSBCCREDIT']]
        
        self.dateInterval = tasks['Job 1']['DateInterval'] #integer
        self.dateUnit = tasks['Job 1']['DateUnit'] #one of 'days', 'weeks', 'months', 'quarters', or 'years'
        
    def validate_date_range(self):
        """
            Description: Examine the date range given the start date, date interval, date unit, and end date
            Purpose: Parse the date range determined by the given start date, 
                      date interval, date unit, and end date. The data are only available
                      for certain time period (e.g. sysauditlog is available from 6 months
                      ago to 6 days ago). 
                     The given start date and end date may not fit well the date interval and date unit, 
                      the end date is adjusted if possible.
            Parameters:
                None
        """
        
        #Sanity check of the date
        if self.end_date < self.start_date:
            print("End date is before start date!")
            sys.exit(1)
        if self.start_date < date.today() - relativedelta(months=6):
            print("Start date is more than 6 months ago from today. Data of last 6 months are available!")
            sys.exit(2)
            
        if self.end_date > date.today() - relativedelta(days=6): #the available data is 6 days ago
            print("End date cannot be within last 6 days!")
            self.end_date = date.today() - relativedelta(days=6)
            print('Adjust the end date to yesterday, {:s}'.format(self.end_date.strftime('%Y-%m-%d')))

        #Make sure the start date, end date and date interval are appropriately set
        dateUnit = self.dateUnit.lower()
        dateInterval = self.dateInterval
        if dateUnit=='weeks':
            dateUnit = 'days'
            dateInterval = self.dateInterval*7
        if dateUnit=='quarters':
            dateUnit = 'months'
            dateInterval = self.dateInterval*3
        
        if dateUnit=='days': #has to be in days
            delta = self.end_date - self.start_date 
            maxDelta = date.today()-relativedelta(days=6) - self.start_date
        else:
            delta = relativedelta(self.end_date, self.start_date)
            maxDelta = relativedelta(date.today()-relativedelta(days=6), self.start_date)
        
        reqInterval=np.ceil(getattr(delta, dateUnit)/dateInterval)*dateInterval \
                    if getattr(delta, dateUnit)>0 else dateInterval        
        maxInterval = getattr(maxDelta, dateUnit)//dateInterval*dateInterval
       
        if reqInterval>maxInterval:
            reqInterval = int(maxInterval)
        else:
            reqInterval = int(reqInterval) #change to integers
            
        if dateUnit=='days':
            if dateInterval==1 and self.start_date != self.end_date:
                tmp_end_date = self.start_date + relativedelta(**{dateUnit: reqInterval})
            else:
                tmp_end_date = self.start_date + relativedelta(**{dateUnit: reqInterval-1})
        else:
            tmp_end_date = self.start_date + relativedelta(**{dateUnit: reqInterval, "days": -1})
        if tmp_end_date<self.start_date:
            print('Inappropriate start date, date interval, and end date combination')
            sys.exit(3)
        elif tmp_end_date==self.start_date:
            if dateInterval != 1:
                print('Inappropriate start date, date interval, and end date combination')
                sys.exit(3)
            elif dateUnit != 'days':
                print('Inappropriate start date, date interval, and end date combination')
                sys.exit(3)
                
        if tmp_end_date != self.end_date:
            self.end_date = tmp_end_date
            print('End date is changed to {:s} in order to accomodate with date interval of {:d} {:s}' \
                  .format(self.end_date.strftime('%Y-%m-%d'), self.dateInterval, self.dateUnit))
            
        print('Start date: {:s}, End date: {:s} with interval of {:d} {:s}' \
              .format(self.start_date.strftime('%Y-%m-%d'), self.end_date.strftime('%Y-%m-%d'), \
                      self.dateInterval, self.dateUnit))
        
        #Get the target dates for the reports
        date_delta = self.end_date - self.start_date
        self.targeted_dates = [self.start_date+relativedelta(days=i) for i in range(date_delta.days+1)]   
    
    def check_n_load_data(self,spark_session,target_fields):
        """
            Description: Get the file paths and load the data from HDFS
            Purpose: Check if files exist and load the data within the desired date range for the specified orgs
            Parameters:
                spark_session: spark session instance
                target_fields: list of strings; the data fields to load from logs
            Return:
                A spark dataframe for the date range and specified orgs
        """
        
        tmp_paths = []

        if self.log_label in ['aracslog', 'arvelog', 'areslog', 'aracsstat']:
            path_label = self.log_label + '/*'
        else:
            path_label = self.log_label
            
        if self.extra_dates == None:
            self.extra_dates = self.targeted_dates
        
        if self.end_date == date.today()-relativedelta(days=6):
            # date_delta = self.end_date - self.start_date
            path_dates = self.extra_dates
        else: #extend one day later to include small amount of transactions appearing in the next day data
            # date_delta = self.end_date + relativedelta(days=1) - self.start_date
            path_dates = self.extra_dates + [self.end_date + relativedelta(days=1)]
        #path_dates = [self.start_date+relativedelta(days=i) for i in range(date_delta.days+1)]
        
        #function to check if a file exist in hdfs
        #this is needed because there are some data missing for some dates in the hdfs system
        #spark: spark session
        #path: path pointing to a file in HDFS
        def hdfs_check_file(spark, path):
            sc = spark.sparkContext
            URI           = sc._gateway.jvm.java.net.URI
            Path          = sc._gateway.jvm.org.apache.hadoop.fs.Path
            FileSystem    = sc._gateway.jvm.org.apache.hadoop.fs.FileSystem
            Configuration = sc._gateway.jvm.org.apache.hadoop.conf.Configuration

            fs = FileSystem.get(URI("hdfs://dfm-cluster"), Configuration())
            status=fs.exists(Path(path))
            
            return status
        
        for path_date in path_dates:
            tmp_path = "hdfs://dfm-cluster/DFM/"+path_label+"/" \
                        +"/".join(path_date.strftime("%Y-%m-%d").split("-")) \
                        +"/"+self.log_label+".avro"
            status = hdfs_check_file(spark_session, tmp_path)
            if status:                    
                tmp_paths.append(tmp_path)
            else:
                print('/'.join(tmp_path.split('/')[-4:]) + ' is missing in HDFS!')
        
        if 'DATELOGGED' not in target_fields:
            target_fields.append('DATELOGGED')
        if 'ORGNAME' not in target_fields:
            target_fields.append('ORGNAME')
                
        #Date fields needed for generating the golive report
        spark_df=spark_session.read.format("com.databricks.spark.avro").load(tmp_paths) \
                    .filter(sal_process_orgname_udf(F.col("ORGNAME")).isin(sum(self.issuer_sets,[]))) \
                    .withColumn('DATE', F.to_date('DATELOGGED')) \
                    .filter(F.col('DATE').isin(self.extra_dates)) \
                    .select(target_fields) \

        #need to repartition the spark data set to make sure it is not on a single node
        numParts = spark_df.rdd.getNumPartitions()
        spark_df = spark_df.repartition(int(2*numParts))
#         spark_df.cache()
#         self.spark_df = spark_df        
        return spark_df
        print("Number of Partitions: " + str(spark_df.rdd.getNumPartitions())) 
        
    def binField(self, spark_df=None, field=None):
        """
            Description: A general method to bin a field
            Purpose: Group and get the count of specified field for each unique value
            Parameters:
                spark_df: the spark dataframe loaded from avro log files
                field: string; the data field the operation will perform on
        """
        try:
            tmp_list = spark_df.select(field).groupby(field).count().rdd.map(lambda x: [x[0], x[1]]).collect()
            self.value_dist = dict(tmp_list)
            self.bin_lower_bound = list(self.value_dist.keys())
        except:
            if field==None:
                print('The field for binning is not given!')
            if spark_df==None:
                print('The spark dataframe is not given or empty!')
            if field not in spark_df.columns:
                print('The given field {:s} is not presented in the spark dataframe'.format(field))    
    
    def check_n_load_metric_data(self, org, metric, field=None):
        """
            Description: A general method to check and load the saved intermediate result
            Purpose: 1) Check the saved intermediate result and return the data that already exist;
                     2) Determine the extra dates to load data from spark
            Parameters:
                org: string; the issuer organization name
                metric: string; the metric name
                field: None (default) or string; the field name for the metric
            Return:
                The pandas dataframe of the existing data if exists otherwise None
        """
        
        try:
            metrDset = pd.read_hdf('dfm_metrics.hdf5', '{:s}/{:s}'.format(org, metric))
            for dstr in metrDset['Date'].unique():
                existing_dates.append(datetime.strptime(dstr, '%Y-%m-%d'))
            self.extra_dates = list(set(self.targeted_dates) - set(existing_dates)) #extra dates to load data
            existing_dates = set(self.targeted_dates) & set(existing_dates) #metric data already available on these dates
            if field==None
                return metrDset.loc[pd.to_datetime(metrDset['Date'],'%Y-%m-%d').isin(existing_dates), :]
            else:
                return metrDset.loc[pd.to_datetime(metrDset['Date'],'%Y-%m-%d').isin(existing_dates),['Date', field+'_value',field+'_count']]
        except:
            self.extra_dates = self.targeted_dates
            return None        
    
    def save_n_update_metric_data(self, new_data, org, metric):
        """
            Description: Save
        """
        
        hdf = pd.HDFStore('dfm_metrics.hdf5', mode='a')
        storeMetric = '{:s}/{:s}'.format(org, metric)
        hdf.put(storeMetric, new_data, format='table', append=True)   
        hdf.close()
    
    def aggregate_results(self, field, existing_data, new_data):
        all_data = pd.concat([existing_data, new_data],axis=0).reset_index()
        all_data = all_data.groupby(field+'_value',as_index=False).agg(sum)
        self.value_dist = dict([[all_data.loc[i,field+'_value'], all_data.loc[i,field+'_count']] for i in range(len(all_data))])
        
        if self.bin_lower_bound == None:
            self.bin_lower_bound = list(self.value_dist.keys)
        
    def rtfOutput (self,rtf):
        """
            Description: A centralized, complex, and powerful function to generate tables 
                        and figures in rtf
            Purpose: Make tables and figures for the given attributes to the class instance
                    and the upated attributes from computations using spark dataframe
            Parameters:
                rtf: rtf instance from RTFReport
        """
        
        if self.nCalls == None:
            if sys.version_info[0] == 3: #python 3
                self.nCalls = sum(self.value_dist.values())
            else:
                self.nCalls = sum(self.value_dist.itervalues())
          
        if (self.nCalls > 0):
            total = self.nCalls

            rtf.AddHeading1 ("Table of " + self.name)
            
            if self.description != None:   
                description = []
                for descrp in self.description:
                    description = description + ["  " + descrp] + [rtf.newline]
                description.pop()
                rtf.AddParagraph (*description)
            if self.tableColPercent == None:
                rtf.TableInit (self.colNames)
            else:
                numCols = len(self.colNames)
                numPers = len(self.tableColPercent)
                if numPers > numCols:
                    self.tableColPercent = self.tableColPercent[:numCols]
                    numPers = numCols
                totalPer = sum(self.tableColPercent)
                
                if totalPer>100:
                    self.tableColPerct = [100./numCols for i in range(numCols)]
                if numPers<numCols:
                    self.tableColPercent = self.tableColPercent + [(100.-totalPer)/(numCols-numPers) for i in range(numCols-numPers)]
                rtf.TableInit (self.colNames, T)

            cum = 0
            x,y,y2 = [], [], []
            num_positive = 0
            if (self.sortHighToLow):
                rng = range (len(self.bin_lower_bound) - 1,0 - 1,-1)
            elif (self.sortByFreq):
                rng = range (0,len(self.bin_lower_bound))
                rng = sorted(range(len(rng)), key=lambda k:rng[k])
            else:
                rng = range (0,len(self.bin_lower_bound))

            for i in rng:
                bn = self.bin_lower_bound[i]
                if bn not in self.value_dist:
                    self.value_dist[bn] = 0

                pct = 100.*self.value_dist[bn]/total
                cum += self.value_dist[bn]
                cum_pct = 100.*cum/total
                if isinstance(bn, (int, float)):
                    tabCols = [("{0:12." + str(self.decimals) + "f}").format(bn), '{0:,d}'.format(self.value_dist[bn]), '{0:7.2f}%'.format(pct), '{0:,d}'.format(cum), '{0:7.2f}%'.format(cum_pct), '{0:7.2f}%'.format(100-cum_pct)] 
                elif self.dateType:
                    if '+' in bn: #date str + values of given field
                        dateStr = bn.split('+')[0]
                        fieldValue = bn.split('+')[1] if bn.split('+')[1] !='' else 'Missing'
                        tabCols = [("{0:s}").format(dateStr), '{0:,d}'.format(self.value_dist[bn]), '{0:7.2f}%'.format(pct), '{0:,d}'.format(cum), '{0:7.2f}%'.format(cum_pct), '{0:7.2f}%'.format(100-cum_pct)] 
                        tabCols.insert(1, '{}'.format(fieldValue))
                    else:
                        dateStr = bn
                        tabCols = [("{0:s}").format(dateStr), '{0:,d}'.format(self.value_dist[bn]), '{0:7.2f}%'.format(pct), '{0:,d}'.format(cum), '{0:7.2f}%'.format(cum_pct), '{0:7.2f}%'.format(100-cum_pct)] 
                
                    if ':' in dateStr and '-' in dateStr: #date and hours
                        dateFormat = '%Y-%m-%d %H:00:00'
                    elif ':' in dateStr: #hours only
                        dateFormat = '%H:00:00'
                    else: #date only
                        dateFormat = '%Y-%m-%d'
                        
                else:
                    tabCols = [("{0:s}").format(bn), '{0:,d}'.format(self.value_dist[bn]), '{0:7.2f}%'.format(pct), '{0:,d}'.format(cum), '{0:7.2f}%'.format(cum_pct), '{0:7.2f}%'.format(100-cum_pct)] 
                
                rtf.TableAddRow (tabCols[:len(self.colNames)])
                
                if self.dateType:
                    if '+' in bn: #date str + values of given field
                        x.append([datetime.strptime(dateStr, dateFormat), fieldValue])
                    else:
                        x.append(datetime.strptime(dateStr, dateFormat))
                else:
                    x.append(bn)
                if self.plotCum:
                    y.append(cum_pct)
                    y2.append(100-cum_pct)
                elif self.plotCounts:
                    y.append(self.value_dist[bn])
                else:
                    y.append(pct)
                if ((self.plotYLog != None) and isinstance(bn, (int, float)) and (bn > 0)):
                    num_positive += self.value_dist[bn]

                if (cum == total):
                    break

            rtf.TableEnd() 
            
            if self.dateType and '+' in bn: #the case for temporal distribution for given field
                x = [[x[i][0], x[i][1], y[i]] for i in range(len(y))]
                tmpDF = pd.DataFrame(data=x, columns=['Date', 'Value', 'Stat'])  
                values = tmpDF['Value'].unique()
                tmpDF = tmpDF.groupby('Value')
                
                for i in range(len(values)):
                    value = values[i]
                    
                    if i==0:
                        y = tmpDF.get_group(value)[['Date', 'Stat']].set_index('Date')
                    else:
                        tmp = tmpDF.get_group(value)[['Date', 'Stat']].set_index('Date')
                        y = pd.concat([y, tmp['Stat']], axis=1)
                    
                    value = '{}'.format(value)
                    if value=='':
                        value = 'Missing'
                    y.rename(columns={'Stat':value}, inplace=True)
                
                y = y.reset_index()    
                x = list(y['Date'])
                y = y.iloc[:,1:]   

            if (self.plot):
                rtf.AddHeading2 ("Plot of " + self.name)

                fig = plt.figure(figsize=(6, 3.7), dpi=100)
                ax = fig.add_subplot(1,1,1)
                plt.gca().set_prop_cycle('color', ['blue', 'cyan', 'magenta', 'black', 'pink', 'red', 'green'])
                
                font0 = FontProperties()
                legend_font = font0.copy()
                legend_font.set_size('small')

                # ll = plt.plot(x, y,  label="Score", color='b', marker='^')
                if (self.plotPoints):
                    if self.plotCum:
                        ll = plt.plot(x, y, 'o')
                    else:
                        ll = plt.plot(x, y,  '-', marker='o')
                else:
                    if self.plotCum:
                        ll = plt.plot(x, y, '-')
                    else:
                        ll = plt.plot(x, y, '-')
                        
                if self.dateType: 
                    #keep the number of xticks to 9 or less
                    if len(x)>9:
                        inc = int(np.ceil(len(x)/8))
                        if (len(x)-1)%inc==0:
                            self.plotXRange = x[0::inc]
                        else:
                            self.plotXRange = x[0::inc] + [x[-1]]
                    else:
                        self.plotXRange = x
                        
                    if len(x)>48 or dateStr.find(':')==-1: #Too many entries or only date but not hours appear
                        ax.xaxis.set_major_formatter(DateFormatter ('%b-%d'))
                        ax.fmt_xdata = DateFormatter ('%M-%d')
                    elif dateStr.find('-')!=-1: #when there is date in the datetime string
                        ax.xaxis.set_major_formatter(DateFormatter ('%b-%d:%H'))
                        ax.fmt_xdata = DateFormatter ('%M-%d:%H')
                    else: #only hours in the datetime string
                        ax.xaxis.set_major_formatter(DateFormatter ('%H:00:00'))
                        ax.fmt_xdata = DateFormatter ('%H:00:00')
                    plt.xticks(rotation='vertical')
                    
                    if len(ll)>1: #there are time series for more than one value, need legend
                        ax.legend(iter(ll), list(y.columns))
                        
                # plt.set_title('Score Distribution', fontsize=12)
                # plt.grid(b=True, which='both', color='c', linestyle=':', linewidth=1)
                plt.grid(True, linestyle=':')
                plt.xlabel(self.plotXLabel,  fontsize=12)

                if self.plotYLabel != None:
                    plt.ylabel(self.plotYLabel, fontsize=12)
                else:
                    if self.plotCum:
                        if self.sortHighToLow:
                            plt.ylabel('Cumulative Percent at or Above', fontsize=12)
                        else:
                            plt.ylabel('Cumulative Percent at or Below', fontsize=12)
                    else:
                        plt.ylabel('Percent of Trans', fontsize=12)

                if self.plotXLog:
                    try:
                        plt.xscale('log')
                    except:
                        sys.stderr.write ('WARNING: {0} Could not log-scale the x-axis (pos 1)! {1}\n'.format(self.name,os.getcwd()))

                if (self.plotXRange != None):
                    try:
                        plt.xlim(self.plotXRange[0], self.plotXRange[-1])
                        plt.xticks(self.plotXRange)
                    except:
                        sys.stderr.write ('WARNING: {0} Could not set limits for log-scale the x-axis! {1}\n'.format(self.name,os.getcwd()))

                if self.plotYLog:
                    if (num_positive > 0):
                        try:
                            plt.yscale('log')
                        except:
                            sys.stderr.write ('WARNING: {0} Could not log-scale the y-axis (pos 0)! {1}\n'.format(self.name,os.getcwd()))
                            if self.description != None:
                                print("  " + self.description)
                    else:
                        sys.stderr.write ('WARNING: {0} Could not log-scale the y-axis (pos 1)! {1}\n'.format(self.name,os.getcwd()))
                        if self.description != None:
                            print("  " + self.description)

                if self.plotYRange == None:
                    maxY = np.nanmax(y)
                    if maxY<10:
                        maxY = int(np.ceil(maxY))
                        self.plotYRange = range(0,maxY+1, 1)
                    else:
                        inc = int(np.ceil(maxY/5))
                        maxY = inc*5
                        self.plotYRange = range(0, maxY+1, inc)                
                
                try:
                    plt.ylim(self.plotYRange[0], self.plotYRange[-1]*1.05)
                    plt.yticks(self.plotYRange)
                except:
                    sys.stderr.write ('WARNING: {0} Could not log-scale the y-axis (pos 2)! {1}\n'.format(self.name,os.getcwd()))

                # plt.set_xticks(range(0,1000,50))
                # plt.set_xticklabels(range(0,1000,50), rotation=90, fontsize=10)
                plt.tick_params(axis='both', which='major', labelsize=9)
                # plt.legend(loc=1,prop={'size':8})

                # plt.subplots_adjust(bottom=0.12)  # Stop cutting off the bottom of the label
                plt.subplots_adjust(bottom=0.20, left=0.20)  # Stop cutting off the bottom of the label
                # (fd, fname) = tempfile.mkstemp(suffix='.png', prefix='tmp', dir=None, text=False)
                fd = tempfile.NamedTemporaryFile(suffix='.png', prefix='tmp')
                fname = fd.name
                try:
                    fig.savefig(fname, dpi=80)
                    rtf.AddImage (fname)
                except:
                    sys.stderr.write ('WARNING: {0} Could not save figure to {1}! {2}\n'.format(self.name,fname,os.getcwd()))

            ####################################

                if self.plotCum and self.plotInverseCum:
                    rtf.AddHeading2 ("Inverse Plot of " + self.name)

                    fig = plt.figure(figsize=(6, 3.7), dpi=100)
                    font0 = FontProperties()
                    legend_font = font0.copy()
                    legend_font.set_size('small')

                    if (self.plotPoints):
                        ll = plt.plot(x, y, 'b-', marker='o')
                    else:
                        ll = plt.plot(x, y2, 'b-')

                    plt.grid(True, linestyle=':')
                    plt.xlabel(self.plotXLabel,  fontsize=12)

                    if self.plotYLabel != None:
                        plt.ylabel("Inverse of {0:s}".format(self.plotYLabel), fontsize=12)
                    else:
                        if self.sortHighToLow:
                            plt.ylabel('Cumulative Percent Below', fontsize=12)
                        else:
                            plt.ylabel('Cumulative Percent Above', fontsize=12)

                    if self.plotXLog:
                        plt.xscale('log')

                    if (self.plotXRange != None):
                        plt.xlim(self.plotXRange[0], self.plotXRange[-1])
                        plt.xticks(self.plotXRange)

                    if self.plotYLog:
                        plt.yscale('log')

                    if (self.plotYRange != None):
                        plt.ylim(self.plotYRange[0], self.plotYRange[-1])
                        plt.yticks(self.plotYRange)

                    # plt.set_xticks(range(0,1000,50))
                    # plt.set_xticklabels(range(0,1000,50), rotation=90, fontsize=10)
                    plt.tick_params(axis='both', which='major', labelsize=9)
                    # plt.legend(loc=1,prop={'size':8})

                    # plt.subplots_adjust(bottom=0.12)  # Stop cutting off the bottom of the label
                    plt.subplots_adjust(bottom=0.20)  # Stop cutting off the bottom of the label
                    # (fd, fname) = tempfile.mkstemp(suffix='.png', prefix='tmp', dir=None, text=False)
                    fd = tempfile.NamedTemporaryFile(suffix='.png', prefix='tmp')
                    fname = fd.name
                    try:
                        fig.savefig(fname, dpi=80)
                        rtf.AddImage (fname)
                    except:
                        print('WARNING: {0} Could not save figure to {1}!'.format(self.name,fname))

In [5]:
#A subclass of GoLiveReport to make statistics on date or hours for a given field
class TemporalDist(GoLiveReport):
    def __init__ (self, name='Dates', start_date=None, dateInterval=1, dateUnit='days', 
                  tableColPercent = [28, 18, 18, 18, 18], plot=False, plotXLabel=None,
                  plotCounts=True
            ):
        
        if plotCounts:
            plotYLabel = 'Transaction count'
        else:
            plotYLabel = 'Percent of transactions'
        
        GoLiveReport.__init__(self, name, dateType=True, plot=plot, plotPoints=True, plotCounts=plotCounts, 
                             plotCum=False, tableColPercent=tableColPercent, plotYLabel=plotYLabel)
        
        if start_date == None:
            self.start_date = datetime.today().date()
        else:
            self.start_date = start_date
        
        if 'HOUR'==name.upper() or 'HOUR' in name.upper(): #time series related with hour of the day   
            self.colNames = ['Hour of Day bin (UTC)', "Count", "Percent", "Cumulative", "Cum_Pct"]
            self.dateInterval = dateInterval
            self.dateUnit = 'hours'
            self.aggUnit = 'hours'
        else:
            self.dateUnit = dateUnit.lower()
            if dateUnit.lower()=='weeks':
                self.name = 'Transactions in every {:d} week(s)'.format(dateInterval)
                self.dateInterval = int(7*dateInterval)
                self.aggUnit = 'weeks'
                self.dateUnit = 'days'
            elif dateUnit.lower()=='months':
                self.name = 'Transactions in every {:d} month(s)'.format(dateInterval)
                self.dateInterval = int(dateInterval)
                self.aggUnit = 'months'
            elif dateUnit.lower()=='quarters': 
                self.name = 'Transactions in every {:d} quarter(s)'.format(dateInterval)
                self.dateInterval = 3*dateInterval
                self.aggUnit = 'quarters'
                self.dateUnit = 'months'
            elif dateUnit.lower()=='years':
                self.name = 'Transactions in every {:d} year(s)'.format(dateInterval)
                self.dateInterval = dateInterval
                self.aggUnit = 'years'
            elif dateUnit.lower()=='days':
                self.name = 'Transactions in every {:d} day(s)'.format(dateInterval)
                self.dateInterval = dateInterval
                self.aggUnit = 'days'
            else:
                print('WARNING: Cannot parse the date unit. Pass days, weeks, months, quarters, and years only.')
                sys.exit(1)
            self.colNames = ['Date bin lower bound (UTC)', "Count", "Percent", "Cumulative", "Cum_Pct"]
    

    #get the date string from the datetime string in the DATELOGGED field
    @staticmethod
    def get_date_str_udf(start_date, interval=1, unit='days'):
        def get_date_str(s, start_date, interval, unit):
            if unit.lower()=='days':
                delta = s - start_date
            else:
                delta = relativedelta(s, start_date)
            interval = getattr(delta, unit.lower())//interval*interval
            dstr = start_date + relativedelta(**{unit: interval})
            return datetime.strftime(dstr, '%Y-%m-%d')            
            
        return F.udf(lambda x: get_date_str(x, start_date, interval, unit), StringType())

    #get the hour string from the datetime string in the DATELOGGED field
    @staticmethod
    def get_hour_str_udf():
        def get_hour_str(s):
            if '+' in s: #'+' may appear to show the timezone info but it is not necessary as normally it is "+0000"
                idx_plus = s.find('+')
                s = s[:idx_plus]
            s = s.replace('T', ' ')
            s = datetime.strptime(s, '%Y-%m-%d %H:%M:%S')
            hour = datetime.strftime(s, '%Y-%m-%d %H:00:00').split()[1]
            return hour
        return F.udf(get_hour_str, StringType())
        
    def binField(self, spark_df, field=None):  
        start_date = self.start_date
        dateInterval = self.dateInterval
        unit = self.dateUnit
        
        if self.dateUnit=='hours':
            hour_str_udf = self.get_hour_str_udf()
            
            spark_df = spark_df.withColumn('datestr', hour_str_udf(F.col('DATELOGGED')))
            if spark_df.select('DATE').distinct().count()==1:
                spark_df = spark_df.withColumn('datestr', F.concat(F.date_format('DATE', 'yyyy-MM-dd'), F.lit(' '), F.col('datestr')))
            description = ["N.B.: Hour bin is the lower bin boundary.", \
                           "The date and time are in UTC timeszone, and the total number of {1:s} is: {0:8,d}"
                          ]
        else:   
            date_str_udf = self.get_date_str_udf(start_date, interval=dateInterval, unit=unit)
            spark_df = spark_df.withColumn('datestr', date_str_udf(F.col('DATE')))                
        
            description = ["N.B.: Date bin is the lower bin boundary.", \
                           "The date is based upon UTC timeszone, and the total number of {1:s} is: {0:8,d}"
                          ]
        if field != None:
            tmp_list = spark_df.select('datestr', field).groupby(['datestr', field]).count().rdd.map(lambda x: [x[0], x[1], x[2]]).collect()
            self.colNames = ["Bin of {:s} lower bound (UTC)".format(self.aggUnit), field, "Count", "Percent", "Cumulative", "Cum_Pct"]
            self.tableColPercent = [26, 22, 12, 13, 14, 13]
        else:
            tmp_list = spark_df.select('datestr').groupby('datestr').count().rdd.map(lambda x: [x[0], x[1]]).collect()
        
        numDates = spark_df.select('datestr').distinct().count()
        description[1] = description[1].format(numDates, self.aggUnit)
        self.description = description

        allList = []
        for tmp in tmp_list:
            dt = tmp[0]           
            if len(dt)>8:
                if dt.find(':')==-1: #only date
                    dateFormat = "%Y-%m-%d"                    
                else: #date and time
                    dateFormat = "%Y-%m-%d %H:00:00"
            else: #only time
                dateFormat = "%H:00:00"

            bn=datetime.strptime(dt, dateFormat)
            if field != None:
                allList.append([bn, tmp[1], tmp[2]])
            else:
                allList.append([bn, tmp[1]])
        allList = sorted(allList, key=lambda x: (x[0], x[1]))

        #create the bins and dictionary to store the counts for each bin
        self.bin_lower_bound = [] 
        tmp_list = []
        for dst in allList:           
            if field != None:
                key = "{0}+{1}".format(dst[0].strftime(dateFormat),dst[1])
                tmp_list.append([key, dst[2]])
            else:
                key = dst[0].strftime(dateFormat)
                tmp_list.append([key, dst[1]])
            self.bin_lower_bound.append(key)
        self.value_dist = dict(tmp_list)
        # value_dist = dict(tmp_list)


In [6]:
#A subclass of GoLiveReport to make Score Distribution
class ScoreDist(GoLiveReport):
    def __init__ (self, name='Score', bins=range(0,1001,50), description=None, 
                plotXRange=range(0,1001,200), plotYRange=range(0,101,20), plotYLabel=None
                ):
        
        if bins[0]>0: 
            plotInverseCum = False
            if name==None:
                name = 'Score (low outsort range)'
            if description == None:
                description = ["N.B.: bin is the lower bin boundary", \
                           "The score distribution for the high score and low outsort range."]
        else:
            plotInverseCum = True
            if description == None:
                description = ["N.B.: bin is the lower bin boundary", \
                           "The score distribution for the entire score range."]
            
        if plotXRange == None:
            plotXRange = bins;  
        
        if plotYLabel == None:
            plotYLabel='Transaction Outsort Rate (%)'
        
        GoLiveReport.__init__(self, name, bins=bins, description=description, sortHighToLow=True, 
                              decimals=0, plotCum=True, plotInverseCum=plotInverseCum, plotXRange=plotXRange, 
                              plotYRange=plotYRange, plotXLabel='Score Threshold', plotYLabel=plotYLabel
                             )
        self.bins = bins
    
    def binField(self, spark_df, card_max=False):    
        increment = (max(self.bins) - min(self.bins))//(len(self.bins)-1)
        lowest = min(self.bins)
     
        def binned(fd, increment=increment, lowest=lowest):    
            if fd==-999:
                return -999

            if fd<lowest:
                return -1;
            else:
                return int(fd//increment*increment)
        binned_udf = F.udf(binned, IntegerType())
        
        spark_df=spark_df.select('HASHCN', 'PREDICTIVE_SCORE', 'MODEL_PRIMING_SCORE') \
                    .withColumn('SCORE', F.when(F.col('PREDICTIVE_SCORE')==-999, \
                                            F.when(F.col('MODEL_PRIMING_SCORE')!='', \
                                            F.col('MODEL_PRIMING_SCORE').astype('integer')) \
                                            .otherwise(-999)).otherwise(F.col('PREDICTIVE_SCORE')))
        
        if card_max: #get the max score distribution for the card
            spark_df = spark_df.groupby('HASHCN').agg(F.max('SCORE').alias('SCORE'))

        tmp_list = spark_df.select('SCORE').withColumn('binned', binned_udf(F.col('SCORE'))) \
                           .select('binned').groupby('binned') \
                           .count().rdd.map(lambda x: [x[0], x[1]]).collect()
        self.value_dist = dict(tmp_list) 
    

In [7]:
class NumTxnDist(GoLiveReport):
    def __init__(self, name='Number of Transactions Per Card (full score range)', bins=range(1,101)
                ):

        GoLiveReport.__init__(self, name, bins=bins, decimals=0, plot=True, plotPoints=True,
                              plotXLabel='Number of transactions per card', plotYLabel='Percent of cards (%)'
                             )
        
    def binField(self, spark_df):
        card_df = spark_df.groupby('HASHCN').agg(F.count('HASHCN').alias('NUMTXN'))
        maxTxns = card_df.select(F.max('NUMTXN')).rdd.map(lambda x: x[0]).collect()[0]
        
        self.bin_lower_bound = range(1,maxTxns+1)
        self.description = ["N.B.: bin is the lower bin boundary", \
                            "The number or transactions per card during the time of this report.  Max number of transactions on one card is {0:,d}.".format(maxTxns)
                           ]        

        tmp_list = card_df.select('NUMTXN').groupby('NUMTXN') \
                           .count().rdd.map(lambda x: [x[0], x[1]]).collect()
        self.value_dist = dict(tmp_list) 
        
        #need to give xticks for plotting, only take eight ticks
        increment = int(np.ceil(maxTxns/7))
        maxTxns = int(increment*7)
        self.plotXRange = range(0, maxTxns+1, increment)

In [8]:
class CardScoreDist(ScoreDist):
    def __init__(self, name='Per-Card Max-Score Distribution (full score range)', bins=range(0,1001,50), 
                plotXRange=range(0,1001,200), plotYRange=None
                ): 

        if bins[0]>0: 
            plotInverseCum = False
            if name==None:
                name = 'Per-Card Max-Score Distribution (full score range)'
            description = ["N.B.: bin is the lower bin boundary", \
                           "The max-score-per-card distribution."
                          ]
        else:
            plotInverseCum = True
            plotYRange = range(0,101, 20)
            description = ["N.B.: bin is the lower bin boundary", \
                           "The max-score-per-card distribution for the high score and low outsort range."
                          ]

        ScoreDist.__init__(self, name=name, bins=bins, description=description, 
                plotXRange=plotXRange, plotYRange=plotYRange, plotYLabel='Card Outsort Rate (%)')

    def binField(self, spark_df):
        ScoreDist.binField(self, spark_df, card_max=True)

In [9]:
class AmountDist(GoLiveReport):
    def __init__(self, name='Amount Bin', bins=None, plotXLabel='Amount (Base Currency)', 
                 amountField='BASE_CURR_AMOUNT'                
                ):
        if amountField=='BASE_CURR_AMOUNT':
            description = ["N.B.: bin is the lower bin boundary", \
                           "The binned amount in the Base Currency of the portfolio."]
            plotXLabel = "Amount (Base Currency)"
        else:
            description = ["N.B.: bin is the lower bin boundary", \
                           "The binned amount in USD."]
            plotXLabel = "Amount (USD)"
            
        self.amountField = amountField
        
        GoLiveReport.__init__(self, name, bins=bins, description=description, decimals=2, plot=True, plotXLog=True, 
                              plotXLabel=plotXLabel, plotYLabel='Percent of Trans (%)')
        
    #transform the amount to desired log scale
    @staticmethod
    def log_range(min, max, nval):
        r = None
        r = range (0, nval)
        r = [math.exp(math.log(min) + x*(math.log(max)-math.log(min))/(nval-1)) for x in r]
        return r
    
    @staticmethod
    def amount_transform_udf(rng):
        def amount_transform(amt, rng):
            if amt<0.1:
                return float("-inf")
            else:
                for i in range(len(rng)-1):
                    if amt>=rng[i] and amt<rng[i+1]:
                        return rng[i]
            if amt>=rng[-1]:
                return rng[-1]
            else:
                return float("-inf")
            
        return F.udf(lambda amt: amount_transform(amt, rng), FloatType())  
    
    def binField(self, spark_df):
        rng = self.log_range(0.1, 100000, 25)
        if self.bin_lower_bound==None:            
            self.bin_lower_bound = list(set([float("-inf")] + rng))
        self.bin_lower_bound.sort()
        amount_transform_udf = self.amount_transform_udf(rng)
        
        tmp_list = spark_df.select(self.amountField) \
                           .withColumn('AMOUNT', amount_transform_udf(F.col(self.amountField))) \
                           .groupby('AMOUNT').count().rdd.map(lambda x: [x[0], x[1]]).collect()  
                
        #need to deal with the precision difference between python and spark
        self.bin_lower_bound = list(np.round(self.bin_lower_bound, 2))
        tmp_list = [[np.round(x[0],2), x[1]] for x in tmp_list]
        
        self.value_dist = dict(tmp_list)
        
        idx = self.bin_lower_bound.index(max(self.value_dist))
        self.bin_lower_bound = self.bin_lower_bound[:idx+1]        

In [10]:
class ErrorDist(GoLiveReport):
    def __init__(self, name='Error and Warnings'):
        
        GoLiveReport.__init__(self, name, colNames=['Value', 'Count', 'Percent'], 
                              tableColPercent=[60, 20, 20], plot=False)
        
    #define a udf to parse error code
    @staticmethod
    def error_parse_udf():
        def error_parse(s):
            errorStr = re.compile(r"^error", re.I)
            warnStr = re.compile(r"^warn", re.I)
            timeStr = re.compile(r"Timeout", re.I)

            if len(s)>0:
                if re.match(errorStr, s):
                    return 'Error'
                elif re.match(timeStr, s):
                    return 'Error'
                elif not re.match(warnStr, s):
                    return 'Other'
                else:
                    return 'Warning'
            else:
                return 'Clean'
        return F.udf(error_parse, StringType()) 
    
    def get_error_stat(self, spark_df):
        error_parse_udf = self.error_parse_udf()
        spark_df = spark_df.select('MODEL_ERROR').withColumn('ERROR_TYPE', error_parse_udf(F.col('MODEL_ERROR')))

        tmp = spark_df.groupby('ERROR_TYPE').agg(F.count('ERROR_TYPE').alias('ERR_COUNT'))
        try:
            nErrors = tmp.filter(F.col('ERROR_TYPE')=='Error').rdd.map(lambda x: x[1]).collect()[0]
        except:
            nErrors = 0
        try:
            nWarnings = tmp.filter(F.col('ERROR_TYPE')=='Warning').rdd.map(lambda x: x[1]).collect()[0]
        except:
            nWarnings = 0
        try:
            nOther = tmp.filter(F.col('ERROR_TYPE')=='Other').rdd.map(lambda x: x[1]).collect()[0]
        except:
            nOther = 0
        try:
            nClean = tmp.filter(F.col('ERROR_TYPE')=='Clean').rdd.map(lambda x: x[1]).collect()[0]
        except:
            nClean = 0
        
        total = nErrors + nWarnings + nOther
        nCalls = total + nClean
        
        #Need to pass the total to super class
        self.nCalls = nCalls
          
        self.description = ["Total errors or warnings: {0:,d}".format (total),
                           "Total calls: {0:,d}".format (nCalls),
                           "Transactions with errors:  \t {0:13,d}".format (nErrors),
                           "Transactions with warnings:\t {0:13,d}".format (nWarnings),
                           "Transactions with other:   \t {0:13,d}".format (nOther), 
                           "Basis-points of errors and warnings: {0:7.2f}".format (100*100*(1. - nClean*1./nCalls))
                           ]
    
    def binField(self, spark_df):
        self.get_error_stat(spark_df)
        GoLiveReport.binField(self, spark_df.filter(F.col('MODEL_ERROR')!=''), 'MODEL_ERROR')

In [11]:
class BaseCurrDist(GoLiveReport):
    def __init__(self, name='BaseCurrCode'):
        colNames = ["Value", "Count", "Percent", "Cumulative", "Cum_Pct"]
        GoLiveReport.__init__(self, name, decimals=0, colNames=colNames, tableColPercent=[36], plot=False)    
    
    def binField(self, spark_df):
        self.bin_lower_bound = spark_df.select('BASE_CURR_CODE').distinct().collect()
        GoLiveReport.binField(self, spark_df, 'BASE_CURR_CODE')

In [12]:
class GroupDist(GoLiveReport):
    def __init__(self, name='Group'):
        colNames = ["Value", "Count", "Percent", "Cumulative", "Cum_Pct"]
        GoLiveReport.__init__(self, name, colNames=colNames, tableColPercent=[36], plot=False)
    
    def binField(self, spark_df):
        GoLiveReport.binField(self, spark_df, 'ORGNAME')       

In [13]:
class ModelDist(GoLiveReport):
    def __init__(self, name='Model'):
        colNames = ["Value", "Count", "Percent", "Cumulative", "Cum_Pct"]
        GoLiveReport.__init__(self, name, colNames=colNames, tableColPercent=[36], plot=False)
        
    def binField(self, spark_df):        
        spark_df = spark_df.withColumn('MODEL', F.concat(F.col('MODEL_ID'), F.lit(' v'), F.col('MODEL_VER'))) \
                           .withColumn('MODEL', F.when(F.col('MODEL')==' v', 'N/A').otherwise(F.col('MODEL')))
        GoLiveReport.binField(self, spark_df, 'MODEL')

In [14]:
class ScoreTypeDist(GoLiveReport):
    def __init__(self, name='Scoring Type'):
        colNames = ["Value", "Count", "Percent", "Cumulative", "Cum_Pct"] 
        GoLiveReport.__init__(self, name, colNames=colNames, plot=False)
    
    def binField(self, spark_df):
        spark_df = spark_df.withColumn('SCORETYPE', F.when((F.col('PREDICTIVE_SCORE').isNull()) & (F.col('MODEL_PRIMING_SCORE')==''), \
                                            F.lit('Blank')).when(F.col('PREDICTIVE_SCORE')==-999, F.lit('Priming')) \
                                           .otherwise(F.lit('Normal')))
        GoLiveReport.binField(self, spark_df, 'SCORETYPE')

In [15]:
#spark.executor.instances: num-executor; also look at spark.dynamicAllocation.initialExecutors
#spark.executor.cores: executor-cores
#spark.executor.memory: executor-memory
#spark.submit.deployMode: deploy-mode
spark = SparkSession.builder \
            .master("yarn") \
            .appName("linhai_golive_rept") \
            .config("spark.executor.instances", 4) \
            .config("spark.executor.cores",4) \
            .config("spark.executor.memory",'4G') \
            .config("spark.submit.deployMode",'client') \
            .config("spark.yarn.executor.memoryOverhead", '2G') \
            .getOrCreate()
            

In [16]:
start_time = datetime.today().time()

In [17]:
job_filepath = '/home/lijli06/Work/DFM_SPARK/golive_report/job_config.json'
glr = GoLiveReport('Super_Metric')
glr.load_job_config(job_filepath)
glr.validate_date_range()

target_fields = ['ORGNAME', 'HASHCN', 'PREDICTIVE_SCORE', 'MODEL_PRIMING_SCORE', 'DATE', 'AMOUNTUSD','TXNSTATUS', \
                 'DATELOGGED', 'BASE_CURR_CODE', 'BASE_CURR_AMOUNT', 'MODEL_ERROR', 'MODEL_ID', 'MODEL_VER']
all_df = glr.check_n_load_data(spark, target_fields)
all_df.cache()

Start date: 2018-05-11, End date: 2018-05-11 with interval of 1 days


DataFrame[ORGNAME: string, HASHCN: string, PREDICTIVE_SCORE: bigint, MODEL_PRIMING_SCORE: string, DATE: date, AMOUNTUSD: double, TXNSTATUS: string, DATELOGGED: string, BASE_CURR_CODE: bigint, BASE_CURR_AMOUNT: double, MODEL_ERROR: string, MODEL_ID: string, MODEL_VER: string]

In [18]:
#generate report for each org
currdate = datetime.today().date().strftime("%B %d, %Y")

for numIssGrp in range(len(glr.issuer_groups)):
    for orgName in glr.issuer_sets[numIssGrp]:
        rtfReport = RTFReport.RTFReport(glr.issuer_groups[numIssGrp]+'_'+orgName)
        if (orgName.upper() == "ALL_BANKS"):
            rtfReport.AddTitle ('Report for All Organizations')
        else:
            rtfReport.AddTitle ('Report for Organization {0:s}'.format(orgName))
        rtfReport.AddSubTitle ('Prepared on ' + currdate)
        if orgName.upper() == 'ALL_BANKS':
            rtfReport.AddParagraph ('This report documents the performance of all portfolios, examining all transactions in Risk Analytics which are eligible to receive a score.')
        else:
            rtfReport.AddParagraph ('This report documents the performance of the {0:s} portfolio, examining all transactions in Risk Analytics which are eligible to receive a score.'.format(orgName))
            banks = [orgName]
        
        #obtain the data for the specified orgs in the group
        tmp_df = all_df.filter(sal_process_orgname_udf(F.col("ORGNAME")).isin(banks))
        #numParts = tmp_df.rdd.getNumPartitions()
        #print("Number of Partitions: " + str(numParts)) 
        #tmp_df = tmp_df.repartition(int(2*numParts))
        
        #Get the statistics for the table of dates
        dateDist = TemporalDist(name='Dates', start_date=glr.start_date, plot=False,
                                dateInterval=glr.dateInterval, dateUnit=glr.dateUnit);
        dateDist.binField(tmp_df)
        
        #Get the statistics for score type
        scoreTypeDist = ScoreTypeDist()
        scoreTypeDist.binField(tmp_df)
       
        #Get the score distribution
        scoreDist = ScoreDist()
        scoreDist.binField(tmp_df)

        #Get the high score distribution
        highScoreDist = ScoreDist(name="Score (low outsort range)", bins=range(400,1001,10),
                         plotXRange=range(400,1001,100), plotYRange=range(0,6,1))
        highScoreDist.binField(tmp_df)
        
        #Get the number of transactions per card
        numTxnDist = NumTxnDist()
        numTxnDist.binField(tmp_df)
        
        #Get the Per-Card Max-Score Distribution (full score range)
        maxScoreDist = CardScoreDist()
        maxScoreDist.binField(tmp_df)

        #Get the Per-Card Max-Score Distribution (high score range) 
        maxHighScoreDist = CardScoreDist(name='Per-Card Max-Score Distribution (low outsort range)', 
                                           bins=range(400,1001,10), plotXRange=range(400,1001,100))
        maxHighScoreDist.binField(tmp_df)
        
        #Get the statistics of errors and warnings
        errorDist = ErrorDist()
        errorDist.binField(tmp_df)
        
        #Get the statistics of BaseCurrCode
        baseCurrDist = BaseCurrDist()
        baseCurrDist.binField(tmp_df)

        #Get the statistics of amount
        if orgName.upper()=='ALL_BANKS':
            amountField = 'AMOUNTUSD'
        else:
            amountField = 'BASE_CURR_AMOUNT'
        amountDist = AmountDist(amountField=amountField)
        amountDist.binField(tmp_df)
        
        #Get the hour of the day statistics       
        hourDist = TemporalDist(name='Hour of Day', plot=True, plotXLabel='Hour of Day')
        hourDist.binField(tmp_df)
        
        #Get the statistics of the groups
        groupDist = GroupDist()
        groupDist.binField(tmp_df)
        
        #Get the statistics of model
        modelDist = ModelDist()
        modelDist.binField(tmp_df)
        
        dateDist.rtfOutput(rtfReport)
        scoreTypeDist.rtfOutput(rtfReport)
        scoreDist.rtfOutput(rtfReport)
        highScoreDist.rtfOutput(rtfReport)
        numTxnDist.rtfOutput(rtfReport)
        maxScoreDist.rtfOutput(rtfReport)
        maxHighScoreDist.rtfOutput(rtfReport)
        errorDist.rtfOutput(rtfReport)
        baseCurrDist.rtfOutput(rtfReport)
        amountDist.rtfOutput(rtfReport)
        hourDist.rtfOutput(rtfReport)
        groupDist.rtfOutput(rtfReport)
        modelDist.rtfOutput(rtfReport)
        
        start_date = glr.start_date
        end_date = glr.end_date
        if start_date==end_date:
            filename = 'Report_{0:s}_{1:s}.rtf'.format(start_date.strftime("%Y%m%d"), orgName.strip())
        else:
            filename = 'Report_{0:s}_{1:s}_{2:s}.rtf'.format(start_date.strftime("%Y%m%d"), end_date.strftime("%Y%m%d"), orgName.strip())
                
        rtfReport.Output(filename=filename)



In [19]:
all_df.unpersist()

DataFrame[ORGNAME: string, HASHCN: string, PREDICTIVE_SCORE: bigint, MODEL_PRIMING_SCORE: string, DATE: date, AMOUNTUSD: double, TXNSTATUS: string, DATELOGGED: string, BASE_CURR_CODE: bigint, BASE_CURR_AMOUNT: double, MODEL_ERROR: string, MODEL_ID: string, MODEL_VER: string]

In [20]:
end_time = datetime.today().time()
elapse_time = datetime.combine(datetime.today(),end_time) - datetime.combine(datetime.today(), start_time)
print('Take total {:.2f} seconds to finish'.format(elapse_time.total_seconds()))

Take total 65.71 seconds to finish


In [21]:
df = spark.createDataFrame([('a', 'aaa'), ('a', 'bbb'), ('a', 'aaa'), ('b', 'aaa'), ('b', 'bbb'), ('b', 'bbb')], ['x', 'y'])
df.show()

In [22]:
df.groupby('x').count().show()

In [23]:
df.groupby(['x', 'y']).count().show()

In [24]:
spark.stop()