In [2]:
# -------------------------------------------------------------------------
# MODIFY WITH CARE
# Standard libraries to be used in AWS Glue jobs
# -------------------------------------------------------------------------

import sys
from awsglue.transforms import *
from awsglue.utils import getResolvedOptions
from pyspark.context import SparkContext
from awsglue.context import GlueContext
from awsglue.job import Job
from pyspark.sql import functions as f
from pyspark.sql.types import *
from pyspark.sql.types import StructType, StructField, IntegerType, StringType, array, ArrayType, DateType
from pyspark.sql import Row, Column
import datetime
import json
import boto3
import logging
import calendar
import uuid
import time
from dateutil import relativedelta

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [26]:
def get_partition():
    return str(datetime.datetime.now().date())

def get_raw_table_name(tables, i):
    location=tables['TableList'][i]['StorageDescriptor']['Location']
    return location[location.rindex(".")+1: len(location) ]

def get_raw_table_columns(client, catalog_database, catalog_table):
    table = client.get_table(DatabaseName=catalog_database,Name=catalog_table)
    return table['Table']['StorageDescriptor']['Columns']

def handle_error(spark, log, message, region_name, log_bucket, job_log_dir, job_log_location, partition, sns_arn):
    '''
        function: handle_error
        description: Stops the program, logs the event, sends a SNS notification.

        Args:
            spark: Spark session object.
            log: List of logs.
            message: Exception that has happened.
            region_name: Name of the AWS region.
            log_bucket: Bucket name for logs.
            job_log_dir: Key of logs directory.
            job_log_location: Location of logs.
            partition: Date by which logs are paritioned.
            sns_arn: ARN of SNS topic.
    '''
    append_log(log, str(datetime.datetime.now()), str(message), 'FAILURE')
    log_df=create_log_df(spark, log)
    log_df.show()
    write_s3_file(log_df, job_log_location, '', partition, format='CSV')
    job_notification_sns(region_name, log_bucket, job_log_dir, partition, sns_arn)
    raise message

def create_crawler(client, crawler_name, iam_role_arn, database_name):
    return client.create_crawler(
        Name=crawler_name,
        Role=iam_role_arn,
        DatabaseName=database_name,
        Targets={
            'S3Targets':[
                {'Path':'s3://bucket/placeholder'}
            ]}
        )

def read_table_from_catalog(glueContext, database, table_name):
    '''
    function: read_table_from_catalog
    description: Reads a table using Glue catalog.

    Args:
        glueContext: GlueContext class object.
        database: Glue catalog database name.
        table_name: Glue catalog table name.
    
    Returns:
        spark.sql.dataframe: Returns a object of Spark Dataframe containing data of the table coming from Glue catalog.
    '''
    return glueContext.create_dynamic_frame.from_catalog(
             database=database,
             table_name=table_name).toDF()
 
def write_s3_file(df, table_location, table, partition=None, uid=None, format='PARQUET', delimiter='\t', coalesce=1, header=False):
    '''
    function: write_s3_file
    description: Write a file of either Parquet or CSV format to S3.

    Args:
        df: spark.sql.dataframe to be written to S3.
        table_location: Location of where the file should be stored.
        table: Name of the table.
        partition: Date by which the file should be stored in S3.
        format: Format in which the file should be stored. (PARQUET is default)
        delimited: How to file should be delimited. (Applicable on CSV files only)
        coalesce: Number of spark.sql.dataframe partitions.
    '''
    try:
        if format == 'PARQUET':
            if uid is None:
                df.write.parquet(table_location+'/'+table+'/'+partition)
            else:
                df.write.parquet(table_location+'/'+table+'/'+uid+'/'+partition)
        if format == 'CSV':
            if partition is None:
                df.coalesce(coalesce).write.option("delimiter", delimiter).option("header", "true").option("quoteAll", "true").option("quote", "\"").csv(table_location+ '/' + uid + '/' + table)
            elif uid is None:
                df.coalesce(coalesce).write.option("delimiter", delimiter).option("quote", "\"").option("quoteAll", "true").csv(table_location +'/' + partition)
            else:
                df.coalesce(coalesce).write.option("delimiter", delimiter).option("quote", "\"").option("quoteAll", "true").csv(table_location +'/' + uid + '/' + partition)
    except Exception as e:
        raise Exception(e)

