In [89]:
%pip install yfiles_jupyter_graphs


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.2.2[0m[39;49m -> [0m[32;49m22.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [90]:
from yfiles_jupyter_graphs import GraphWidget
import xml.etree.ElementTree as ET
import os 

In [91]:
namespaces = {
  'ODM': "http://www.cdisc.org/ns/odm/v1.3",
  'SDM':  "http://www.cdisc.org/ns/studydesign/v1.0"
}

def fully_qualified(ns, element):
  return "{%s}%s" % (ns, element)

def remove_ns(name):
  for k, v in namespaces.items():
    if v in name:
      return name.replace("{%s}" % (v), "")
  return name

ignore_nodes = [ fully_qualified(namespaces['SDM'], "Summary"), fully_qualified(namespaces['ODM'], "Description"), fully_qualified(namespaces['ODM'], "ItemGroupDef") ]
node_nodes = [ fully_qualified(namespaces['SDM'], "Epoch") ]
edge_nodes = [ fully_qualified(namespaces['SDM'], "ActivityRef") ]


In [92]:
nodes = []
node_index = 0
edges = []
edge_index = 0

def add_node(label, properties):
  global node_index
  node_index += 1
  properties['label'] = label
  nodes.append({ 'id': node_index, 'properties': properties })
  return node_index

def add_edge(label, start, end):
  global edge_index
  edge_index += 1
  edges.append( {id: edge_index, 'start': start, 'end': end, 'properties': {'label': label}})
  return edge_index

In [93]:

def process_child(parent):
  attribs = parent.attrib
  label = remove_ns(parent.tag)
  parent_index = add_node(label, attribs)
  for child in parent:
    if not child.tag in ignore_nodes:
      child_index = process_child(child)
      child_edge_index = add_edge("child", parent_index, child_index)
  return parent_index

In [94]:
notebook_path = os.path.abspath("notebook.ipynb")
file_path = os.path.join(os.path.dirname(notebook_path), "source_data/lzzt_trial.xml")
tree = ET.parse(file_path)

root = tree.getroot()

for item in root.findall('.'):
  process_child(item)

#print("NODES:", nodes)

In [95]:
def custom_node_color(index: int, node: dict):
  if 'node_type' in node['properties']:
    if node['properties']['node_type'] == 'entry_exit':
      return 'black'
    elif node['properties']['node_type'] == 'anchor':
      return '#999999'
    elif node['properties']['node_type'] == 'condition':
      return '#999999'
    elif node['properties']['node_type'] == 'timepoint':
      return '#555555'
    elif node['properties']['node_type'] == 'visit':
      return '#c1141a'
    elif node['properties']['node_type'] == 'activity':
      return '#1555bd'
    elif node['properties']['node_type'] == 'bc':
      return '#c0d6e4'
    else:
      return 'white'
  else: 
    return 'white'

def custom_node_style(index: int, node: dict):
  if 'node_type' in node['properties']:
    if node['properties']['node_type'] == 'entry_exit':
      return { 'shape': 'round-rectangle' }
    elif node['properties']['node_type'] == 'anchor':
      return { 'shape': 'hexagon2' }
    elif node['properties']['node_type'] == 'condition':
      return { 'shape': 'diamond' }
    elif node['properties']['node_type'] == 'timepoint':
      return { 'shape': 'ellipse' }
    elif node['properties']['node_type'] == 'visit':
      return { 'shape': 'circle' }
    elif node['properties']['node_type'] == 'activity':
      return { 'shape': 'circle' }
    elif node['properties']['node_type'] == 'bc':
      return { 'shape': 'circle' }
    else:
      return { 'shape': 'circle' }
  else: 
    return { 'shape': 'circle' }


In [96]:
w = GraphWidget()
w.set_directed(True)

w.set_nodes(nodes)
w.set_edges(edges)

w.set_node_color_mapping(custom_node_color)
w.set_node_styles_mapping(custom_node_style)
w

GraphWidget(layout=Layout(height='500px', width='100%'))