### FUNCIONES GENERALES

###### **Data reading**
- ``read_path_data(data_file, dataframe_format, sep = ',', inferSchema = True, header = True, encoding = 'utf-8', schema = None)``: returns a Spark or pandas dataframe so the data scientist can work with it on Databricks not needing to transform the data manually regardless of its format
- ``read_dwh_table(query, entorno="pro", column=None, lower_bound=None, upper_bound=None, num_partitions=50)``: read tables from the Data WareHouse with an SQL query and lower the column names
- ``read_staging_table(table, entorno="pro", schema="VARCALC", lower_col_names=True, column=None, lower_bound=None, upper_bound=None, num_partitions=50)``: read table from STAGING.
- ``write_staging_table(df, table, mode)``: write dataframe(df) to STAGING table
- ``write_staging_table_truncate(df, table, mode)``: write dataframe(df) to STAGING table keeping the schema defined in DWH SQL
- ``read_dlake_table(table_path, container, environment="dev", table_format="delta", header=True, sep=",")``: function that reads a table from the Data Lake.
- ``write_dlake_table(df, table_path, container, mode = "overwrite", environment = "dev", table_format = "delta", partitioned_by = None, replace_where = None, header = True, sep = ",")``: function to save a table in the DataLake.

###### Spark dataframes utils
- ``spark_dataframe_shape(df)``: returns a list with the shape of a spark dataframe including rows and columns in the first and second element correpondingly.
- ``spark_lower_column_names(df)``: lower the column names
- ``spark_upper_column_names(df)``: upper the column names
- ``spark_numeric_variables(df)``: returns a list with column names of those columns with any numeric spark sql type.
- ``spark_string_variables(df)``: returns a list with column names of those columns with a StringType spark sql type
- ``spark_datetime_variables(df)``: returns a list with column names of those columns with a Datetime spark sql type
- ``spark_datetime_var_split(df, datetime_vars)``: If a spark DataFrame includes any datetime type variable, it will split them in 3 columns of day, month and year so they can be analyzed separatelly and deletes the date variables. The new columns will be named the same as the original one including "_day", "_month" or "_year"
- ``convert_numeric_variables_to_float(df):``: convierte las variables numericas de un dataframe a formato float

###### General utils
- ``spark_written_csv_rename(csv_path, desired_name)``: as Hadoop generates strange names when creating a csv file and some control files that will fill up our Blob Storage with rubbish, this function renames that file to a desired name and deletes all useless stuff
- ``delete_blob_folder(path)``: deletes all files in a specified path
- ``get_lastday_lastmonth(fecha)``: calculates the last day of the last month to the given date in YYYYMMDD format
- ``incrementos_periodo(fecha_inicial, unidad_temporal, incremento, formato_salida, cierre)``: Adds or subtracts a given time interval to an initial date in YYYYMMDD format.
- ``unpersist_rdd(table_id=None, unpersist_all=False)``: Function to unpersist cached RDDs from memory or disk.
- ``calculate_business_days(input_date)``: Calculates the volume of business days from a given date.
- ``get_previous_day(fecha)``: Given a date in yyyyMMdd format, returns the previous day.
- ``get_today()``: Returns todays's date in yyyyMMdd format.
- ``get_yesterday()``: Returns yesterday's date in yyyyMMdd format.
- `spark_dim_window_filter(tabla, fecha_hub, claves, drop = True, registro_fecha = "fecha_inicio")`: Filtra registros de un periodo en la tabla de dimensiones sin usar claves subrogadas.

In [0]:
import os
import pandas as pd
import numpy as np
import shutil
import datetime
from dateutil.relativedelta import relativedelta
from pandas.tseries.offsets import BDay

from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql import types as T
from adal import AuthenticationContext

Data sources connect information

In [0]:
# BlobStorage connecting information
storage_account_name = "3satoristgaccountmodelos"
storage_account_access_key = "FZDvphbs+NXwzQpuDUz+8ZQToxxgDX5Dab1gpG1ZGgWENq7qgR33FqL3MpsMbUc29Tzf9HJQd/hPPzgwvhplcw=="

spark.conf.set(
  "fs.azure.account.key." + storage_account_name + ".blob.core.windows.net", 
  storage_account_access_key)

Data reading