def compare_row_count(spark, glue, glueContext, source_database, destination_database):
    source_tables = glue.get_tables(DatabaseName=source_database)
    row_count_list = []
    
    for i in range(1,len(source_tables['TableList'])):
    #for i in range(1,2):
        source_table = source_tables['TableList'][i]['Name']  
        destination_table = get_raw_table_name(source_tables, i)

        source_df = read_table_from_catalog(glueContext, database=source_database, table_name=source_table)
        source_row_count=source_df.count()
        #print(source_table+'  '+destination_table+'  '+str(source_row_count))
        try:
            if (source_row_count > 0):
                destination = glue.get_partitions(DatabaseName=destination_database, TableName=destination_table)
                destination_df = spark.read.load(destination['Partitions'][0]['StorageDescriptor']['Location'])
                
                for partition in destination['Partitions'][1:]:
                    desintation_location = partition['StorageDescriptor']['Location']
                    destination_df = destination_df.union(spark.read.load(desintation_location))

                destination_row_count=destination_df.count()
              
                if source_row_count == destination_row_count:
                    matches = 'Rows count match'
                else:
                    matches = '**** Rows count mismatch ****'
                    
                row_count_list.append([source_table, source_row_count, destination_table, destination_row_count, matches])
        except Exception as e:
            print(e, source_table, destination_table)
    row_count_df = spark.createDataFrame(row_count_list, schema=['Source', 'Source row count', 'Destination', 'Destination row count', 'Status'])
    return row_count_df

def append_path_to_list(list, location, table_name):
    list.append({'Path': location + '/' + table_name})
    
def update_crawler(client, crawler_name, s3targets):
    client.update_crawler(
        Name=crawler_name,
        Targets = {'S3Targets':s3targets}
            
    )
    
def start_crawler(client, crawler_name):
    print(crawler_name + ' started.')
    
    # Getting PRE-RUN READY status.
    while(True):
        time.sleep(1)
        response = client.get_crawler(
                        Name=crawler_name
                   )
        
        if response['Crawler']['State'] == 'READY':
            print(response['Crawler']['State'])
            break
            
    client.start_crawler(
        Name=crawler_name
    )
    
    # Getting RUNNING status for stdout.            
    while(True):
        time.sleep(15)
        response = client.get_crawler(
                        Name=crawler_name
                   )
        
        if response['Crawler']['State'] == 'RUNNING':
            print(response['Crawler']['State'])
            break
        
    # Getting STOPPING status for stdout.
    while(True):
        time.sleep(1)
        response = client.get_crawler(
                        Name=crawler_name
                   )
        
        if response['Crawler']['State'] == 'STOPPING':
            print(response['Crawler']['State'])
            break
    
   # Getting READY status.
    while(True):
        time.sleep(1)
        response = client.get_crawler(
                        Name=crawler_name
                   )
        
        if response['Crawler']['State'] == 'READY':
            print(response['Crawler']['State'])
            break

def delete_crawler(client, crawler_name):
    # Getting READY status before deleting making sure it won't delete a running crawler.
    while(True):
        time.sleep(1)
        response = client.get_crawler(
                        Name=crawler_name
                   )
        
        if response['Crawler']['State'] == 'READY':
            print(response['Crawler']['State'])
            break
    
    client.delete_crawler(
        Name=crawler_name
    )
    
    print(crawler_name + ' deleted.')

def append_log(log, logdate, id, message):
    '''
    function: append_log
    description: Appends to the log.

    Args:
        log: spark.sql.dataframe of logs.
        logdate: Date on which the log was created.
        id: Id to identify the logs.
        message: Message of the log.

    returns:
        spark.sql.dataframe: Returns a dataframe containing newly appended log.
    '''
    return log.append(Row(logdate,id, message))

def delete_catalog_table(client, database, table):
    try:
        response = client.delete_table(DatabaseName=database,Name=table)
    except Exception as e:
        print(table+' does not exist in glue catalog')

