In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = "1"
import torch
import yaml
from medpy.io import load, save
import pyvista
import json
import numpy as np
from models import build_model
from inference import relation_infer
from utils import patchify_voxel, patchify_graph, unpatchify_graph
from multiprocessing import Pool
from functools import partial
import networkx as nx
from skimage.measure import marching_cubes
import open3d as o3d
import itertools
%load_ext autoreload
%autoreload 2

In [2]:
def plot_test_sample(image, points, edges):
    meshes = []
    graphs = []
    # image = image[:, 54:-54, 54:-54]
    porder_points = [
        [0, 0, 0],
        [1, 0, 0],
        [0, 1, 0],
        [1, 1, 0],
        [0, 0, 1],
        [1, 0, 1],
        [0, 1, 1],
        [1, 1, 1],
    ]
    border_edges = [
        [0, 1],
        [0, 2],
        [1, 3],
        [2, 3],
        [4, 5],
        [4, 6],
        [5, 7],
        [6, 7],
        [0, 4],
        [1, 5],
        [2, 6],
        [3, 7],
    ]

    ref_color = [[1, 0, 0] for i in range(len(edges))]
    ref_line_set = o3d.geometry.LineSet(
        points=o3d.utility.Vector3dVector(points + np.array([0.6, 0, 0])),
        lines=o3d.utility.Vector2iVector(edges)
    )
    ref_line_set.colors = o3d.utility.Vector3dVector(ref_color)
    graphs.append(ref_line_set)
    point_cloud = o3d.geometry.PointCloud()
    point_cloud.points = o3d.utility.Vector3dVector(points + np.array([0.6, 0, 0]))
    point_cloud.paint_uniform_color([0, 1, 0])
    graphs.append(point_cloud)

    ref_color = [[0.2, 0.2, 0.2] for i in range(len(border_edges))]
    ref_line_set = o3d.geometry.LineSet(
        points=o3d.utility.Vector3dVector(porder_points + np.array([0.6, 0, 0])),
        lines=o3d.utility.Vector2iVector(border_edges),
    )
    ref_line_set.colors = o3d.utility.Vector3dVector(ref_color)
    graphs.append(ref_line_set)

    verts, faces, norms, vals = marching_cubes(image > 0.0, level=0, method='lewiner')
    verts = verts / np.array(image.shape)

    mesh = np.concatenate((faces[:, :2], faces[:, 1:]), axis=0)
    adjucency = np.zeros((verts.shape[0], verts.shape[0]))

    for e in mesh:
        adjucency[e[0], e[1]] = 1.0
        adjucency[e[1], e[0]] = 1.0

    adjucency = np.triu(adjucency)
    mesh = np.array(np.where(np.triu(adjucency) > 0)).T

    pred_color = [[0, 0, 1] for i in range(len(mesh))]
    pred_line_set = o3d.geometry.LineSet(
        points=o3d.utility.Vector3dVector(verts - np.array([0.6, 0, 0])),
        lines=o3d.utility.Vector2iVector(mesh),
    )

    pred_line_set.colors = o3d.utility.Vector3dVector(pred_color)
    meshes.append(pred_line_set)

    pred_color = [[0.2, 0.2, 0.2] for i in range(len(border_edges))]
    pred_line_set = o3d.geometry.LineSet(
        points=o3d.utility.Vector3dVector((porder_points - np.array([0.6, 0, 0]))),
        lines=o3d.utility.Vector2iVector(border_edges),
    )

    pred_line_set.colors = o3d.utility.Vector3dVector(pred_color)
    meshes.append(pred_line_set)
    o3d.visualization.draw_geometries(meshes + graphs)