In [0]:
def _get_token(entorno="pro"):
  """
  Get token necessary for Service Principal autentication.
  PRIVATE USAGE ONLY!
  
  Parameters:
  ------------
  entorno : string
    Indicates the SQL server environment. Valid options:
        - "pro" -> production (3)
        - "dev" -> development (6)

  Returns:
  --------
  Tuple with 3 itmes:
    - jdbc_url_sp : string containing the url of de DWH database
    - jdbc_url_staging_sp : string containing the url of de Staging database
    - properties_sp : diccionary contaning the configuration of the Service Principal
    
  """
  
  if entorno not in ["pro", "dev"]:
    raise ValueError("Parameter 'entorno' must be either 'pro' or 'dev'.")
  
  entorno_code = {"pro": 3, "dev": 6}.get(entorno)
  
  authority_url = "https://login.microsoftonline.com/cbeb3ecc-6f45-4183-b5a8-088140deae5d"
  resource = "https://database.windows.net/"
  context = AuthenticationContext(authority_url)
  client_id = "4064fe44-fe06-4978-8e43-2a8fcf9d99cd"
  token_sp = (context
             .acquire_token_with_client_credentials(resource, 
                                                    client_id, 
                                                    dbutils.secrets.get(scope="key_prod", key="sp_prod")))
  jdbc_server_sp = '{0}-satori-dwh-server.database.windows.net'.format(entorno_code)
  jdbc_database_sp = 'SOVSAT{0}'.format(entorno_code)
  jdbc_database_staging_sp = 'STAGING'
  jdbc_url_sp = "jdbc:sqlserver://{0}:{1};database={2}".format(jdbc_server_sp, 1433, jdbc_database_sp) 
  jdbc_url_staging_sp = "jdbc:sqlserver://{0}:{1};database={2}".format(jdbc_server_sp, 1433, jdbc_database_staging_sp) 
  properties_sp = {
      "accessToken" : token_sp["accessToken"], 
      "driver" : "com.microsoft.sqlserver.jdbc.SQLServerDriver",
      "hostNameInCertificate" : "*.database.windows.net", 
      "encrypt": "true"
  }
  
  return jdbc_url_sp, jdbc_url_staging_sp, properties_sp

In [0]:
def read_staging_table(table, entorno="pro", schema="VARCALC", lower_col_names=True,
                       column=None, lower_bound=None, upper_bound=None, num_partitions=50):
  
  """
  Read a table from the Staging database 
  
  Parameters:
  ------------
  - table : string
    Name table to send to the jdbc.
    
  - entorno : string
    Indicates the SQL server environment. Valid options:
        - "pro" / "produccion" -> production (3)
        - "dev" / "desarrollo" -> development (6)
        
  - schema : string
    Schema of the Staging database to be used
    
  - lower_col_names : Boolean
    Boolean that indicates if the column names of the table should be lowered
    
    
  If the user wants to parallelize the reading process, the following parameters should be specified:
  
  - column : string
    Reference column to be used during the partitioning process. It should be an Integer, Date or TimeStamp Type
    
  - lower_bound : Integer / Date / TimeStamp
    Smallest value of the reference column.
    
  - upper_bound : Integer / Date / TimeStamp
    Biggest value of the reference column.

  - num_partitions : Integer.
    Number of partitions to be used.
        
  
  Returns:
  --------
  A dataframe with the data received from the DWH after applying the query.
  """  
  
  table = "{0}.{1}".format(schema, table)
  entorno_code = {"pro": "pro", "dev": "dev", "produccion": "pro", "desarrollo": "dev"}[entorno]
  jdbc_url_sp, jdbc_url_staging_sp, properties_sp = _get_token(entorno_code)
  
  if (column is None) and (lower_bound is None) and (upper_bound is None):
  
    df = spark.read.jdbc(jdbc_url_staging_sp, table, properties=properties_sp)
    
  else:
  
      df = spark.read.jdbc(url=jdbc_url_staging_sp, 
                           table=table, 
                           properties=properties_sp, 
                           column=column, 
                           lowerBound=lower_bound, 
                           upperBound=upper_bound, 
                           numPartitions=num_partitions)
    
  if lower_col_names:
    df = spark_lower_column_names(df)
  
  return df

In [0]:
def read_path_data(data_file, dataframe_format, sep = ',', inferSchema = True, header = True, encoding = 'utf-8', schema = None):
  
  """
  Returns a Spark or pandas dataframe so the data scientist can work with it on Databricks not needing to transform the data manually regardless of its format.
  
  Parameters:
  ------------
  data_file : string
    Path to the input data
  dataframe_format : string
    Format of the output dataframe, it can be either spark or pandas
  sep : string
    Columns separator character, ',' by default
  inferSchema : boolean
    Only on spark dataframes. Whether the schema is infered or not. True by default
  header : boolean
    Whether the data includes a header or not. True by default
  encoding : string
    Data encoding. utf-8 by default
  
  Returns:
  --------
  A dataframe of dataframe_format.
  """
  
  file_format = data_file.split('.')[-1]
  
  if dataframe_format == 'spark':
  
    if file_format == 'csv':
      if schema!=None:
        df = spark.read.format(file_format).load(data_file, schema=schema, delimiter=sep, header = header, encoding = encoding)
      else:
        df = spark.read.format(file_format).load(data_file, inferSchema = inferSchema, delimiter=sep, header = header, encoding = encoding)
        
    elif file_format == 'sas7bdat':
      sas_table = pd.read_sas(data_file.replace('dbfs:/', '/dbfs/'), encoding = encoding)
      if schema != None:
        df = spark.createDataFrame(sas_table, schema = schema)
      else:
        df = spark.createDataFrame(sas_table)
        
    for column in df.columns:
      df = df.withColumnRenamed(column,column.lower())
      df = df.withColumnRenamed(column,column.replace(' ','_').replace('.','_').replace(';','_').replace(',','_').replace('-','_').replace('á', 'a').replace('é', 'e').replace('í', 'i').replace('ó', 'o').replace('ú', 'u'))
      
  return df

