In [None]:
import pandas as pd

from pyspark.sql.types import * 
# from pyspark.sql.types import StringType, IntegerType, ArrayType, DataType 
from pyspark.sql import Column, functions as F, DataFrame, Window

from functools import wraps, reduce
from time import time
from typing import Callable

from itertools import chain
from datetime import date
import datetime as dt
from dateutil.relativedelta import relativedelta


class ExecTiming:

  def _init__(self):
    return None

  def timing (f):
    """
    function to see execution time of functions
    """
    @wraps(f)
    def wrap(*args, **kw):
      ts= time()
      result = f(*args, **kw)
      te = time()
      print (f'Function {f._name__} took {te-ts:2.4f} seconds') 
      return result
    
    return wrap


class PyDatetime:

  """ 
  class for Datetime manipulation

  Parameters
  ---------
  input
    date in str format %Y-%m-%d

  ymd
    str (yyyy-mm-dd) / %Y-%m-%d

  timestamp_col 
    TimestampType attribute is dataframe column created from string date
  """

  def _init__(self, ymd=None): 
    self.date_input= ymd
    self.timestamp_col= self.new_timestamp_col(self.date_input) 
    self.lday= self.new_lday(self.date_input)

  @property
  def date_input (self): 
      return self._date_input

  @date_input.setter
  def date_input (self, value):
    # validation
    try:
      if value is None: 
        self._date_input 
        date.today().strftime('%Y-%m-%d')

      elif isinstance (dt.datetime.strptime(value, '%Y-%m-%d').date(), date): 
        self._date_input = value 
    except ValueError:
      raise ValueError('Expected a string with format %Y-%m -%d')


  def new_timestamp_col(self, value :str) -> TimestampType:

    """
    return:
    pyspark TimestampType column '%Y-%m-%d'
    """
    self._timestamp_col = F.to_date(F.lit(value)).cast(TimestampType())
    return self._timestamp_col

  def new_lday(self, value: str, num_months: int=12)-> list:
    """
    method to find last day of each month

    return: dictionary with list of dates for first day and last day
    """

    clean_date= dt.datetime.strptime (value, '%Y-%m-%d').date()

    first_day_list = [] 
    last_day_list = []

    first_day= clean_date.replace(day=1) 
    last_day= first_day - dt.timedelta(days=1)

    for i in range(num_months):
      last_month= first_day- relativedelta(months=i) 
      first_day_list.append(last_month)

      last_day = first_day_list[i] - dt.timedelta(days=1)
      last_day_list.append(last_day)

    result= {'first': first_day_list, 'last': last_day_list} 
    return result


class CustomUdf:
  """
  class for Datetime manipulation

  Parameters
  -----------
    timestamp: str
      Create TimestampType attribute from string date Return
    return
      udf function or python function, depending on DataFrame column
  """

  def _init_(self, returnType: DataType=StringType()):
    self.spark_udf_type = returnType

  def _call_(self, func: Callable): 
    def wrapped_func (*args, **kwargs):
      if any([isinstance(arg, Column) for arg in args]) or \
        any ( [isinstance (vv, Column) for vv in kwargs.values()]): 
        return F.udf(func, self.spark_udf_type) (*args, **kwargs)
      else: 
        return func(*args, **kwargs)
    return wrapped_func

## pivottable class

class PivotTable:
  """
  notes: name columns without special chars (i.e..') else cannot resolve column name

  col*
    col[0] 1st level 
    col [1] 2nd level
"""

  def _init_(self, df: DataFrame, row: list= None, value=None , col: list =None):
    self._df = df
    self._row= row
    self._value= value
    self._col= col

  def add_totals(self, df) -> DataFrame:
    cols= df.columns
    total_row= df.agg( *((F.sum(col)) .alias (col) for col in cols) ) 
    total_row= total_row.withColumn (df. columns [0], F.lit('Total'))
    result= df.union (total_row)

    return result


  def pivot_1(self) -> DataFrame: 
    df= self._df \
        .select(*self._row, self._value)\
        .groupBy (self._row[0])\
        .agg (F.sum (self._value).alias('$'), F.count(F.lit (1)). alias ('#'))\
        .sort (F.desc('$'))

    df= self.add_totals (df)

    return df


  def pivot_2(self) -> DataFrame:

    if self._col is not None: 
      df= self._df \
        .groupBy (self._row[0], self._col[1]).pivot (self._co1[0])\
        .agg( F.sum(self._value).alias ('$'), F.count(F.lit (1)).alias('#'))

      agg_cols = [x for x in df.columns if x not in [self._row[0], *self.__col]]

      df= df\
        .groupBy(self._row[0],)\
        .pivot (self._col[1])\
        .agg( *[F.sum (F.col (x)).alias(x) for x in agg_cols])

      df= self.add_totals (df)
      
      return df
    
    else:
      print('column list required')


  def pivot_all(self):
    return self.pivot_1(), self.pivot_2()

