In [18]:
from plantclef.spark import get_spark
from pathlib import Path
from pyspark.sql import functions as F

spark = get_spark(cores=4, app_name="masking-notebook")
root = Path("~/shared/plantclef/data").expanduser().as_posix()
masks = spark.read.parquet(f"{root}/masking/test_2024_v2/data")
masks.printSchema()

unpivoted = (
    masks.unpivot(
        "image_name", [c for c in masks.columns if "mask" in c], "mask_type", "mask"
    )
    .repartition(96)
    .cache()
)
unpivoted.printSchema()
unpivoted.show(5)

                                                                                

root
 |-- image_name: string (nullable = true)
 |-- leaf_mask: binary (nullable = true)
 |-- flower_mask: binary (nullable = true)
 |-- plant_mask: binary (nullable = true)
 |-- sand_mask: binary (nullable = true)
 |-- wood_mask: binary (nullable = true)
 |-- tape_mask: binary (nullable = true)
 |-- tree_mask: binary (nullable = true)
 |-- rock_mask: binary (nullable = true)
 |-- vegetation_mask: binary (nullable = true)
 |-- sample_id: integer (nullable = true)

root
 |-- image_name: string (nullable = true)
 |-- mask_type: string (nullable = false)
 |-- mask: binary (nullable = true)





+--------------------+----------+--------------------+
|          image_name| mask_type|                mask|
+--------------------+----------+--------------------+
|CBN-Pla-B2-201807...|plant_mask|[78 9C EC DD CF A...|
|CBN-PdlC-E2-20190...| wood_mask|[78 9C EC DD CB A...|
|OPTMix-083-P3-186...|plant_mask|[78 9C EC DD BD 8...|
|CBN-PdlC-C2-20190...| sand_mask|[78 9C EC DD CF A...|
|OPTMix-0598-P1-14...| rock_mask|[78 9C EC DD BD D...|
+--------------------+----------+--------------------+
only showing top 5 rows



                                                                                

In [27]:
from plantclef.serde import deserialize_mask


# what percentage of the masks are empty on average?
@F.udf("struct<total:int, covered:int>")
def mask_stats(mask: bytearray) -> tuple[int, int]:
    mask = deserialize_mask(mask)
    return {
        "total": int(mask.size),
        "covered": int(mask.sum()),
    }


mask_stats = unpivoted.withColumn("stats", mask_stats("mask")).drop("mask").cache()
mask_stats.printSchema()

root
 |-- image_name: string (nullable = true)
 |-- mask_type: string (nullable = false)
 |-- stats: struct (nullable = true)
 |    |-- total: integer (nullable = true)
 |    |-- covered: integer (nullable = true)



In [28]:
calculated = (
    mask_stats.groupBy("mask_type")
    .agg(
        F.sum("stats.total").alias("total"),
        F.sum("stats.covered").alias("covered"),
        F.round(F.expr("sum(stats.covered) / sum(stats.total)"), 3).alias(
            "pct_covered"
        ),
    )
    .orderBy("pct_covered", ascending=False)
)
calculated.printSchema()
calculated.show()

root
 |-- mask_type: string (nullable = false)
 |-- total: long (nullable = true)
 |-- covered: long (nullable = true)
 |-- pct_covered: double (nullable = true)





+---------------+-----------+-----------+-----------+
|      mask_type|      total|    covered|pct_covered|
+---------------+-----------+-----------+-----------+
|vegetation_mask|14093264657|11752390514|      0.834|
|     plant_mask|14093264657|11492809355|      0.815|
|      tree_mask|14093264657|10324181206|      0.733|
|      rock_mask|14093264657| 8640829625|      0.613|
|      sand_mask|14093264657| 8517531962|      0.604|
|      wood_mask|14093264657| 8273386945|      0.587|
|      tape_mask|14093264657| 6194593548|       0.44|
|    flower_mask|14093264657| 4347562371|      0.308|
|      leaf_mask|14093264657|  395652821|      0.028|
+---------------+-----------+-----------+-----------+



                                                                                