#-------------------------------------

def read_dwh_table(query, entorno="pro", 
                   column=None, lower_bound=None, upper_bound=None, num_partitions=50):
  
  """
  Read tables from the Data WareHouse with an SQL query and lower the column names.
  
  Parameters:
  ------------
  - query : string
    An SQL query to send to the jdbc. Must contain the columns selected and the DWH database.table.
    
  - entorno : string
    Indicates the SQL server environment. Valid options:
        - "pro" / "produccion" -> production (3)
        - "dev" / "desarrollo" -> development (6)
        
  If the user wants to parallelize the reading process, the following parameters should be specified:
  
  - column : string
    Reference column to be used during the partitioning process. It should be an Integer, Date or TimeStamp Type
    
  - lower_bound : Integer / Date / TimeStamp
    Smallest value of the reference column.
    
  - upper_bound : Integer / Date / TimeStamp
    Biggest value of the reference column.

  - num_partitions : Integer.
    Number of partitions to be used.
    
  
  Returns:
  --------
  A dataframe with the data received from the DWH after applying the query.
  """
  
  entorno_code = {"pro": "pro", "dev": "dev", "produccion": "pro", "desarrollo": "dev"}[entorno]
  jdbc_url_sp, jdbc_url_staging_sp, properties_sp = _get_token(entorno_code)
  
  if (column is None) and (lower_bound is None) and (upper_bound is None):
    
      df = spark.read.jdbc(url=jdbc_url_sp, 
                       table=query, 
                       properties=properties_sp)
    
    
  else:
  
    df = spark.read.jdbc(url=jdbc_url_sp, 
                         table=query, 
                         properties=properties_sp, 
                         column=column, 
                         lowerBound=lower_bound, 
                         upperBound=upper_bound, 
                         numPartitions=num_partitions)
  
  df = spark_lower_column_names(df)
  
  return df  

In [0]:
def write_staging_table(df, table, mode, entorno="pro", schema="VARCALC"):
  
  """
  Write df to Staging database
  
  Parameters:
  ------------
  df: dataframe to save
  
  table : staging name table
  
  mode: 'Overwrite'/ 'Append'
  
  entorno : string
    Indicates the SQL server environment. Valid options:
        - "pro" / "produccion" -> production (3)
        - "dev" / "desarrollo" -> development (6)

  schema : string
    Schema of the Staging database to be used
  
  Returns:
  --------
  Nothing, this function just runs a process
  """  
  entorno_code = {"pro": "pro", "dev": "dev", "produccion": "pro", "desarrollo": "dev"}[entorno]
  table = "{0}.{1}".format(schema, table)
  jdbc_url_sp, jdbc_url_staging_sp, properties_sp = _get_token(entorno_code)
  df.write.mode(mode).jdbc(jdbc_url_staging_sp, table, properties=properties_sp)

In [0]:
def write_staging_table_truncate(df, table, mode, entorno="pro", schema="VARCALC"):
  
  """
  Write df to Staging database (truncate option enabled)
  
  Parameters:
  ------------
  df: dataframe to save
  
  table : staging name table
  
  mode: 'Overwrite'/ 'Append'
  
  entorno : string
    Indicates the SQL server environment. Valid options:
        - "pro" / "produccion" -> production (3)
        - "dev" / "desarrollo" -> development (6)

  schema : string
    Schema of the Staging database to be used
  
  Returns:
  --------
  Nothing, this function just runs a process
  """  
  entorno_code = {"pro": "pro", "dev": "dev", "produccion": "pro", "desarrollo": "dev"}[entorno]
  table = "{0}.{1}".format(schema, table)
  jdbc_url_sp, jdbc_url_staging_sp, properties_sp = _get_token(entorno_code)
  (df.write
   .mode(mode)
   .option("truncate",True)
   .jdbc(jdbc_url_staging_sp, table, properties=properties_sp))

Spark dataframes utils

In [0]:
def spark_dataframe_shape(df):
  
  """
  Returns a list with the shape of a spark dataframe including rows and columns in the first and second element correpondingly.
  
  Parameters:
  ------------
  df : spark DataFrame
    The dataframe of which we want to calculate the shape
  
  Returns:
  --------
  A list with rows on the first element and columns on the second.
  """
  print('The dataframe has {} rows and {} columns'.format(df.count(), len(df.columns)))
  return [df.count(), len(df.columns)]

#-------------------------------------

def spark_lower_column_names(df):
  
  """
  Lower the column names.
  
  Parameters:
  ------------
  df : spark DataFrame
    The dataframe of which we want to lower the column names
  
  Returns:
  --------
  The same dataframe with the same column names in lower.
  """
  
  for column in df.columns:
      
      df = df.withColumnRenamed(column,column.lower())
  
  return df

#-------------------------------------

