In [87]:
import os
import json
import yaml
import datetime
from pyspark.sql import SparkSession
from pyspark.sql import functions as func
from pyspark.sql.types import StructType,StructField,DateType,TimestampType,StringType,IntegerType,DecimalType
from azure.storage.blob import BlobServiceClient

In [88]:
with open('config.yml') as f:
    azure_blob_conn_values = yaml.load(f, Loader=yaml.FullLoader)
            
storage_account_name = azure_blob_conn_values['storage_account_name'].strip()
storage_account_access_key = azure_blob_conn_values['storage_account_access_key'].strip()
storage_container_name = azure_blob_conn_values['storage_container_name'].strip()
blob_output_dir = azure_blob_conn_values['blob_output_dir'].strip()
blob_conn_string = azure_blob_conn_values['blob_conn_string'].strip()

blob_base_path =  "wasbs://" + storage_container_name + "@" + storage_account_name + ".blob.core.windows.net/"

In [89]:
# Get all the list of files from the blob container

def ls_files(client, path, recursive=False):
    '''
    List files under a path, optionally recursively
    '''
    if not path == '' and not path.endswith('/'):
        path += '/'

    blob_iter = client.list_blobs(name_starts_with=path)
    files = []
    for blob in blob_iter:
        relative_path = os.path.relpath(blob.name, path)
        if recursive or not '/' in relative_path:
            files.append(relative_path)
    return files

# Name of the container you are interested in.
container_name = storage_container_name

blob_service_client = BlobServiceClient.from_connection_string(blob_conn_string)
client = blob_service_client.get_container_client(storage_container_name)

files = ls_files(client, '', recursive=True)

csv_file_list = []
json_file_list = []

for file in files:
    if file[-4:] == ".txt" and "NYSE" in file:
        csv_file_list.append(str(file))
    elif file[-4:] == ".txt" and "NASDAQ" in file:
        json_file_list.append(str(file))

In [90]:
# Define the common schema
common_event = StructType([\
                    StructField("trade_dt", StringType(),True),\
                    StructField("arrival_tm", StringType(),True),\
                    StructField("rec_type", StringType(),True),\
                    StructField("symbol", StringType(),True),\
                    StructField("event_tm", StringType(),True),\
                    StructField("event_seq_nb", IntegerType(),True),\
                    StructField("trade_pr", StringType(),True),\
                    StructField("bid_pr", DecimalType(),True),\
                    StructField("bid_size", IntegerType(),True),\
                    StructField("ask_pr", DecimalType(),True),\
                    StructField("ask_size", IntegerType(),True),\
                    StructField("partition", StringType(),True)\
                          ])

In [91]:
import decimal

def parse_csv(line:str):
    
    record_type_pos = 2
    record = line.split(",") 
    
    try:        
        # [logic to parse records]
        if record[record_type_pos] == "T":
 
            event = [record[0],\
                     record[1],\
                     record[2],\
                     record[3],\
                     record[4],\
                     int(record[5]),\
                     record[6],\
                     decimal.Decimal(record[7]),\
                     int(record[8]),\
                     decimal.Decimal(record[9]),\
                     int(record[10]),\
                     "T"]
            return event
        elif record[record_type_pos] == "Q":
            event = [record[0],\
                     record[1],\
                     record[2],\
                     record[3],\
                     record[4],\
                     int(record[5]),\
                     record[6],\
                     decimal.Decimal(record[7]),\
                     int(record[8]),\
                     decimal.Decimal(record[9]),\
                     int(record[10]),\
                     "Q"]
            
            return event
    except Exception as e:
        event = [None,None,record[record_type_pos],None,None,None,None,None,None,None,None,"B"]
        return event
    
