# Advent of Code 2017 - Day 7

In [None]:
data = []
with open('inputs_day_7.txt', 'r') as f:
  for line in f:
    data.append(line.strip())

print(data[:10])

['gbyvdfh (155) -> xqmnq, iyoqt, dimle', 'oweiea (97)', 'szhxrs (14)', 'pjvwsiw (23)', 'ycbok (193) -> xibkhsl, futjpq', 'wtqnsfh (32)', 'jyphghz (5573) -> tuxkm, gcprg, aabvhmt', 'fznumf (62)', 'kogwes (98)', 'sirhaf (57)']


In [None]:
# Parse
map = {}
for x in data:
  key = x[ : x.find(' ')]
  weight = int(x[x.find('(') + 1 : x.find(')')])
  above = []
  if '->' in x:
    above = x[x.find('>') + 2 :].split(', ')
  map[key] = (weight, above)

## Part 1

No need to recreate the tower (tree). We can check for the program that is only mentioned once in the input. Or more robustly, find a program that is mentioned in the left hand side of a "->" symbol but nowhere on the right hand side. 

In [None]:
from itertools import chain

root = None
all_above_ones = list(chain(*[map[key][1] for key in map])) # Concatenation of all programs above some other program
data_raw = '\n'.join(data)
for key in map:
  if not key in all_above_ones:
    root = key
    break

print(root)

qibuqqg


## Part 2

We can use the *map* variable directly in recursion, but trees make things much more convenient.

In [None]:
class Node():
  def __init__(self, name, val):
    self.name = name
    self.val = val
    self.children = [] # A list of children (not binary)

In [None]:
# Recursively build a tree
# Can optionally pass in a dictionary dic to save references to each node in tree

def build_tree(map, root, dic = None):
  root_node = Node(root, map[root][0])
  for c in map[root][1]:
    child_node = build_tree(map, c, dic)
    root_node.children.append(child_node)
  if dic != None:
    dic[root] = root_node
  return root_node

node_references = {}
root_node = build_tree(map, root, node_references)

In [None]:
# Given a node, recursively, find all weights (its weight and recursive weight of its children)
def find_cumulative_weight_above(root_node):
  weight = root_node.val
  for cn in root_node.children:
    weight += find_cumulative_weight_above(cn)
  return weight

Let's first visualize the problem (the one for the example shown since the actual problem is too large).

In [None]:
from graphviz import Digraph

g = Digraph('G', filename='Tree')
#g.attr('node', shape='circle')

for key in map:
  label = key + '\nweight:' + str(map[key][0]) + "\ncum: " + str(find_cumulative_weight_above(node_references[key]))
  g.node(key, label = label )
  for v in map[key][1]:
    g.edge(key, v)
g.view()

'Tree.pdf'

In [None]:
# Given a list of neighboring nodes, recursivley find the incorrect node (assume only 1 is wrong) that leads to imbalance
# Not tested throughly. It may fail for some edge cases but works for my input
# Looks more complicated than it is (I will rewrite this if I get a chance)

def find_wrong_node(nodes):
  cumulative_weights = [find_cumulative_weight_above(n) for n in nodes]
  if len(set(cumulative_weights)) > 1: # Is not balanced
    # Find the one node that is causing trouble (Assumes at least three children?)
    erroneous_node_weight = [w for w in cumulative_weights if cumulative_weights.count(w) == 1][0]
    erroneous_node_index = cumulative_weights.index(erroneous_node_weight)
    erroneous_node = nodes[erroneous_node_index]
    if erroneous_node.children != []:
      error = find_wrong_node(erroneous_node.children)
      if not error: # if error was not found under the subtree, it must be at current level
        correct_cum_val =  [w for w in cumulative_weights if cumulative_weights.count(w) != 1][0]
        wrong_cum_cumulative_weight = cumulative_weights[erroneous_node_index]
        diff = (correct_cum_val - wrong_cum_cumulative_weight)
        print(erroneous_node.name, 'Old Value:', erroneous_node.val, 'New Value to Balance Tree (Old {}):'.format(diff), erroneous_node.val + diff)
        return erroneous_node.name
      return error
    else: # Error found in leaf
      correct_cum_val =  [w for w in cumulative_weights if cumulative_weights.count(w) != 1][0]
      wrong_cum_cumulative_weight = cumulative_weights[erroneous_node_index]
      diff = (correct_cum_val - wrong_cum_cumulative_weight)
      print(erroneous_node.name, 'Old Value:', erroneous_node.val, 'New Value to Balance Tree (Old {}):'.format(diff), erroneous_node.val + diff)
      return erroneous_node.name

wrong_node = find_wrong_node(root_node.children)
wrong_node

egbzge Old Value: 1086 New Value to Balance Tree (Old -7): 1079


'egbzge'