<p style='
  color: #3b4045; 
  text-align: center;
  font-weight: bold;
  font-family: -apple-system,BlinkMacSystemFont, "Segoe UI Adjusted","Segoe UI","Liberation Sans",sans-serif;     font-size: 2.07692308rem; '> 
    Vascular Graph Matching and Comparison
</p>

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from fundus_vessels_toolkit.seg2graph import RetinalVesselSeg2Graph

In [3]:
# Path to the raw fundus image
RAW_PATH = "/home/gaby/These/Data/Fundus/Vessels/FIVES/test/downsampled-1024/1-images/110_G.jpg"

# Path to a vessel ground truth (for comparison purposes)
VESSELS_PATH = "/home/gaby/These/Data/Fundus/Vessels/FIVES/test/downsampled-1024/2-vessels/110_G.png"


### Load the raw Image and the Vascular Segmentation

In [4]:
import cv2

def load_img(path, binarize=False):
    img = cv2.imread(path)
    img = img.astype(float)/255
    return img.mean(axis=2)>.5 if binarize else img


Load image and segmentation mask

In [5]:
raw = load_img(RAW_PATH)
vessels = load_img(VESSELS_PATH, binarize=True)

Perform segmentation using a pretrained model

In [6]:
from fundus_vessels_toolkit.models import segment
vessels = segment(raw)

ModuleNotFoundError: No module named 'timm.models.layers.activations'

---

### Graph Matching

First let's perform the graph extraction from both the ground-truth and the predicted segmentation.

In [7]:
max_vessel_diameter = raw.shape[1]//70
seg2graph = RetinalVesselSeg2Graph(max_vessel_diameter)

vgraph_pred = seg2graph(vessels)
vgraph_true = seg2graph(vessels)

Then we use `simple_graph_matching` to find a simple match between the nodes of the two graph simply based on their positions. (The algorithm will try to minimize the total distance between matched nodes).

In [8]:
from fundus_vessels_toolkit.vgraph.matching import simple_graph_matching

D = max_vessel_diameter

(argmatch_pred, argmatch_true), dist = simple_graph_matching(vgraph_pred.nodes_yx_coord, vgraph_true.nodes_yx_coord, 
                                                             max_matching_distance=D*7, min_matching_distance=D/2, density_sigma=D, 
                                                             return_distance=True)
nmatch = len(argmatch_pred)

vgraph_pred.shuffle_nodes(argmatch_pred)
vgraph_true.shuffle_nodes(argmatch_true)

f"{nmatch} / {vgraph_true.nodes_count} nodes from the prediction segmentation were matched!"

'133 / 133 nodes from the prediction segmentation were matched!'

Lets have a look at the matched graph. (The paired nodes share the same color and ID on both image. Unpaired nodes appear in grey.)

_(The prediction is on the left, the ground truth on the right.)_

In [9]:
import numpy as np
from jppype.view2d import imshow, View2D, sync_views
from jppype.utilities.color import colormap_by_name
from ipywidgets import GridspecLayout

def create_view(vessels, vgraph, edge_map=False, edge_labels=False, node_labels=True):
    v = View2D()
    v.add_image(raw, 'raw')
    v.add_label(vessels, 'vessel', 'white', options={'opacity': 0.2})
    v['vessel graph'] = vgraph.jppype_layer(edge_map=edge_map, node_labels=node_labels, edge_labels=edge_labels)
    v['vessel graph'].nodes_cmap = {None: colormap_by_name()} | {_: "#444" for _ in range(nmatch, vgraph.nodes_count+1)}
    return v
    

In [14]:
v = View2D()
v.add_image(raw)

View2D()

In [17]:
v._transform

(0, 0, 1e-08)

In [18]:
grid = GridspecLayout(1,2, height='700px')
grid[0,0] = create_view(vessels, vgraph_pred, edge_map=True, node_labels=False)
grid[0,1] = create_view(vessels, vgraph_pred, edge_map=True)
sync_views(grid[0,0], grid[0,1])
grid

TypeError: super(type, obj): obj must be an instance or subtype of type

---

### Graph Edit Distance
(WIP)

In [None]:
from fundus_vessels_toolkit.vgraph.matching import naive_edit_distance
pred_diff, true_diff, (nmatch, pred_labels, true_labels) = naive_edit_distance(vgraph_pred, vgraph_true, max_matching_distance=D*7, min_matching_distance=D/2, density_matching_sigma=D, return_labels=True)
true_n = vgraph_true.branches_count
pred_n = vgraph_pred.branches_count

precision = (true_n - true_diff) / true_n
recall = (pred_n - pred_diff) / pred_n
f1_topo = 2/(1/precision  + 1/recall)

print(f"{true_diff} / {true_n} target branches were missed in the prediction, {pred_diff} / {pred_n} predicted branches are not present in the target.")
print(f"F1 Topo: {f1_topo:.3f}")

In [None]:
 grid = GridspecLayout(1,2, height='700px')
grid[0,0] = create_view(vessels, vgraph_pred, node_labels=False)
grid[0,1] = create_view(true_vessels, vgraph_true, node_labels=True)

#grid[0,0]['vessel graph'].edges_cmap = {edge_id: 
#                {0: 'green', # False Positive
#                 1: 'white', # True Positive
#                 2: 'LightGreen', # Split edges
#                 3: 'orange', # Fused edges
#              }[c] for edge_id, c in enumerate(pred_labels)}

#grid[0,1]['vessel graph'].edges_cmap = {edge_id: 
#                {0: 'red', # False Negative
#                 1: 'white', # True Positive
#                 2: 'LightGreen', # Fused edges
#                 3: 'orange', # Split edges
#              }[c] for edge_id, c in enumerate(true_labels)}

grid[0,0]['vessel graph'].edges_cmap = {edge_id: 
                {0: 'white', # False Positive
                 1: 'red', # True Positive
              }[c] for edge_id, c in enumerate(pred_labels)}

grid[0,1]['vessel graph'].edges_cmap = {edge_id: 
                {0: 'white', # False Positive
                 1: 'red', # True Positive
              }[c] for edge_id, c in enumerate(true_labels)}

sync_views(grid[0,0], grid[0,1])
grid