In [1]:
import sys
!{sys.executable} -m pip install networkx



In [2]:
import networkx as nx
import unittest as ut

In [36]:
def n_distance_neighbors(G, start_node, distance):
  """
  This function takes in a graph, starting node, and a distance. It will traverse
  the graph from the starting node and find all the neighbors within the distance
  where distance between two vertices in a graph is the number of edges in a 
  shortest path.

  :param G: A networkx graph.
  :param start_node: The starting point. Distance from this node. Must be in G.
  :param distance: The distance to find neighbors within.
  :return: A list of vertices within specified distance.
  :rtype: List
  """
  neighbors = []
  nodes = nx.single_source_shortest_path_length(G,start_node,distance)
  for x in nodes:
    if nodes[x] <= distance:
      neighbors.append(x)
  return neighbors


In [37]:
class TestNeighbors(ut.TestCase):
  
  def test_distance_two_simple_graph(self):
    G = nx.Graph()
    G.add_edges_from([(1, 2), (1, 3),(1,4),(2,5),(5,6),(6,7)])
    result = n_distance_neighbors(G, 1, 2)
    answer = [1,2,3,4,5]
    self.assertEqual(result, answer)

  def test_distance_three_simple_graph(self):
    G = nx.Graph()
    G.add_edges_from([(1, 2), (1, 3),(1,4),(2,5),(5,6),(6,7)])
    result = n_distance_neighbors(G, 1, 3)
    answer = [1,2,3,4,5,6]
    self.assertEqual(result, answer)

  def test_distance_three_simple_graph(self):
    G = nx.Graph()
    G.add_edges_from([(1, 2), (1, 3),(1,4),(2,5),(5,6),(6,7)])
    result = n_distance_neighbors(G, 1, 4)
    answer = [1,2,3,4,5,6,7]
    self.assertEqual(result, answer)

  def test_distance_three_simple_graph(self):
    G = nx.Graph()
    G.add_edges_from([(1, 2), (1, 3),(1,4),(2,5),(5,6),(6,7)])
    result = n_distance_neighbors(G, 1, 4)
    answer = [1,2,3,4,5,6,7]
    self.assertEqual(result, answer)

  def test_distance_two_complex_graph(self):
    G = nx.Graph()
    G.add_edges_from([(1, 2), (1, 3),(1,9), (2,3), (2,4), (2,5), (2,8), (3,6), 
                      (4,5), (4,12), (4,13),(6,7), (7,8), (9,10), (10,11), (11,12),
                      (13,14), (14,15), (14,17),(15,16) ])
    result = n_distance_neighbors(G, 1, 2)
    result.sort()
    answer = [1,2,3,4,5,6,8,9, 10]
    self.assertEqual(result, answer)

  def test_distance_three_complex_graph(self):
    G = nx.Graph()
    G.add_edges_from([(1, 2), (1, 3),(1,9), (2,3), (2,4), (2,5), (2,8), (3,6), 
                      (4,5), (4,12), (4,13),(6,7), (7,8), (9,10), (10,11), (11,12),
                      (13,14), (14,15), (14,17),(15,16) ])
    result = n_distance_neighbors(G, 1, 3)
    result.sort()
    answer = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
    self.assertEqual(result, answer)

  def test_distance_four_complex_graph(self):
    G = nx.Graph()
    G.add_edges_from([(1, 2), (1, 3),(1,9), (2,3), (2,4), (2,5), (2,8), (3,6), 
                      (4,5), (4,12), (4,13),(6,7), (7,8), (9,10), (10,11), (11,12),
                      (13,14), (14,15), (14,17),(15,16) ])
    result = n_distance_neighbors(G, 1, 4)
    result.sort()
    answer = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,14]
    self.assertEqual(result, answer)

  def test_distance_five_complex_graph(self):
    G = nx.Graph()
    G.add_edges_from([(1, 2), (1, 3),(1,9), (2,3), (2,4), (2,5), (2,8), (3,6), 
                      (4,5), (4,12), (4,13),(6,7), (7,8), (9,10), (10,11), (11,12),
                      (13,14), (14,15), (14,17),(15,16) ])
    result = n_distance_neighbors(G, 1, 5)
    result.sort()
    answer = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17]
    self.assertEqual(result, answer)

  def test_distance_six_complex_graph(self):
    G = nx.Graph()
    G.add_edges_from([(1, 2), (1, 3),(1,9), (2,3), (2,4), (2,5), (2,8), (3,6), 
                      (4,5), (4,12), (4,13),(6,7), (7,8), (9,10), (10,11), (11,12),
                      (13,14), (14,15), (14,17),(15,16) ])
    result = n_distance_neighbors(G, 1, 6)
    result.sort()
    answer = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
    self.assertEqual(result, answer)

ut.main(argv=[''], exit=False)

.......
----------------------------------------------------------------------
Ran 7 tests in 0.008s

OK


<unittest.main.TestProgram at 0x7fe7c4594860>