#### import

In [None]:
%load_ext autoreload
%autoreload 2

In [2]:
import cv2
import numpy as np
import os
import skimage.morphology as skmorph

In [3]:
from ipyprotopypes import ImageViewer
viewer1 = ImageViewer()
viewer2 = ImageViewer()

## Setup

#### Segmentation Model

In [None]:
def load_model_from_wandb(artifact, source_exp=None, alias='latest'):
    import wandb
    import json
    from nntemplate import Cfg, model
    
    if ':' not in artifact:
        artifact = artifact + ':' + alias
    
    api = wandb.Api(timeout=19)
    model_artifact = api.artifact(artifact)
    run = model_artifact.logged_by()
    cfg = json.loads(run.json_config)
    cfg = Cfg.Dict.from_dict({k: v['value'] for k, v in cfg.items()})
    if source_exp:
        cfg = Cfg.parse(source_exp, override={k: cfg[k] for k in ('experiment', 'model', 'task')})[0]
    else:
        cfg.Parser.parse_registered_cfg(cfg, inplace=True)
        
    exp: ExperimentCfg = cfg.root()['experiment']
    task_cfg: LightningTaskCfg = cfg.root()['task']
    
    print(cfg['model'])

    model = cfg['model'].create()
    task = task_cfg.create_task(model)
    path = model_artifact.get_path(alias+'.ckpg').download()
    print(path)
    task.load_from_checkpoint(path,
                              cfg=cfg['task'], model=model)
    
    return task
    

segmentation_model  = load_model_from_wandb("liv4d/Fundus Vessels Segmentation/Baseline_Pretrained.models",
                                 '/home/gaby/These/src/nn-template/experimentations/exp.yaml',
                                  alias='best-val-dice').cuda().eval()

def sup_multiple(n, m=32):
    return m - (n-1)%m + n - 1

def segment(x, threshold=0.5, model=segmentation_model):
    import torch
    from nntemplate.torch_utils import crop_pad
    
    final_shape = x.shape[:2]
    padded_shape = tuple(sup_multiple(_, 32) for _ in final_shape)
    
    with torch.no_grad():
        x = torch.from_numpy(x).permute(2,0,1).unsqueeze(0).float()
        x = crop_pad(x, padded_shape).cuda()
        y = model(x=x)
        y = crop_pad(y, final_shape)
        pred = y.squeeze(0)[1] >= threshold
        return pred.cpu().numpy()


CfgDict[model]:
type: smp
architecture: Unet
encoder_name: resnet34
encoder_weights: imagenet
n-scale: 3
depth: 1
n-features: 8
norm: switchable

