In [None]:
%run /OEA_py

In [None]:
oea.set_workspace(workspace)

In [1]:
import pyspark.sql.functions as f
import pandas as pd
import numpy as np

class OEATestKit:
    #---------------------------- Dataframe Generators ----------------------------
    """
    Gets a dataframe from a raw data file

    Arguments:
        primary_key(string): the primary key for the entity
        entity_path (string): the path to the entity
        options (object): options about how to infer the schema
    """
    def get_raw_dataframe(self, primary_key, entity_path, options):
        primary_key = oea.fix_column_name(primary_key) # fix the column name, in case it has a space in it or some other invalid character
        ingested_path = f'stage2/Ingested/{entity_path}'
        raw_path = f'stage1/Transactional/{entity_path}'
        batch_type, source_data_format = oea.get_batch_info(raw_path)
        source_url = oea.to_url(f'{raw_path}/{batch_type}_batch_data')

        if batch_type == 'snapshot'or batch_type=='additive': source_url = f'{source_url}/{oea.get_latest_folder(source_url)}'

        if options == None: options = {}
        options['format'] = source_data_format # eg, 'csv', 'json'
        if source_data_format == 'csv' and (not 'header' in options or options['header'] == None): options['header'] = True  # default to expecting a header in csv files

        spark.sql("set spark.sql.streaming.schemaInference=true")
        return spark.read.format('delta').load(oea.to_url(source_url), **options)

    """
    Gets a dataframe from a lake database

    Arguments:
        item (string): the name of the entity to retrieve from the lake database
        stage (int): the number of the stage to get the dataframe from
        type_id (string): The type identifier.  i for ingest or r for refine
        collection: The name of the collection, aka the name of the parentmost folder
        version (string): The version number
    """
    def get_lake_dataframe(self,item, stage, type_id, collection, version):
        version_delimiter = "p"
        version_split = version.split(".")
        namespace = f"ldb_{oea.workspace}_s{stage}{type_id.lower()}_{collection.lower()}_v{version_delimiter.join(version_split)}"
        return spark.sql(f"SELECT * FROM {namespace}.{item.lower()}")

    #------------------------------ Utility Methods --------------------------------
    """
    This method checks if all of the rows in the first dataframe exist in the second dataframe, by
    comparing the dataframes on the primary key.  Returns True if all the entities in df1 appear in df2

    Arguments:
        df1 (Spark.Dataframe): Dataframe of the first set of data
        df2 (Spark.Dataframe): Dataframe of the second set of data
        primary_key (string | string[]): the primary key(s) used for comparison 
    """
    def is_subset(self, df1, df2, primary_key):
        merged_df = pd.merge(df1.toPandas(), df2.toPandas(), on=primary_key, how="left", indicator="exists")
        merged_df['exists'] = np.where(merged_df.exists == 'both', True, False)
        return merged_df[merged_df.exists == False].shape[0] == 0
    
    """
    This method checks if the df has any duplicates using the primary_key as a comparison.  Returns True if there are duplicates
    else it returns false

    Arguments:
        df (Spark.Dataframe): the dataframe to test
        primary_key (string | string[]): The primary key column(s)
    """
    def has_duplicates(self, df, primary_key):
        grouped_df = df.groupBy(primary_key).count()
        return grouped_df.where(f.col('count') > 1).count() > 0
    
    #------------------------- Test Cases --------------------------------------

    """
    Tests to confirm that the raw dataset is a subset of the lake database, thus ensuring all the entities in the raw data
    exist in the lake database
    """
    def test_raw_data_is_subset_of_lake(self, collection, version, item, primary_key, options):
        #Get the lake database as a dataframe
        lake_df = self.get_lake_dataframe(item, 2, "i", collection, version)

        #get the raw data as a dataframe
        entity_path = f'{collection}/v{version}/{item}'
        raw_df = self.get_raw_dataframe(primary_key, entity_path, options)

        #check if the raw data exists in the lake database
        assert self.is_subset(raw_df, lake_df, primary_key)
    
    def test_has_duplicates(self, collection, version, item, primary_key):
        #Get the lake database as a dataframe
        lake_df = self.get_lake_dataframe(item, 2, "i", collection, version)
        assert self.has_duplicates(lake_df, primary_key) != True
        
oea_test_kit = OEATestKit()

StatementMeta(spark3p3sm, 91, 1, Finished, Available)

IndentationError: expected an indented block (2641926224.py, line 84)