# Unit Testing


Unit testing refers to the testing of individual functions to check the are functioning as expected. Sometimes little quirks in the inputs that we aren't expected, or changes made later to the function can make it behave in unexpected ways.

The function below is a simple pyspark function that will calculate the average value of a given column in a dataframe

In [None]:
from pyspark.sql import functions as F
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, DoubleType
# Initialize Spark session
spark = SparkSession.builder.appName("PySparkUnitTesting").getOrCreate()

In [None]:
def calculate_average(df, column_name):
    """
    Calculate the average of a specified column in a DataFrame.

    :param df: Input DataFrame
    :param column_name: Name of the column to calculate the average for
    :return: The average value of the column
    """
    # Calculate the average value
    avg_value = df.agg(F.avg(F.col(column_name))).collect()[0][0]
    
    return avg_value

Looks simple enough right? Let's give it some data to try out.

This dataframe represents the names and purchase totals of customers to our shop. We'll calculate the average purchase amount across all customers.

In [None]:
# Example DataFrame
data = [("Alice", 34.50), ("Bob", 45.50), ("Cathy", 29.50)]
df = spark.createDataFrame(data, ["Name", "Total"])

# Calculate the average total
average_total = calculate_average(df, "Total")
print(f"Average Total: {average_total}")


Looks good to me! Now we can implement this in our pipeline and completely forget about it. It works on our data from today so I'm sure nothing will go wrong ever.

### Empty Dataframe

What if one day no one bought anything? This has never happened before so we didn't really think of it when building the function. We still have a dataframe but it's completely empty.

In [None]:
# Create an empty DataFrame
empty_df = spark.createDataFrame([], StructType([]))

# Calculate the average purchase total
average_total = calculate_average(empty_df, "Total")
print(f"Average Total: {average_total}")

Oh dear - we got an error. And this error, if the function was part of a larger pipeline would grind everything to a halt.
We need to go back and fix the function and consider what should happen in this situation.
You can choose what happens, you might want it to return 0, or None, or some other value depending on the reason for creating this average.

In [None]:
def calculate_average(df, column_name):
    """
    Calculate the average of a specified column in a DataFrame.

    :param df: Input DataFrame
    :param column_name: Name of the column to calculate the average for
    :return: The average value of the column
    """
    # Check if the DataFrame is empty
    if df.count() == 0:
        print("Error: The DataFrame is empty.")
        return None
    
    # Calculate the average value
    avg_value = df.agg(F.avg(F.col(column_name))).collect()[0][0]
    
    return avg_value

In [None]:
# Create an empty DataFrame, 
empty_df = spark.createDataFrame([], StructType([]))

# Calculate the average purchase total
average_total = calculate_average(empty_df, "Total")
print(f"Average Total: {average_total}")

This is definitely better than an error. Hopefully nothing else will go wrong!

### Non-existent column

You're wanting to do some market research on your customer base, and find out the average age of your customers. You're pretty sure you have that information, so you can just use your handy calculate_average() function to do it.

In [None]:
# Calculate the average age
average_age = calculate_average(df, "Age")
print(f"Average Age: {average_age}")


Another error - which will hault our whole pipeline. It says there's no "Age" column, but you're sure you take that information, so it must be dropped elsewhere. You'll need to update your function to handle this error.

In [None]:
def calculate_average(df, column_name):
    """
    Calculate the average of a specified column in a DataFrame.

    :param df: Input DataFrame
    :param column_name: Name of the column to calculate the average for
    :return: The average value of the column
    """
    # Check if the DataFrame is empty
    if df.count() == 0:
        print("Error: The DataFrame is empty.")
        return None
    
    # Check if column exists
    if column_name not in df.columns:
        print(f"Error: Column '{column_name}' not found in the DataFrame.")
        return None
    
    # Calculate the average value
    avg_value = df.agg(F.avg(F.col(column_name))).collect()[0][0]
    
    return avg_value


In [None]:
# Calculate the average age
average_age = calculate_average(df, "Age")
print(f"Average Age: {average_age}")

Perfect - now that's handled and we no longer get a long error message.

### Non-numeric data

It's a new day and some new data, let's take a look.

In [None]:
# Example DataFrame
data = [("Daniel", "£12.00"), ("Erica", "£4.50"), ("Frankie", "£175.34")]
df = spark.createDataFrame(data, ["Name", "Total"])

# Calculate the average total
average_total = calculate_average(df, "Total")
print(f"Average Total: {average_total}")


Hmm...

Not an error, but you know that's wrong. A new staff member has input the totals with the currency instead of as just numbers and now the function isn't working as intended - you'll need to make sure this case it handled.

In [None]:
def calculate_average(df, column_name):
    """
    Calculate the average of a specified column in a DataFrame.

    :param df: Input DataFrame
    :param column_name: Name of the column to calculate the average for
    :return: The average value of the column
    """
    # Check if the DataFrame is empty
    if df.count() == 0:
        print("Error: The DataFrame is empty.")
        return None
    
    # Check if column exists
    if column_name not in df.columns:
        print(f"Error: Column '{column_name}' not found in the DataFrame.")
        return None
    
    # Remove non-numeric characters (like currency symbols) and cast to DoubleType
    df = df.withColumn(column_name, F.regexp_replace(F.col(column_name), "[^0-9.]", "").cast(DoubleType()))
    df = df.filter(F.col(column_name).isNotNull())
    
    # Calculate the average value
    avg_value = df.agg(F.avg(F.col(column_name))).collect()[0][0]
    
    return avg_value


