In [2]:
from plantclef.spark import get_spark
from pathlib import Path
from pyspark.sql import functions as F

spark = get_spark(cores=4, app_name="masking-notebook")
root = Path("~/shared/plantclef/data").expanduser().as_posix()
masks = spark.read.parquet(f"{root}/masking/test_2024_v2/data")
masks = (
    masks.unpivot(
        "image_name", [c for c in masks.columns if "mask" in c], "mask_type", "mask"
    )
    .repartition(96)
    .cache()
)
masks.printSchema()
masks.show(5)

root
 |-- image_name: string (nullable = true)
 |-- mask_type: string (nullable = false)
 |-- mask: binary (nullable = true)

+--------------------+----------+--------------------+
|          image_name| mask_type|                mask|
+--------------------+----------+--------------------+
|CBN-Pla-B2-201807...|plant_mask|[78 9C EC DD CF A...|
|CBN-PdlC-E2-20190...| wood_mask|[78 9C EC DD CB A...|
|OPTMix-083-P3-186...|plant_mask|[78 9C EC DD BD 8...|
|CBN-PdlC-C2-20190...| sand_mask|[78 9C EC DD CF A...|
|OPTMix-0598-P1-14...| rock_mask|[78 9C EC DD BD D...|
+--------------------+----------+--------------------+
only showing top 5 rows



In [None]:
import numpy as np
from functools import partial
from plantclef.serde import deserialize_mask


# grid each mask, and the calculate how many of those grids are not null.
def split_into_tiles(mask: np.ndarray, grid_size: int) -> np.ndarray:
    w, h = mask.shape
    grid_w, grid_h = w // grid_size, h // grid_size
    tiles = []
    for i in range(grid_size):
        for j in range(grid_size):
            left = i * grid_w
            upper = j * grid_h
            right = left + grid_w
            lower = upper + grid_h
            tiles.append(mask[left:right, upper:lower])
    return np.array(tiles)


def count_tiles(mask: bytearray, grid_size: int = 3, threshold: float = 0.5) -> int:
    mask = deserialize_mask(mask)
    tiles = split_into_tiles(mask, grid_size)
    # print(tiles.shape)
    means = np.mean(tiles.reshape(tiles.shape[0], -1), axis=1)
    # print(means)
    return int(np.sum(means > threshold))


grid_count = (
    masks.where(F.col("mask_type") == "plant_mask")
    .select(
        "image_name",
        *[
            F.udf(partial(count_tiles, grid_size=i), "int")("mask").alias(f"{i}x{i}")
            for i in range(3, 11)
        ],
    )
    .unpivot(
        "image_name", [f"{i}x{i}" for i in range(3, 11)], "grid_size", "tile_count"
    )
).cache()
grid_count.show()



+--------------------+---------+----------+
|          image_name|grid_size|tile_count|
+--------------------+---------+----------+
|CBN-Pla-B2-201807...|      3x3|         9|
|CBN-Pla-B2-201807...|      4x4|        11|
|CBN-Pla-B2-201807...|      5x5|        19|
|CBN-Pla-B2-201807...|      6x6|        29|
|CBN-Pla-B2-201807...|      7x7|        37|
|CBN-Pla-B2-201807...|      8x8|        49|
|CBN-Pla-B2-201807...|      9x9|        61|
|CBN-Pla-B2-201807...|    10x10|        77|
|OPTMix-083-P3-186...|      3x3|         9|
|OPTMix-083-P3-186...|      4x4|        16|
|OPTMix-083-P3-186...|      5x5|        25|
|OPTMix-083-P3-186...|      6x6|        36|
|OPTMix-083-P3-186...|      7x7|        49|
|OPTMix-083-P3-186...|      8x8|        64|
|OPTMix-083-P3-186...|      9x9|        81|
|OPTMix-083-P3-186...|    10x10|       100|
|RNNB-5-2-20240118...|      3x3|         9|
|RNNB-5-2-20240118...|      4x4|        16|
|RNNB-5-2-20240118...|      5x5|        23|
|RNNB-5-2-20240118...|      6x6|

                                                                                

In [None]:
scratch_root = Path("~/scratch/plantclef/data").expanduser().as_posix()
grid_count.write.parquet(f"{scratch_root}/masking/grid_count", mode="overwrite")

                                                                                

In [25]:
(
    grid_count.withColumn(
        "total_tiles", (F.split("grid_size", "x").getItem(0).cast("int")) ** 2
    )
    .groupBy("grid_size")
    .agg(
        F.sum("total_tiles").alias("total_tiles"),
        F.sum("tile_count").alias("covered_tiles"),
    )
    .withColumn(
        "pct_covered", F.round(F.col("covered_tiles") / F.col("total_tiles"), 3)
    )
    .orderBy("total_tiles")
).show()

+---------+-----------+-------------+-----------+
|grid_size|total_tiles|covered_tiles|pct_covered|
+---------+-----------+-------------+-----------+
|      3x3|    15255.0|        13855|      0.908|
|      4x4|    27120.0|        24336|      0.897|
|      5x5|    42375.0|        37667|      0.889|
|      6x6|    61020.0|        53798|      0.882|
|      7x7|    83055.0|        72887|      0.878|
|      8x8|   108480.0|        94489|      0.871|
|      9x9|   137295.0|       119041|      0.867|
|    10x10|   169500.0|       146303|      0.863|
+---------+-----------+-------------+-----------+