In [3]:
def plot_val_rel_sample(image, points1, edges1, points2, edges2, attn_map=None, relative_coords=True):
    ref_line_sets = []
    pred_line_sets = []
    x_max, y_max, z_max = points1.max(0)
    border_points = [
        [0, 0, 0],
        [x_max, 0, 0],
        [0, y_max, 0],
        [x_max, y_max, 0],
        [0, 0, z_max],
        [x_max, 0, z_max],
        [0, y_max, z_max],
        [x_max, y_max, z_max],
    ]
    border_edges = [
        [0, 1],
        [0, 2],
        [1, 3],
        [2, 3],
        [4, 5],
        [4, 6],
        [5, 7],
        [6, 7],
        [0, 4],
        [1, 5],
        [2, 6],
        [3, 7],
    ]

    ref_color = [[1, 0, 0] for i in range(len(edges1))]
    ref_line_set = o3d.geometry.LineSet(
        points=o3d.utility.Vector3dVector(points1),
        lines=o3d.utility.Vector2iVector(edges1),
    )
    ref_line_set.colors = o3d.utility.Vector3dVector(ref_color)
    ref_line_sets.append(ref_line_set)
    point_cloud = o3d.geometry.PointCloud()
    point_cloud.points = o3d.utility.Vector3dVector(points1)
    point_cloud.paint_uniform_color([0, 1, 0])
    ref_line_sets.append(point_cloud)

    ref_color = [[0.2, 0.2, 0.2] for i in range(len(border_edges))]
    ref_line_set = o3d.geometry.LineSet(
        points=o3d.utility.Vector3dVector(border_points),
        lines=o3d.utility.Vector2iVector(border_edges),
    )
    ref_line_set.colors = o3d.utility.Vector3dVector(ref_color)
    ref_line_sets.append(ref_line_set)

    ref_color = [[1, 0, 0] for i in range(len(edges2))]
    ref_line_set = o3d.geometry.LineSet(
        points=o3d.utility.Vector3dVector(points2 + np.array([points1.max(axis=0)[0]+10, 0, 0])),
        lines=o3d.utility.Vector2iVector(edges2),
    )
    ref_line_set.colors = o3d.utility.Vector3dVector(ref_color)
    ref_line_sets.append(ref_line_set)
    point_cloud = o3d.geometry.PointCloud()
    point_cloud.points = o3d.utility.Vector3dVector(points2 + np.array([points1.max(axis=0)[0]+10, 0, 0]))
    point_cloud.paint_uniform_color([0, 1, 0])
    ref_line_sets.append(point_cloud)

    ref_color = [[0.2, 0.2, 0.2] for i in range(len(border_edges))]
    ref_line_set = o3d.geometry.LineSet(
        points=o3d.utility.Vector3dVector(border_points + np.array([points1.max(axis=0)[0]+10, 0, 0])),
        lines=o3d.utility.Vector2iVector(border_edges),
    )
    ref_line_set.colors = o3d.utility.Vector3dVector(ref_color)
    ref_line_sets.append(ref_line_set)

    verts, faces, norms, vals = marching_cubes(image > 0.0, level=0, method='lewiner')
    verts = verts  - np.array([points1.max(axis=0)[0]+10, 0, 0]) #/ np.array(image.shape)

    mesh = o3d.geometry.TriangleMesh()
    mesh.vertices = o3d.utility.Vector3dVector(verts)

    mesh.triangles = o3d.utility.Vector3iVector(faces)

    pred_color = [[0.2, 0.2, 0.2] for i in range(len(border_edges))]
    pred_line_set = o3d.geometry.LineSet(
        points=o3d.utility.Vector3dVector((border_points - np.array([points1.max(axis=0)[0]+10, 0, 0]))),
        lines=o3d.utility.Vector2iVector(border_edges),
    )

    pred_line_set.colors = o3d.utility.Vector3dVector(pred_color)
    pred_line_sets.append(pred_line_set)

    o3d.visualization.draw_geometries(ref_line_sets + pred_line_sets + [mesh])

In [4]:
class obj:
    def __init__(self, dict1):
        self.__dict__.update(dict1)
        
def dict2obj(dict1):
    return json.loads(json.dumps(dict1), object_hook=obj)

patch_size = (64, 64, 64)
pad = (5, 5, 5)

In [5]:
config = "configs/synth_3D.yaml"
model_ckpt =  'trained_weights/last_checkpoint.pt'

### Load Config

In [6]:
with open(config) as f:
    print('\n*** Config file')
    print(config)
    config = yaml.load(f, Loader=yaml.FullLoader)
    print(config['log']['exp_name'])
config = dict2obj(config)
device = torch.device("cuda")


*** Config file
configs/synth_3D.yaml
synth_data_vesselformer


### Load Model

