(this setup block is taken from SAM param tuning tutorial, some of it isn't needed but for simplicity while testing to not break things we're keeping it all)

## Environment Setup

If running locally using jupyter, first install `segment_anything` in your environment using the [installation instructions](https://github.com/facebookresearch/segment-anything#installation) in the repository. If running from Google Colab, set `using_colab=True` below and run the cell. In Colab, be sure to select 'GPU' under 'Edit'->'Notebook Settings'->'Hardware accelerator'.

In [None]:
using_colab = True

In [None]:
if using_colab:
    import torch
    import torchvision
    print("PyTorch version:", torch.__version__)
    print("Torchvision version:", torchvision.__version__)
    print("CUDA is available:", torch.cuda.is_available())
    import sys
    !{sys.executable} -m pip install opencv-python matplotlib
    !{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'
    !wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

mask_generator = SamAutomaticMaskGenerator(sam)

PyTorch version: 2.6.0+cu124
Torchvision version: 0.21.0+cu124
CUDA is available: True
Collecting git+https://github.com/facebookresearch/segment-anything.git
  Cloning https://github.com/facebookresearch/segment-anything.git to /tmp/pip-req-build-8lmtpvue
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/segment-anything.git /tmp/pip-req-build-8lmtpvue
  Resolved https://github.com/facebookresearch/segment-anything.git to commit dca509fe793f601edb92606367a655c15ac00fdf
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: segment_anything
  Building wheel for segment_anything (setup.py) ... [?25l[?25hdone
  Created wheel for segment_anything: filename=segment_anything-1.0-py3-none-any.whl size=36592 sha256=dda4df50248291d7276bcacd2e1cfe456acd233ab3b47273f6cd948f6e1ac8c0
  Stored in directory: /tmp/pip-ephem-wheel-cache-azscqm52/wheels/15/d7/bd/05f5f23b7dcbe70cbc6783b06f12143b0cf1a5da5c7b52dcc5
Successful

(End of tutorial copy)

# More Setup
Populate the folders created below with the respective images.

We assume
* Original raw images are in the form name.jpg
* Mask images are in the form name_gtFine_color.png
Where the two names match on a 1-to-1 basis. Any other files that do not fit into this structure will be ignored.
* All masks include leaf, stem, and root, but may or may not include seed

In [None]:
!mkdir batch_1
!mkdir batch_2
!mkdir batch_3

!mkdir batch_1_masks
!mkdir batch_2_masks
!mkdir batch_3_masks

In [None]:
import cv2
import numpy as np
import os
from glob import glob


# Discard color values on the edge, 5% and 20% are too low, we get a lot of background finding
PERCENT_MARGIN = 45


# Get lower and upper color bounds for the mask across all images in the training folders
def get_color_bounds(train_img_folder, train_mask_folder, mask_bgr):
  all_hsv_min = []
  all_hsv_max = []

  img_paths = sorted(glob(os.path.join(train_img_folder, "*.jpg")))
  mask_paths = sorted(glob(os.path.join(train_mask_folder, "*_gtFine_color.png")))

  # Check 1-to-1 correspondence
  if len(img_paths) != len(mask_paths):
      print(f"Warning: Mismatch between image ({len(img_paths)}) and mask ({len(mask_paths)}) count")

  for img_path in img_paths:
    filename = os.path.splitext(os.path.basename(img_path))[0]
    mask_path = os.path.join(train_mask_folder, f"{filename}_gtFine_color.png")

    if not os.path.exists(mask_path):
      print(f"Warning: No matching mask found for {filename}")
      continue

    image = cv2.imread(img_path)
    mask = cv2.imread(mask_path)

    # No need to convert to RGB — we work in BGR since the masks and target color are both BGR
    match_mask = cv2.inRange(mask, mask_bgr, mask_bgr)

    if np.count_nonzero(match_mask) == 0:
      print(f"Warning: No matching pixels for {filename}")
      continue

    hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
    hsv_pixels = hsv_image[match_mask > 0]

    # Get middle 90% percentiles for each channel independently
    h_vals = hsv_pixels[:, 0]
    s_vals = hsv_pixels[:, 1]
    v_vals = hsv_pixels[:, 2]

    h_lower = np.percentile(h_vals, PERCENT_MARGIN)
    h_upper = np.percentile(h_vals, 100-PERCENT_MARGIN)
    s_lower = np.percentile(s_vals, PERCENT_MARGIN)
    s_upper = np.percentile(s_vals, 100-PERCENT_MARGIN)
    v_lower = np.percentile(v_vals, PERCENT_MARGIN)
    v_upper = np.percentile(v_vals, 100-PERCENT_MARGIN)

    all_hsv_min.append([h_lower, s_lower, v_lower])
    all_hsv_max.append([h_upper, s_upper, v_upper])

  all_hsv_min = np.vstack(all_hsv_min)
  all_hsv_max = np.vstack(all_hsv_max)

  final_lower = np.min(all_hsv_min, axis=0)
  final_upper = np.max(all_hsv_max, axis=0)

  final_lower = np.round(final_lower).astype(np.uint8)
  final_upper = np.round(final_upper).astype(np.uint8)

  # TODO (FUTURE): handle hue wrapping when final_lower[0] > final_upper[0] across the hue circle (0-179 in OpenCV)
  # For example, if one mask gives H bounds [170, ...] to [179, ...] and another gives [0, ...] to [10, ...],
  # we may want to combine them as (170–179) OR (0–10), which requires a dual-range representation.

  return final_lower, final_upper

In [None]:
import cv2
import numpy as np


# Return the innermost point of the contour using distance transform
def get_contour_point(contour):
  # Create a blank mask the size of the contour's bounding box
  x, y, w, h = cv2.boundingRect(contour)
  mask = np.zeros((h, w), dtype=np.uint8)

  # Shift contour to local coordinates and draw it on the mask
  shifted_contour = contour - [x, y]
  cv2.drawContours(mask, [shifted_contour], -1, 255, thickness=cv2.FILLED)

  # Compute the distance transform
  dist_transform = cv2.distanceTransform(mask, distanceType=cv2.DIST_L2, maskSize=5)
  min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(dist_transform)

  # Return the point of maximum distance (shifted back to original image coordinates)
  innermost_point = [max_loc[0] + x, max_loc[1] + y]
  return innermost_point

In [None]:
import cv2
import numpy as np


# Return color_mask and contours in the image within the given HSV-format color bounds
def get_color_contours(image_rgb, lower_color_bound, upper_color_bound):
  hsv = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2HSV)

  color_mask = cv2.inRange(hsv, lower_color_bound, upper_color_bound)
  color_mask = cv2.morphologyEx(color_mask, cv2.MORPH_OPEN, np.ones((5, 5), np.uint8))

  contours, _ = cv2.findContours(color_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  return color_mask, contours

In [None]:
import cv2
import numpy as np


# Maximum distance of contour point from largest contour to avoid pruning
DISTANCE_THRESHOLD = 100


# Return the color mask and the contour points for the subset of contours that are close enough and large enough
def filter_contours(color_mask, contours, max_click_points, percent_click_points):
  if not contours:
    return color_mask, []

  # Compute areas and sort contours by area (descending)
  contours_with_areas = [(contour, cv2.contourArea(contour)) for contour in contours]
  contours_sorted = sorted(contours_with_areas, key=lambda x: x[1], reverse=True)

  # Determine how many contours to keep (ensuring at least 1 contour is kept if present)
  num_to_keep = max(1, min(int(len(contours_sorted) * percent_click_points), max_click_points))
  contours_to_check = [c[0] for c in contours_sorted[:num_to_keep]]

  # Use the largest contour as the reference for distance filtering
  largest_contour = contours_to_check[0]
  center_point = np.array(get_contour_point(largest_contour))

  # Prepare mask and result list
  color_mask_filtered = np.zeros_like(color_mask)
  contour_points = []

  for contour in contours_to_check:
    contour_point = np.array(get_contour_point(contour))
    distance = np.linalg.norm(contour_point - center_point)

    if distance <= DISTANCE_THRESHOLD:
      cv2.drawContours(color_mask_filtered, [contour], -1, 255, thickness=cv2.FILLED)
      contour_points.append(contour_point.tolist())

  return color_mask_filtered > 0, contour_points

In [None]:
from segment_anything import SamPredictor


# Click on the given points in the image and return the SAM mask
def get_sam_mask(image_rgb, click_points):
  predictor = SamPredictor(sam)
  predictor.set_image(image_rgb)

  # Run SAM using each click
  sam_masks = []
  for point in click_points:
    masks_i, _, _ = predictor.predict(
      point_coords=np.array([point]),
      point_labels=np.array([1]),
      multimask_output=False,
    )
    sam_masks.append(masks_i[0])

  # Combine all masks into one group
  sam_mask_union = np.any(np.stack(sam_masks), axis=0) if sam_masks else np.zeros(image_rgb.shape[:2], dtype=bool)
  return sam_mask_union

In [None]:
# Combine the color mask and sam mask
def get_combined_mask(color_mask, sam_mask):
  return np.logical_or(color_mask, sam_mask)

In [None]:
from skimage.measure import label, regionprops


# Prune the mask to only keep the largest segment
def get_pruned_mask(mask):
  label_mask = label(mask)
  regions = regionprops(label_mask)
  if not regions:
      return np.zeros_like(mask, dtype=bool)
  largest_region = max(regions, key=lambda r: r.area)
  return label_mask == largest_region.label

In [None]:
import matplotlib.pyplot as plt


# Visualize the results and intermediary steps
def visualize(image_rgb, color_mask, filtered_color_mask, sam_click_points, sam_mask, combined_mask, pruned_mask, final_mask):
  return # Colab cuts us off on big batch

  plt.figure(figsize=(30, 8))

  # Original Image
  plt.subplot(1, 8, 1)
  plt.imshow(image_rgb)
  plt.title("Original Image")

  # Color-based Mask
  plt.subplot(1, 8, 2)
  plt.imshow(image_rgb)
  plt.imshow(color_mask, alpha=0.5, cmap="Greens")
  plt.title("Color-based Mask")

  # Filtered Color Mask
  plt.subplot(1, 8, 3)
  plt.imshow(image_rgb)
  plt.imshow(filtered_color_mask, alpha=0.5, cmap="Greens")
  plt.title("Filtered Color Mask")

  # SAM Click Points
  plt.subplot(1, 8, 4)
  plt.imshow(image_rgb)
  for point in sam_click_points:
      plt.plot(point[0], point[1], 'ro')
  plt.title("SAM Click Points")

  # SAM-based Mask
  plt.subplot(1, 8, 5)
  plt.imshow(image_rgb)
  plt.imshow(sam_mask, alpha=0.5, cmap="Greens")
  plt.title("SAM-based Mask")

  # Combined Mask (pre-pruning)
  plt.subplot(1, 8, 6)
  plt.imshow(image_rgb)
  plt.imshow(combined_mask, alpha=0.5, cmap="Greens")
  plt.title(f"Combined Mask (pre-pruning)")

  # Pruned Mask
  plt.subplot(1, 8, 7)
  plt.imshow(image_rgb)
  plt.imshow(pruned_mask, alpha=0.5, cmap="Greens")
  plt.title(f"Pruned mased")

  # Final Mask
  plt.subplot(1, 8, 8)
  plt.imshow(image_rgb)
  plt.imshow(final_mask, alpha=0.5, cmap="Greens")
  plt.title(f"Final mask - post intersection removal if applicable")

  plt.tight_layout()
  plt.show()

In [None]:
import cv2
import numpy as np


# Find and visualize the steps toward finding the part for the given input params
def find_part(image_rgb, lower_color_bound, upper_color_bound, max_click_points, percent_click_points, mask_to_remove=None):
  color_mask, contours = get_color_contours(image_rgb, lower_color_bound, upper_color_bound)
  print(f"num contours = {len(contours)}")

  filtered_color_mask, sam_click_points = filter_contours(color_mask, contours, max_click_points, percent_click_points)
  print(f"num sam click points = {len(sam_click_points)}")

  sam_mask = get_sam_mask(image_rgb, sam_click_points)
  combined_mask = get_combined_mask(filtered_color_mask, sam_mask)

  # TODO: this is a bandaid for combo result enhancement
  # pruned_mask = get_pruned_mask(combined_mask)
  pruned_mask = combined_mask

  # remove using mask_to_remove (for leaf we remove the other parts)
  final_mask = pruned_mask
  if mask_to_remove is not None:
    # Create an inverse mask from the mask_to_remove (invert the 255/0 values)
    mask_to_remove_inv = cv2.bitwise_not(mask_to_remove)

    # Perform bitwise AND to keep parts of pruned_mask that don't intersect with mask_to_remove
    final_mask = cv2.bitwise_and(pruned_mask, mask_to_remove_inv)

  visualize(image_rgb, color_mask, filtered_color_mask, sam_click_points, sam_mask, combined_mask, pruned_mask, final_mask)

  return final_mask

In [None]:
import cv2
import numpy as np

# Get empty masks (placeholder for actual implementations allow testing framework)
def get_empty_masks(image_paths):
    print("WARNING: get_empty_masks was called meaning mask getting for this part is not implemented")

    empty_masks = []
    for path in image_paths:
        image = cv2.imread(path)
        height, width = image.shape[:2]
        empty_mask = np.zeros((height, width), dtype=np.uint8)
        empty_masks.append(empty_mask)

    return empty_masks

In [None]:
# Union the given masks or return None if there are none
def union_masks(masks):
  mask_to_remove = None
  for mask_path, _ in masks:
    mask_bgr = cv2.imread(mask_path, cv2.IMREAD_COLOR)
    if mask_bgr is None:
      print(f"Failed to load mask: {mask_path}")
      continue

    # Convert the mask to grayscale if it's RGB (assuming it could be colored)
    mask_gray = cv2.cvtColor(mask_bgr, cv2.COLOR_BGR2GRAY)

    # If it's the first mask, initialize mask_to_remove
    if mask_to_remove is None:
      mask_to_remove = mask_gray
    else:
      # Perform union: combine the masks (logical OR)
      mask_to_remove = cv2.bitwise_or(mask_to_remove, mask_gray)

  return mask_to_remove

In [None]:
# Don't resrict colors much
COMBO_CLICK_POINT_MAX = 10
COMBO_CLICK_POINT_PERCENT = 1


# Bespoke to combo strategy
def get_combo_masks(image_paths, lower_color_bound, upper_color_bound, args):
  masks = []

  for image_path in image_paths:
    print(image_path)

    image_rgb = cv2.imread(image_path, cv2.COLOR_BGR2RGB)
    if image_rgb is None:
      print(f"Failed to load image: {image_path}")
      continue

    mask = find_part(image_rgb, lower_color_bound, upper_color_bound, COMBO_CLICK_POINT_MAX, COMBO_CLICK_POINT_PERCENT)
    masks.append(mask)

  return masks

In [None]:
# Take top 50% largest leaf contours to a maximum of 5 (finding green on the leaf is splotchier)
LEAF_CLICK_POINT_MAX = 5
LEAF_CLICK_POINT_PERCENT = 0.5


# Get leaf masks given bounds
# Args include list of other masks by filename to remove from the leaf mask
def get_leaf_masks(image_paths, lower_color_bound, upper_color_bound, args):
  masks = []

  for image_path in image_paths:
    print(image_path)

    image_rgb = cv2.imread(image_path, cv2.COLOR_BGR2RGB)
    if image_rgb is None:
      print(f"Failed to load image: {image_path}")
      continue

    # filename = os.path.splitext(os.path.basename(image_path))[0]
    # matching_masks = [mask_dict[filename] for mask_dict in args if filename in mask_dict]
    # if not matching_masks:
    #   print(f"Warning: No masks found for {filename}")
    # else:
    #   # We expect 2-3 (might be missing seed, should not be missing anything else)
    #   print(f"Found {len(matching_masks)} matching masks for {filename}")

    # masks_to_remove = union_masks(matching_masks)

    mask = find_part(image_rgb, lower_color_bound, upper_color_bound, LEAF_CLICK_POINT_MAX, LEAF_CLICK_POINT_PERCENT)
    masks.append(mask)

  return masks

In [None]:
# Take the largest 1 seed contour (we only have one seed)
SEED_CLICK_POINT_MAX = 1
SEED_CLICK_POINT_PERCENT = 1


# Get seed masks given bounds
def get_seed_masks(image_paths, lower_color_bound, upper_color_bound, args):
  masks = []

  for image_path in image_paths:
    print(image_path)

    image_rgb = cv2.imread(image_path, cv2.COLOR_BGR2RGB)
    if image_rgb is None:
      print(f"Failed to load image: {image_path}")
      continue

    mask = find_part(image_rgb, lower_color_bound, upper_color_bound, SEED_CLICK_POINT_MAX, SEED_CLICK_POINT_PERCENT)
    masks.append(mask)

  return masks

In [None]:
import cv2


def get_sam_mask_union(image_rgb, click_points):
  predictor = SamPredictor(sam)
  predictor.set_image(image_rgb)

  # Run SAM using each click
  sam_masks = []
  for point in click_points:
      masks_i, _, _ = predictor.predict(
          point_coords=np.array([point]),
          point_labels=np.array([1]),
          multimask_output=False,
      )
      sam_masks.append(masks_i[0])

  # Combine all masks into one "leaf" group
  sam_mask_union = np.any(np.stack(sam_masks), axis=0) if sam_masks else np.zeros(image_rgb.shape[:2], dtype=bool)
  return sam_mask_union


# Remove any found background artifacts included
def clean_mask_largest_region(mask):
    label_mask = label(mask)
    regions = regionprops(label_mask)
    if not regions:
        return np.zeros_like(mask, dtype=bool)
    largest_region = max(regions, key=lambda r: r.area)
    return label_mask == largest_region.label


def get_white_mask(image_bgr, lower_color_bound, upper_color_bound):
    # Convert to grayscale for brightness
    gray = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2GRAY)
    # Simple threshold to capture the brightest parts (root)
    hsv = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2HSV)
    white_mask = cv2.inRange(hsv, lower_color_bound, upper_color_bound)
    # Debug: Show the mask
    plt.imshow(white_mask, cmap='gray')
    plt.title("Raw White Mask")
    plt.show()
    return white_mask > 0


# Visualize results
def _visualize(image_rgb, green_mask_bool, sam_mask_union, raw_combined_mask, refined_mask_clean, chosen_strategy):
  plt.figure(figsize=(22, 6))

  plt.subplot(1, 5, 1)
  plt.imshow(image_rgb)
  plt.title("Original Image")

  plt.subplot(1, 5, 2)
  plt.imshow(image_rgb)
  plt.imshow(green_mask_bool, alpha=0.5, cmap="Greens")
  plt.title("Color-based Mask")

  plt.subplot(1, 5, 3)
  plt.imshow(image_rgb)
  plt.imshow(sam_mask_union, alpha=0.5, cmap="Reds")
  plt.title("SAM-based Mask")

  plt.subplot(1, 5, 4)
  plt.imshow(image_rgb)
  plt.imshow(raw_combined_mask, alpha=0.5, cmap="Oranges")
  plt.title(f"Combined Mask (pre-cleaning) [{chosen_strategy}]")

  plt.subplot(1, 5, 5)
  plt.imshow(image_rgb)
  plt.imshow(refined_mask_clean, alpha=0.5, cmap="Purples")
  plt.title(f"Final Mask (post-cleaning) [{chosen_strategy}]")

  plt.tight_layout()
  plt.show()


def get_root_click(white_mask):
    contours, _ = cv2.findContours(white_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    print(f"Found {len(contours)} contours")
    if not contours:
        print("No root found.")
        return np.array([])
    # Filter for tall, relatively thin contours
    image_height = white_mask.shape[0]
    min_height = image_height // 3  # At least 1/3 of image height
    valid_contours = []
    for c in contours:
        x, y, w, h = cv2.boundingRect(c)
        # Tall and relatively thin
        if h > min_height and w < 100:  # Loose width constraint
            valid_contours.append(c)
            print(f"Contour: x={x}, y={y}, w={w}, h={h}")
    if not valid_contours:
        print("No tall, thin root contours found.")
        return np.array([])
    # Pick tallest contour
    root_contour = max(valid_contours, key=lambda c: cv2.boundingRect(c)[3])
    x, y, w, h = cv2.boundingRect(root_contour)
    # Find where it widens (top of root, at leaves/seed)
    contour_points = root_contour.squeeze()
    widths = []
    for i in range(y, y + h, 5):
        row = white_mask[i:i+1, x:x+w]
        width = np.sum(row) / 255
        widths.append((i, width))
    # Detect widening
    min_width = min(w for _, w in widths if w > 0)
    top_y = y
    for i, width in widths:
        if width > min_width * 2:
            top_y = i
            break
    # Click at midpoint of root below widening
    root_click_y = (top_y + (y + h)) // 2
    root_click_x = x + w // 2
    root_click = np.array([root_click_x, root_click_y])
    print(f"Root centroid: {root_click}, Top Y: {top_y}")
    # Clip mask to below top_y (exclude leaves/seed)
    white_mask[:top_y, :] = 0
    return np.array([root_click])

def run_root(image_path, lower_color_bound, upper_color_bound):
    image = cv2.imread(image_path)
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    white_mask_bool = get_white_mask(image, lower_color_bound, upper_color_bound)
    root_click_point = get_root_click(white_mask_bool)
    click_points = root_click_point.tolist() if root_click_point.size else []
    sam_mask_union = get_sam_mask_union(image_rgb, click_points)
    chosen_strategy = "both"
    raw_combined_mask = get_combined_mask(white_mask_bool, sam_mask_union)
    refined_mask_clean = clean_mask_largest_region(raw_combined_mask)
    # Pass original image as placeholder for image_after_line_removal
    _visualize(image_rgb, white_mask_bool, sam_mask_union, raw_combined_mask, refined_mask_clean, chosen_strategy)

    return refined_mask_clean

# Get root masks given bounds
def get_root_masks(image_paths, lower_color_bound, upper_color_bound, args):
  masks = []

  for image_path in image_paths:
    print(image_path)

    image_rgb = cv2.imread(image_path, cv2.COLOR_BGR2RGB)
    if image_rgb is None:
      print(f"Failed to load image: {image_path}")
      continue

    mask = run_root(image_path, lower_color_bound, upper_color_bound)
    masks.append(mask)

  return masks

In [None]:
# Get stem masks given bounds
def get_stem_masks(image_paths, lower_color_bound, upper_color_bound, args):
  # TODO (FUTURE): Stem masks are not yet implemented, returning empty masks for now
  return get_empty_masks(image_paths)

In [None]:
import numpy as np
import cv2

# Function to compute IoU between two binary masks
def compute_iou(mask1, mask2):
  # Calculate intersection and union of the two masks
  intersection = np.sum(np.logical_and(mask1, mask2))
  union = np.sum(np.logical_or(mask1, mask2))

  if union == 0:
    return 0  # Avoid division by zero if the union is zero

  return intersection / union

# Evaluate function returning best, worst, q1, q2, q3, mean, std_dev
def evaluate(real_masks, part_masks):
  ious = []

  # Loop through each pair of real mask and leaf mask
  for real_mask, leaf_mask in zip(real_masks, part_masks):
    # Ensure the masks are binary
    real_mask_bin = (real_mask > 0).astype(np.uint8)
    leaf_mask_bin = (leaf_mask > 0).astype(np.uint8)

    skipped = 0

    # Compute IoU for each pair of masks
    try:
      # We crash in here sometimes, no time to fix it proper, just print a warning and skip
      iou = compute_iou(real_mask_bin, leaf_mask_bin)
      ious.append(iou)
    except Exception as e:
      skipped += 1
      print(f"Evaluation error: {e}")

  print(f"WARNING: SKIPPED = {skipped}")

  # Convert IoU list to a numpy array for statistical calculations
  ious = np.array(ious)

  # Calculate best, worst, q1, q2 (median), q3, mean, median, and standard deviation
  best = np.max(ious)
  worst = np.min(ious)
  q1 = np.percentile(ious, 25)
  q2 = np.percentile(ious, 50)  # This is the median
  q3 = np.percentile(ious, 75)
  mean = np.mean(ious)
  std_dev = np.std(ious)

  # Return the statistics
  return best, worst, q1, q2, q3, mean, std_dev

In [None]:
import os
import cv2
import shutil
import numpy as np
import shutil
import random
from glob import glob


# Number of image-style batches
NUM_BATCHES = 3

# We use 20% of our images for color-bound training and the remaining 80% for testing
TRAINING_PERCENT = 0.2
TESTING_PERCENT = 0.8

# Provided mask definitions
BACKGROUND = np.array([0, 0, 0])
LEAF = np.array([142, 146, 189])
SEED = np.array([202, 240, 25])
ROOT = np.array([243, 131, 142])
STEM = np.array([181, 104, 222])


# Convert RGB to BGR for OpenCV
def rgb_to_bgr(rgb):
    return np.array([rgb[2], rgb[1], rgb[0]])


LEAF_BGR = rgb_to_bgr(LEAF)
SEED_BGR = rgb_to_bgr(SEED)
ROOT_BGR = rgb_to_bgr(ROOT)
STEM_BGR = rgb_to_bgr(STEM)

# Copy out seeded images (for seed testing)
for i in range(1, NUM_BATCHES+1):
  img_folder = f"batch_{i}"
  mask_folder = f"batch_{i}_masks"
  seeded_img_folder = f"batch_{i}_seeded"
  seeded_mask_folder = f"batch_{i}_seeded_masks"

  os.makedirs(seeded_img_folder, exist_ok=True)
  os.makedirs(seeded_mask_folder, exist_ok=True)

  image_paths = sorted(glob(f"{img_folder}/*.jpg"))

  for image_path in image_paths:
    filename = os.path.basename(image_path).split('.jpg')[0]
    mask_filename = f"{filename}_gtFine_color.png"
    mask_path = os.path.join(mask_folder, mask_filename)

    if not os.path.exists(mask_path):
      print(f"Mask not found for {image_path}, expected to find {mask_path}")
      continue

    # Read the mask in color (OpenCV loads as BGR)
    mask = cv2.imread(mask_path)

    # Check if SEED color is present in the mask
    if np.any(np.all(mask == SEED_BGR, axis=-1)):
      # Copy image and mask to new folders
      shutil.copy(image_path, os.path.join(seeded_img_folder, os.path.basename(image_path)))
      shutil.copy(mask_path, os.path.join(seeded_mask_folder, os.path.basename(mask_path)))
      print(f"Copied seeded image and mask: {filename}")


# Separate out training samples from the seeded data for color finding
for i in range(1, NUM_BATCHES+1):
  img_folder = f"batch_{i}_seeded"
  mask_folder = f"batch_{i}_seeded_masks"

  # Create training folders if they don't exist
  training_folder = f"batch_{i}_training"
  training_folder_masks = f"batch_{i}_training_masks"
  os.makedirs(training_folder, exist_ok=True)
  os.makedirs(training_folder_masks, exist_ok=True)

  # Find all raw image files (JPG) and infer their names
  image_files = [f for f in os.listdir(img_folder) if f.endswith(".jpg")]
  filenames = [os.path.splitext(f)[0] for f in image_files]

  # Shuffle and split
  # We take training % from seeded images, which is ≤ training % of total images
  random.shuffle(filenames)
  num_training = int(len(filenames) * TRAINING_PERCENT)
  training_filenames = filenames[:num_training]

  for name in training_filenames:
    img_src = os.path.join(img_folder, f"{name}.jpg")
    img_dst = os.path.join(training_folder, f"{name}.jpg")

    mask_src = os.path.join(mask_folder, f"{name}_gtFine_color.png")
    mask_dst = os.path.join(training_folder_masks, f"{name}_gtFine_color.png")

    if os.path.exists(img_src) and os.path.exists(mask_src):
      shutil.move(img_src, img_dst)
      shutil.move(mask_src, mask_dst)
    else:
      print(f"Skipping {name} (missing image or mask)")


# TODO (FUTURE): if we're doing line-removal we should do that here on testing images to only do it once per each
#  ^ try border noise flood fill
#  ^ try anti prompt SAM? Probably don't have time for that right now

Copied seeded image and mask: Rep1_0%Sucrose_Col-0_1
Copied seeded image and mask: Rep1_0%Sucrose_Col-0_10
Copied seeded image and mask: Rep1_0%Sucrose_Col-0_11
Copied seeded image and mask: Rep1_0%Sucrose_Col-0_12
Copied seeded image and mask: Rep1_0%Sucrose_Col-0_13
Copied seeded image and mask: Rep1_0%Sucrose_Col-0_2
Copied seeded image and mask: Rep1_0%Sucrose_Col-0_3
Copied seeded image and mask: Rep1_0%Sucrose_Col-0_4
Copied seeded image and mask: Rep1_0%Sucrose_Col-0_5
Copied seeded image and mask: Rep1_0%Sucrose_Col-0_6
Copied seeded image and mask: Rep1_0%Sucrose_Col-0_7
Copied seeded image and mask: Rep1_0%Sucrose_Col-0_8
Copied seeded image and mask: Rep1_0%Sucrose_Col-0_9
Copied seeded image and mask: Rep1_0%Sucrose_gaut10-3+gaut11-3_1
Copied seeded image and mask: Rep1_0%Sucrose_gaut10-3+gaut11-3_10
Copied seeded image and mask: Rep1_0%Sucrose_gaut10-3+gaut11-3_11
Copied seeded image and mask: Rep1_0%Sucrose_gaut10-3+gaut11-3_12
Copied seeded image and mask: Rep1_0%Sucrose

In [None]:
import os
import cv2
import numpy as np


# Deterministically get ordered file paths from folder for comparison matching
def get_file_paths(folder):
  return [
    os.path.join(folder, f)
    for f in sorted(os.listdir(folder))
    if os.path.isfile(os.path.join(folder, f))
  ]


# Get the real masks given the paths for them
# color_bgrs usually consist of just one color bgr, but for the bespoke leaf/stem case for eval there's 2
def get_real_masks(image_paths, color_bgrs):
  masks = []

  # Convert all BGRs to RGB since the images are in RGB color space
  color_rgbs = [np.array(tuple(reversed(bgr))) for bgr in color_bgrs]  # List of (R, G, B)

  count = 0
  for path in image_paths:
    mask_img = cv2.imread(path, cv2.IMREAD_COLOR)  # BGR
    mask_rgb = cv2.cvtColor(mask_img, cv2.COLOR_BGR2RGB)

    # Initialize match mask
    combined_mask = np.zeros(mask_rgb.shape[:2], dtype=np.uint8)

    # Check against each color
    for color_rgb in color_rgbs:
      match = cv2.inRange(mask_rgb, color_rgb, color_rgb)
      combined_mask = cv2.bitwise_or(combined_mask, match)

    masks.append(combined_mask > 0)

  return masks


# Crash protection
from collections import defaultdict
part_masks_by_batch_by_part = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
real_masks_by_batch_by_part = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))


# Get masks and print results
# folder_qualifier should be "" is it is not present, otherwise "_<qualifier>"
def get_results(plant_part, folder_qualifier, color_bgr, eval_bgrs, get_part_masks, args=[]):
  print(f"Plant part: {plant_part}")

  for i in range(1, NUM_BATCHES+1):
    print(f"Batch_{i}")

    train_img_folder = f"batch_{i}_training"
    train_mask_folder = f"batch_{i}_training_masks"
    testing_img_folder = f"batch_{i}{folder_qualifier}"
    testing_mask_folder = f"batch_{i}{folder_qualifier}_masks"

    num_training= len([f for f in os.listdir(train_img_folder) if os.path.isfile(os.path.join(train_img_folder, f))])
    print(f"Number training batch_{i}: {num_training}")
    num_testing= len([f for f in os.listdir(testing_img_folder) if os.path.isfile(os.path.join(testing_img_folder, f))])
    print(f"Number testing batch_{i}: {num_testing}")

    # Get color bounds for our plant_part
    lower_color_bound, upper_color_bound = get_color_bounds(train_img_folder, train_mask_folder, color_bgr)
    print(f"Lower color bound: {lower_color_bound}")
    print(f"Upper color bound: {upper_color_bound}")

    # # TODO: delete this section
    # print("INFO: ending early for manual color bound testing")
    # return

    # Get paths in deterministic order for comparison matching
    test_image_paths = get_file_paths(testing_img_folder)
    mask_image_paths = get_file_paths(testing_mask_folder)

    part_masks = get_part_masks(test_image_paths, lower_color_bound, upper_color_bound, args)
    real_masks = get_real_masks(mask_image_paths, eval_bgrs)

    # Crash protection
    part_masks_by_batch_by_part[plant_part][i] = part_masks
    real_masks_by_batch_by_part[plant_part][i] = real_masks

    best, worst, q1, q2, q3, mean, std_dev = evaluate(real_masks, part_masks)
    print(f"Mean IoU: {mean:.4f}")
    print(f"Standard Deviation: {std_dev:.4f}")
    print(f"Best IoU: {best:.4f}")
    print(f"Worst IoU: {worst:.4f}")
    print(f"25th percentile (Q1): {q1:.4f}")
    print(f"Median (Q2): {q2:.4f}")
    print(f"75th percentile (Q3): {q3:.4f}")

    # if extra_eval_bgr is not None:
    #   # We want an additional evaluation of double colors
    #   print("Extra eval (bespoke to leaf/stem combo)")

    #   real_masks = get_real_masks(mask_image_paths, [color_bgr, extra_eval_bgr])

    #   best, worst, q1, q2, q3, mean, std_dev = evaluate(real_masks, part_masks)
    #   print(f"Mean IoU: {mean:.4f}")
    #   print(f"Standard Deviation: {std_dev:.4f}")
    #   print(f"Best IoU: {best:.4f}")
    #   print(f"Worst IoU: {worst:.4f}")
    #   print(f"25th percentile (Q1): {q1:.4f}")
    #   print(f"Median (Q2): {q2:.4f}")
    #   print(f"75th percentile (Q3): {q3:.4f}")

    mask_by_filename = {test_image_paths[i]: part_masks[i] for i in range(len(test_image_paths))}
    return mask_by_filename

In [None]:
# Get combo masks (leaf + stem + seed) - independently they don't do well, let's try them together
# Use SEED_BGR, that works well, and then just "_seeded" images because we use the seeded color
# combo_masks_by_filename = get_results("combo", "_seeded", SEED_BGR, [SEED_BGR, LEAF_BGR, STEM_BGR], get_combo_masks)

In [None]:
# Get seed results
seed_masks_by_filename = get_results("seed", "_seeded", SEED_BGR, [SEED_BGR], get_seed_masks)

Plant part: seed
Batch_1
Number training batch_1: 73
Number testing batch_1: 292
Lower color bound: [ 27  10 120]
Upper color bound: [125  67 158]
batch_1_seeded/Rep1_0%Sucrose_Col-0_1.jpg
num contours = 0
num sam click points = 0
batch_1_seeded/Rep1_0%Sucrose_Col-0_10.jpg
num contours = 5
num sam click points = 1
batch_1_seeded/Rep1_0%Sucrose_Col-0_12.jpg
num contours = 1
num sam click points = 1
batch_1_seeded/Rep1_0%Sucrose_Col-0_13.jpg
num contours = 2
num sam click points = 1
batch_1_seeded/Rep1_0%Sucrose_Col-0_2.jpg
num contours = 6
num sam click points = 1
batch_1_seeded/Rep1_0%Sucrose_Col-0_3.jpg
num contours = 0
num sam click points = 0
batch_1_seeded/Rep1_0%Sucrose_Col-0_4.jpg
num contours = 1
num sam click points = 1
batch_1_seeded/Rep1_0%Sucrose_Col-0_5.jpg
num contours = 2
num sam click points = 1
batch_1_seeded/Rep1_0%Sucrose_Col-0_6.jpg
num contours = 4
num sam click points = 1
batch_1_seeded/Rep1_0%Sucrose_Col-0_8.jpg
num contours = 3
num sam click points = 1
batch_1_se

In [None]:
# Get root results
# root_masks_by_filename = get_results("root", "", ROOT_BGR, [ROOT_BGR], get_root_masks)

In [None]:
# Get stem results
# stem_masks_by_filename = get_results("stem", "", STEM_BGR, [STEM_BGR], get_stem_masks)

In [None]:
# IMPORTANT: only run after running the other 3 because we use them for leaf

# Get leaf results
# We often find other plant parts here so we add them as args to remove from leaf masks
# STEM_BGR is the extra eval BGR for a second evaluation
# leaf_masks = get_results("leaf", "", LEAF_BGR, [LEAF_BGR, STEM_BGR], get_leaf_masks, args=[seed_masks_by_filename, root_masks_by_filename, steam_masks_by_filename])
leaf_masks = get_results("leaf", "", LEAF_BGR, [LEAF_BGR], get_leaf_masks)

Plant part: leaf
Batch_1
Number training batch_1: 73
Number testing batch_1: 367
Lower color bound: [ 13  11 130]
Upper color bound: [ 43  71 178]
batch_1/Rep1_0%Sucrose_Col-0_1.jpg
num contours = 1
num sam click points = 1
batch_1/Rep1_0%Sucrose_Col-0_10.jpg
num contours = 2
num sam click points = 1
batch_1/Rep1_0%Sucrose_Col-0_11.jpg
num contours = 0
num sam click points = 0
batch_1/Rep1_0%Sucrose_Col-0_12.jpg
num contours = 0
num sam click points = 0
batch_1/Rep1_0%Sucrose_Col-0_13.jpg
num contours = 1
num sam click points = 1
batch_1/Rep1_0%Sucrose_Col-0_2.jpg
num contours = 0
num sam click points = 0
batch_1/Rep1_0%Sucrose_Col-0_3.jpg
num contours = 0
num sam click points = 0
batch_1/Rep1_0%Sucrose_Col-0_4.jpg
num contours = 0
num sam click points = 0
batch_1/Rep1_0%Sucrose_Col-0_5.jpg
num contours = 0
num sam click points = 0
batch_1/Rep1_0%Sucrose_Col-0_6.jpg
num contours = 0
num sam click points = 0
batch_1/Rep1_0%Sucrose_Col-0_7.jpg
num contours = 0
num sam click points = 0
ba

In [None]:
# TODO: visualize the real mask next to the mask we got and the statistics for an invidual image too