def spark_upper_column_names(df):
  
  """
  Upper the column names.
  
  Parameters:
  ------------
  df : spark DataFrame
    The dataframe of which we want to upper the column names
  
  Returns:
  --------
  The same dataframe with the same column names in upper.
  """
  
  for column in df.columns:
      
      df = df.withColumnRenamed(column,column.upper())
  
  return df

#-------------------------------------

def spark_numeric_variables(df):
  
  """
  Returns a list with column names of those columns with any numeric spark sql type.
  
  Parameters:
  ------------
  df : spark DataFrame
    The dataframe that we want to analyze
  
  Returns:
  --------
  A list with the names of those columns that are numeric.
  """
  
  # create a list with all variables which dataType is numeric
  list_column_names = [variable.name for variable in df.schema.fields if ('IntegerType' in str(variable.dataType)
                                                                or 'DoubleType' in str(variable.dataType) 
                                                                or 'DecimalType' in str(variable.dataType)
                                                                or 'FloatType' in str(variable.dataType)
                                                                or 'LongType' in str(variable.dataType)
                                                                or 'ShortType' in str(variable.dataType)
                                                                or 'ByteType' in str(variable.dataType))]
  
  return list_column_names

#-------------------------------------

def spark_string_variables(df):
  
  """
  Returns a list with column names of those columns with a StringType spark sql type.
  
  Parameters:
  ------------
  df : spark DataFrame
    The dataframe that we want to analyze
  
  Returns:
  --------
  A list with the names of those columns that are StringType.
  """
  
  # create a list with all variables which dataType is numeric
  list_column_names = [variable.name for variable in df.schema.fields if variable.dataType == T.StringType()]
  
  return list_column_names

#-------------------------------------

def spark_datetime_variables(df):
  
  """
  Returns a list with column names of those columns with a Datetime spark sql type.
  
  Parameters:
  ------------
  df : spark DataFrame
    The dataframe that we want to analyze
  
  Returns:
  --------
  A list with the names of those columns that are StringType.
  """
  
  # create a list with all variables which dataType is numeric
  list_column_names = [variable.name for variable in df.schema.fields if variable.dataType in 
                       (TimestampType(), DateType())]
  
  return list_column_names

#-------------------------------------

def spark_datetime_var_split(df, datetime_vars):
  
  """
  If a spark DataFrame includes any datetime type variable, it will split them in 3 columns of day, month and year so they can be analyzed separatelly and deletes the date variables. The new columns will be named the same as the original one including "_day", "_month" or "_year".
  
  Parameters:
  ------------
  df : spark DataFrame
    The dataframe that we want to analyze
  datetime_vars: list
    A list including all Spark Date type variables
  
  Returns:
  --------
  A dataframe with 3 new columns for each date type variable including its day, month and year and deleting the original variable.
  """
    
  for i in range(len(datetime_vars)):
    df = df.withColumn(datetime_vars[i] + '_day', F.dayofmonth(F.col(datetime_vars[i])))
    df = df.withColumn(datetime_vars[i] + '_month', F.month(F.col(datetime_vars[i])))
    df = df.withColumn(datetime_vars[i] + '_year', F.year(F.col(datetime_vars[i])))
    df = df.drop(datetime_vars[i])
    
  return df

#-------------------------------------

def columns_name_normalization(df, replacing_character = '_'):
  
  """
  Renames all columns of a dataframe in order to replace non allowed characters.
  
  Parameters:
  ------------
  df : spark DataFrame
  
  replacing_character : String
    The character by which we would like to replace the non allowed ones. _ by default
  
  Returns:
  --------
  A dataframe with 3 new columns for each date type variable including its day, month and year and deleting the original variable.
  """

  for column in df.columns:
    df = df.withColumnRenamed(column, column.replace(' ', replacing_character).replace('.', replacing_character).replace(';', replacing_character).replace(',', replacing_character).replace('-', replacing_character))
    
  return df

General utils

In [0]:
def spark_written_csv_rename(csv_path, desired_name):
  
  """
  As Hadoop generates strange names when creating a csv file and some control files that will fill up our Blob Storage with rubbish, this function renames that file to a desired name and deletes all useless stuff.
  
  Parameters:
  -----------
  csv_path: Path in which the csv file is saved by spark
    string
  desired_name: Name we would like the csv file to have. The .csv format is not needed as it will write it automatically.
    string
  
  Returns:
  --------
  Nothing, this function just runs a process
  """
  
  os_path = csv_path.replace('dbfs:','/dbfs')
  list_files = []
  
  for (os_path, folders, files) in os.walk(os_path):
    list_files.extend(files)
    break
  
  for file in list_files:
       
    if file.split(".")[-1] == 'csv' and file[:10] == 'part-00000':
      os.rename(os_path + '/' + file, os_path + '/' + desired_name + '.csv')
      
    else: 
      os.remove(os_path + file)

#-------------------------------------