def job_notification_sns(region_name, log_bucket, job_log_dir, partition, sns_arn):
    '''
    function: format_log_date
    description: Returns current timestamp (date and time) for logs in Amazon Athena format.

    Args:
        region_name: AWS Region in which the SNS and S3.
        log_bucket: Name of the log bucket.
        job_log_dir: Folder in which the logs will be placed.
        partition: Folder (named as date of log creation) by which the logs will be partitioned.
        sns_arn: ARN of SNS topic to which the notificaions is to be published.
    '''

    sns = boto3.client('sns', region_name=region_name)
    s3 = boto3.client('s3' , region_name=region_name)
    s3r = boto3.resource('s3' , region_name=region_name)

    response = s3.list_objects(Bucket = log_bucket, Prefix = job_log_dir+'/'+partition)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [15]:

# -------------------------------------------------------------------------
#   Turn DEBUG ON if temporary tables are desired. DEBUG creates additional temporary tables for every step in the script.
#    DEBUG: If 'True', the script will store all dataframes, including the temporary ones.
#    INTERACTIVE: Set 'True' for Jupyter Notebook.
#                 Set 'False' for AWS Glue ETL job.
# -------------------------------------------------------------------------

INTERACTIVE=True
DEBUG=True  

# -------------------------------------------------------------------------
# MODIFY AS PER JOB
#    These are the variables that drive the whole job.
#    Modify each variable for each job accordingly.
# -------------------------------------------------------------------------

# Variables for GLue Tables
JOB_NAME='ingest_raw_fossil'
REGION_NAME='us-east-1'
UID=uuid.uuid4().hex
PARTITION="dt="+get_partition()

# Variables for Job Log
LOG_BUCKET='ds-operations-1111-raw'
JOB_LOG_DIR='log/'+JOB_NAME
JOB_LOG_LOCATION='s3://'+LOG_BUCKET+'/'+JOB_LOG_DIR
SNS_ARN='arn:aws:sns:us-east-1:175908995626:monitor-alert'

# Variables for RAW Table Location on S3 and Glue Catalog Database
CATALOG_DATABASE='fossil'
RAW_DATABASE='raw_fossil'
RAW_BUCKET='ds-operations-1111-raw'
RAW_TAB_LOCATION='s3://'+RAW_BUCKET+'/'+CATALOG_DATABASE+'/'+UID


# Variables for crawler.
CRAWLER_NAME=JOB_NAME+'_'+UID
CRAWLER_ARN='arn:aws:iam::175908995626:role/glue-role'

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [6]:
# -------------------------------------------------------------------------
# MODIFY WITH CARE
# -------------------------------------------------------------------------

sc = SparkContext.getOrCreate()
glueContext = GlueContext(sc)

if not INTERACTIVE:
    spark = glueContext.spark_session
    
client = boto3.client('glue', region_name=REGION_NAME)
log = []

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [18]:
# -------------------------------------------------------------------------
# Script to fetch data from the EDW RDS instance and save results to S3 
# -------------------------------------------------------------------------

try:   
    first = True
    i=0
    s3pathlist=[]
    while True:
        if first:
            tables = client.get_tables(DatabaseName=CATALOG_DATABASE)
            first = False
        else:
            tables = client.get_tables(DatabaseName=CATALOG_DATABASE, NextToken=tables['NextToken'])
        
        for i in range(0, len(tables['TableList'])):
            raw_table=get_raw_table_name(tables, i)
            catalog_table=tables['TableList'][i]['Name']
            columns=get_raw_table_columns(client, CATALOG_DATABASE, catalog_table)
            table_df = read_table_from_catalog(glueContext, database=CATALOG_DATABASE, table_name=catalog_table)
            columns_df=table_df.columns
            row_count=table_df.count()
            #print(catalog_table+'  '+str(row_count))
        
            for column in columns_df:
                table_df=table_df.withColumnRenamed(column, column.replace(' ', '_'))
 
            if (row_count > 0):
                write_s3_file(table_df, RAW_TAB_LOCATION, raw_table, PARTITION)
                append_path_to_list(s3pathlist, RAW_TAB_LOCATION, raw_table)  
                append_log(log, str(datetime.datetime.now()), raw_table+' - Row Count: ', row_count)
            else:
                append_log(log, str(datetime.datetime.now()), raw_table+' - Table Ignored: ','0 rows')
    
    
        if 'NextToken' not in tables.keys():
            break
    print(s3pathlist)
        
