In [0]:
import unittest
from pyspark.sql.functions import col

class TestItemsNotebook(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        # Load DataFrame from Parquet file for all tests
        parquet_path = "/mnt/gold/gold/"
        cls.df = spark.read.parquet(parquet_path)

    def test_brand_and_type_indexed(self):
        df = self.df
        brands = df.select("brand_indexed").rdd.flatMap(lambda x: x).collect()
        types = df.select("type_indexed").rdd.flatMap(lambda x: x).collect()

        # Check brand_indexed is all float
        non_float_brands = [b for b in brands if not isinstance(b, float)]
        if non_float_brands:
            print("Non-float entries in brand_indexed:", non_float_brands)
        self.assertTrue(len(non_float_brands) == 0, "Not all brand_indexed values are float")

        # Check type_indexed is all float
        non_float_types = [t for t in types if not isinstance(t, float)]
        if non_float_types:
            print("Non-float entries in type_indexed:", non_float_types)
        self.assertTrue(len(non_float_types) == 0, "Not all type_indexed values are float")

    def test_size_value_positive(self):
        df = self.df
        size_values = df.select("size_value").rdd.flatMap(lambda x: x).collect()

        # Check all are convertible to float
        def is_convertible_positive(val):
            try:
                return float(val) > 0
            except Exception:
                return False

        non_numeric = [val for val in size_values if not is_convertible_positive(val) and val is not None]
        if non_numeric:
            print("Non-numeric or non-positive size_value entries:", non_numeric)

        # Only keep truly numeric entries for the next check
        numeric_values = []
        for val in size_values:
            try:
                f = float(val)
                if f > 0:
                    numeric_values.append(f)
            except Exception:
                continue

        self.assertTrue(len(non_numeric) == 0, "size_value contains non-convertible or non-positive entries")

    def test_size_unit_consistency(self):
        # Define valid unit encodings, update if your mapping changes
        valid_units = {0, 1, 2, 3}
        df = self.df
        units = df.select("size_unit_encoded").rdd.flatMap(lambda x: x).collect()

        invalid_units = [unit for unit in units if unit not in valid_units]
        if invalid_units:
            print("Invalid encoded units found in size_unit_encoded:", set(invalid_units))
        self.assertTrue(len(invalid_units) == 0, f"Invalid unit codes present: {set(invalid_units)}")

if __name__ == "__main__":
    unittest.main(argv=['first-arg-is-ignored'], exit=False)
