In [1]:
# Following walktrhough from https://docs.microsoft.com/en-us/azure/azure-databricks/databricks-extract-load-sql-data-warehouse#load-data-into-azure-sql-data-warehouse

In [2]:
from pyspark.sql import functions as F
from pyspark.sql.types import *
from datetime import date
from pyspark.sql.utils import AnalysisException
import sys
import logging
import re

In [3]:
STORAGE_KEY = dbutils.secrets.get("billing", "storage_key")
DB_CONNECTION = dbutils.secrets.get("billing", "db_connection")
STORAGE_NAME = "blxbillingstorage"
CONTAINER_NAME_SOURCE = "billingfiles"
CONTAINER_NAME_DB_TEMP = "billingfilesfordw"
DB_TEMP_STORAGE = "wasbs://" + CONTAINER_NAME_DB_TEMP + "@" + STORAGE_NAME + ".blob.core.windows.net/tmp"
SOURCE_FOLDER = '/mnt/' + CONTAINER_NAME_SOURCE + '/'
FILE_PATTERN = "usage-(\d{4}-\d{2}-\d{2})T\d{2}-\d{2}-\d{2}-twoweeks_block\.csv$"
# # Testing
#FILE_PATTERN = "usage-(\d{4}-\d{2}-\d{2})T\d{2}-\d{2}-\d{2}-december\.csv$"

_LOGGER = logging.getLogger(__name__)
_LOGGER.handlers = []
_LOGGER.setLevel(logging.INFO)

format = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
stream_handler = logging.StreamHandler(sys.stdout)
stream_handler.setFormatter(format)
_LOGGER.addHandler(stream_handler)

In [4]:
# Configure key
spark.conf.set(
  "fs.azure.account.key." + STORAGE_NAME + ".blob.core.windows.net", STORAGE_KEY)

In [5]:
# [
#     "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType",
#     "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType",
#     "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType"]
BILLING_SCHEMA = StructType([
                  StructField('AccountId', IntegerType(), True),
                  StructField('AccountName', StringType(), True),
                  StructField('AccountOwnerEmail', StringType(), True),
                  StructField('AdditionalInfo', StringType(), True),
                  StructField('ConsumedQuantity', DecimalType(), True),
                  StructField('ConsumedService', StringType(), True),
                  StructField('ConsumedServiceId', IntegerType(), True),
                  StructField('Cost', DecimalType(38, 18), True),
                  StructField('CostCenter', StringType(), True),
                  StructField('Date', DateType(), True),
                  StructField('DepartmentId', StringType(), True),
                  StructField('DepartmentName', StringType(), True),
                  StructField('InstanceId', StringType(), True),
                  StructField('MeterCategory', StringType(), True),
                  StructField('MeterId', StringType(), True),
                  StructField('MeterName', StringType(), True),
                  StructField('MeterRegion', StringType(), True),
                  StructField('MeterSubCategory', StringType(), True),
                  StructField('Product', StringType(), True),
                  StructField('ProductId', StringType(), True),
                  StructField('ResourceGroup', StringType(), True),
                  StructField('ResourceLocation', StringType(), True),
                  StructField('ResourceLocationId', IntegerType(), True),
                  StructField('ResourceRate', StringType(), True),
                  StructField('ServiceAdministratorId', StringType(), True),
                  StructField('ServiceInfo1', StringType(), True),
                  StructField('ServiceInfo2', StringType(), True),
                  StructField('StoreServiceIdentifier', StringType(), True),
                  StructField('SubscriptionGuid', StringType(), True),
                  StructField('SubscriptionId', StringType(), True),
                  StructField('SubscriptionName', StringType(), True),
                  StructField('Tags', StringType(), True),
                  StructField('UnitOfMeasure', StringType(), True),
                  StructField('PartNumber', StringType(), True),
                  StructField('ResourceGuid', StringType(), True),
                  StructField('OfferId', StringType(), True),
                  StructField('ChargesBilledSeparately', StringType(), True),
                  StructField('Location', StringType(), True),
                  StructField('ServiceName', StringType(), True),
                  StructField('ServiceTier', StringType(), True),
                  StructField("_corrupt_record", StringType(), True)])


