In [None]:
import numpy as np
from PIL import Image
import random
import matplotlib.pyplot as plt
import time

In [2]:
class Object:
    def __init__(self, identifier, color):
        self.id = identifier
        self.points = set()
        # self.visited = False
        self.color = np.array(color, dtype=np.uint8)
        
    def get_color(self):
        return self.color

    def add_point(self, x, y):
        self.points.add((x, y))

class NumpyImage:
    def __init__(self, img):
        self.img = img
        self.shape = img.shape
        self.width = img.shape[1]
        self.height = img.shape[0]
        self.visited = np.zeros((self.height, self.width), dtype=bool)
        
    def visited(self, x, y):
        return self.visited[x, y]
    
    def is_background(self, x, y):
        """Check if pixel is background (255,255,255)"""
        return np.array_equal(self.img[x, y], [255, 255, 255])    

    def is_red(self, x, y):
        """Check if pixel is red (255,0,0)"""
        return np.array_equal(self.img[x, y], [255, 0, 0])
    
    def get_neighbors(self, x, y):
        neighbors = []
        if x > 0:
            neighbors.append((x-1, y))
        if x < self.width-1:
            neighbors.append((x+1, y))
        if y > 0:
            neighbors.append((x, y-1))
        if y < self.height-1:
            neighbors.append((x, y+1))
        return neighbors
    
    def paint_pixel(self, x,y, color):
        self.img[x,y] = color

class Detections:
    def __init__(self, img: NumpyImage):
        self.objs_img = img
        self.objects : dict[(int,Object)] = {} # Dictionary of objects indexed by ID
        self.points_of_objects: dict[(tuple,Object)] = {}  # Dictionary mapping (x,y) to object ID
        
    def get_objects_size(self):
        return len(self.objects)
        
    def get_object(self, x, y):
        return self.points_of_objects.get((x, y), None)
    
    def add_object(self, x, y):
        count = self.get_objects_size() + 1
        obj = Object(identifier=count, color=[random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)])
        self.objects[count] = obj
        self.points_of_objects[(x, y)] = obj
        return obj
    
    def iterative_flood_fill(self, x, y): # BFS
        queue = [(x, y)]
        
        obj = self.add_object(x, y)
        
        while len(queue) > 0:
            x, y = queue.pop(0)
            
            if self.objs_img.visited[y, x]:
                continue
            
            self.objs_img.visited[y, x] = True
            
            self.objs_img.paint_pixel(x=x,y=y, color=obj.color)
            self.points_of_objects[(x, y)] = obj
            obj.add_point(x, y)
                        
            for neighbor_x, neighbor_y in self.objs_img.get_neighbors(x, y):
                if not self.objs_img.visited[neighbor_y, neighbor_x] and self.objs_img.is_red(neighbor_x, neighbor_y):
                    queue.append((neighbor_x, neighbor_y))
                    
    
    def traverse_img(self):
        for x in range(self.objs_img.height):
            for y in range(self.objs_img.width):
                # Only process red pixels that haven't been visited yet
                if not self.objs_img.is_background(x, y) and not self.objs_img.visited[y, x]:
                    self.iterative_flood_fill(x, y)
                
                
        
        return self.objs_img.img

In [15]:

filename = "blob_output.png"
# Load the image
img_path = f"/home/lrn/Repos/flood-fill-cuda/{filename}" 
img = Image.open(img_path)
img_array = np.array(img)  # Convert to numpy array


In [None]:
# Create NumpyImage
numpy_img = NumpyImage(img_array.copy())

# Then create Detections with the NumpyImage
detections = Detections(numpy_img)

start_time = time.time()
# Now call traverse_img
result = detections.traverse_img()
end_time = time.time() - start_time
print(f"Time elapsed: {end_time} seconds")

In [17]:
# Create a figure with two subplots side by side
plt.figure(figsize=(20, 10))

# Plot image before flood fill
plt.subplot(1, 2, 1)
plt.imshow(img_array)
plt.title('Image before flood fill')
plt.axis('off')

# Plot image after flood fill
plt.subplot(1, 2, 2)
plt.imshow(result)
plt.title('Image after flood fill')
plt.axis('off')

plt.tight_layout()
plt.savefig(f'./results/objects_{filename}')
plt.close()

In [18]:
detections.get_objects_size()

3183

In [1]:
import numpy as np
from PIL import Image
import random
from collections import deque

# Constants
RED_THRESHOLD = 200  # Threshold for red detection
WHITE_BG = (255, 255, 255)

def is_red(pixel):
    r, g, b = pixel
    return r > RED_THRESHOLD and g < 50 and b < 50

def bfs_flood_fill(image, visited, start_x, start_y, new_color):
    height, width = image.shape[:2]
    queue = deque([(start_x, start_y)])
    visited[start_x, start_y] = True
    
    directions = [(0, 1), (1, 0), (0, -1), (-1, 0)]
    
    while queue:
        x, y = queue.popleft()
        image[x, y] = new_color
        
        for dx, dy in directions:
            nx, ny = x + dx, y + dy
            if (0 <= nx < height and 0 <= ny < width and 
                not visited[nx, ny] and is_red(image[nx, ny])):
                visited[nx, ny] = True
                queue.append((nx, ny))

def process_image(image_path):
    # Load image
    img = Image.open(image_path).convert('RGB')
    img_array = np.array(img, dtype=np.uint8)
    height, width = img_array.shape[:2]
    
    # Initialize visited array and output
    visited = np.zeros((height, width), dtype=bool)
    output = img_array.copy()
    blob_count = 0
    
    # Scan for blobs
    for x in range(height):
        for y in range(width):
            if is_red(output[x, y]) and not visited[x, y]:
                # Generate random color (0-254 ensures not white or bright red)
                new_color = (random.randint(0, 254), random.randint(0, 254), random.randint(0, 254))
                bfs_flood_fill(output, visited, x, y, new_color)
                blob_count += 1
    
    return Image.fromarray(output), blob_count

# Test it
if __name__ == "__main__":
    image_path = "/home/lrn/Repos/flood-fill-cuda/blob_output.png"  # Replace with your image path
    result_img, blob_count = process_image(image_path)
    print(f"Detected {blob_count} blobs")
    result_img.save("colored_blobs_cpu_mvp.png")

Detected 3183 blobs