def delete_blob_folder(path):
  
  """
  deletes all files in a specified path  
  
  Parameters:
  -----------
  path: Path with the folders in the blob we want to detele
    string
  
  Returns:
  --------
  Prints the folders / files that are being deleted under the given path
  """
  
  os_path = path.replace('dbfs:','/dbfs')
  list_folders = []
  list_files = []
  
  for (os_path, folders, files) in os.walk(os_path):
    list_folders.extend(folders)
    break
  
  for folder in list_folders:  
    shutil.rmtree(os_path + folder)
    print(folder + ' removed')
    
  for (os_path, folders, files) in os.walk(os_path):
    list_files.extend(files)
    break
    
  for file in list_files:
    os.remove(os_path + file)
    print(file + ' removed')

#-------------------------------------
    
  
# Refactorizada en la siguiente celda usando las funciones incrementos_periodo() y is_last_day_month()

# def get_lastday_lastmonth(fecha):
  
#   """
#   Calculates the last day of the last month to the given date in YYYYMMDD format 
  
#   Parameters:
#   -----------
#   fecha: Date in YYYYMMDD format 
#     string
  
#   Returns:
#   --------
#   last day of the last month to the given date in YYYYMMDD format
#   """
    
#   fecha_sep = datetime.datetime(int(fecha[:4]), int(fecha[4:6]), int(fecha[6:]))
#   first_day = fecha_sep.replace(day=1)
#   prev_month_lastday = first_day - datetime.timedelta(days=1)
#   ultimo_dia_ultimo_mes = str(prev_month_lastday.year) + str(prev_month_lastday.month).zfill(2) + str(prev_month_lastday.day).zfill(2)
  
#   today_plus_one = fecha_sep + datetime.timedelta(days=1)
#   if today_plus_one.month != fecha_sep.month:
#     ultimo_dia_ultimo_mes = fecha
  
#   return ultimo_dia_ultimo_mes

In [0]:
def incrementos_periodo(periodo_ini, unidad_temporal='dia', incremento=0, formato_salida='completo', cierre=None): 
  '''
    Función que suma o resta periodos temporales a la fecha inicial introducida. 
    Por defecto devuelve el día de ayer. 
    
    Parámetros: 
    ------------
    periodo_ini (str): 
    Fecha de la que se parte para los calculos. 
    Si es un int pasarlo como str(int). Formato 'YYYYMMDD'. Ej: '20191031'
    
    unidad_temporal (str): 
    Unidad temporal que se quiere incrementar. 
    Valores admitidos: 'dia', 'mes', 'ano'.
    
    incremento (int): 
    Cantidad de unidades temporales a incrementar: positivo suma, negativo resta
    
    formato_salida (str): 
    Formato de fecha deseado en la salida. 
    Valores admitidos: 'dia', 'mes', 'ano', 'completo', 'periodo', 'datetime'.
  
    cierre (str): None por defecto. 
    Develve el cierre mensual o anual del periodo calculado. 
    Valores admitidos: 'mensual', 'anual'

    Returns:
    -----------
    fecha_salida (str): 
    Fecha de salida con el formato especificado (formato_salida). 
    Si el formato de salida es 'datetime' el tipo de datos que devuelve es datetime.datetime

    Ejemplos de uso: 
    ------------------
    Calcular el cierre mes actual, formato salida 'completa': 
    >>> incrementos_periodo('20191226', incremento = 0, cierre = 'mensual')
    >>> '20191231'
    
    Calcular el cierre anual anterior, formato salida 'datetime': 
    >>> incrementos_periodo('20191226', unidad_temporal = ano, incremento = -1, formato_salida = 'datetime', cierre = 'anual')
    >>> datetime.datetime(2018, 12, 31, 0, 0)

    Calcular el dia restando 15 dias, salida formato salida 'dia':
    >>> incrementos_periodo('20191226', unidad_temporal = 'dia', incremento = -15, formato_salida = 'dia')
    >>> '11'

    Calcular periodo del mes anterior, formato salida 'periodo': 
    >>> incrementos_periodo('20191226', unidad_temporal = 'mes', incremento = -1, formato_salida = 'periodo')
    >>> '201911'
  '''
  
  # Comprobar valores de los parametros introducidos
  if len(periodo_ini) != 8:
    print('El formato de fecha introducido no es válido. Recuerda que tiene que tener formato YYYYMMDD, con 0 si el mes o el día son < 10 → Ej.: 20191105')
    return
    
  if unidad_temporal not in ['dia', 'mes', 'ano']:
    print('La unidad temporal introducida no es válida. Los valores aceptados son: \'dia\', \'mes\' o \'ano\'')
    return    
  
  if formato_salida not in ['dia', 'mes', 'ano', 'completo', 'periodo', 'datetime']:
    print('El formato de fecha para la salida no es válido. Los valores aceptados son: \'dia\', \'mes\' o \'ano\', \'completo\', \'periodo\' o \'datetime\'')
    return    
  
  # Convertir fecha inicial periodo_ini a formato datetime
  try: 
    datetime_object = datetime.datetime.strptime(periodo_ini, '%Y%m%d')
  except:
    print('El formato de fecha introducido no es válido.')
    return

  # Incrementar segun unidad_temporal e incremento
  new_datetime = datetime_object
  incrementos_dict = {
                        'dia':datetime.timedelta(days=incremento), 
                        'mes':relativedelta(months=incremento), 
                        'ano':relativedelta(years=incremento)
                    }

  new_datetime += incrementos_dict[unidad_temporal]
  
  # Si queremos el cierre mensual o anual del periodo calculado
  if cierre == 'anual':
    # Reemplazar y sumar dias para pasar al mes siguiente
    new_datetime = new_datetime.replace(month=1) + relativedelta(months=+12)
    # Ultimo dia del mes anterior
    new_datetime = new_datetime - relativedelta(months=new_datetime.month)

  if (cierre == 'mensual') | (cierre == 'anual'): 
    # Reemplazar y sumar dias para pasar al mes siguiente
    new_datetime = new_datetime.replace(day=1) + datetime.timedelta(days=31)
    # Ultimo dia del mes anterior
    new_datetime = new_datetime - datetime.timedelta(days=new_datetime.day)
  
  # Devolver fecha completa, dia, mes, ano, periodo o datetime segun especificado
  dict_salidas = {
                    'completo':new_datetime.strftime('%Y%m%d'), 
                    'dia':str(new_datetime.day), 
                    'mes':str(new_datetime.month), 
                    'ano':str(new_datetime.year), 
                    'periodo':new_datetime.strftime('%Y%m'), 
                    'datetime': new_datetime
                }

  fecha_salida =  dict_salidas[formato_salida]

  return fecha_salida

