In [1]:
test_input = """RL

AAA = (BBB, CCC)
BBB = (DDD, EEE)
CCC = (ZZZ, GGG)
DDD = (DDD, DDD)
EEE = (EEE, EEE)
GGG = (GGG, GGG)
ZZZ = (ZZZ, ZZZ)"""

test_input_2 = """LLR

AAA = (BBB, BBB)
BBB = (AAA, ZZZ)
ZZZ = (ZZZ, ZZZ)"""

In [2]:
lr_index_map = {"L": 0, "R": 1}

lines = test_input_2.split("\n")
instructions = [lr_index_map[c] for c in lines[0].strip()]

In [3]:
nodes = {}

for line in lines[2:]:
    start, mapping_string = line.strip().split(" = ")
    mapping  = tuple(mapping_string[1:-1].split(", "))
    
    nodes[start] = mapping

nodes

{'AAA': ('BBB', 'BBB'), 'BBB': ('AAA', 'ZZZ'), 'ZZZ': ('ZZZ', 'ZZZ')}

In [4]:
from itertools import cycle

In [5]:
instruction_loop = cycle(instructions)

current_node = 'AAA'
actions_limit = 1000
actions_taken = 0

while current_node != 'ZZZ':
    next_instruction = instruction_loop.__next__()
    current_node = nodes[current_node][next_instruction]
    
    actions_taken += 1
    if actions_taken > actions_limit:
        break

print(f"Took {actions_taken} actions")

Took 6 actions


In [6]:
test_input_3 = """LR

11A = (11B, XXX)
11B = (XXX, 11Z)
11Z = (11B, XXX)
22A = (22B, XXX)
22B = (22C, 22C)
22C = (22Z, 22Z)
22Z = (22B, 22B)
XXX = (XXX, XXX)"""

lines = test_input_3.split("\n")
instructions = [lr_index_map[c] for c in lines[0].strip()]

nodes = {}

for line in lines[2:]:
    start, mapping_string = line.strip().split(" = ")
    mapping  = tuple(mapping_string[1:-1].split(", "))
    
    nodes[start] = mapping

nodes

{'11A': ('11B', 'XXX'),
 '11B': ('XXX', '11Z'),
 '11Z': ('11B', 'XXX'),
 '22A': ('22B', 'XXX'),
 '22B': ('22C', '22C'),
 '22C': ('22Z', '22Z'),
 '22Z': ('22B', '22B'),
 'XXX': ('XXX', 'XXX')}

In [7]:
c_state = [node for node in nodes if node.endswith("A")]

In [8]:
def is_finished(state):
    for element in state:
        if not element.endswith("Z"):
            return False
    return True

In [9]:
instruction_loop = cycle(instructions)

current_node = 'AAA'
actions_limit = 1000
actions_taken = 0

next_state = None

while not is_finished(c_state):
    next_instruction = instruction_loop.__next__()
    
    c_state = [nodes[element][next_instruction] for element in c_state]
    
    actions_taken += 1
    if actions_taken > actions_limit:
        break

print(f"Took {actions_taken} actions")

Took 6 actions


In [10]:
nodes, instructions

({'11A': ('11B', 'XXX'),
  '11B': ('XXX', '11Z'),
  '11Z': ('11B', 'XXX'),
  '22A': ('22B', 'XXX'),
  '22B': ('22C', '22C'),
  '22C': ('22Z', '22Z'),
  '22Z': ('22B', '22B'),
  'XXX': ('XXX', 'XXX')},
 [0, 1])

In [11]:
c_state = "22B"

for i in range(10):
    next_instruction = next(instruction_loop)
    c_state = nodes[c_state][next_instruction]
    print(i, c_state)

0 22C
1 22Z
2 22B
3 22C
4 22Z
5 22B
6 22C
7 22Z
8 22B
9 22C


Construct a wild graph

In [12]:
starting_nodes = [node for node in nodes if node.endswith("A")]

max_iter = 1_000_000_000

wild_mapping = dict()

for starting_node in starting_nodes:
    seq_idx = 0
    c_node = starting_node
    while True:
        direction = instructions[seq_idx]
        next_node = nodes[c_node][direction]
        
        map_id = f"{c_node}_{seq_idx}"
        if map_id in wild_mapping:
            break
        
        next_seq_idx = (seq_idx + 1) % len(instructions)
        next_map_idx = f"{next_node}_{next_seq_idx}"
        wild_mapping[map_id] = next_map_idx
        
        c_node = next_node
        seq_idx = next_seq_idx
        
wild_mapping

{'11A_0': '11B_1',
 '11B_1': '11Z_0',
 '11Z_0': '11B_1',
 '22A_0': '22B_1',
 '22B_1': '22C_0',
 '22C_0': '22Z_1',
 '22Z_1': '22B_0',
 '22B_0': '22C_1',
 '22C_1': '22Z_0',
 '22Z_0': '22B_1'}

In [13]:
def parse(filename):
    with open(filename, "r", encoding="utf-8") as f:
        lines = f.readlines()

    lr_index_map = {"L": 0, "R": 1}

    instructions = [lr_index_map[c] for c in lines[0].strip()]

    nodes = {}

    for line in lines[2:]:
        start, mapping_string = line.strip().split(" = ")
        mapping = tuple(mapping_string[1:-1].split(", "))

        nodes[start] = mapping

    return instructions, nodes

In [16]:
instructions, nodes = parse("input.txt")

Nice, this looks feasible. We know that for each starting point, we'll traverse some distance and then join a loop. We can now go about finding the lengths of each of the loops.

In [17]:
starting_nodes = [node for node in nodes if node.endswith("A")]

