In [1]:
import os
import random

import networkx as nx
import numpy as np
import torch

import plotly.graph_objects as go

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DEFAULT_RADIUS = 3
DEFAULT_COLOR = 'yellow'


def draw_3d_vessel_network(nx_vessel_network, nodes_pos, edges_radius=None, nodes_groups=None):
    coordinates_by_node = {node_id: coordinate for node_id, coordinate in enumerate(nodes_pos)}

    groups = []
    if nodes_groups:
        for group in nodes_groups:
            groups.append({
                'nodes_x': [coordinates_by_node[i][0] for i in group['nodes']],
                'nodes_y': [coordinates_by_node[i][1] for i in group['nodes']],
                'nodes_z': [coordinates_by_node[i][2] for i in group['nodes']],
                'color': group['color'],
                'opacity': group['opacity']
            })

    x_nodes = [coordinates_by_node[i][0] for i in range(len(coordinates_by_node))]  # x-coordinates of nodes
    y_nodes = [coordinates_by_node[i][1] for i in range(len(coordinates_by_node))]  # y-coordinates
    z_nodes = [coordinates_by_node[i][2] for i in range(len(coordinates_by_node))]  # z-coordinates

    # we  need to create lists that contain the starting and ending coordinates of each edge.
    x_edges = []
    y_edges = []
    z_edges = []

    for edge in nx_vessel_network.edges():
        # format: [beginning,ending,None]
        x_coords = [coordinates_by_node[edge[0]][0], coordinates_by_node[edge[1]][0], None]
        x_edges += x_coords

        y_coords = [coordinates_by_node[edge[0]][1], coordinates_by_node[edge[1]][1], None]
        y_edges += y_coords

        z_coords = [coordinates_by_node[edge[0]][2], coordinates_by_node[edge[1]][2], None]
        z_edges += z_coords

    trace_edges = []
    for edge_idx in range(0, len(nx_vessel_network.edges())):
        start_pos = edge_idx * 3
        x_edge = x_edges[start_pos:start_pos + 3]
        y_edge = y_edges[start_pos:start_pos + 3]
        z_edge = z_edges[start_pos:start_pos + 3]

        if edges_radius is not None:
            edge_radius = edges_radius[edge_idx]
        else:
            edge_radius = DEFAULT_RADIUS

        # Create a trace for the edges
        trace_edges.append(
            go.Scatter3d(x=x_edge, y=y_edge, z=z_edge, mode='lines', line=dict(color='rgba(255, 255, 255, 0.5)',
                                                                               width=edge_radius * 2),
                         hoverinfo='none'))

    trace_nodes = []
    if nodes_groups:
        for group in groups:
            trace_nodes.append(go.Scatter3d(x=group['nodes_x'], y=group['nodes_y'], z=group['nodes_z'], mode='markers',
                                            marker=dict(symbol='circle', size=5, color=group['color']),
                                            opacity=group['opacity']
                                            # line=dict(color='black', width=0.5)),
                                            ))
    else:
        # Create a trace for the nodes
        trace_nodes.append(go.Scatter3d(x=x_nodes, y=y_nodes, z=z_nodes, mode='markers',
                                        marker=dict(symbol='circle', size=5, color='lightgreen'),
                                        # line=dict(color='black', width=0.5)),
                                        ))

    axis = dict(showbackground=False, showline=False, zeroline=False, showgrid=False, showticklabels=False, title='')

    layout = go.Layout(title="Vascular Networks", width=650, height=625, showlegend=False, scene=dict(xaxis=dict(axis),
                                                                                                      yaxis=dict(axis),
                                                                                                      zaxis=dict(axis),
                                                                                                      ),
                       margin=dict(t=100),
                       hovermode='closest')

    data = [*trace_edges, *trace_nodes]
    fig = go.Figure(data=data, layout=layout)

    return fig


In [3]:
HIDDEN_SIZE = 512
NUM_LAYERS = 4
NUM_FEATURES = 202
MAX_PATHS_FOR_EACH_REACHABLE_NODE = 2
MAX_NUM_INPUT_PATHS = 4
MAX_NUM_INPUT_NODES = 7
MAX_NUM_OUTPUT_NODES = 10
CHECKPOINT_SAVE_PATH = '.'
CHECKPOINT_SAVE_NAME = 'checkpoint.pt'
CHECKPOINT_SAVE_PATH = os.path.join(CHECKPOINT_SAVE_PATH, CHECKPOINT_SAVE_NAME)

In [4]:
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

In [None]:
from sgg.model import GraphEncoderRNN, GraphDecoderRNN

encoder_rnn = GraphEncoderRNN(n_dimensions=3, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYERS, coordinates_range=NUM_FEATURES)
decoder_rnn = GraphDecoderRNN(n_dimensions=3, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYERS, coordinates_range=NUM_FEATURES)