In [7]:
net = build_model(config).to(device)
checkpoint = torch.load(model_ckpt, map_location='cpu')
missing_keys, unexpected_keys = net.load_state_dict(checkpoint['net'])
unexpected_keys = [k for k in unexpected_keys if not (k.endswith('total_params') or k.endswith('total_ops'))]
if len(missing_keys) > 0:
    print('Missing Keys: {}'.format(missing_keys))
if len(unexpected_keys) > 0:
    print('Unexpected Keys: {}'.format(unexpected_keys))

### Patching operation

In [28]:
image_data, _ = load('data/vessel_data/raw/47.nii.gz')
seg_data, _ = load('data/vessel_data/seg/47.nii.gz')
vtk_data = pyvista.read('data/vessel_data/vtk/47.vtk')
shift = [np.shape(image_data)[0] / 2 - 1.8, np.shape(image_data)[1] / 2 + 8.3, 4.0]
coordinates = np.float32(np.asarray(vtk_data.points / 3.0 + shift))

vtk_data.points = coordinates
vtk_data = vtk_data.clean().extract_surface()

start_ = np.array([20,20,20])
num_patch = [2,4,6]

bounds = [start_[0], start_[0]+54*num_patch[0]-1, start_[1], start_[1]+ 54*num_patch[1]-1, start_[2], start_[2]+54*num_patch[2]-1]
image_data = image_data[start_[0]:start_[0]+54*num_patch[0], start_[1]:start_[1]+54*num_patch[1], start_[2]:start_[2]+54*num_patch[2]]
seg_data = seg_data[start_[0]:start_[0]+54*num_patch[0], start_[1]:start_[1]+54*num_patch[1], start_[2]:start_[2]+54*num_patch[2]]
vtk_data = vtk_data.clip_box(bounds, invert=False).extract_surface().clean()

coordinates = np.array(vtk_data.points) - start_
vtk_data.points = coordinates
edges = np.asarray(vtk_data.lines.reshape(vtk_data.n_cells, 3))[:,1:]

_, _, _, _, merged_graph = patchify_graph(image_data, vtk_data, patch_size, pad)
patch_list, start_ind, seq_ind, padded_shape = patchify_voxel(image_data, patch_size, pad)

seg_patch_list, _, _, _ = patchify_voxel(seg_data, patch_size, pad)

patches = np.float32(np.stack(patch_list))
vmax = patches.max()*0.001
patches = patches/vmax-0.5
patches = torch.tensor(patches).cuda().contiguous().unsqueeze(1)

segs = np.float32(np.stack(seg_patch_list))
segs = torch.tensor(segs).cuda().contiguous().unsqueeze(1)


### Using the newly created points also in the reference points

In [29]:
coordinates = np.array(merged_graph.points)
edges = np.asarray(merged_graph.lines.reshape(-1, 3))[:,1:]

### Do patch inference

In [30]:
patch_graphs = {'pred_nodes':[],'pred_rels':[],'pred_radius':[]}
batch_size = 48
net.eval()  # Put the CNN in evaluation mode
for i in range(np.int(np.ceil(1.0 * patches.shape[0] / batch_size))):
    images = patches[batch_size * i:i * batch_size+batch_size].contiguous()
    seg = segs[batch_size * i:i * batch_size+batch_size].contiguous()
    h, out = net(images)
    out = relation_infer(h.detach(), out, net.relation_embed, net.radius_embed, config.MODEL.DECODER.OBJ_TOKEN, config.MODEL.DECODER.RLN_TOKEN, config.MODEL.DECODER.RAD_TOKEN)
    
    # quick visualization
    # if (seg[0].squeeze().cpu().numpy()>0.0).sum():
    #     plot_test_sample(seg[0].squeeze().cpu().numpy(), out['pred_nodes'][0], out['pred_rels'][0])

    patch_graphs['pred_nodes'].extend(out['pred_nodes'])
    patch_graphs['pred_rels'].extend(out['pred_rels'])
    patch_graphs['pred_radius'].extend(out['pred_radius'])

In [31]:
print("Unifying of patch")
occu_matrix, out = unpatchify_graph(patch_graphs, start_ind, seq_ind, pad, imsize=padded_shape)

Unifying of patch


### Interim Visualization