max_iter = 1_000_000_000
n_iter = 0

wild_mapping = dict()
loops = []
loop_lengths = []
distance_to_loop = []
end_idxs = []

for starting_node in starting_nodes:
    seq_idx = 0
    c_node = starting_node
    nodes_seen = []
        
    while True:
        direction = instructions[seq_idx]
        next_node = nodes[c_node][direction]
        
        map_id = f"{c_node}_{seq_idx}"
        
        if map_id in wild_mapping:
            # We've found a loop so need to figure our whats in it and how long it is.
            if map_id not in nodes_seen:
                print("Found existing loop")
            else:
#                 print(f"New loop starting at {map_id}", nodes_seen)
                start_of_loop = nodes_seen.index(map_id)
                loop = nodes_seen[start_of_loop:].copy()
                
                loops.append(loop)
                loop_lengths.append(len(loop))
                distance_to_loop.append(len(nodes_seen) - len(loop))
            break
        
        nodes_seen.append(map_id)
        next_seq_idx = (seq_idx + 1) % len(instructions)
        next_map_idx = f"{next_node}_{next_seq_idx}"
        wild_mapping[map_id] = next_map_idx
        
        # Check if we've got an end
        if c_node.endswith("Z"):
            end_idxs.append(len(nodes_seen) - 1) 
        
        c_node = next_node
        
        seq_idx = next_seq_idx
        
        n_iter += 1
        if n_iter >= max_iter:
            break
        
print(f"Total map contains {len(wild_mapping)} elements")

Total map contains 107331 elements


so we might come across the end node twice in a single loop. There's definitely a simpler solution to this! But I'll forge on. 

Here we have two loops, one which starts at the second element, and then sees the end nodes every 2 iterations. The second also starts on the second element, and sees the end node every 3 iterations (is it always true that even though there are two ends we'll see them equally spaced?)

In [18]:
def get_end_points(loop):
    end_idxs = set()
    for i, node_id in enumerate(loop):
        if node_id[2] == "Z":
            end_idxs.add(i)
    return end_idxs
get_end_points(loops[3])

{16529}

In [19]:
distance_to_loop

[2, 2, 5, 2, 2, 2]

In [20]:
loop_lengths

[21409, 14363, 15989, 16531, 19241, 19783]

In [21]:
end_idxs

[21409, 14363, 15989, 16531, 19241, 19783]

So what do we do with this information now? When will all of these loops syncronise? So it seems like for 5 of the 6 loops the "end" two elements before the end of the loop, and for the other its 5 elements before. 

In [22]:
from functools import reduce

Thats a lot of different possible combinations! We can solve this by taking one loop to be our "anchor", and going around that until each of the other loops are aligned. To find out how many times we should go around each loop we need to find the difference in length between the loops, and calculate how often the two are aligned.

The misalignment between loops is the distance to the end node on the other loop when we first encounter the end on our main loop.

This nice thing about this approach is we can check our working by running the same algorithm using each loop as our main loop.

In [23]:
MAIN_LOOP = 0

difference_per_loop = []
offsets = []

for i in range(0,len(starting_nodes)):
    if i == MAIN_LOOP:
        continue
    else:
        difference_per_loop.append(loop_lengths[MAIN_LOOP] - loop_lengths[i])
        offsets.append(end_idxs[MAIN_LOOP] - end_idxs[i])
    
print(difference_per_loop)
print(offsets)

[7046, 5420, 4878, 2168, 1626]
[7046, 5420, 4878, 2168, 1626]


Based on these offsets and difference per loops, which happen to be the same number, we need to find out how many times we should go round the main loop to make all of them modulo their loop lengths = 0.

In [24]:
import math

In [26]:
offset = 2
loop_length = 21
diff_per_loop = 17

def find_times_round_loop(offset, loop_length, diff_per_loop):
    times_round = 0

    while offset != 0:
        offset = (offset + diff_per_loop) % loop_length
        times_round += 1
    return times_round

print(find_times_round_loop(offset, loop_length, diff_per_loop))

11


We need to find out how many times we need to go round each loop to overcome the offset. It is also convenient that going round one more time will bring each of them back to the original position.

In [27]:
times_round_loop = []

for i in range(0,len(starting_nodes)):
    if i == MAIN_LOOP:
        continue
    else:
        difference = loop_lengths[MAIN_LOOP] - loop_lengths[i]
        offset = end_idxs[MAIN_LOOP] - end_idxs[i]
        times_round_loop.append(find_times_round_loop(offset, loop_lengths[i], difference))
    
print(times_round_loop)

[52, 58, 60, 70, 72]


These look like some nice numbers - I did forget that we actually want to find out how many times we should go round the loop to get back to where we started, so I need to add 1 onto each of these outputs.

We can then see how many times we'd have to go round the main loop to return each of the loops back to where they started - effectively this shows us how long the main sequence is.

Now comes some convient trickery. To overcome the offset of each loop, we need to go round one fewer times. However we also go loop_length steps to get to our starting position. These two cancel eachother out, and this means our final answer will be the length of the looping sequence * length of the main loop (loop 0 in this case).

We can find the length of the loop sequence as the lowest common multiple of the number of times you'd need to go round each loop to get back to where you started.

In [31]:
loop_sequence_length = math.lcm(*[t + 1 for t in times_round_loop])
print(f"Loop sequence length: {loop_sequence_length}")

Loop sequence length: 988641701


In [33]:
n_steps = loop_sequence_length * loop_lengths[MAIN_LOOP]
print(f"All camels will be on end nodes after {n_steps} steps")

All camels will be on end nodes after 21165830176709 steps


Phew that was a hard one!