def is_last_day_month(fecha):
  """
  Checks if date is last day of the current month
  
  Parameters:
  -----------
  fecha: Date in YYYYMMDD format 
    string
    
  Returns:
  --------
  (boolean)
  """

  lastday_thismonth = incrementos_periodo(fecha, 'mes', 0, 'completo', 'mensual')
  return (lastday_thismonth == fecha)

def get_lastday_lastmonth(fecha):
  
  """
  Calculates the last day of the last month to the given date in YYYYMMDD format.
  If given date is last day of month, it returns raw given date. 
  
  Parameters:
  -----------
  fecha: Date in YYYYMMDD format 
    string
  
  Returns:
  --------
  last day of the last month to the given date in YYYYMMDD format
  """
  
  lastday_lastmonth = incrementos_periodo(fecha, 'mes', -1, 'completo', 'mensual')
  return fecha if (is_last_day_month(fecha)) else lastday_lastmonth

In [0]:
def unpersist_rdd(table_id=None, unpersist_all=False):
  """ 
  Function to unpersist cached DataFrames from memory and disk.
  This function is meant to be used when you no longer have any reference to the DataFrame/RDD to unpersist. Otherwise you can use df.unpersist() or rdd.unpersist()
  
  Parameters:
    table_id (int): 
    ID of the RDD to unpersist, it can be obtained from the Storage tab on Spark UI.
    
    unpersist_all (Boolean, False by default):
    If true, unpersists all DataFrames on memory..      
  
  Returns: 
    Nada, simplemente borra el(los) RDD(s) seleccionado(s). 
  """
  
  if unpersist_all == True: 
    for (id, rdd) in spark.sparkContext._jsc.getPersistentRDDs().items():
      rdd.unpersist()
      print("Unpersisted {} rdd".format(id))
      
  else:    
    try: 
      spark._jsc.getPersistentRDDs()[table_id].unpersist()

    except AttributeError as e:
      print("RDD with {} id not found".format(table_id))
      if isinstance(table_id, int) == False:
        print("table_id must be an INTEGER")

In [0]:
def convert_numeric_variables_to_float(df):
  """ 
  Convierte las variables numericas de un dataframe a formato float
  
  Parameters:
  df (Spark DataFrame)

  Returns: 
    El mismo dataframe con las columnas numericas cambiadas a float 
  """

  # create a list with all variables which dataType is numeric
  list_column_names = [variable.name for variable in df.schema.fields if ('IntegerType' in str(variable.dataType)
                                                                or 'DoubleType' in str(variable.dataType) 
                                                                or 'DecimalType' in str(variable.dataType)
                                                                or 'FloatType' in str(variable.dataType)
                                                                or 'LongType' in str(variable.dataType)
                                                                or 'ShortType' in str(variable.dataType)
                                                                or 'ByteType' in str(variable.dataType))]
  for col_name in list_column_names:
    df = df.withColumn(col_name, F.col(col_name).cast('float'))
            
  return df

In [0]:
def calculate_business_days(input_date):
  """
  Calculates the volume of business days from a given date
  
  Parameters:
  -----------
  input_date: Date in YYYYMMDD or datetime format 
    string/date
    
  Returns:
  --------
  integer with number of business days
  """
  
  if type(input_date) != datetime.date:
    datetime_date = datetime.datetime.strptime(input_date, '%Y%m%d')
  else:
    datetime_date = input_date
    
  for i in range(datetime_date.month * 11):
    if datetime_date.month == 1:
      estimated_Bdays = datetime_date.day - i
    else:
      estimated_Bdays = (datetime_date.month - 1) * 22 + datetime_date.day - i

    estimated_end_year = datetime_date - BDay(estimated_Bdays)
    if estimated_end_year.day == 31 and estimated_end_year.month == 12:
      actual_Bdays = estimated_Bdays
      
  return actual_Bdays