class Utilities:
  """
  Method: map_create_col

  Purpose: adds new column that contains the new mapped values, to dataframe

  Parameters

  map_dict: dict
    dictionary with map of old: new 
  
  new_col_name: str 
    new column name in string

  """

  def __init__(self):
    return None

  def map_create_col(self, df: DataFrame, map_dict: dict, new_col_name: str, old_col_name: str)-> DataFrame:
    mapping_expr= F.create_map([F.lit(x) for x in chain(*map_dict.items())])
    df = df.withColumn (new_col_name, mapping_expr[F.col (old_col_name)]) 
    return df

  def neg_to_zero(self, df: DataFrame, col_name: str)-> DataFrame:
    df1= df.withColumn (col_name, 
                      F.when (F.col(col_name) <0, 0)\
                      .otherwise (F.col(col_name))
          )
    return df1


  def create_index(df: DataFrame, index_name: str)-> DataFrame: 
    df = df.withColumn(index_name, 
                    F.row_number().over(Window.orderBy(F.monotonically_increasing_id()))-1 )
    return df


  def string_to_list(long_string: str)-> list:
    """
    purpose:  
      copy paste select columns from SAS
    """
    result= [i for i in long_string.split()] 
    return result


class Joins:
  """
  note:
  https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.sql.DataFrame.union.html
  pyspark union and unionall behave the same, does not eliminate duplicate. need to remove.
  """

  @staticmethod
  def union(*df) -> DataFrame:
    """
    df:

    list of dataframes. note that order of columns shld be same. else use unionByName
    """

    result= reduce (DataFrame.union, df)
    return result

  @staticmethod
  def union_all(*df) -> DataFrame:
    """
    df: list of dataframes. note that order of columns shld be same. else use unionByName
    """
    result= reduce (DataFrame.unionAll, df)  
    result= result.distinct()
    return result



class EDA:
  """
  class to make it ez to check data
  """
  def _init_(self, df): 
    self. df = df

  def summary_stat (self):
    return self._df.summary("count", "min", "25%", "75%", "max")

  def col_zoom(self, col_name):
    result= self._df\
            .select(col_name)\
            .where( F.col(col_name).isNotNull())\
            .toPandas ()
    return result



  def col_zoom_nozero (self, col_name):
          
    result= self._df\
            .select(col_name)\
            .where( F.col(col_name).isNotNull())\
            .where( F.col(col_name) > 0)\
            .toPandas()
    
    return result


class SasPandas:
  """
  purpose:

  convert sas or panda datasets to spark or other formats.
    """
  def _init_(self, spark_instance):
    """
    spark_instance:
      need spark instance for spark methods
    """
    self._spark = spark_instance


  def excel_to_pandas (self, file_path: str):
    """
    file_path:
      '/MAS606 - with formulas v2.xls'
  
    errors:
      if xlrd engine not installed, ensure installed via pip3 
      if openpyxl engine not installed, ensure installed via pip3
    """
    if file_path.endswith('.xls'):
      print('xls file recognized') 
      engine_option= 'xlrd' 
    elif file_path.endswith('.xlsx'): 
      print('xlsx file recognized') 
      engine_option = 'openpyxl'

    result= pd.read_excel(file_path, engine= engine_option)
    return result


  def sas_to_pandas (self, file_path: str): 
    """
    Given pandas dataframe, it will return a spark's dataframe. 
    """
    result= pd.read_sas(file_path, encoding='ISO-8859-1')
    return result 

    
  def pandas_to_spark(self, df) -> DataFrame:
    """
    df:
      pandas dataframe
    """
    
    columns = list(df.columns)
    types= list(df.dtypes)
    struct_list= [] 
    for column, typo in zip(columns, types): 
      struct_list.append(self.define_structure (column, typo)) 
    
    p_schema= StructType (struct_list)
    return self._spark.spark.createDataFrame(df, p_schema)

  def sasxls_to_spark(self, file_path) -> DataFrame:
    """
    file_path:
      either excel or sas extension 
      eg. '/dataset.sas7bdat' or 'excel.xls'
    """
    if file_path.endswith(('.xls', '.xlsx')): 
      df= self.excel_to_pandas (file_path) 
      result = self.pandas_to_spark(df) 
      return result
      
    elif file_path.endswith('.sas7bdat'): 
      df= self.sas_to_pandas (file_path) 
      result = self.pandas_to_spark(df) 
      print('sas file recognized')


  def equivalent_type(self, f):
    if f == 'datetime64 [ns]': return TimestampType() 
    elif f == 'int64': return LongType() 
    elif f == 'int 32': return IntegerType() 
    elif f == 'float64': return FloatType()
    else: return StringType()


  def define_structure(self, string, format_type): 
    try: typo= self.equivalent_type (format_type) 
    except: typo= StringType() 
    return StructField(string, typo)


  def get_parquet (self, yyyymmdd: str, table: str)-> DataFrame:
    """
    purpose:
      get parquet quick.
    issue:
      insufficient parquet files now. use 1 month first.
    """
    table_list= ['AMBS', 'AMNA', 'AMPS']
    
    if table in table_list:
      path= "alluxio :///rmgcbgcredit/credit/vplus/" 
      ext= ".parquet"
      output= self._spark.spark.read.parquet (path + table + '20220131' + ext) 
      return output
    else:
      print('Please input table name ass one of deze', table_list) 
      return None



