In [1]:
import heapq
from collections import Counter
from itertools import product
from math import prod
import numpy as np
from dataclasses import dataclass

In [2]:
#FILE = "example.txt"
FILE = "input.txt"

In [3]:
inputs = [
    line.rstrip('\n').split(' ') for line in open(FILE)
]

In [4]:
class Machine:
    lights: str
    lights_bitmask: np.ndarray
    wiring: list[list[int]]
    wiring_bitmasks: np.ndarray
    joltages: list[int]
    required_permutations: int

    def __init__(self, lights: str, wiring: list[list[int]], joltages: list[int]):
        self.lights = lights
        self.lights_bitmask = np.array(
            [1 if i == '#' else 0 for i in lights]
        )
        self.wiring = wiring
        self.joltages = joltages

        self.wiring_bitmasks = np.array([
            [1 if i in j else 0 for i in range(len(lights))] for j in wiring
        ])

        self.required_permutations = self.get_required_permutations()
    
    def get_required_permutations(self):
            matched = False
            
            n_switches = len(self.wiring)
            required_count = None

            count = 0
            while not matched:
                count += 1
                for wiring_indices in product(range(n_switches), repeat=count):
                    state = np.zeros(len(self.lights_bitmask), dtype=int)
                    for idx in wiring_indices:
                        state = np.bitwise_xor(state, self.wiring_bitmasks[idx])
                    if np.array_equal(state, self.lights_bitmask):
                        required_count = count
                        matched = True
                        break
                if matched:
                    break
            return required_count if matched else -1


In [5]:
parsed_inputs = [
    Machine(
        lights=line[0].lstrip('[').rstrip(']'),
        wiring=[
            list(
                map(
                    int, i.lstrip('(').rstrip(')').strip().split(',')
                )
            ) 
             for i in line[1:-1]
        ],
        joltages=tuple(map(int, line[-1].strip('{').strip('}').split(',')))
    ) for line in inputs
    ]

In [6]:
parsed_inputs[0].wiring

[[2, 3, 4, 5], [0, 2, 3, 5, 6], [1, 4, 6], [2, 3, 4], [0, 5]]

In [7]:
parsed_inputs[0].wiring_bitmasks

array([[0, 0, 1, 1, 1, 1, 0],
       [1, 0, 1, 1, 0, 1, 1],
       [0, 1, 0, 0, 1, 0, 1],
       [0, 0, 1, 1, 1, 0, 0],
       [1, 0, 0, 0, 0, 1, 0]])

In [8]:
parsed_inputs[0].lights_bitmask

array([0, 0, 1, 1, 1, 1, 0])

In [9]:
parsed_inputs[0].required_permutations

1

In [10]:
required_permutations = sum(i.required_permutations for i in parsed_inputs)

In [11]:
required_permutations

399