In [0]:
%run "./WalmartSales-modules-dev"

In [0]:
try:
  import pytest
except ImportError as error:
  %pip install pytest

In [0]:
from pyspark import SparkContext
from pyspark.sql import SQLContext
import pytest
import pandas as pd

In [0]:
def get_sorted_data_frame(data_frame, columns_list):
    return data_frame.sort_values(columns_list).reset_index(drop=True)
  
def test_replaceNullCustomerID(sql_context):
       
    input = sql_context.createDataFrame(
        [('489434',85048,"test product1",12,"12/1/2009 7:45",6.95,"13085","United Kingdom"),
         ('489434',85048,"test product1",12,"12/1/2009 7:45",6.95,None,"United Kingdom")
        ],
        ['Invoice', 'StockCode','Description','Quantity','InvoiceDate','Price','Customer ID','Country'],
    )
    
    expected_output = sql_context.createDataFrame(
        [('489434',85048,"test product1",12,"12/1/2009 7:45",6.95,"13085","United Kingdom"),
         ('489434',85048,"test product1",12,"12/1/2009 7:45",6.95,"Guest","United Kingdom")
        ],
        ['Invoice', 'StockCode','Description','Quantity','InvoiceDate','Price','Customer ID','Country'],
    )
    
    input.show(5)
    
    real_output = replaceNullCustomerID(input)
    
    real_output.show(5)
    
    real_output = get_sorted_data_frame(
        real_output.toPandas(),
        ['Invoice', 'StockCode','Description','Quantity','InvoiceDate','Price','Customer ID','Country'],
    )
    
    expected_output = get_sorted_data_frame(
        expected_output.toPandas(),
        ['Invoice', 'StockCode','Description','Quantity','InvoiceDate','Price','Customer ID','Country'],
    )
    
    pd.testing.assert_frame_equal(expected_output, real_output, check_like=True)

In [0]:
def test_replaceNullDescription(sql_context):
       
    input = sql_context.createDataFrame(
        [('489434',85048,"test product1",12,"12/1/2009 7:45",6.95,"13085","United Kingdom"),
         ('489434',85048,None,12,"12/1/2009 7:45",6.95,"13086","United Kingdom")
        ],
        ['Invoice', 'StockCode','Description','Quantity','InvoiceDate','Price','Customer ID','Country'],
    )
    
    expected_output = sql_context.createDataFrame(
        [('489434',85048,"test product1",12,"12/1/2009 7:45",6.95,"13085","United Kingdom"),
         ('489434',85048,"Unlisted",12,"12/1/2009 7:45",6.95,"13086","United Kingdom")
        ],
        ['Invoice', 'StockCode','Description','Quantity','InvoiceDate','Price','Customer ID','Country'],
    )
    
    input.show(5)
    
    real_output = replaceNullDescription(input)
    
    real_output.show(5)
    
    real_output = get_sorted_data_frame(
        real_output.toPandas(),
        ['Invoice', 'StockCode','Description','Quantity','InvoiceDate','Price','Customer ID','Country'],
    )
    
    expected_output = get_sorted_data_frame(
        expected_output.toPandas(),
        ['Invoice', 'StockCode','Description','Quantity','InvoiceDate','Price','Customer ID','Country'],
    )
    
    pd.testing.assert_frame_equal(expected_output, real_output, check_like=True)

In [0]:
def test_addcolumnQuarter(sql_context):
       
    input = sql_context.createDataFrame(
        [('489434',85048,"test product1",12,"12/1/2009 7:45",6.95,"13085","United Kingdom"),
         ('489434',85048,"test product1",12,"6/1/2009 7:45",6.95,"13086","United Kingdom")
        ],
        ['Invoice', 'StockCode','Description','Quantity','InvoiceDate','Price','Customer ID','Country'],
    )
    
    expected_output = sql_context.createDataFrame(
        [('489434',85048,"test product1",12,"12/1/2009 7:45",6.95,"13085","United Kingdom","Qtr4"),
         ('489434',85048,"test product1",12,"6/1/2009 7:45",6.95,"13086","United Kingdom","Qtr2")
        ],
        ['Invoice', 'StockCode','Description','Quantity','InvoiceDate','Price','Customer ID','Country',"Qtr"],
    )
    
    input.show(5)
    
    real_output = addcolumnQuarter(input)
    
    real_output.show(5)
    expected_output.show(5)
    
    real_output = get_sorted_data_frame(
        real_output.toPandas(),
        ['Invoice', 'StockCode','Description','Quantity','InvoiceDate','Price','Customer ID','Country',"Qtr"],
    )
    
    expected_output = get_sorted_data_frame(
        expected_output.toPandas(),
        ['Invoice', 'StockCode','Description','Quantity','InvoiceDate','Price','Customer ID','Country',"Qtr"],
    )
    
    pd.testing.assert_frame_equal(expected_output, real_output, check_like=True)