def save_file(df, file_path):
  try:
    df.write.parquet(file_path)
  except AnalysisException as aex:
    if "already exists" in aex.desc:
      _LOGGER.info("Skipping '{}'. File already exists.".format(file_path))
      return
    raise aex


def add_to_db_temp(df, table_name):
  _LOGGER.info("Saving to " + table_name)
  
#   Configured to append to file
# .mode("Append|Overwrite|ErrorIfExists|Ignore")

  df.write \
    .format("com.databricks.spark.sqldw") \
    .option("url", DB_CONNECTION) \
    .option("forwardSparkAzureStorageCredentials", "true") \
    .option("dbTable", table_name) \
    .mode("Append") \
    .option("tempDir", DB_TEMP_STORAGE) \
    .save()


def process_file(file_name, date_part):
  _LOGGER.info("Processing " + file_name)
  today = date.today().strftime("%Y%m%dT")
  
  billing_df = spark.read.load(file_name, format="csv", header="true", schema=BILLING_SCHEMA, escape="\"")
  
  # Filter out bad records
  bad_rows = billing_df.filter(billing_df._corrupt_record.isNotNull())
  num_bad_rows = len(bad_rows.head(1))
  _LOGGER.info("Skipped {} row(s).".format(num_bad_rows))  
  billing_df.filter(billing_df._corrupt_record.isNull())
  
  # Split columns
  split_col = F.split(billing_df.CostCenter, '-')
  split_more = F.split(F.trim(split_col.getItem(1)), ' ')
  billing_df = billing_df\
        .withColumn('Year', F.year(billing_df.Date))\
        .withColumn('Month', F.date_format(billing_df.Date, 'yyyy-MM').cast('Date'))\
        .withColumn('CostCenterCode', split_col.getItem(0))\
        .withColumn('CostCenterSBG', split_more.getItem(0))\
        .withColumn('CostCenterSBU', split_more.getItem(1))\
        .withColumn('CostCenterSBE', split_more.getItem(2))\
        .withColumn('CostCenterRegion', split_more.getItem(3))
#         .withColumn('_IsAKS', billing_df.InstanceId.rlike("aks").cast("Boolean").cast("Int"))
  
  # Summarize data
  _LOGGER.info("...Summarizing data...")
  summary_df = billing_df\
    .groupby(['Date', 'Year', 'Month', 'SubscriptionGuid', 'SubscriptionName', 'CostCenter', 'CostCenterSBG', 'CostCenterSBU', 'CostCenterSBE', 'CostCenterRegion', 'MeterCategory', 'MeterSubCategory', 'Product', 'UnitOfMeasure'])\
    .agg(\
         F.sum(billing_df.ConsumedQuantity).alias("ConsumedQuantity"),
         F.sum(billing_df.Cost).alias("Cost"),
         F.count(billing_df.InstanceId).alias("InstanceCount")\
        )\
    .orderBy(F.desc("Cost"))

  # Summarize data to monthly
  _LOGGER.info("...Summarizing monthly...")
  monthly_agg_df = billing_df\
    .groupby('Month')\
    .agg(\
         F.round(F.sum(billing_df.Cost),2).alias("Cost")\
        )\
    .orderBy(F.asc("Month"))\
    .filter(billing_df.Month.isNotNull())
  
  # Add tracking columns
  summary_df = summary_df\
        .withColumn('_FileName', F.lit(file_name))\
        .withColumn('_JobRunTime', F.lit(F.current_timestamp()))\
        .withColumn('_MergeDate', F.lit(''))
  monthly_agg_df = monthly_agg_df\
        .withColumn('_FileName', F.lit(file_name))\
        .withColumn('_JobRunTime', F.lit(F.current_timestamp()))\
  
  # Save to DB
  _LOGGER.info("...save to db")
  add_to_db_temp(summary_df, 'fact_UsageSummary_Temp')  
  add_to_db_temp(monthly_agg_df, 'fact_UsageMonthly_Temp')  
  
  # Mark as done
  _LOGGER.info("...rename file")
  dbutils.fs.mv(file_name, file_name + ".processed")


def main():
   
  for file_info in dbutils.fs.ls(SOURCE_FOLDER):
    match = re.match(FILE_PATTERN, file_info.name)
    if match:
      process_file(file_info.path, match.group(1))

In [6]:
_LOGGER.info('starting main.')
main()