def parse_json(line:str):
    
    record = json.loads(line)
    record_type = record['event_type']  

    try:                           
        # [logic to parse records]
        if record_type == "T":            
            event = [record['trade_dt'],\
                     record['file_tm'],\
                     record['event_type'],\
                     record['symbol'],\
                     record['event_tm'],\
                     int(record['event_seq_nb']),\
                     record['exchange'],\
                     decimal.Decimal(record['bid_pr']),\
                     int(record['bid_size']),\
                     decimal.Decimal(record['ask_pr']),\
                     int(record['ask_size']),\
                     "T"]
            return event
        elif record_type == "Q":
            event = [record['trade_dt'],\
                     record['file_tm'],\
                     record['event_type'],\
                     record['symbol'],\
                     record['event_tm'],\
                     int(record['event_seq_nb']),\
                     record['exchange'],\
                     decimal.Decimal(record['bid_pr']),\
                     int(record['bid_size']),\
                     decimal.Decimal(record['ask_pr']),\
                     int(record['ask_size']),\
                     "Q"]
            return event
    except Exception as e:
        event = [None,None,record['event_type'],None,None,None,None,None,None,None,None,"B"]
        return event

In [92]:
# Loop through all the csv files for processing

#create a spark session
spark = SparkSession.builder.appName("stockExchange").getOrCreate()
key_str = "fs.azure.account.key." + storage_account_name + ".blob.core.windows.net"
spark.conf.set(key_str,storage_account_access_key)
sc = spark.sparkContext
    
for file in csv_file_list:
    
    filepath =  blob_base_path + file
        
    raw = sc.textFile(filepath)
    
    parsed = raw.map(lambda line: parse_csv(line))
    
    data = spark.createDataFrame(parsed,common_event)
    
    out_path = blob_base_path + blob_output_dir
    
    data.write.partitionBy("partition").mode("append").parquet(out_path)

spark.stop() 

In [93]:
# Loop through all the JSON files for processing

#create a spark session
spark = SparkSession.builder.appName("stockExchange").getOrCreate()
key_str = "fs.azure.account.key." + storage_account_name + ".blob.core.windows.net"
spark.conf.set(key_str,storage_account_access_key)
sc = spark.sparkContext
    
for file in json_file_list:
    
    filepath =  blob_base_path + file
        
    raw = sc.textFile(filepath)
    
    parsed = raw.map(lambda line: parse_json(line))
    
    data = spark.createDataFrame(parsed,common_event)
    
    out_path = blob_base_path + blob_output_dir
    
    data.write.partitionBy("partition").mode("append").parquet(out_path)

spark.stop() 

In [94]:
#create a spark session
spark = SparkSession.builder.appName("stockExchange").getOrCreate()
key_str = "fs.azure.account.key." + storage_account_name + ".blob.core.windows.net"
spark.conf.set(key_str,storage_account_access_key)

trade_path = blob_base_path + blob_output_dir + '/partition=T'

trade_common = spark.read.parquet(trade_path)

In [95]:
# Group By unique identifiers to get the latest transaction
trade = trade_common.select("trade_dt","symbol","event_tm","event_seq_nb","arrival_tm","trade_pr")

In [96]:
trade_cnt = trade.groupBy('trade_dt','symbol','event_tm','event_seq_nb',"trade_pr").agg(func.max('arrival_tm').alias("max_arrival_dtm"))

In [97]:
#Join and filter the records for latest transaction

trade.createOrReplaceTempView("trade_tbl")
trade_cnt.createOrReplaceTempView("trade_cnt_tbl")

join_trade_df = spark.sql("""
    select * 
    from trade_tbl t1
    where EXISTS (
            select 1 from trade_cnt_tbl t2
            WHERE t1.trade_dt = t2.trade_dt
              and t1.symbol = t2.symbol
              and t1.event_tm = t2.event_tm
              and t1.event_seq_nb = t2.event_seq_nb
              and t1.trade_pr = t2.trade_pr
        )
""")

In [98]:
# Write the corrected data to the blob
trade_date = "2020-07-29"
join_trade_df.write.parquet(blob_base_path + blob_output_dir + "/trade/trade_dt={}".format(trade_date))

In [99]:
#Terminate the spark execution
spark.stop()