First, import and load the image.

In [None]:
# Imports
from PIL import Image
import colorsys
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colormaps as cm

# Import the image
img = Image.open("image324x324inside.jpg")
# img = Image.open("image324x324outside.jpg")


# Define constants
MAGIC_HUE = 60                       # Hue for green plant
CLUSTER_RADIUS = 10                  # Radius for mean shift clustering
MAX_ITERATIONS = 20                  # Max iterations for search
MIN_PIXELS_OF_COLOR = 500            # Minimum pixels to cluster a color
MIN_SATURATION = 0.15                # Minimum saturation to consider a pixel
MIN_VALUE = 0.1                      # Minimum value to consider a pixel
MAX_VALUE = 0.93                     # Maximum value to consider a pixel
CLUSTER_SELECTION_DIST = 20          # Max distance from MAGIC_HUE to select a cluster

Now we will create a plot showing RGB colors to get an idea of the color distribution.

In [None]:
# Fill a dict with rgb value counts
plot_rgb = dict()
for x in range(img.width):
    for y in range(img.height):
        p = img.getpixel((x, y))
        r, g, b = p[:3]
        if plot_rgb.get(f"{r},{g},{b}") is None:
            plot_rgb[f"{r},{g},{b}"] = 1
        else:
            plot_rgb[f"{r},{g},{b}"] += 1

# Convert dict to list of points
all_rgb = [list(map(int, k.split(","))) for k, v in plot_rgb.items() if v > 10]
plot_hsv = []
colors = []
for r, g, b in all_rgb:
    h, s, v = colorsys.rgb_to_hsv(r/255.0, g/255.0, b/255.0)
    if s < MIN_SATURATION or v < MIN_VALUE or v > MAX_VALUE:
        continue
    h = int(h * 255)
    s = int(s * 255)
    v = int(v * 255)
    plot_hsv.append((h, s, v))
    colors.append(f"#{r:02x}{g:02x}{b:02x}")

# Plot the RGB colors
%matplotlib qt
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(*zip(*plot_hsv), c=colors, marker='o')


<mpl_toolkits.mplot3d.art3d.Path3DCollection at 0x173f8c3cf50>

Now, use mean shift clustering to find clusters of certain hues.

In [None]:
means = []
pixels = img.load()

# This will be easier if we organize the pixels not by their location, but by their hue
pixels_hue = np.zeros(256)
for x in range(img.width):
    for y in range(img.height):
        r, g, b = pixels[x, y][:3]
        h, s, v = colorsys.rgb_to_hsv(r/255.0, g/255.0, b/255.0)
        if s < MIN_SATURATION or v < MIN_VALUE or v > MAX_VALUE:
            continue
        pixels_hue[int(h * 255)] += 1

# Filter out low-count colors
pixels_hue = [ v if v >= MIN_PIXELS_OF_COLOR else 0 for v in pixels_hue ]

# Insert a mean or increment its count
def insert_mean(x):
    for mean in means:
        if x == mean[0]:
            mean[1] += 1
            return
    means.append([x, 1])

# Function to find the discrete points in a circular radius (not really a circle in 1D)
def points_in_radius(pixels, x, radius):
    in_radius = []
  
    # Get the points in the nearby box first
    x_min = max(0, x - radius)
    x_max = min(255, x + radius)
    for px in range(x_min, x_max + 1):
        in_radius.append(px)
  
    return in_radius

# Return points with a color
def filter_points(pixels, points):
    filtered_points = []

    for px in points:
        if pixels[px] > 0:
            filtered_points.append(px)

    return filtered_points

# Find the mean for a given point
# This involves finding all points within CLUSTER_RADIUS, averaging them, and repeating until convergence
def find_mean(pixels, x):
    # Start with the initial point
    mx = x

    # Get the initial points in the radius
    in_radius = points_in_radius(pixels, mx, CLUSTER_RADIUS)

    for iteration in range(MAX_ITERATIONS):
        colors_in_radius = filter_points(pixels, in_radius)

        # Calculate the new mean
        nmx = sum([ p for p in colors_in_radius ]) // len(colors_in_radius)

        # If the mean hasn't changed, we're done
        if nmx == mx:
            break

        # Update the mean and the points in the radius
        dx = nmx - mx
        mx = nmx
        for px in colors_in_radius:
            px = np.clip(px + dx, 0, 255)

    return mx

# Tie it all together and do the mean shift clustering
count = 0
for x, v in enumerate(pixels_hue):
    count += 1
    if count % 100 == 0:
        print(f"Processed {count} of {len(pixels_hue)} colors")

    if v == 0:
        continue

    mx = find_mean(pixels_hue, x)
    insert_mean(mx)
    
%matplotlib qt
plt.figure()
plt.plot(range(256), pixels_hue)

# print(means)

Processed 100 of 256 colors
Processed 200 of 256 colors


[<matplotlib.lines.Line2D at 0x173f8cb9a90>]

In [None]:
# Fill a dict with rgb value counts
plot_hsv = [(mean[0], 128, 128) for mean in means]
plot_rgb = [tuple(map(lambda c: int(c * 255), colorsys.hsv_to_rgb(h/255.0, s/255.0, v/255.0))) for h, s, v in plot_hsv]
colors = [f"#{r:02x}{g:02x}{b:02x}" for r, g, b in plot_rgb]
sizes = [mean[1]*10 for mean in means]

# Plot the RGB colors
%matplotlib qt
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(*zip(*plot_hsv), c=colors, marker='o', s=sizes)
ax.set_xlabel('Hue')
ax.set_ylabel('Saturation')
ax.set_zlabel('Value')

max_mean = max([ mean[1] for mean in means ])
print(f"{len(means)} clusters, max of {max_mean} points in a cluster")

10 clusters, max of 6 points in a cluster


Now we select green clusters.

In [None]:
selected_clusters = [mean for mean in means if abs(mean[0] - MAGIC_HUE) <= CLUSTER_SELECTION_DIST]
print(f"Selected {len(selected_clusters)} clusters near hue {MAGIC_HUE}")

Selected 4 clusters near hue 60


Now recluster all the points but make a map this time

In [None]:
plant_pixels = np.zeros((img.width, img.height), dtype=bool)

for x in range(img.width):
    for y in range(img.height):
        r, g, b = pixels[x, y][:3]
        h, s, v = colorsys.rgb_to_hsv(r/255.0, g/255.0, b/255.0)
        if s < MIN_SATURATION or v < MIN_VALUE or v > MAX_VALUE:
            continue
        h = int(h * 255)

        if (pixels_hue[h] == 0):
            continue

        mx = find_mean(pixels_hue, h)
        if mx in [mean[0] for mean in selected_clusters]:
            plant_pixels[x, y] = True

Now display the results and overlay the found pixels onto the original image.

In [None]:
# Show results
first_pass_img_array = (plant_pixels.astype(np.uint8) * 255)
img_first_pass = Image.fromarray(first_pass_img_array.transpose())

overlay_img = img.copy()
for x in range(overlay_img.width):
    for y in range(overlay_img.height):
        if first_pass_img_array[x, y]:
            overlay_img.putpixel((x, y), (255, 0, 0))  # Mark detected plant pixels in red

img_first_pass.show()
overlay_img.show()