In [13]:
# Having a quick visualization
edges1 = np.array(out['pred_rels'])
coord1 = np.array(out['pred_nodes'])
plot_val_rel_sample(seg_data.squeeze(), coordinates, edges,
                    coord1, edges1)

### Function to find node to be merged

In [33]:
def merge_nodes(item, occu_matrix):
    merge_nodes_list = []
    img1 = occu_matrix[item[0]].flatten()
    img2 = occu_matrix[item[1]].flatten()
    
    intersect = img1*img2>0
    img1p = img1[intersect]
    img2p = img2[intersect]

    all_pairs = [(aItem, bItem) for aItem, bItem in zip(img1p, img2p)]
    all_pairs = list(set(all_pairs))
    for id_pairs in all_pairs:
        iou = np.sum((img1==id_pairs[0])*(img2==id_pairs[1]))/np.sum(((img1==id_pairs[0])+(img2==id_pairs[1]))>0)
        # print("item:", item, "id_pairs:",id_pairs,"iou:",iou)
        if iou>0.4:
            # print("Item:", item, "node_pairs:", id_pairs, out["pred_nodes"][int(id_pairs[0])-1], out["pred_nodes"][int(id_pairs[1])-1])
            merge_nodes_list.append(np.int32(np.array(id_pairs))-1)
    return merge_nodes_list

Construct the whole graph

In [34]:
Whole_G=nx.Graph()
Whole_G.add_nodes_from(list(range(len(out['pred_nodes']))))
Whole_G.add_edges_from(out['pred_rels'])
loc_dict = {i: loc for i, loc in enumerate(out['pred_nodes'])}
nx.set_node_attributes(Whole_G, loc_dict, "position")
rad_dict = {tuple(i): rad for i, rad in zip(out['pred_rels'],out['pred_rads'])}
nx.set_edge_attributes(Whole_G, rad_dict, "radius")

Find nodes to be merged

In [35]:
merge_nodes_results = []
all_pairs = list(itertools.combinations(range(8), 2))
with Pool(processes=6) as pool:
    merge_nodes_results = pool.map(partial(merge_nodes, occu_matrix=occu_matrix), all_pairs)
merge_nodes_list = [item for sublist in merge_nodes_results for item in sublist]

Create graph from mergable nodes

In [36]:
merge_G=nx.from_edgelist(np.array(merge_nodes_list))
sub_graphs = [merge_G.subgraph(c) for c in nx.connected_components(merge_G)]

remove redundant nodes

In [37]:
new_node_idx = len(out['pred_nodes'])
for sub_g in sub_graphs:
    new_node_pos = np.array([Whole_G.nodes[node_]["position"] for node_ in sub_g.nodes()]).mean(0)

    # add new node
    Whole_G.add_node(new_node_idx)
    nx.set_node_attributes(Whole_G, {new_node_idx:new_node_pos}, "position")
    # print(new_node_pos)
    for node_ in sub_g.nodes():
        neighbors = Whole_G.neighbors(node_)
        remove_list = []
        for n1 in neighbors:
            radius_ = Whole_G.edges[(n1, node_)]["radius"] #nx.get_edge_attributes(Whole_G, edge_)
            remove_list.append([node_, n1])
            Whole_G.add_edge(n1, new_node_idx, radius=radius_)
            # nx.set_edge_attributes(Whole_G, rad_dict, "radius")

        # remove old nodes
        Whole_G.remove_node(node_)
        Whole_G.remove_edges_from(remove_list)
    new_node_idx += 1

In [38]:
pred_node = {node_:i for i, node_ in enumerate(Whole_G.nodes())}
pred_coord = np.array([Whole_G.nodes[node_]["position"] for node_ in Whole_G.nodes()])
pred_edges = np.array(Whole_G.edges())
e1 = [pred_node[x] for x in pred_edges[:, 0]]
e2 = [pred_node[x] for x in pred_edges[:, 1]]
pred_edges = np.array([[e1_,e2_] for e1_, e2_ in zip(e1,e2)])
pred_radius = np.array([Whole_G.edges[edge_]["radius"] for edge_ in Whole_G.edges()])

### Visualization

In [39]:
plot_val_rel_sample(seg_data.squeeze(), coordinates, edges,
                    pred_coord, pred_edges)