In [2]:
from plantclef.spark import get_spark
from pathlib import Path
from pyspark.sql import functions as F

spark = get_spark(cores=4, app_name="masking-notebook")
root = Path("~/shared/plantclef/data").expanduser().as_posix()
masks = spark.read.parquet(f"{root}/masking/test_2024_v2/data")
masks = (
    masks.unpivot(
        "image_name", [c for c in masks.columns if "mask" in c], "mask_type", "mask"
    )
    .repartition(96)
    .cache()
)
masks.printSchema()
masks.show(5)

root
 |-- image_name: string (nullable = true)
 |-- mask_type: string (nullable = false)
 |-- mask: binary (nullable = true)

+--------------------+----------+--------------------+
|          image_name| mask_type|                mask|
+--------------------+----------+--------------------+
|CBN-Pla-B2-201807...|plant_mask|[78 9C EC DD CF A...|
|CBN-PdlC-E2-20190...| wood_mask|[78 9C EC DD CB A...|
|OPTMix-083-P3-186...|plant_mask|[78 9C EC DD BD 8...|
|CBN-PdlC-C2-20190...| sand_mask|[78 9C EC DD CF A...|
|OPTMix-0598-P1-14...| rock_mask|[78 9C EC DD BD D...|
+--------------------+----------+--------------------+
only showing top 5 rows



In [4]:
grid_emb = spark.read.parquet(f"{root}/embeddings/test/test_2024/grid=4x4")
grid_emb.printSchema()
grid_emb.show(5)

root
 |-- image_name: string (nullable = true)
 |-- tile: integer (nullable = true)
 |-- cls_embedding: array (nullable = true)
 |    |-- element: float (containsNull = true)
 |-- sample_id: integer (nullable = true)

+--------------------+----+--------------------+---------+
|          image_name|tile|       cls_embedding|sample_id|
+--------------------+----+--------------------+---------+
|CBN-Pla-D5-201607...|  11|[0.079906486, 1.3...|        0|
|CBN-PdlC-F2-20160...|   8|[2.3675528, 1.911...|        0|
|CBN-PdlC-C4-20190...|  10|[0.2052727, 1.130...|        0|
|CBN-PdlC-C3-20190...|   7|[0.6042919, 1.624...|        0|
|CBN-Pla-D3-201508...|   4|[-0.2694691, 1.74...|        0|
+--------------------+----+--------------------+---------+
only showing top 5 rows



In [10]:
import numpy as np
from functools import partial
from plantclef.serde import deserialize_mask, serialize_mask


# grid each mask, and the calculate how many of those grids are not null.
def split_into_tiles(mask: np.ndarray, grid_size: int) -> np.ndarray:
    w, h = mask.shape
    grid_w, grid_h = w // grid_size, h // grid_size
    tiles = []
    for i in range(grid_size):
        for j in range(grid_size):
            left = i * grid_w
            upper = j * grid_h
            right = left + grid_w
            lower = upper + grid_h
            tiles.append(mask[left:right, upper:lower])
    return np.array(tiles)


def tile_mask_percentage(mask: bytearray, grid_size: int = 3) -> list[int]:
    mask = deserialize_mask(mask)
    tiles = split_into_tiles(mask, grid_size)
    # print(tiles.shape)
    means = np.mean(tiles.reshape(tiles.shape[0], -1), axis=1)
    # print(means)
    return means.tolist()


@F.udf("binary")
def merge_masks(masks: list[bytearray]) -> bytearray:
    masks = [deserialize_mask(m) for m in masks]
    merged = np.bitwise_or.reduce(masks)
    return serialize_mask(merged)


tile_mask_info = (
    masks
    # first generate a combined mask
    .where(F.col("mask_type").isin(["plant_mask", "flower_mask", "leaf_mask"]))
    .groupBy("image_name")
    .agg(F.collect_list("mask").alias("masks"))
    .select("image_name", merge_masks(F.col("masks")).alias("mask"))
    # then calculate the tile mask percentage for the particular grid
    .select(
        "image_name",
        F.posexplode(
            F.udf(
                partial(tile_mask_percentage, grid_size=4),
                returnType="array<float>",
            )(F.col("mask"))
        ).alias("tile", "pct_covered"),
    )
).cache()

tile_mask_info.printSchema()
tile_mask_info.show(5)

root
 |-- image_name: string (nullable = true)
 |-- tile: integer (nullable = false)
 |-- pct_covered: float (nullable = true)





+--------------------+----+-----------+
|          image_name|tile|pct_covered|
+--------------------+----+-----------+
|CBN-PdlC-E1-20150...|   0|  0.9646807|
|CBN-PdlC-E1-20150...|   1|  0.8233373|
|CBN-PdlC-E1-20150...|   2| 0.53400785|
|CBN-PdlC-E1-20150...|   3|  0.5673638|
|CBN-PdlC-E1-20150...|   4| 0.80243677|
+--------------------+----+-----------+
only showing top 5 rows



                                                                                

In [11]:
grid_emb.count(), tile_mask_info.count()

(27120, 27120)