In [0]:
def read_dlake_table(table_path, container, environment="dev", table_format="delta", header=True, sep=","):
  """
  Function that reads a table from the Data Lake.
  
  Args:
    - table_path (String): Path where the table is saved. Can't start with '/'.
    - container (String): Container of the Data Lake where the table is stored.
    - environment (String): 'dev' / 'desarrollo' or 'pro' / 'produccion'.
    - table_format (String): Format of the table (parquet, PARQUET, delta, DELTA, csv, CSV).
    - header (Boolean): Boolean that indicates the header option (only for CSV tables).
    - sep (String): Separator (only for CSV tables).
    
  Returns:
    - Spark DataFrame.
  """
  
  environment_code = {"pro": "pro", "dev": "dev", "produccion": "pro", "desarrollo": "dev"}[environment]
  # Check functions inputs
  if not environment in ["dev", "pro", "desarrollo", "produccion"]:
    raise ValueError("'environment' must be one of the following options: 'dev' or 'pro'. {0} found.".format(environment))
  
  if not table_format in ["parquet", "PARQUET", "delta", "DELTA", "csv", "CSV"]:
    raise ValueError("'table_format' must be one of the following options: 'parquet', 'PARQUET', 'delta', 'DELTA', 'csv', 'CSV'. {0} found.".format(table_format))
    
  if table_path[0] == "/":
    raise ValueError("'table_path can't start with the '/' character.")
  
  # Get full_table_path
  envs = {"pro": 3, "dev": 6}
  full_table_path = os.path.join("abfss://{0}@{1}satdatalakegen2.dfs.core.windows.net".format(container, envs[environment_code]), table_path)
  print(full_table_path)
  # Read table
  if table_format in ["parquet", "PARQUET", "delta", "DELTA"]:
    df = (spark.read.format(table_format).load(full_table_path))
  else:
    df = (spark.read.format(table_format).load(full_table_path, header=header, sep=sep))
    
  return df

In [0]:
def write_dlake_table(df, table_path, container, mode = "overwrite", environment = "dev", 
                      table_format = "delta", partitioned_by = None, replace_where = None, 
                      header = True, sep = ",", delete_versions=True):
  """
  Function to save a table in the Data Lake.
  
  Args:
    - df (DataFrame): Spark DataFrame to be persisted.
    - table_path (String): Path where the table is saved. Can't start with '/'.
    - container (String): Container of the Data Lake where the table is stored.
    - mode (String): Save mode.
    - environment (String): 'dev' / 'desarrollo' or 'pro' / 'produccion'.
    - table_format (String): Format of the table (parquet, PARQUET, delta, DELTA, csv, CSV).
    - partitioned_by (String or List(String)): Columns by which the table should be partitioned.
    - replace_where (String): Expression to be used in the 'replace_where' option for DELTA tables.
    - header (Boolean): Boolean that indicates the header option (only for CSV tables).
    - sep (String): Separator (only for CSV tables).
    
  Returns:
    - Nothing, the table is persisted in the the Data Lake.
  """
  

  # Check functions inputs
  if not environment in ["dev", "pro", "desarrollo", "produccion"]:
    raise ValueError("'environment' must be one of the following options: 'dev' or 'pro'. {0} found.".format(table_format))
  
  if not table_format in ["parquet", "PARQUET", "delta", "DELTA", "csv", "CSV"]:
    raise ValueError("'table_format' must be one of the following options: 'parquet', 'PARQUET', 'delta', 'DELTA', 'csv', 'CSV'. {0} found.".format(table_format))
    
  if not mode in ["error", "errorifexists", "append", "overwrite", "ignore"]:
    raise ValueError("'mode' must be one of the following options: 'error', 'errorifexists', 'append', 'overwrite' or 'ignore'. {0} found.".format(mode))
  
  if not replace_where is None and not table_format == "delta" :
    raise ValueError("'replace_where' option can only be set if format = 'delta'. {0} found.".format(replace_where))
  
  if partitioned_by is not None and not (type(partitioned_by) is list or type(partitioned_by) is str):
    raise ValueError("'partitioned_by' must be str or list type. {0} found.".format(type(partitioned_by)))
    
  if replace_where is not None and not type(replace_where) is str:
    raise ValueError("'replace_where' must be str type. {0} found.".format(type(replace_where)))
    
  if table_path[0] == "/":
    raise ValueError("'table_path can't start with the '/' character.")
  
  # Get full_table_path
  environment_code = {"pro": "pro", "dev": "dev", "produccion": "pro", "desarrollo": "dev"}[environment]
  envs = {"pro": 3, "dev": 6}
  full_table_path = os.path.join("abfss://{0}@{1}satdatalakegen2.dfs.core.windows.net".format(container, envs[environment_code]), table_path)
    
  # Save table                 
  if replace_where is None:
    if partitioned_by is None:
      if table_format in ["parquet", "PARQUET", "delta", "DELTA"]:
        (df
         .write
         .format(table_format)
         .mode(mode)
         .save(full_table_path))
      else:
        (df
         .write
         .format(table_format)
         .mode(mode)
         .save(full_table_path, header=header, sep=sep))
        
    else:
      if table_format in ["parquet", "PARQUET", "delta", "DELTA"]:
        (df
         .write
         .format(table_format)
         .mode(mode)
         .partitionBy(partitioned_by)
         .save(full_table_path))
      else:
        (df
         .write
         .format(table_format)
         .mode(mode)
         .partitionBy(partitioned_by)
         .save(full_table_path,  header=header, sep=sep))
  
  else:
    if partitioned_by is None:
      (df
       .write
       .format(table_format)
       .mode(mode)
       .option("replaceWhere", replace_where)
       .save(full_table_path))
    else:
      (df
       .write
       .format(table_format)
       .mode(mode)
       .option("replaceWhere", replace_where)
       .partitionBy(partitioned_by)
       .save(full_table_path))
      
  if table_format in ["delta", "DELTA"]:
    if delete_versions:
      spark.sql("set spark.databricks.delta.retentionDurationCheck.enabled = false")
      spark.sql("VACUUM delta.`{}` RETAIN 0 HOURS".format(full_table_path))

