In [0]:
import unittest
from pyspark.sql.functions import col

class TestSupermarketsNotebook(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        parquet_path = "/mnt/silver/supermarkets/"
        cls.df = spark.read.parquet(parquet_path)

    def test_column_renames_and_casts(self):
        df = self.df

        df = df.withColumnRenamed("supermarketNo", "supermarket")
        df = df.withColumn("postalcode", col("postal_code").cast("string"))
        # Explicitly remove old columns that are no longer needed after rename/cast
        df = df.drop("supermarketNo", "postal_code")

        # Check that the new column names exist in the DataFrame
        self.assertTrue("supermarket" in df.columns)
        self.assertTrue("postalcode" in df.columns)
        # Check that the old column names no longer exist
        self.assertFalse("supermarketNo" in df.columns)
        self.assertFalse("postal_code" in df.columns)

        # Retrieve all values from 'postalcode' as a list
        postalcodes = df.select("postalcode").rdd.flatMap(lambda x: x).collect()

    def test_postalcode_string_indexing(self):
        from pyspark.ml.feature import StringIndexer

        df = self.df
        # Consistent renaming and string casting as above
        df = df.withColumnRenamed("supermarketNo", "supermarket")
        df = df.withColumn("postalcode", col("postal_code").cast("string"))

        # Create a StringIndexer to encode 'postalcode' strings into numeric indices
        indexer = StringIndexer(
            inputCol="postalcode",
            outputCol="postalcode_indexed",
            handleInvalid="keep"
        )
        # Fit the indexer on the DataFrame and apply the transformation
        model = indexer.fit(df)
        df_indexed = model.transform(df)

        # Verify that the new indexed column exists
        self.assertIn("postalcode_indexed", df_indexed.columns)
        # Check the data type of the indexed column is 'double' (the default output type)
        types = dict(df_indexed.dtypes)
        self.assertEqual(types["postalcode_indexed"], "double")

if __name__ == "__main__":
    # Required for running unittest in interactive and notebook environments
    unittest.main(argv=['first-arg-is-ignored'], exit=False)
