In [1]:
import os
import json
import yaml
import datetime
from pyspark.sql import SparkSession
from pyspark.sql import functions as func
from pyspark.sql.types import StructType,StructField,DateType,TimestampType,StringType,IntegerType,DecimalType
from azure.storage.blob import BlobServiceClient
from pyspark.sql.functions import year, month, dayofmonth,to_timestamp,to_date,split,substring,col,when
from pyspark.sql.functions import udf
import logging

In [2]:
class sparkExecutionLogger:
    
    def __init__(self):
        # Adding information related to logger
        self.logger = logging.getLogger('dev')
        self.logger.setLevel(logging.INFO)
        self.fileHandler = logging.FileHandler('spark_log.log')
        self.fileHandler.setLevel(logging.INFO)
        self.logger.addHandler(self.fileHandler)
        self.formatter = logging.Formatter('%(asctime)s  %(name)s  %(levelname)s: %(message)s')
        self.fileHandler.setFormatter(self.formatter)
        
    def logMessage(self,info):
        self.logger.info(info)

class pipelineStage0:
    
    def __init__(self,configFile):
        self.configFile = configFile
        self.log = sparkExecutionLogger()
        
    def getConfig(self):
        
        try:            
            self.log.logMessage("Fetching configuration values from config file")
            
            with open(self.configFile) as f:
                azure_blob_conn_values = yaml.load(f, Loader=yaml.FullLoader) 
            
            storage_acct_name = azure_blob_conn_values['storage_account_name'].strip()
            storage_acct_access_key = azure_blob_conn_values['storage_account_access_key'].strip()
            storage_cont_name = azure_blob_conn_values['storage_container_name'].strip()
            blob_output_dir = azure_blob_conn_values['blob_output_dir'].strip()
            blob_conn_str = azure_blob_conn_values['blob_conn_string'].strip() 
        
            blob_base_path =  "wasbs://" + storage_cont_name + "@" + storage_acct_name + ".blob.core.windows.net/"
            
            return storage_acct_name,storage_acct_access_key,\
                    storage_cont_name,blob_output_dir,blob_base_path,blob_conn_str
    
        except Exception as e:        
            self.log.logMessage("Exception " + str(e))
            return False

class pipelineStage1:
    
    def __init__(self,contName,blobConnStr):
        self.contName = contName
        self.blobConnStr = blobConnStr
        self.log = sparkExecutionLogger()        
    
    def ls_files(self,client,path,recursive=False):
        
        if not path == '' and not path.endswith('/'):
            path += '/'

        blob_iter = client.list_blobs(name_starts_with=path)
        
        files = []
        
        for blob in blob_iter:
            relative_path = os.path.relpath(blob.name, path)
            if recursive or not '/' in relative_path:
                files.append(relative_path)
        return files
    

    def getFileList(self,fileType,fileFilter):
        
        try:
            
            self.log.logMessage("Fetching list of files from the blob container")
            
            blob_service_client = BlobServiceClient.from_connection_string(self.blobConnStr)
            client = blob_service_client.get_container_client(self.contName)

            files = self.ls_files(client, '', recursive=True)

            file_list = []

            for file in files:
                if file[-4:] == fileType and fileFilter in file:
                    file_list.append(str(file))        
        
            return file_list
        
        except Exception as e:        
            self.log.logMessage("Exception " + str(e))
            return False    
    
class pipelineStage2:
    
    def __init__(self,sparkSession):
        self.spark = sparkSession
        self.log = sparkExecutionLogger()        
        
    def processData(self,filepath):
        
        try:
            
            self.log.logMessage("Processing file " + filepath)
            
            df = self.spark.read.format("csv").option("header","true").option("inferSchema", "true").load(filepath)
            df.na.fill('Guest',subset = ["Customer ID"])
            df.na.fill('Unlisted',subset = ["Description"])
            split_col = split(df['InvoiceDate'], '/')
            df = df.withColumn("Year",substring(split_col.getItem(2),1,4))
            df = df.withColumn('month',split_col.getItem(0).cast("int"))
            df = df.withColumn('Qtr',(when( (col("month") == 1) | (col("month") == 2) | (col("month") == 3),"Qtr1")
                                        .when( (col("month") == 4) | (col("month") == 5) | (col("month") == 6),"Qtr2")
                                        .when( (col("month") == 7) | (col("month") == 8) | (col("month") == 9) ,"Qtr3")
                                        .otherwise("Qtr4")))
    
            df = df.withColumn('InvoiceType',(when( (col("Quantity") <= 0),"Return")
                                                .otherwise("Purchase")))        
        
            return df
        
        except Exception as e:
            self.log.logMessage("Exception " + str(e))
            return False        
    
    def splitFiles(self,df,filepath):
        
        try:
            
            self.log.logMessage("Split files " + filepath)
            
            df_uk = df.filter(df.Country == 'United Kingdom')
            df_others = df.filter(df.Country != 'United Kingdom')        
        
            return df_uk,df_others 
        
        except Exception as e:
            self.log.logMessage("Exception " + str(e))
            return False          
        
    def writeFiles(self,outPath,df):
        try:        
            self.log.logMessage("Writing file to blob output directory " + outPath)            
            df.write.mode("append").csv(outPath)
            return True
        
        except Exception as e:            
            self.log.logMessage("Exception " + str(e))
            return False

In [3]:
def main():
    
    log = sparkExecutionLogger()
    
    try:
        fileType = ".csv"
        fileFilter = 'input'
        configFile = 'config.yml'
        
        # Getting all the config values from the config file
        stg0 = pipelineStage0(configFile)
        out0 = stg0.getConfig()
    
        if  out0 != False:    
            acct_name,acct_key,cont_name,blob_output,blob_path,blob_conn = out0
            
            # Getting the list of new files in the input folder
            stg1 = pipelineStage1(cont_name,blob_conn)
            out1 = stg1.getFileList(fileType,fileFilter)
    
            if out1 != False:
                fileList = out1
                
                # Creating a new spark session
                spark = SparkSession.builder.appName("stockExchange").getOrCreate()
                key_str = "fs.azure.account.key." + acct_name + ".blob.core.windows.net"
                spark.conf.set(key_str,acct_key)
                
                stg2 = pipelineStage2(spark)
                outPath = blob_path + blob_output

                for file in fileList:
                    filepath =  blob_path + file 
                    
                    # Processing the file data
                    out2 = stg2.processData(filepath)
            
                    if out2 != False:
                        df = out2
                        
                        #Splitting the files
                        out3 = stg2.splitFiles(df,filepath)
                
                        if out3 != False:
                            df_uk,df_others = out3
                            
                            # Writing the split files to the output blob
                            if stg2.writeFiles(outPath,df_uk):
                                stg2.writeFiles(outPath,df_others)
    
                spark.stop()    
    except Exception as e:
        log.logMessage("Exception " + str(e))
        return False

In [4]:
if __name__ == "__main__":
    main()