In [6]:
from typing import Tuple
from typing import List
from collections import defaultdict
import json
import math

class AdjacencyTuples(object):
  
  def __init__(self):
    self._adjacencyMap = {}
    
  def Add(self, key: str, value: Tuple[str, str]):
    self._adjacencyMap[key] = value
    
  def Next(self, current: str, direction: str):
    direction_index = 0 if direction == 'L' else 1
    return self._adjacencyMap[current][direction_index]
  
  def GetKeysEndingWithA(self) -> List[str]:
    return [k for k in self._adjacencyMap.keys() if k.endswith('A')]
  
  def GetKeys(self) -> List[str]:
    return list(self._adjacencyMap.keys())
  
  def Parse(lines: List[str]):
    adjacency_tuples = AdjacencyTuples()
    for line in lines:
      key_str, value_str = line.split('=')
      key_str = key_str.strip()
      value_str = value_str.strip()[1:-1].split(',')
      left = value_str[0].strip()
      right = value_str[1].strip()
      adjacency_tuples.Add(key_str, (left, right))
    return adjacency_tuples
      

# open text file
def ReadFile(filename: str):
  with open(filename, 'r') as f:
    lines = f.readlines()
  return lines

# parse text file
def ParseFile(lines: List[str]):
  directions = lines[0].strip()
  adjacency_tuples = AdjacencyTuples.Parse(lines[2:])
  return directions, adjacency_tuples

def SolvePartOne(filename: str):
  lines = ReadFile(filename)
  directions, adjacency_tuples = ParseFile(lines)
  current = 'AAA'
  idx = 0
  steps = 0
  while current != 'ZZZ':
    current = adjacency_tuples.Next(current, directions[idx])
    steps += 1
    idx = steps % len(directions)
  return steps

def SolvePartTwo(filename: str):
  lines = ReadFile(filename)
  directions, adjacency_tuples = ParseFile(lines)
  keys = adjacency_tuples.GetKeys()
  currents = list(keys)
  z_indices = defaultdict(set)
  for idx in range(len(directions)):
    currents = [adjacency_tuples.Next(c, directions[idx]) for c in currents]
    for i, c in enumerate(currents):
      if c.endswith('Z'):
        move_count = idx + 1
        z_indices[keys[i]].add(move_count)
  # build map from key to last values of currents
  one_round_map = {k: v for k, v in zip(keys, currents)}
  one_round_length = len(directions)
  
  # This is extremely lame and assumes that the path from A to Z is the same
  # as the path from the same Z to Z. And that each A goes in its own path.
  # I hate advent of code.
  keys_ending_with_a = adjacency_tuples.GetKeysEndingWithA()
  current_map = one_round_map
  current_z_indices = z_indices
  current_steps = one_round_length
  first_time_hitting_z = {}
  while True:
    # print("current steps:", current_steps)
    # print("current_z_indices:", json.dumps({k: str(v) for k, v in current_z_indices.items()}, indent=2))
    # print("current_map:", json.dumps(current_map, indent=2))
    # if all currents have a common Z index, we're done
    # to find that, get the intersection of all z_indices of currents
    for k in keys_ending_with_a:
      if k not in first_time_hitting_z and len(current_z_indices[k]) > 0:
        first_time_hitting_z[k] = min(current_z_indices[k])
        # print("first time hitting Z for", k, "is", first_time_hitting_z[k])
    if len(first_time_hitting_z) == len(keys_ending_with_a):
      # return the LCM of all first time hitting z
      return math.lcm(*first_time_hitting_z.values())
      
    common_z_indices = set.intersection(*[current_z_indices[c] for c in keys_ending_with_a])
    if len(common_z_indices) > 0:
      return min(common_z_indices)
    # otherwise, double the search space
    current_z_indices = defaultdict(set,{ k: {v + current_steps for v in current_z_indices[current_map[k]]}.union(current_z_indices[k]) for k in keys })
    current_map = { k: current_map[v] for k, v in current_map.items() }
    current_steps *= 2


assert SolvePartOne('sample.txt') == 2
assert SolvePartOne('sample2.txt') == 6
part_one_solution = SolvePartOne('input.txt')
print('Part One Solution:', part_one_solution)
assert part_one_solution == 17621

SolvePartTwo('sample3.txt')
part_two_solution = SolvePartTwo('input.txt')
print('Part Two Solution:', part_two_solution)
assert part_two_solution == 20685524831999

Part One Solution: 17621
Part Two Solution: 20685524831999
