In [0]:
import unittest
from pyspark.sql import SparkSession
from pyspark.sql.functions import col

class TestSupermarketsNotebook(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        cls.spark = SparkSession.builder.master("local[2]").appName("SupermarketsTest").config("spark.sql.shuffle.partitions", "1").getOrCreate()
        # Sample data mimicking raw supermarkets input
        data = [
            (1001, "Supermarket A", "46220"),
            (1002, "Supermarket B", None),                 # Missing postal_code to test null fill
            (1003, "Supermarket C", "46222")
        ]
        cols = ["supermarketNo", "supermarket", "postal-code"]
        cls.df = cls.spark.createDataFrame(data, cols)
    
    @classmethod
    def tearDownClass(cls):
        cls.spark.stop()

    def test_column_renames_and_casts(self):
        df = self.df
        
        # Rename supermarketNo to supermarket and postal-code to postalcode cast to string
        df = df.withColumnRenamed("supermarketNo", "supermarket")
        df = df.withColumn("postalcode", col("postal-code").cast("string"))

        # Fill null postalcode with "UNKNOWN"
        df = df.fillna("UNKNOWN", subset=["postalcode"])
        
        # Check column renames
        self.assertTrue("supermarket" in df.columns)
        self.assertTrue("postalcode" in df.columns)
        self.assertFalse("supermarketNo" in df.columns)
        self.assertFalse("postal-code" in df.columns)
        
        # Check postalcode fill for missing value
        postalcodes = df.select("postalcode").rdd.flatMap(lambda x: x).collect()
        self.assertIn("UNKNOWN", postalcodes)

    def test_postalcode_string_indexing(self):
        from pyspark.ml.feature import StringIndexer
        
        df = self.df
        df = df.withColumnRenamed("supermarketNo", "supermarket")
        df = df.withColumn("postalcode", col("postal-code").cast("string"))
        df = df.fillna("UNKNOWN", subset=["postalcode"])

        indexer = StringIndexer(inputCol="postalcode", outputCol="postalcode_indexed", handleInvalid="keep")
        model = indexer.fit(df)
        df_indexed = model.transform(df)

        # Check that postalcode_indexed is added
        self.assertIn("postalcode_indexed", df_indexed.columns)

        # Check the indexed column is numeric
        types = dict(df_indexed.dtypes)
        self.assertEqual(types["postalcode_indexed"], "double")

if __name__ == "__main__":    unittest.main(argv=[''], exit=False)

unittest.main()