except Exception as e:
    if INTERACTIVE:
        raise e
    else:
        handle_error(spark, log, e, REGION_NAME, LOG_BUCKET, JOB_LOG_DIR+'/'+UID, JOB_LOG_LOCATION, PARTITION, SNS_ARN)  

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

[{'Path': 's3://ds-operations-1111-raw/fossil/d72ac211a70b43c58e9da4fa46e74f65/coal_prod'}, {'Path': 's3://ds-operations-1111-raw/fossil/d72ac211a70b43c58e9da4fa46e74f65/fossil_capita'}, {'Path': 's3://ds-operations-1111-raw/fossil/d72ac211a70b43c58e9da4fa46e74f65/gas_prod'}, {'Path': 's3://ds-operations-1111-raw/fossil/d72ac211a70b43c58e9da4fa46e74f65/oil_prod'}]

In [21]:
# -------------------------------------------------------------------------
# Script to delete previously catalogued tables 
# -------------------------------------------------------------------------

try:
    first = True
    while True:
        if first:
            tables = client.get_tables(DatabaseName=CATALOG_DATABASE)
            first = False
        else:
            tables = client.get_tables(DatabaseName=CATALOG_DATABASE, NextToken=tables['NextToken'])
        
        for i in range(0, len(tables['TableList'])):
            raw_table=get_raw_table_name(tables, i)
            delete_catalog_table(client, RAW_DATABASE, raw_table)
    
        if 'NextToken' not in tables.keys():
            break
except Exception as e:
    if INTERACTIVE:
        raise e
    else:
        handle_error(spark, log, e, REGION_NAME, LOG_BUCKET, JOB_LOG_DIR+'/'+UID, JOB_LOG_LOCATION, PARTITION, SNS_ARN) 

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

coal_prod does not exist in glue catalog
fossil_capita does not exist in glue catalog
gas_prod does not exist in glue catalog
oil_prod does not exist in glue catalog

In [22]:
# -------------------------------------------------------------------------
# Crawl tables
# -------------------------------------------------------------------------

try:
    create_crawler(client, CRAWLER_NAME, CRAWLER_ARN, RAW_DATABASE)
    update_crawler(client, CRAWLER_NAME, s3pathlist)
    start_crawler(client, CRAWLER_NAME)
    delete_crawler(client, CRAWLER_NAME)
    append_log(log, str(datetime.datetime.now()), 'Crawling of AMS tables : ','SUCCESS')
except Exception as e:
    if INTERACTIVE:
        delete_crawler(client, CRAWLER_NAME)
        raise e
    else:
        delete_crawler(client, CRAWLER_NAME)
        handle_error(spark, log, e, REGION_NAME, LOG_BUCKET, JOB_LOG_DIR+'/'+UID, JOB_LOG_LOCATION, PARTITION, SNS_ARN)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

{'ResponseMetadata': {'RequestId': '511a17b7-e47d-11e9-bf61-29ce2eb39ff3', 'HTTPStatusCode': 200, 'HTTPHeaders': {'date': 'Tue, 01 Oct 2019 18:57:29 GMT', 'content-type': 'application/x-amz-json-1.1', 'content-length': '2', 'connection': 'keep-alive', 'x-amzn-requestid': '511a17b7-e47d-11e9-bf61-29ce2eb39ff3'}, 'RetryAttempts': 0}}
ingest_raw_fossil_d72ac211a70b43c58e9da4fa46e74f65 started.
READY
RUNNING
STOPPING
READY
READY
ingest_raw_fossil_d72ac211a70b43c58e9da4fa46e74f65 deleted.

In [23]:
# -------------------------------------------------------------------------
# Perform count comparisons
# -------------------------------------------------------------------------

compare_df=compare_row_count(spark, client, glueContext, CATALOG_DATABASE, RAW_DATABASE)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [28]:
# -------------------------------------------------------------------------
# save logs and send logs to SNS topic for distribution
# -------------------------------------------------------------------------

write_s3_file(compare_df, JOB_LOG_LOCATION, PARTITION, coalesce=1, delimiter=',',format='CSV', header=True, uid=UID)
job_notification_sns(REGION_NAME, LOG_BUCKET, JOB_LOG_DIR+'/'+UID, PARTITION, SNS_ARN)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…