In [0]:
from datetime import datetime
from decimal import Decimal
from typing import List
from pyspark.sql.types import DateType, DecimalType, IntegerType, StructType, StructField, StringType, TimestampType
from pyspark.sql import SparkSession
import json

In [0]:
def get_dir_content(ls_path):
  dir_paths = dbutils.fs.ls(ls_path)
  subdir_paths = [get_dir_content(p.path) for p in dir_paths if p.isDir() and p.path != ls_path]
  flat_subdir_paths = [p for subdir in subdir_paths for p in subdir]
  return list(map(lambda p: p.path, dir_paths)) + flat_subdir_paths

In [0]:
def common_event(trade_dt: DateType, rec_type: StringType, symbol: StringType, exchange: StringType,
                 event_tm: TimestampType, event_seq_nb: IntegerType, arrival_tm: TimestampType,
                 trade_pr: DecimalType(30, 15), trade_size: IntegerType, bid_pr: DecimalType(30, 15),
                 bid_size: IntegerType, ask_pr: DecimalType(30, 15), ask_size: IntegerType, partition: StringType,
                 line: StringType):
    """Returns common event schema

    Args:
        ... field and data type for common schema
        partition: partition key for trade quote or bad T,Q, or B
        line: used to return bad line
    Returns:
        either the bad line or good record as list of values for each field"""

    if partition == "B":
        return line
    else:
        return [trade_dt, rec_type, symbol, exchange,
                event_tm, event_seq_nb, arrival_tm,
                trade_pr, trade_size, bid_pr, bid_size, ask_pr, ask_size, partition, line]

In [0]:
def parse_csv(line: str):
    """CSV parser to be used in Spark transformation process"""
    record_type_pos = 2
    record = line.split(',')
    try:
        # logic to parse records
        if record[record_type_pos] == 'T':
            event = common_event(datetime.strptime(record[0], "%Y-%m-%d"), record[2], record[3], record[6],
                                 datetime.strptime(record[4], '%Y-%m-%d %H:%M:%S.%f'), int(record[5]),
                                 datetime.strptime(record[1], '%Y-%m-%d %H:%M:%S.%f'), Decimal(record[7]),
                                 int(record[8]), None, None, None, None, 'T', None)
            return event
        elif record[record_type_pos] == 'Q':
            event = common_event(datetime.strptime(record[0], "%Y-%m-%d"), record[2], record[3], record[6],
                                 datetime.strptime(record[4], '%Y-%m-%d %H:%M:%S.%f'), int(record[5]),
                                 datetime.strptime(record[1], '%Y-%m-%d %H:%M:%S.%f'), None, None, Decimal(record[7]),
                                 int(record[8]), Decimal(record[9]), int(record[10]), 'Q', None)
            return event
    except Exception as e:
        # save record to dummy event in bad partition
        # fill in the fields as None or empty string
        print(e)
        return common_event(None, None, None, None, None, None, None, None, None, None, None, None, None, "B",line)

In [0]:
def parse_json(line: str):
    record = json.loads(line)
    record_type = record['event_type']
    try:
        # logic to parse records
        if record_type == "T":
            event = common_event(datetime.strptime(record["trade_dt"], "%Y-%m-%d"), record["event_type"],
                                 record["symbol"], record["exchange"],
                                 datetime.strptime(record["event_tm"], '%Y-%m-%d %H:%M:%S.%f'),
                                 int(record["event_seq_nb"]),
                                 datetime.strptime(record["file_tm"], '%Y-%m-%d %H:%M:%S.%f'), Decimal(record["price"]),
                                 int(record["size"]), None, None, None, None, "T", None)
            return event
        elif record_type == 'Q':
            event = common_event(datetime.strptime(record["trade_dt"], "%Y-%m-%d"), record["event_type"],
                                 record["symbol"], record["exchange"],
                                 datetime.strptime(record["event_tm"], '%Y-%m-%d %H:%M:%S.%f'),
                                 int(record["event_seq_nb"]),
                                 datetime.strptime(record["file_tm"], '%Y-%m-%d %H:%M:%S.%f'), None, None,
                                 Decimal(record["bid_pr"]), int(record["bid_size"]), Decimal(record["ask_pr"]),
                                 int(record["ask_size"]), "Q", None)
            return event
    except Exception as e:
        print(e)
        return common_event(None, None, None, None, None, None, None, None, None, None, None, None, None, "B", line)

In [0]:
def process_file(path, schem):
    raw = spark.sparkContext.textFile(path)
    if 'csv' in path:
        parsed = raw.map(lambda line: parse_csv(line))
    else:
        parsed = raw.map(lambda line: parse_json(line))
    data = spark.createDataFrame(parsed, schema=schem)
    return data

In [0]:
schema = StructType().add("trade_dt", DateType()).add("rec_type", StringType()).add("symbol", StringType())\
    .add("exchange", StringType()).add("event_tm", TimestampType()).add("event_seq_nb", IntegerType())\
    .add("arrival_tm", TimestampType()).add("trade_pr", DecimalType()).add("trade_size", IntegerType())\
    .add("bid_pr", DecimalType()).add("bid_size", IntegerType()).add("ask_pr", DecimalType())\
    .add("ask_size", IntegerType()).add("partition", StringType()).add("line", StringType())

In [0]:
dbutils.fs.mount(
  source = "wasbs://gcap@gcblobcf2022.blob.core.windows.net",
  mount_point = "/mnt/azuremountgc",
  extra_configs = {"fs.azure.account.key.gcblobcf2022.blob.core.windows.net":"h3lPjueXOs8RE37xrWUD70ZNsHp4wUo/BeGWX6FlLieE/8RZlRWn1VGMT/QoqcF7h2KwXS34RoyVYoUIqS7MeA=="}
)

[0;31m---------------------------------------------------------------------------[0m
[0;31mExecutionError[0m                            Traceback (most recent call last)
[0;32m<command-4497323250226909>[0m in [0;36m<module>[0;34m[0m
[0;32m----> 1[0;31m dbutils.fs.mount(
[0m[1;32m      2[0m   [0msource[0m [0;34m=[0m [0;34m"wasbs://gcap@gcblobcf2022.blob.core.windows.net"[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[1;32m      3[0m   [0mmount_point[0m [0;34m=[0m [0;34m"/mnt/azuremountgc"[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[1;32m      4[0m   [0mextra_configs[0m [0;34m=[0m [0;34m{[0m[0;34m"fs.azure.account.key.gcblobcf2022.blob.core.windows.net"[0m[0;34m:[0m[0;34m"h3lPjueXOs8RE37xrWUD70ZNsHp4wUo/BeGWX6FlLieE/8RZlRWn1VGMT/QoqcF7h2KwXS34RoyVYoUIqS7MeA=="[0m[0;34m}[0m[0;34m[0m[0;34m[0m[0m
[1;32m      5[0m )

[0;32m/databricks/python_shell/dbruntime/dbutils.py[0m in [0;36mf_with_exception_handling[0;34m(*args, **kwargs)[0m
[1;32m    3

In [0]:
paths = [p for p in get_dir_content("/mnt/azuremountgc") if '.txt' in p]

In [0]:
for path in paths:
    df = process_file(path, schema)
    df.write.partitionBy("partition").mode("overwrite").parquet("/mnt/azuremountgc/output")