In [None]:
# Read/Write codes


def file_reader(source, file_path, header=True, schema=True, multiline=True):
  '''
  Generic File Reader. Can read any text and binary files and returns the spark dataframe
  '''
  if source.lower() == "parquet":
    return spark.read.format("parquet").load(file_path)
  elif source.lower() == "csv":
    return spark.read.format("csv").option("header", header).option("inferSchema", schema) \
                                   .load(file_path)
  else:
    raise Exception("Invalid source type. Please specify a valid source type.")



def file_writer(target, df, file_path, save_mode="overwrite", part_cols=""):
  '''
  Generic File Writer. Can write any text and binary files
  '''
  if len(part_cols) > 0:
    df.write.format(target).mode(save_mode).partitionBy([col(x) for x in part_cols]) \
                           .save(file_path)
  else:
    df.write.format(target).mode(save_mode).save(file_path)




# Merge table logic

def execute_merge(df, db_params, primary_keys):
  '''
  Execute the merge statement
  '''
  try:
    src_table = get_delta_table(Constant.SILVER_CONTAINER + "/" + db_params.get_database() \
                                                          + "/" + db_params.get_table())
    src_table.alias("a").merge(df.alias("b"), get_merge_on_columns(primary_keys)) \
                        .whenMatchedUpdateAll() \
                        .whenNotMatchedInsertAll() \
                        .execute()
  except Exception:
    traceback.print_exc()
    raise Exception("Error occured when performing merge operation for " + db_params.get_database() \
                                                                         + "." + db_params.get_table())


# Casting dataframe columns


 castDF = curveDF.withColumn("date", to_date(curveDF["date"], "MM/dd/yyyy")) \
                        .withColumn("cut_off_date", to_date(curveDF["cut_off_date"], "MMM-EEE-dd-yyyy")) \
                        .withColumn("peak", curveDF["peak"].cast(DoubleType())) \
                        .withColumn("off_peak", curveDF["off_peak"].cast(DoubleType())) \
                        .withColumn("atc", curveDF["atc"].cast(DoubleType())) \
                        .withColumnRenamed("cut_off_date", "short_term_cut_off_date") \
                        .withColumn("effective_date", lit(ifc.eff_date)) \
                        .withColumn("year_diff", year("date") - year("effective_date"))





# Date conversion logics

def columns_to_unix_timestamp(df, table_name, table_attributes):
  df = datetime_columns_to_unix_timestamp(df, table_name, table_attributes)
  df = datetimeoffset_columns_to_unix_timestamp(df, table_name, table_attributes)
  return df

# COMMAND ----------

def datetime_columns_to_unix_timestamp(df, table_name, table_attributes):
  '''
  This method converts datetime columns to unix timestamps based on the given table attributes
  '''
  column_list = table_attributes.get_tablename_to_timestamp_columns().get(table_name)
  if column_list:
    df = reduce(sql_datetime_to_unix_timestamp, column_list,  df)
  return df

# COMMAND ----------

def datetimeoffset_columns_to_unix_timestamp(df, table_name, table_attributes):
  '''
  This method converts datetimeoffset columns to unix timestamps based on the given table attributes
  '''
  column_list = table_attributes.get_tablename_to_datetimeoffset_columns().get(table_name)
  if column_list:
    df = reduce(sql_datetimeoffset_to_unix_timestamp, column_list,  df)
  return df

# COMMAND ----------

def sql_datetime_to_unix_timestamp(source_df, column_zoneid_tuple):
  '''
  Convert datetimes with the supplied timezone to unix timestamps
  column_zoneid_tuple - e.g. ("EnterDate", "EST")
  '''
  return source_df.withColumn(column_zoneid_tuple[0], to_utc_timestamp(column_zoneid_tuple[0], 
                                                                             column_zoneid_tuple[1]))

# COMMAND ----------

def sql_datetimeoffset_to_unix_timestamp(source_df, column_format_tuple):
  '''
  Convert datetimeoffset to UTC if you pass in a zone-offset
  column_format_tuple - e.g. ("ModifiedDate", "yyyy-MM-dd HH:mm:ss z")
  '''
  return source_df.withColumn(column_format_tuple[0], to_timestamp(column_format_tuple[0], 
                                                                       column_format_tuple[1]))