./artifacts/Baseline_Pretrained.models:v16/best-val-dice.ckpg


  rank_zero_warn(


In [5]:
import torch
from collections import OrderedDict
d = torch.load("./artifacts/Baseline_Pretrained.models:v16/best-val-dice.ckpg", map_location='cpu')['state_dict']
params = OrderedDict([(k.split('.',1)[1],v) for k, v in d.items()])
torch.save(params, "./artifacts/resnet34.pt")

In [7]:
import segmentation_models_pytorch as smp
model = smp.Unet('resnet34', classes=2, activation='sigmoid')
model.load_state_dict(torch.load("./artifacts/resnet34.pt"))
model = model.eval().cuda()

In [18]:
from fundus_vessels_toolkit.models import segment as fvt_segment

#### Image

In [8]:
def load_img(path, binarize=False):
    img = cv2.imread(path)
    img = img.astype(float)/255
    return img.mean(axis=2)>.5 if binarize else img

def inplace_post_process(*vessel_maps):
    for v in vessel_maps:
        skmorph.remove_small_holes(v, 5, 2, out=v)
        skmorph.opening(v, np.ones((3,3)), out=v)
    return vessel_maps

In [36]:
ID = 30
RAW_PATH = "/home/gaby/These/Data/Fundus/Vessels/HRF/downsampled-1024/1-images/"
VESSELS_PATH = "/home/gaby/These/Data/Fundus/Vessels/HRF/downsampled-1024/2-vessels/"
MASK_PATH = "/home/gaby/These/Data/Fundus/Vessels/HRF/downsampled-1024/mask/"

raw_filename = os.listdir(RAW_PATH)[ID]
print('READING: ')

raw_filename, suffix = raw_filename.rsplit('.', 1)
vessel_filename = next(f for f in os.listdir(VESSELS_PATH) if f.startswith(raw_filename))
mask_filename = next((f for f in os.listdir(MASK_PATH) if f.startswith(raw_filename)), None)

raw = load_img(RAW_PATH+raw_filename+'.'+suffix)
vessels_true = load_img(VESSELS_PATH+vessel_filename, binarize=True)
mask = None if mask_filename is None else load_img(MASK_PATH+mask_filename, binarize=True)

seg2graph = RetinalVesselSeg2Graph(raw.shape[0]//100)

print('raw: ', RAW_PATH+raw_filename+'.jpg')
print('vessels: ', VESSELS_PATH+vessel_filename)
print('mask: ', MASK_PATH+mask_filename if mask_filename is not None else None)


vessels = segment(raw)
vessels2 = segment(raw, model=model)
if mask is not None:
    vessels *= mask
    vessels_true *= mask
    vessels2 *= mask
#inplace_post_process(vessels_true, vessels)


READING: 
raw:  /home/gaby/These/Data/Fundus/Vessels/HRF/downsampled-1024/1-images/09_dr.jpg
vessels:  /home/gaby/These/Data/Fundus/Vessels/HRF/downsampled-1024/2-vessels/09_dr.png
mask:  /home/gaby/These/Data/Fundus/Vessels/HRF/downsampled-1024/mask/09_dr.png


In [29]:
vessels3 = fvt_segment(raw)

Consider resizing the image to a size close to 1024x1024.
  final_shape = x.shape


In [26]:
vessels3 *= mask

#### Displays

In [32]:
def display(vessels, skel, skel_dist=None, view=viewer1):
    bin_skel = skel>0
    raw_display = raw.copy()
    raw_display += np.expand_dims(vessels*.07, 2)
    if skel_dist is None:
        raw_display[bin_skel] = np.ones((3)) * .5
    else:
        raw_display[bin_skel] = np.expand_dims(skel_dist[bin_skel]/skel_dist.max(), 1)
    raw_display[skel==3] = [1,1,0]
    raw_display[skel==4] = [0,1,1]
    raw_display[skel==1] = [0,0,1]
    #raw_display[skel==2] = [1,1,1]
    #raw_display[skel==3] = [0,0,0]
    view.image = raw_display.transpose((2,0,1))
    
def display_label(vessels, label, jonctions=None, label_jonctions=True, label_branch=True, view=viewer1):
    import colorsys
    import cv2
    
    bin_skel = skel>0
    raw_display = raw.copy()
    raw_display += np.expand_dims(vessels*.07, 2)
    
    bin_label = np.tile(np.expand_dims(label!=0, 2), (1,1,3))
    colors = np.asarray([colorsys.hsv_to_rgb(h, s, v) for h in np.linspace(0,1,28) for s in [.85, 1] for v in [ 1, .85]])
    colors = colors.reshape((7,-1,3)).transpose((1,0,2)).reshape((-1,3))
    for i, c in enumerate(colors):
        raw_display[(label!=0) & (label%len(colors)==i)] = c
    if label_branch:
        for i in range(1, label.max()+1):
            y, x = np.where(label==i)
            if len(y):
                y = int(np.mean(y)+10)
                x = int(np.mean(x)+10)
                color = colors[i%len(colors)]
                raw_display = cv2.putText(raw_display, str(i-1), (x, y), cv2.FONT_HERSHEY_PLAIN, .6, color)
            else:
                print(f'Branch {i-1} was not found')
            
    if jonctions is not None:
        jy, jx = np.round(jonctions).astype(np.int64)
        raw_display[jy, jx] = [1,1,1]
        if label_jonctions:
            for i, (y,x) in enumerate(zip(jy, jx)):
                raw_display = cv2.putText(raw_display, f":{i}", (x+3, y+3), cv2.FONT_HERSHEY_PLAIN, .6, (.6, .6, .6)) 
    view.image = raw_display.transpose((2,0,1))

In [27]:
raw_display = raw.copy()
raw_display += np.expand_dims(vessels3*.07, 2)
viewer1.image = raw_display.transpose((2,0,1))
viewer1

ImageViewer()

In [12]:
raw_display = raw.copy()
raw_display += np.expand_dims(vessels2*.07, 2)
viewer2.image = raw_display.transpose((2,0,1))
viewer2

ImageViewer()

### Skeletonize and graph

In [34]:
skel = seg2graph.skeletonize(vessels)
display(vessels, skel, view=viewer1)
skel = seg2graph.skeletonize(vessels_true)
display(vessels_true, skel, view=viewer2)

In [None]:
seg2graph = RetinalVesselSeg2Graph(raw.shape[0]//80)
conn1, branch_labels, node_yx1 = seg2graph.seg2adjacency(vessels, return_label=True)
graph1 = branches_by_nodes_to_node_graph(conn1, node_yx1)
display_label(vessels, branch_labels, node_yx1, view=viewer1)

conn2, branch_labels, node_yx2 = seg2graph.seg2adjacency(vessels_true, return_label=True)
graph2 = branches_by_nodes_to_node_graph(conn2, node_yx2)
display_label(vessels_true, branch_labels, node_yx2, view=viewer2)

In [None]:
print(f'Empty nodes: {[i for i, n in enumerate(conn1.T) if n.sum()==0]}')
print(f'Irregular branches: {[i for i, b in enumerate(conn1) if b.sum()!=2]}')
if len(node_yx[0]) != conn.shape[1]:
    print(f"Invalid number of nodes coordindates: {len(node_yx[0])} instead of {conn.shape[1]}.")
if np.max(branch_labels) != conn.shape[0]:
    print(f"Invalid number of branch labels: {np.max(branch_labels)} instead of {conn.shape[0]}.")

In [None]:
graph = branches_by_nodes_to_node_graph(conn1, node_yx1)
edges = {}
branches = set()
for edge in graph.edges():
    e = graph.edges[edge]
    edges[f"Branch {e['branch']:03}"] = f"n{edge[0]}-{edge[1]}"
    branches |= {e['branch']}
    
for b in sorted(edges.keys()):
      print(b, edges[b])

In [None]:
for i, b in enumerate(conn.T):
    print(f'Node {i} ({node_yx[0][i]:.0f}, {node_yx[1][i]:.0f}): ', np.where(b)[0])

In [None]:
for i, b in enumerate(conn):
    print(f'Branch {i}: ', np.where(b)[0])

In [None]:
conn.sum(axis=1).astype(int)

In [None]:
np.sum(conn.sum(axis=1)>2)

In [None]:
conn.sum(axis=1).shape

# Euclidian Minimal Matching

In [None]:
from pygmtools.linear_solvers import hungarian

MAX_PAIRING_DISTANCE = 18


In [None]:
yx1 = np.stack(node_yx1, axis=1)
yx2 = np.stack(node_yx2, axis=1)
N1 = len(yx1)
N2 = len(yx2)

euclidian_distance = np.linalg.norm(yx1[:, None]-yx2[None, :], axis=2)
weight = 1/(1e-8+euclidian_distance)
min_weight = .5/(1e-8+MAX_PAIRING_DISTANCE)
matched_nodes = hungarian(weight[None, ...], [N1], [N2], np.repeat([[min_weight]], N1, axis=1), np.repeat([[min_weight]], N2, axis=1))

In [None]:
_, dim1, dim2 = np.where(matched_nodes)
import matplotlib.pyplot as plt
print(euclidian_distance[dim1, dim2].max())
plt.hist(euclidian_distance[dim1, dim2])

In [None]:
matched_nodes.shape

In [None]:
for n1, n2, w in zip(nodes1, nodes2, near_nodes_weight):
    print(n1, n2, f"{w:.8}")

In [None]:
p = maximum_bipartite_matching(graph, 'column')
print('Matched node: ', np.sum(p!=-1), ' / ', len(p))
for p2, p1 in enumerate(p):
    print(p2, '->', p1)

In [None]:
p = maximum_bipartite_matching(graph)
print('Matched node: ', np.sum(p!=-1), ' / ', len(p))
for p2, p1 in enumerate(p):
    print(p2, '->', p1)

In [None]:
N = n1+n2
B = nx.Graph()
B.add_nodes_from(np.arange(N), bipartite=0)
B.add_nodes_from(np.arange(N, 2*N), bipartite=1)
for n1, n2, w in zip(nodes1, nodes2, near_nodes_weight):
    B.add_weighted_edges_from([(n1, N+n2, w), (N+n2, n1, w)])
for n1 in nodes1:
    B.add_weighted_edges_from([(node1, N+N2+node1, min_weight), (N+N2+node1, node1, min_weight)]+)
for n2 in nodes2:
    B.add_weighted_edges_from([(node1, N+N2+node1, min_weight), (N+N2+node1, node1, min_weight)])
    

In [None]:
nx.bipartite.maximum_matching(B)