In [None]:
from math import lcm

In [None]:
with open('../inputs/20.txt') as f:
    data = f.read().splitlines()

In [None]:
def parse(lines):
    out = {}
    conj_mods = set()
    
    for line in lines:
        type, node = [None, 'broadcaster']
        
        if not line.startswith('broadcaster'):
            type = 'FLIP' if line[0] == '%' else 'CONJ'
            node = line[1:3]
            
        out[node] = {
            'type': type,
            'outputs': line.split(' -> ')[1].split(', '),
            **({'state': 0} if type == 'FLIP' else {}),
            **({'input_states': {}} if type == 'CONJ' else {})
        }
        
        if type == 'CONJ':
            conj_mods.add(node)

    for mod, mod_data in out.items():
        for output in mod_data['outputs']:
            if output in conj_mods:
                out[output]['input_states'][mod] = 0

    return out

In [None]:
def push(network, pushes = None, rx_params = None):
    cache = {}
    queue = [('button', 'broadcaster', 0)]
    
    low_count = 0
    high_count = 0

    while queue:
        origin, target, intensity = queue.pop(0)
        
        if intensity:
            high_count += 1
        else:
            low_count += 1
    
        if (origin, target, intensity) in cache:
            queue += cache[(origin, target, intensity)]
            continue

        if target == 'rx':
            continue
        
        if (
            rx_params and
            intensity == 1 and 
            target == rx_params['conjunction_node'] and
            origin not in rx_params['node_activation_count']
        ):
            rx_params['node_activation_count'][origin] = pushes

        effects = []
        node = network[target]
                
        if target == 'broadcaster':
            for output in node['outputs']:
                effects.append((target, output, intensity))
                
        if node['type'] == 'FLIP':
            if intensity == 1:
                continue
            if node['state'] == 1:
                node['state'] = 0
                for output in node['outputs']:
                    effects.append((target, output, 0))
            else:
                node['state'] = 1
                for output in node['outputs']:
                    effects.append((target, output, 1))
        
        if node['type'] == 'CONJ':
            node['input_states'][origin] = intensity
            
            if sum(node['input_states'].values()) == len(node['input_states']):
                for output in node['outputs']:
                    effects.append((target, output, 0))
            else:
                for output in node['outputs']:
                    effects.append((target, output, 1))
                    
        
        cache[(origin, target, intensity)] = effects
        
        queue += effects

    return (low_count, high_count)

In [None]:
# Part 1
def calc_pulses(network):
    low_count = 0
    high_count = 0

    for _ in range(1000):
        l, h = push(network)
        
        low_count += l
        high_count += h
                    
    return high_count * low_count

calc_pulses(parse(data))

In [None]:
# Part 2
def find_conjunction_node(network):
    for mod_id, mod_data in network.items():
        for output in mod_data['outputs']:
            if output == 'rx' and mod_data['type'] == 'CONJ':
                return [mod_id, list(mod_data['input_states'].keys())]

def calc_rx_activation_push(network):
    conjunction_node, triggering_nodes = find_conjunction_node(network)
    node_activation_count = {}
    
    pushes = 1
    rx_params = {
        'conjunction_node': conjunction_node,
        'node_activation_count': node_activation_count
    }
    while len(node_activation_count) < len(triggering_nodes):
        push(network, pushes, rx_params)
        pushes += 1
        
    return lcm(*rx_params['node_activation_count'].values())

calc_rx_activation_push(parse(data))