In [1]:
from collections import OrderedDict, deque

In [2]:
class Signal:
    LOW = 0
    HIGH = 1

In [3]:
class Module:
    name = None
    destination_modules = None

    def __init__(self, name=None, destination_modules=None):
        self.name = name
        self.destination_modules = destination_modules

    def receive(self, source, signal):
        return [(d, self.name, signal) for d in self.destination_modules]

    def __repr__(self):
        return '{} -> {}'.format(self.name, ', '.join(self.destination_modules))

    def __hash__(self):
        return hash(repr(self))

In [4]:
class FlipFlopModule(Module):
    on = False

    def __init__(self, name=None, destination_modules=None, on=False):
        super().__init__(name, destination_modules)
        self.on = on

    def receive(self, source, signal):
        if signal == Signal.HIGH:
            return []
        if self.on:
            self.on = False
            return [(d, self.name, Signal.LOW) for d in self.destination_modules]
        self.on = True
        return [(d, self.name, Signal.HIGH) for d in self.destination_modules]

    def __repr__(self):
        prefix = '%'
        if self.on:
            prefix = 'o'
        return '{}{} -> {}'.format(prefix, self.name, ', '.join(self.destination_modules))

In [5]:
class ConjunctionModule(Module):
    input_modules = None
    memory = None

    def set_input_modules(self, input_modules):
        self.input_modules = input_modules
        self.memory = OrderedDict()
        for m in input_modules:
            self.memory[m] = Signal.LOW

    def receive(self, source, signal):
        self.memory[source] = signal
        values = set(self.memory.values())
        if len(values) == 1 and values.pop() == Signal.HIGH:
            return [(d, self.name, Signal.LOW) for d in self.destination_modules]
        return [(d, self.name, Signal.HIGH) for d in self.destination_modules]

    def __repr__(self):
        prefix = '&'
        if self.memory:
            prefix = '[{}]'.format('-'.join(['{}:{}'.format(k, v) for k, v in self.memory.items()]))
        return '{}{} -> {}'.format(prefix, self.name, ', '.join(self.destination_modules))

In [6]:
def get_input(fname='test.txt'):
    modules = {}
    with open(fname) as f:
        for l in f.readlines():
            line = l.rstrip()
            m, dlist = line.split(' -> ')
            dest = dlist.split(', ')
            mname = m
            module = None
            if m[0] == '%':
                mname = m[1:]
                module = FlipFlopModule(mname, dest)
            elif m[0] == '&':
                mname = m[1:]
                module = ConjunctionModule(mname, dest)
            else:
                module = Module(mname, dest)
            modules[mname] = module
    return modules

In [7]:
test_input = get_input('test.txt')
my_input = get_input('input.txt')

In [8]:
list(test_input.values())

[broadcaster -> a, %a -> inv, con, &inv -> b, %b -> con, &con -> output]

In [9]:
def set_inputs(modules):
    inputs = {}
    for name, module in modules.items():
        for dest in module.destination_modules:
            if dest not in inputs:
                inputs[dest] = []
            inputs[dest].append(name)
    for module in modules.values():
        if type(module) == ConjunctionModule:
            module.set_input_modules(inputs[module.name])

In [10]:
set_inputs(test_input)
set_inputs(my_input)

In [11]:
def solve1(modules, loops=1000):
    counts = [0, 0] # LOW, HIGH
    sig_str = ['low', 'high']
    for i in range(loops):
        q = deque([('broadcaster', 'button', Signal.LOW)])
        ccounts = [0, 0]
        while q:
            name, source, signal = q.popleft()
            # print("{} -{}-> {}".format(source, sig_str[signal], name))
            ccounts[signal] += 1
            if name not in modules:
                continue
            nexts = modules[name].receive(source, signal)
            q.extend(nexts)
        counts[0] += ccounts[0]
        counts[1] += ccounts[1]
    return counts[0] * counts[1]

In [12]:
test_input = get_input('test.txt')
set_inputs(test_input)
solve1(test_input)

11687500

In [13]:
test_simple = get_input('test_simple.txt')
set_inputs(test_simple)
solve1(test_simple)

32000000

In [14]:
solve1(my_input)

834323022

In [15]:
my_input = get_input('input.txt')
set_inputs(my_input)

In [16]:
with open('input.dot', 'w') as f:
    print('digraph input {', file=f)
    print('  {', file=f)
    for module in my_input.values():
        if type(module) == FlipFlopModule:
            print('    {} [shape=invtriangle color=blue]'.format(module.name), file=f)
        elif type(module) == ConjunctionModule:
            print('    {} [shape=invhouse color=red]'.format(module.name), file=f)
        else:
            print('    {} [shape=oval]'.format(module.name), file=f)
    print('  }', file=f)
    for module in my_input.values():
        if type(module) in (FlipFlopModule, ConjunctionModule):
            print('  {}'.format(repr(module)[1:]), file=f)
        else:
            print('  {}'.format(repr(module)), file=f)
    print('}', file=f)

In [17]:
def cycle_length(modules):
    states = set(tuple(repr(m) for m in modules.values()))
    loops = 0
    while True:
        loops += 1
        q = deque([('broadcaster', 'button', Signal.LOW)])
        ccounts = [0, 0]
        while q:
            name, source, signal = q.popleft()
            if name not in modules:
                continue
            nexts = modules[name].receive(source, signal)
            q.extend(nexts)
        state = tuple(repr(m) for m in modules.values())
        if state in states:
            break
        states.add(state)
    return loops - 1

In [18]:
test_input = get_input('test.txt')
set_inputs(test_input)
cycle_length(test_input)

4

In [19]:
my_input = get_input('input.txt')
set_inputs(my_input)

broadcast_destinations = my_input['broadcaster'].destination_modules

cycle_lengths = []

for dest in broadcast_destinations:
    my_input['broadcaster'].destination_modules = [dest]
    cycle_lengths.append(cycle_length(my_input))

In [20]:
cycle_lengths

[4001, 3739, 3821, 3943]

In [21]:
from math import lcm
print(lcm(*cycle_lengths))

225386464601017
