In [1]:
import sys

sys.path.append('../utils')

In [2]:
import numpy as np
import re

from aoc import * # get_input, d4, d8, d4n
from itertools import product
from functools import reduce

In [3]:
YEAR = 2021
DAY = 22

In [4]:
sample = '''on x=10..12,y=10..12,z=10..12
on x=11..13,y=11..13,z=11..13
off x=9..11,y=9..11,z=9..11
on x=10..10,y=10..10,z=10..10'''

In [5]:
inp = get_input(YEAR, DAY)

In [6]:
def do_transform(inp):
    return list(filter(lambda x: len(x) > 0, inp.split('\n')))

In [7]:
sample = do_transform(sample)
inp = do_transform(inp)

In [8]:
class KDNode:
    def __init__(self, region, state):
        self.region = region
        self.state = state
        self.children = []
        
    def outside_of(self, region):
        return any(map(lambda x: x[1][0] >= x[0][1] or x[1][1] <= x[0][0], zip(self.region, region)))
        
    def contained_by(self, region):
        return all(map(lambda x: x[1][0] <= x[0][0] < x[0][1] <= x[1][1], zip(self.region, region)))
        
    def set_state(self, region, state):
        if self.outside_of(region):
            return
        elif self.contained_by(region):
            self.state = state
            self.children = []
            return
        elif len(self.children) == 0:
            sub = []
            for (l, r), (l1, r1) in zip(self.region, region):
                pts = sorted(list(set(filter(lambda x: l <= x <= r, [l, r, l1, r1]))))
                sub.append(list(zip(pts[:-1], pts[1:])))
            self.children = list(map(lambda x: KDNode(x, self.state), product(*sub)))

        for child in self.children:
            child.set_state(region, state)
            
    def value(self):        
        if len(self.children) == 0:
            return reduce(lambda x, y: x * y, map(lambda x: x[1] - x[0], self.region), 1) if self.state else 0
        return sum(child.value() for child in self.children)

In [9]:
def solve(inp, k, inf):
    root = KDNode([(-inf, inf + 1)] * k, False)
    
    for state, xs, xe, ys, ye, zs, ze in map(lambda x: (x.split()[0], *map(int, re.findall(r'-?\d+', x.split()[1]))), inp):
        root.set_state([(xs, xe + 1), (ys, ye + 1), (zs, ze + 1)], state == 'on')
        
    return root.value()

In [10]:
def part1(inp):
    return solve(inp, 3, 50)

In [11]:
part1(sample)

39

In [12]:
part1_ans = part1(inp)
part1_ans

648681

In [None]:
submit_answer(part1_ans, YEAR, DAY)

In [13]:
def part2(inp):
    return solve(inp, 3, 1000000)

In [14]:
part2(sample)

39

In [15]:
part2_ans = part2(inp)
part2_ans

1302784472088899

In [None]:
submit_answer(part2_ans, YEAR, DAY, level=2)