# Boundary Detection

```{note}
This post extends concepts discussed in [a previous post about pixels as vectors](images).
You don't *need* to read that one first.
```

## Overview

There are plenty of tools available for performing image processing,
but let's say that we want to implement our own *boundary detection* algorithm.
How would you go about that?

Now, I'm not an image processing wizard, but my guess is that you would
somehow need to compare pixel value(s) to their neighbour(s),
If pixel values are similar enough, then the pixels might belong
to the same object. If they are different enough, then they might belong
to different objects. This seems simple enough, but in order to really
solve this problem we will need to specify a few more things:

* What is our method for *comparing* pixel values?
* How do we define a  *threshold* for deciding if pixels belong to the same object or not?
* How do we determine *neighbors* of pixels?


## Comparing Pixel Values

Imagine if we treat all the pixels in an image as vectors.
Vectors have both _magnitude_ and _direction_.
The [dot product](https://en.wikipedia.org/wiki/Dot_product)
is a mathematical operation that returns the angle between two vectors.
Could we use this simple measurement as a means for comparing pixel values?

## Threshold 

The *dot product* returns a single value (i.e. a *float*).
Couldd our logic be as simple as comparing this value against a threshold?
Obviously, we would need a method for being able to *tailor* this value.

## Finding Neighbouring Pixels

We will use one of my favourite techniques: [graph processing with networkx](https://networkx.org/documentation/stable/index.html)!
We will use *networkx* to connect neighboring pixels.
If we determine that pixels belong to the same object
then we maintain the connection between the pixels.
If we determine that the pixels belong to different objects,
then we will sever the connection.


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import networkx as nx

In [None]:
img_url = "https://github.com/scikit-image/scikit-image/blob/main/skimage/data/phantom.png?raw=true"
img = mpimg.imread(img_url)
# img = img[::5,::5,:]
plt.imshow(img)
plt.show()

In [None]:
def create_graph(img: np.ndarray) -> nx.Graph:
    nrows, ncols, _ = img.shape
    g = nx.Graph(nrows=nrows, ncols=ncols, img=img)
    for r in range(nrows):
        for c in range(ncols):
            g.add_node((r, c), color=img[r, c], group=None)

    for r in range(1, nrows - 1):
        for c in range(1, ncols - 1):
            source = (r, c)
            g.add_edge(source, (r - 1, c))
            g.add_edge(source, (r + 1, c))
            g.add_edge(source, (r, c - 1))
            g.add_edge(source, (r, c + 1))

    for r in range(nrows - 1):
        g.add_edge((r, 0), (r + 1, 0))
        g.add_edge((r, ncols - 1), (r + 1, ncols - 1))

    for c in range(ncols - 1):
        g.add_edge((0, c), (0, c + 1))
        g.add_edge((nrows - 1, c), (nrows - 1, c + 1))

    return g


def draw_graph(g, node_size=1):
    pos = {n: (n[1], -n[0]) for n in g.nodes}
    node_color = [g.nodes[n]["group"] for n in g.nodes]
    if any(nc is None for nc in node_color):
        node_color = [g.nodes[n]["color"] for n in g.nodes]
    nx.draw_networkx(
        g,
        pos=pos,
        node_size=node_size,
        node_color=node_color,
        with_labels=False,
    )


g = create_graph(img)
draw_graph(g, node_size=10)

In [None]:
def compute_vector_metrics(v):
    mag = np.linalg.norm(v)
    if mag == 0.0:
        return v, mag
    unit = v / mag
    return unit, mag


def are_pixels_connected(
    p1, p2, mag_threshold: float, dot_threshold: float
) -> bool:
    p1u, p1m = compute_vector_metrics(p1)
    p2u, p2m = compute_vector_metrics(p2)

    dot = np.dot(p1u, p2u)

    if mag_threshold < abs(p1m - p2m) and dot_threshold < abs(2 - dot):
        return False

    return True


def determine_connected_pixels(g: nx.Graph, **kwargs):
    g = g.copy()
    for s, t in g.edges:
        ps = g.nodes[s]["color"]
        pt = g.nodes[t]["color"]
        if not are_pixels_connected(ps, pt, **kwargs):
            g.remove_edge(s, t)

    gc = nx.connected_components(g)
    for i, nodes in enumerate(gc):
        for n in nodes:
            g.nodes[n]["group"] = i

    return g


g_conn = determine_connected_pixels(
    g, dot_threshold=np.inf, mag_threshold=np.inf
)
print(nx.number_connected_components(g_conn))
draw_graph(g_conn, node_size=5)

In [None]:
g_conn = determine_connected_pixels(
    g, dot_threshold=0.0005, mag_threshold=0.05
)
print(nx.number_connected_components(g_conn))
draw_graph(g_conn, node_size=5)

In [None]:
g_conn = determine_connected_pixels(
    g, dot_threshold=0.0001, mag_threshold=0.05
)
print(nx.number_connected_components(g_conn))
draw_graph(g_conn, node_size=5)