In [105]:
import networkx as nx
import matplotlib.pyplot as plt

In [106]:
def node_label(position, cf, ctp):
  return str(position) + ':' + str(cf) + ',' + str(ctp)

## Pruning Functons

In [107]:
# remove all nodes unreachable from a first node

def remove_unreachable_nodes(graph):
    """
    Remove nodes from a NetworkX graph that are not on a path to any of the specified nodes.

    Args:
        G (nx.Graph or nx.DiGraph): The input graph.
        specified_nodes (list): List of nodes that are considered reachable.

    Returns:
        nx.Graph or nx.DiGraph: The modified graph with unreachable nodes removed.
    """
    # find first nodes
    nodes = list(graph.nodes())
    first_nodes = []
    for node in nodes:
      index = int(node.split(':')[0])
      if index == 0:
        first_nodes.append(node)
    
    print(first_nodes)
      
    # Create a set of reachable nodes
    reachable_nodes = set(first_nodes)

    # Perform a breadth-first search from each specified node and add all reachable nodes to the set
    for node in first_nodes:
        reachable_nodes.update(nx.descendants(graph, node))

    # Get the set of nodes that are not reachable
    unreachable_nodes = set(graph.nodes()) - reachable_nodes

    # Remove the unreachable nodes from the graph
    graph.remove_nodes_from(unreachable_nodes)

    # Return the modified graph with unreachable nodes removed
    return graph

In [108]:
# remove nodes not on a path to a final node

def remove_unreaching_nodes(graph, final_index):
  """Remove nodes from a NetworkX graph that do not have a path to any of the
  nodes in the input array.

  Args:
      G (nx.Graph or nx.DiGraph): The input graph.
      nodes (list): List of nodes that are considered reachable.

  Returns:
      nx.Graph or nx.DiGraph: The modified graph with unreachable nodes removed.
  """
  # Create a list of last nodes
  nodes = list(graph.nodes())
  final_nodes = []
  for node in nodes:
    index = int(node.split(':')[0])
    if index == final_index:
      final_nodes.append(node)
  on_path = []

  # add all nodes on a path to the end to on_path.
  
  while nodes:
    node = nodes.pop()
    is_on_path = None
    for final_node in final_nodes:
      if final_node.__contains__(node):
        is_on_path = True
      elif not is_on_path:
        descendants = (nx.descendants(graph, node))
        if descendants.__contains__(final_node):
          is_on_path = True
    if is_on_path:
      on_path.append(node)
    else:
      print("hi")
      
  
  # Create a set of all of the nodes that do not have a path to any of the specified nodes.
  
  unreachable_nodes = set(graph.nodes()) - set(on_path)
  
  # Remove all of the unreachable nodes from the graph.
  for node in unreachable_nodes:
    graph.remove_node(node)
  
  return graph


In [110]:
def make_network(nodes, edges, final_index, willPrint=True):
  """
  input:
    1. nodes array ( form: [a,b,c] )
    2. edges array ( form: [(a,b),(b,c)] )
  output:
    a NetworkX graph"""
    
  graph = nx.DiGraph()
  graph.add_nodes_from(nodes)
  graph.add_edges_from(edges)
  
  graph = remove_unreachable_nodes(graph)
  graph = remove_unreaching_nodes(graph,final_index)
  
  # display nodes and edges unless otherwise specified
  if willPrint:
    print(graph.nodes)
    print(graph.edges)
  
  return graph

In [111]:
def visualize_graph(G):
  # Set the spring layout with custom node positions
  pos = nx.spring_layout(G, seed=42)

  #create dictionary of x-axis and y-axis indexes
  pos_nodes = G.nodes
  pos_node_x = {}
  pos_node_y = {}
  for node in pos_nodes.items():
    node = node[0]
    index = node.split(':')[0]
    ctp = node.split(',')[1]
    pos_node_y[node] = int(ctp)
    pos_node_x[node] = (int(index), int(ctp))
  
  print(pos_node_x)

  # Set custom x-axis positions for nodes
  for node, coords in pos_node_x.items():
    pos[node] = coords
  """or node, y in pos_node_y.items():
    pos[node][0] = y """
  
  # TODO create dictionary of y-axis positions for pitch classes
  
  nx.draw(G, pos, with_labels=True, arrows=True)
  plt.figure(figsize = (2^16,2^16))