In [1]:
from plantclef.spark import get_spark
from pathlib import Path
from pyspark.sql import functions as F

spark = get_spark(cores=4, app_name="masking-notebook")
root = Path("~/shared/plantclef/data").expanduser().as_posix()
masks = spark.read.parquet(f"{root}/masking/test_2024_v2/data")
masks = (
    masks.unpivot(
        "image_name", [c for c in masks.columns if "mask" in c], "mask_type", "mask"
    )
    .repartition(96)
    .cache()
)
masks.printSchema()
masks.show(5)

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/03/02 23:51:46 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/03/02 23:51:46 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).
25/03/02 23:51:47 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
                                                                                

root
 |-- image_name: string (nullable = true)
 |-- mask_type: string (nullable = false)
 |-- mask: binary (nullable = true)



                                                                                

+--------------------+----------+--------------------+
|          image_name| mask_type|                mask|
+--------------------+----------+--------------------+
|CBN-Pla-B2-201807...|plant_mask|[78 9C EC DD CF A...|
|CBN-PdlC-E2-20190...| wood_mask|[78 9C EC DD CB A...|
|OPTMix-083-P3-186...|plant_mask|[78 9C EC DD BD 8...|
|CBN-PdlC-C2-20190...| sand_mask|[78 9C EC DD CF A...|
|OPTMix-0598-P1-14...| rock_mask|[78 9C EC DD BD D...|
+--------------------+----------+--------------------+
only showing top 5 rows



In [2]:
masks.select("mask_type").distinct().show()

+---------------+
|      mask_type|
+---------------+
|      tree_mask|
|     plant_mask|
|    flower_mask|
|      sand_mask|
|      rock_mask|
|vegetation_mask|
|      tape_mask|
|      wood_mask|
|      leaf_mask|
+---------------+



In [7]:
import numpy as np
from plantclef.serde import deserialize_mask, serialize_mask


@F.udf("binary")
def merge_masks(masks: list[bytearray]) -> bytearray:
    masks = [deserialize_mask(m) for m in masks]
    merged = np.bitwise_or.reduce(masks)
    return serialize_mask(merged)


combined_mask = (
    masks.where(F.col("mask_type").isin(["plant_mask", "flower_mask", "leaf_mask"]))
    .groupBy("image_name")
    .agg(F.collect_list("mask").alias("masks"))
    .select("image_name", merge_masks(F.col("masks")).alias("mask"))
).cache()
combined_mask.printSchema()
combined_mask.show(5)

root
 |-- image_name: string (nullable = true)
 |-- mask: binary (nullable = true)





+--------------------+--------------------+
|          image_name|                mask|
+--------------------+--------------------+
|CBN-PdlC-E1-20150...|[78 9C EC DD BD A...|
|CBN-PdlC-F6-20130...|[78 9C EC D7 B1 A...|
|CBN-PdlC-E6-20180...|[78 9C EC DD BB 8...|
|RNNB-4-1-20240117...|[78 9C EC D6 CF A...|
|CBN-PdlC-D3-20200...|[78 9C EC DD CB C...|
+--------------------+--------------------+
only showing top 5 rows



                                                                                

In [9]:
# let's check out some statistics
@F.udf("struct<total:int, covered:int>")
def mask_stats(mask: bytearray) -> tuple[int, int]:
    mask = deserialize_mask(mask)
    return {
        "total": int(mask.size),
        "covered": int(mask.sum()),
    }


stats = (
    combined_mask.withColumn("stats", mask_stats(F.col("mask")))
    .select(
        F.sum("stats.total").alias("total"), F.sum("stats.covered").alias("covered")
    )
    .withColumn("coverage", F.col("covered") / F.col("total"))
).cache()

stats.printSchema()
stats.show()

root
 |-- total: long (nullable = true)
 |-- covered: long (nullable = true)
 |-- coverage: double (nullable = true)





+-----------+-----------+------------------+
|      total|    covered|          coverage|
+-----------+-----------+------------------+
|14093264657|11545247752|0.8192032174933703|
+-----------+-----------+------------------+



                                                                                