In [0]:
def get_previous_day(fecha):
  """
  Given a date in yyyyMMdd format, returns the previous day
  """

  if fecha[6:8]=='01':
    fecha_ant = get_lastday_lastmonth(fecha)
  else:
    fecha_ant=str(int(fecha)-1)
  return fecha_ant

In [0]:
def get_today():
  """
  Returns todays's date in yyyyMMdd format.
  """
  
  return datetime.datetime.today().strftime('%Y%m%d')

In [0]:
def get_yesterday():
  """
  Returns yesterday's date in yyyyMMdd format.
  """  
  
  return get_previous_day(get_today())

In [0]:
LIMITS_COLS_DWH = {
  "sk_crpa": [1, 2.5e8],
  "sk_mediador": [1, 1e6],
  "sk_sucursal": [1, 2e3],
  "cod_mediador": [0, 100000],
  "sk_nivcod": [1, 5000],
  "sk_division": [1, 20],
  "sk_territorial": [1, 120],
  "sk_sucursal": [1, 1300],
  "sk_centro": [1, 4000],
  "sk_comercial": [1, 5000],
  "id_familia": [1, 8e6],   
  "sk_siniestro": [1, 100e6],
  "sk_stro_gtia_part": [1, 120e6],
  "sk_cliente": [1, 100e6],
  "antiguedad_cliente": [-100, 1250],
  "sk_encuesta": [1, 3e6],
  "default": [1, 2147483640],
  "sk_apunte": [1, 2.2e8],
  "periodo": {"default": [20011231, int(get_yesterday())], 
              "FP001": [20021231, int(get_yesterday())], 
              "FP002": [20011231, int(get_yesterday())], 
              "HP001": [20181231, int(get_yesterday())], 
              "FA001": [20011231, int(get_yesterday())], 
              "HA001": [20190228, int(get_yesterday())]}
}

In [0]:
import sys
from pyspark.sql.window import Window as W


def spark_dim_window_filter(tabla, fecha_hub, claves, drop = True, registro_fecha = "fecha_inicio"):
  """
  Filtra registros de un periodo en la tabla de dimensiones sin usar claves subrogadas.
  Útil cuando tabla hub no está disponible, pero más caro computacionalmente.

  Parameters:
  -------------
  fecha_hub : string
    Fecha en la que se desea realizar la consulta a la tabla 
  claves : lista de strings o string (única clave)
    Lista de nombre/s del campo/s de clave de negocio para crear las ventanas.
  compania: dataframe  
    Tabla de dimensiones
  drop: Boolean
    Por defecto elimina duplicados del subconjunto de claves.
  registro_fecha: string
    Nombre del campo que contiene la fecha de grabado del registro.
    Por defecto: "fecha_inicio."

  Returns
  -------------
  Un dataframe filtrado para un periodo dado por fecha_hub
  
  Ejemplo
  -------------
  df001 = spark_dim_window_filter(DF001, "20201231", ["cod_encargo"] )
  """
  #Comprobación listado de claves
  if type(claves) == str:
    claves = [claves]
  w_table = W.partitionBy(claves).rowsBetween(-sys.maxsize, sys.maxsize)
  tabla = (tabla
           .filter(F.col(registro_fecha) <= fecha_hub)
           .withColumn(registro_fecha, F.col(registro_fecha).cast("int"))
           .withColumn("max_fecha", F.max(registro_fecha).over(w_table))
           .filter(F.col(registro_fecha) == F.col("max_fecha"))
           .drop("max_fecha")  
           )  
  if drop == True:
    tabla = tabla.dropDuplicates(subset=claves)
  return tabla