In [1]:
import collections

In [2]:
with open('day6.input') as fp:
    lines = fp.readlines()
data = [line.strip() for line in lines]
data[:5]

['BYZ)LMV', '2CT)GV2', '6RK)HK7', 'RJJ)MVV', 'YFQ)4LC']

In [3]:
testdata = 'COM)B B)C C)D D)E E)F B)G G)H D)I E)J J)K K)L'.split()
testdata

['COM)B', 'B)C', 'C)D', 'D)E', 'E)F', 'B)G', 'G)H', 'D)I', 'E)J', 'J)K', 'K)L']

## Part 1 ##

In [4]:
def connections(orbits):
    c = collections.defaultdict(list)
    for term in orbits:
        a, b = term.split(')')
        c[a].append(b)
    children = []
    for v in c.values():
        children.extend(v)
    root = set(c.keys()) - set(children)
    if len(root) != 1:
        raise ValueError(f'Multiple roots of the tree: {root}')
    return list(root)[0], c

In [5]:
root, edges = connections(testdata)
root, edges

('COM',
 defaultdict(list,
             {'COM': ['B'],
              'B': ['C', 'G'],
              'C': ['D'],
              'D': ['E', 'I'],
              'E': ['F', 'J'],
              'G': ['H'],
              'J': ['K'],
              'K': ['L']}))

In [6]:
def walk(root, edges, prn=False):
    lnodes = [root]
    pos = 0
    while lnodes:
        pos += 1
        rnodes = []
        for lnode in lnodes:
            if lnode in edges:
                rnodes.extend(edges[lnode])
        if prn:
            print(pos, rnodes)
        yield pos*len(rnodes)
        lnodes = rnodes        

In [7]:
42 == sum(walk('COM', edges, True))

1 ['B']
2 ['C', 'G']
3 ['D', 'H']
4 ['E', 'I']
5 ['F', 'J']
6 ['K']
7 ['L']
8 []


True

In [8]:
puzzle_root, puzzle_edges = connections(data)
puzzle_root

'COM'

In [9]:
sum(walk('COM', puzzle_edges))

312697

## Part 2 ##

In [10]:
testdata2 = 'COM)B B)C C)D D)E E)F B)G G)H D)I E)J J)K K)L K)YOU I)SAN'.split()

In [11]:
def bidirectional_connections(orbits):
    c = collections.defaultdict(list)
    for term in orbits:
        a, b = term.split(')')
        c[a].append(b)
        c[b].append(a)
    return c

In [12]:
test2conns = bidirectional_connections(testdata2)
test2conns

defaultdict(list,
            {'COM': ['B'],
             'B': ['COM', 'C', 'G'],
             'C': ['B', 'D'],
             'D': ['C', 'E', 'I'],
             'E': ['D', 'F', 'J'],
             'F': ['E'],
             'G': ['B', 'H'],
             'H': ['G'],
             'I': ['D', 'SAN'],
             'J': ['E', 'K'],
             'K': ['J', 'L', 'YOU'],
             'L': ['K'],
             'YOU': ['K'],
             'SAN': ['I']})

In [13]:
def remove_node(node, c):
    del c[node]
    for key in c:
        if node in c[key]:
            c[key].remove(node)

In [14]:
def endpoints(c):
    """Find the point 'YOU' is orbiting, and the point 'SAN' is orbiting.
       Also, remove 'YOU' and 'SAN' from the bidirectional connection graph c.
    """
    start = c['YOU'][0]
    end = c['SAN'][0]
    remove_node('YOU', c)
    remove_node('SAN', c)
    return start, end

In [15]:
def transition(conns):
    """Find the smallest number of steps from start to end. I do this by finding all nodes 1 away from the starting position.
       Then remove all of the starting nodes from the graph (to eliminate backsteps), and set the list of 1-away nodes as the
       now current nodes. Iterate until the end point is in the list of current nodes."""
    c = conns.copy()
    start, end = endpoints(c)
    cur = [start]
    dist = 0
    while end not in cur:
        new = []
        for node in cur:
            new.extend(c[node])
            remove_node(node, c)
        dist += 1
        cur = new
    return dist

In [16]:
4 == transition(test2conns)

True

In [17]:
transition(bidirectional_connections(data))

466