In [None]:
# Example DataFrame
data = [("Daniel", "£12.00"), ("Erica", "£4.50"), ("Frankie", "£175.50")]
df = spark.createDataFrame(data, ["Name", "Total"])

# Calculate the average total
average_total = calculate_average(df, "Total")
print(f"Average Total: {average_total}")


And done! Hopefully nothing else goes wrong now, and hopefully the edits we made won't effect the normal functionality...

All of these issues could have been avoided if we used **Unit Tests** when developing the function in the first place. Unit tests allow us to check the behaviour of our functions with both expected and unexpected inputs outside the production environment. If they fail, they won't crash your whole pipeline and they only take a few seconds to a few minutes to run.

In [None]:
%%writefile average_column_function.py
# This is writing this cell to a flat python file - this is the final function we are testing
from pyspark.sql import functions as F
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, DoubleType

def calculate_average(df, column_name):
    """
    Calculate the average of a specified column in a DataFrame.

    :param df: Input DataFrame
    :param column_name: Name of the column to calculate the average for
    :return: The average value of the column
    """
    # Check if the DataFrame is empty
    if df.count() == 0:
        print("Error: The DataFrame is empty.")
        return None
    
    # Check if column exists
    if column_name not in df.columns:
        print(f"Error: Column '{column_name}' not found in the DataFrame.")
        return None
    
    # Remove non-numeric characters (like currency symbols) and cast to DoubleType
    df = df.withColumn(column_name, F.regexp_replace(F.col(column_name), "[^0-9.]", "").cast(DoubleType()))
    df = df.filter(F.col(column_name).isNotNull())
    
    # Calculate the average value
    avg_value = df.agg(F.avg(F.col(column_name))).collect()[0][0]
    
    return avg_value


In [None]:
%%writefile test_average_column.py 
#^Very import to start your file and functions with "test" - this is how pytest finds them!
from pyspark.sql import functions as F
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, DoubleType
from average_column_function import calculate_average # You will need to import your function to test it!


def test_calculate_average():
    """
    Test the calculate average function for expected behaviour
    """
    # Arrange
    # Initialize Spark session
    spark = SparkSession.builder.appName("PySparkUnitTesting").getOrCreate()
    data = [("100",), ("200",), ("300",)]
    input_df = spark.createDataFrame(data, ["value"])

    expected = 200.0

    # Act
    actual = calculate_average(input_df, "value")

    # Assert
    assert actual == expected, f"Expected {expected} but got {actual}"

def test_calculate_average_empty_df():
    """
    Test the calculate average function for an empty df
    """
    # Arrange
    # Initialize Spark session
    spark = SparkSession.builder.appName("PySparkUnitTesting").getOrCreate()
    data = []
    input_df = spark.createDataFrame(data, StructType([]))

    expected = None

    # Act
    actual = calculate_average(input_df, "value")

    # Assert
    assert actual == expected, f"Expected {expected} but got {actual}"

def test_calculate_average_no_column():
    """
    Test the calculate average function for a column that doesn't exist
    """
    # Arrange
    # Initialize Spark session
    spark = SparkSession.builder.appName("PySparkUnitTesting").getOrCreate()
    data = [("100",), ("200",), ("300",)]
    input_df = spark.createDataFrame(data, ["value"])

    expected = None

    # Act
    actual = calculate_average(input_df, "age")

    # Assert
    assert actual == expected, f"Expected {expected} but got {actual}"

def test_calculate_average_currency_inputs():
    """
    Test the calculate average function for a column that doesn't exist
    """
    # Arrange
    # Initialize Spark session
    spark = SparkSession.builder.appName("PySparkUnitTesting").getOrCreate()
    data = [("£100",), ("£200",), ("£300",)]
    input_df = spark.createDataFrame(data, ["value"])

    expected = 200

    # Act
    actual = calculate_average(input_df, "value")

    # Assert
    assert actual == expected, f"Expected {expected} but got {actual}"


In [None]:
!pytest

If everything went as expected, we should have four passing tests!

Try forcing one to fail by editing the expected output, and see what happens.

# Your Turn

Now it's your turn, trying writing a simple python function, and then write some test cases for it considering the expected and unexpected inputs. The cells below have the "cell magic" to create the files for you. 
For the test, use the Arrange, Act, Assert framework to structure them.

Suggestions:

- A function to multiply two numbers
    - What happens if you input a string?
    - Does it work for big numbers? Small numbers? Negative numbers?
- A function to test if a number is even
    - Does it work as intended and return True for even and False for odd?
    - What about negative numbers? 0?
- A function to return the longest word in a list of words
    - What if the list is empty?
    - What if two words have the same length? What would you *want* to happen?
    - What if the list was full of numbers?
- A function to count the number of vowels in a string
    - What is the string is empty?
    - What if there are no vowels? Or only vowels?
    - What is there's a space in the word?

In [None]:
%%writefile my_new_function.py

In [None]:
%%writefile test_my_new_function.py

from my_new_function import my_function

Hint: running just pytest will run ALL tests it finds, if you only want to run one file put the filename after (including the .py)

In [None]:
!pytest 