class SparkDatatype:
  """
  purpose:
    convert spark datatype, when you already have spark dataframe.
    i.e.
    int_cast=[]

    decimal_cast-[
    "total_bal', AMPS USER_6_BNP',
    'credit_limit', AMPS USER_5_BNP*

    char_cast-[
    'block_code_1', 'block_code_2",]

    date_cast=[AMBS_DATE_OPENED',AMBS_DATE_CGOFF',]
    a= SparkDatatype().col_cast(final_all, date_cast= date_cast, 
                                    int_cast= int_cast, 
                                    decimal_cast= decimal_cast, 
                                    char_cast= char_cast)

    
  """


  # def col_check (self, all_cols, *data_types): 
  #   sum=[]
  #   for i in data_types: 
  #       sum sum + i
  #   diff= set(all_cols)- set(sum)
  #   print (diff)

  def caster(self, col_to_cast: list, df: DataFrame, cast_type: str) -> DataFrame: 
    for i in col_to_cast:
      if cast_type == 'date':
        df= df.withColumn(i,
                    F.to_date (F.unix_timestamp (F.col (i), 'yyyy-MM-dd').cast("timestamp"))) 
      
      else: 
        df-df.withcolumn(i, F.col (i). cast(cast_type))

    output= df 
    return output


  def col_cast(self, df: DataFrame, **cast_lists) -> DataFrame:
    df= self.caster(cast_lists ['date_cast'], df, 'date')
    df= self.caster(cast_lists['int_cast'], df, 'int')
    df= self.caster(cast_lists['decimal_cast'], df, 'decimal(15,2)') 
    df= self.caster(cast_lists ['char_cast'], df, 'string') 
    return df


class String0p:
  """
  purpose:

    functions for common pyspark string manipulation

  1.e.
    c= fi_keys.withColumn('a', F.lit('abiaw oifwefoiw'))\
            .withColumn('b', F.col('Financial Institutions'))\
            .withColumn('c', F.concat('a', 'b', 'Financial Institutions'))
    concat_list= sorted(c.columns)
    d = G.StringOp.concat_no_space(c, 'e', concat_list)
  """

  @staticmethod
  def concat_no_space(df: DataFrame, col_name: str, concat_list: list) -> DataFrame:

    df= df\
      .withColumn (col_name,
        F.regexp_replace(F.concat(*[F.col (i) for i in concat_list]), " ", ""))
    return df


  @staticmethod
  def acc_concat (non_col: list, full_list: list, delimiter: str)-> DataFrame:
    """
    purpose:
    -----------
    ez way to concat string literals and columns, to make new column

    parameters:
    ----------
    full list
      input full list in order

    non_col:
      string literals to join

    assume:
      full list> non-col list
    """

    col_list= list(set (full_list)- set(non_col))
    concat_list= []

    for i in full_list: 
      if i in non_col:
        concat_list.append(F.lit(i)) 
      else:
        concat_list.append(F.col(1))

    result= F.regexp_replace(F.concat(*concat_list), delimiter, "")
    return result