In [0]:
def test_addcolumnInvoiceType(sql_context):
       
    input = sql_context.createDataFrame(
        [('489434',85048,"test product1",0,"12/1/2009 7:45",0.00,"13085","United Kingdom"),
         ('489434',85048,"test product1",-1,"12/1/2009 7:45",-1.00,"13086","United Kingdom"),
         ('489434',85048,"test product1",12,"12/1/2009 7:45",6.95,"13086","United Kingdom")
        ],
        ['Invoice', 'StockCode','Description','Quantity','InvoiceDate','Price','Customer ID','Country'],
    )
    
    expected_output = sql_context.createDataFrame(
        [('489434',85048,"test product1",0,"12/1/2009 7:45",0.00,"13085","United Kingdom","Return"),
         ('489434',85048,"test product1",-1,"12/1/2009 7:45",-1.00,"13086","United Kingdom","Return"),
         ('489434',85048,"test product1",12,"12/1/2009 7:45",6.95,"13086","United Kingdom","Purchase")
        ],
        ['Invoice', 'StockCode','Description','Quantity','InvoiceDate','Price','Customer ID','Country',"InvoiceType"],
    )
    
    input.show(5)
    
    real_output = addcolumnInvoiceType(input)
    
    real_output.show(5)

    real_output = get_sorted_data_frame(
        real_output.toPandas(),
        ['Invoice', 'StockCode','Description','Quantity','InvoiceDate','Price','Customer ID','Country',"InvoiceType"],
    )
    
    expected_output = get_sorted_data_frame(
        expected_output.toPandas(),
        ['Invoice', 'StockCode','Description','Quantity','InvoiceDate','Price','Customer ID','Country',"InvoiceType"],
    )
    
    pd.testing.assert_frame_equal(expected_output, real_output)

In [0]:
def test_filterDf(sql_context):
       
    input = sql_context.createDataFrame(
        [('489434',85048,"test product1",12,"12/1/2009 7:45",0.00,"13085","United Kingdom"),
         ('489434',85048,"test product1",12,"12/1/2009 7:45",-1.00,"13086","United Kingdom"),
         ('489434',85048,"test product1",12,"12/1/2009 7:45",6.95,"13086","India")
        ],
        ['Invoice', 'StockCode','Description','Quantity','InvoiceDate','Price','Customer ID','Country'],
    )
    
    expected_output_uk = sql_context.createDataFrame(
        [('489434',85048,"test product1",12,"12/1/2009 7:45",0.00,"13085","United Kingdom"),
         ('489434',85048,"test product1",12,"12/1/2009 7:45",-1.00,"13086","United Kingdom")
        ],
        ['Invoice', 'StockCode','Description','Quantity','InvoiceDate','Price','Customer ID','Country'],
        )
    
    expected_output_others = sql_context.createDataFrame(
        [('489434',85048,"test product1",12,"12/1/2009 7:45",6.95,"13086","India")
        ],
        ['Invoice', 'StockCode','Description','Quantity','InvoiceDate','Price','Customer ID','Country'],
        )
    
    input.show(5)
    
    real_output_uk,real_output_others = filterDf(input)
    
    real_output_uk.show(5)
    real_output_others.show(5)
    
    real_output_uk = get_sorted_data_frame(
        real_output_uk.toPandas(),
        ['Invoice', 'StockCode','Description','Quantity','InvoiceDate','Price','Customer ID','Country'],
    )
    
    expected_output_uk = get_sorted_data_frame(
        expected_output_uk.toPandas(),
        ['Invoice', 'StockCode','Description','Quantity','InvoiceDate','Price','Customer ID','Country'],
    )
    
    pd.testing.assert_frame_equal(expected_output_uk, real_output_uk, check_like=True)

    real_output_others = get_sorted_data_frame(
        real_output_others.toPandas(),
        ['Invoice', 'StockCode','Description','Quantity','InvoiceDate','Price','Customer ID','Country'],
    )
    
    expected_output_others = get_sorted_data_frame(
        expected_output_others.toPandas(),
        ['Invoice', 'StockCode','Description','Quantity','InvoiceDate','Price','Customer ID','Country'],
    )
    
    pd.testing.assert_frame_equal(expected_output_others, real_output_others, check_like=True)    

In [0]:
sc = spark.sparkContext
sql_context = SQLContext(sc)

In [0]:
test_replaceNullCustomerID(sql_context)

In [0]:
test_replaceNullDescription(sql_context)

In [0]:
test_addcolumnQuarter(sql_context)

In [0]:
test_addcolumnInvoiceType(sql_context)

In [0]:
test_filterDf(sql_context)