In [1]:
from typing import Tuple, List
Interval = List[Tuple[int, int]]
Row = List[Tuple[str, Interval]]

In [2]:
def parse_one_coordinate(coordinate: str) -> str:
    numbers = coordinate.split("=")[1]
    return (int(numbers.split(".")[0]), int(numbers.split(".")[2]))

def parse_interval(string: str) -> Interval:
    return [ parse_one_coordinate(c) for c in  string.split(",")]

In [3]:
assert parse_interval("x=10..12,y=20..22,z=-20..12") == [(10, 12), (20, 22), (-20, 12)]

In [4]:
def read_data(path) -> List[Tuple[str, Interval]]:
    f = open(path,"r")
    lines = [line.strip() for line in f.readlines()]
    return [(line.split(" ")[0], parse_interval(line)) for line in lines]

In [5]:
demo_dataset = read_data("demo.txt")
demo2_dataset = read_data("demo2.txt")
demo3_dataset = read_data("demo3.txt")
full_dataset = read_data("data.txt")
demo_dataset

[('on', [(10, 12), (10, 12), (10, 12)]),
 ('on', [(11, 13), (11, 13), (11, 13)]),
 ('off', [(9, 11), (9, 11), (9, 11)]),
 ('on', [(10, 10), (10, 10), (10, 10)])]

# Part 1

In [6]:
def limit(number: int) -> int:
    if number < -50:
        return -51
    elif number > 50:
        return 51
    else:
        return number

In [7]:
def get_limits(row: Row) -> Interval:
    return [(limit(c1), limit(c2)) for c1, c2 in row[1]]

In [8]:
def count_on(dataset: List[Row]) -> int:
    on = set()
    for row in dataset:
        onoff = row[0]
        [(x1,x2), (y1, y2), (z1, z2)] = get_limits(row)
        for x in range(x1, x2 + 1):
            if abs(x) <= 50:
                for y in range(y1, y2 + 1):
                    if abs(y) <= 50:
                        for z in range(z1, z2 + 1):
                            if abs(z) <= 50:
                                if onoff == "on":
                                    on.add((x,y,z))
                                else:
                                    on.discard((x,y,z))
    return len(on)

In [9]:
assert count_on(demo_dataset) == 39
assert count_on(demo2_dataset) == 590784

In [10]:
%%time
count_on(full_dataset)

CPU times: user 572 ms, sys: 31.6 ms, total: 604 ms
Wall time: 602 ms


642125

# Part 2

Let's keep a list of disjoint intervals.  
When we add a new interval intersecting an existing interval we split it into several intervals representing the new interval minus the intersection.

In [11]:
def is_contained_element( a: int, b: Tuple[int, int]) -> bool:
    return (a[0] <= b[0] <= a[1]) and (a[0] <= b[1] <= a[1])

In [12]:
def isin( a: Tuple[int, int], b: Tuple[int, int]) -> bool:
    return (a[0] <= b[0] <= a[1]) or (b[0] <= a[0] <= b[1])

In [13]:
assert isin((-44, 5), (-5, 47))
assert isin((-27, 21), (-31, 22))
assert isin((-14, 35), (-19, 33))

In [14]:
def has_intersection(a: Interval, b: Interval) -> bool:
    return isin(a[0], b[0]) and isin(a[1], b[1]) and isin(a[2], b[2])

In [15]:
assert has_intersection([(-5, 47), (-31, 22), (-19, 33)], [(-44, 5), (-27, 21), (-14, 35)])
assert has_intersection([(-44, 5), (-27, 21), (-14, 35)], [(-5, 47), (-31, 22), (-19, 33)])

In [16]:
def is_contained(b: Interval, a: Interval) -> bool:
    """Whether b is fully contained in a"""
    return is_contained_element(a[0], b[0]) and is_contained_element(a[1], b[1]) and is_contained_element(a[2], b[2])

In [17]:
def intersect_tuples(x1, x2, a1, a2) -> List[Tuple[int, int]]:
    if x1 <= a1 <= x2 <= a2:
        return [(x1, a1 - 1), (a1, x2)]
    elif a1 <= x1 <= a2 <= x2:
        return [(x1, a2), (a2 + 1, x2)]
    elif a1 <= x1 <= x2 <= a2:
        return [(x1, x2)]
    elif x1 <= a1 <= a2 <= x2:
        return [(x1,a1-1), (a1, a2), (a2 + 1, x2)]
    else:
        raise Exception("Unexpected tuples to intersect")

In [18]:
assert intersect_tuples(11, 15, 10, 12) == [(11, 12), (13, 15)]

In [19]:
def is_empty(i: Interval) -> bool:
    return (i[0][0] > i[0][1]) or (i[1][0] > i[1][1]) or (i[2][0] > i[2][1])

In [20]:
def intersect(to_split: Interval, interval: Interval) -> List[Interval]:
    """Splits in 9 parts an interval."""
    [(x1,x2), (y1, y2), (z1, z2)] = to_split
    [(a1,a2), (b1, b2), (c1, c2)] = interval
    intervals = []
    for X in intersect_tuples(x1, x2, a1, a2):
        for Y in intersect_tuples(y1, y2, b1, b2):
            for Z in intersect_tuples(z1, z2, c1, c2):
                intervals.append([X, Y, Z])
    # filter out included and empty
    return [i for i in intervals if not is_contained(i, interval) and not is_empty(i)]

In [21]:
from copy import deepcopy

def intersect_with_all(to_add, intervals) -> List[Interval]:
    new_intervals = [i for i in intervals]
    all_parts = [to_add]
    for interval in intervals:
        new_parts = []
        for part in all_parts:
            if has_intersection(part, interval):
                if is_contained(interval, part):
                    new_parts.append(part)
                    l = len(new_intervals)
                    new_intervals.remove(interval)
                elif not is_contained(part, interval):
                    intersected = intersect(part, interval)
                    new_parts += intersect(part, interval)
            else:
                new_parts.append(part)
        all_parts = new_parts
    return new_intervals + all_parts
            

In [22]:
def remove_interval(to_remove:Interval, intervals: List[Interval]) -> List[Interval]:
    new_intervals = []
    for interval in intervals:
        if has_intersection(to_remove, interval):
            intersected = intersect(interval, to_remove)
            new_intervals += intersected
        else:
            new_intervals.append(interval)
    return new_intervals

In [23]:
def tuple_size(t: Tuple[int, int]) -> int:
    return t[1] - t[0] + 1

def interval_size(i: Interval) -> int:
    return tuple_size(i[0]) * tuple_size(i[1]) * tuple_size(i[2])

In [24]:
def validate_no_intersection(intervals: List[Interval]):
    for i in intervals:
        for j in intervals:
            if i != j and has_intersection(i, j):
                raise Exception("Intersection: " + str(i) + " and " +  str(j))

In [25]:
def ok_i_ll_be_smart_to_count_on(dataset) -> int:
    intervals = []
    for row in dataset:
        if row[0] == "on":
            intervals = intersect_with_all(row[1], intervals)
        else:
            intervals = remove_interval(row[1], intervals)
        # validate_no_intersection(intervals)
    return sum([ interval_size(i) for i in intervals])

In [26]:
assert ok_i_ll_be_smart_to_count_on(demo_dataset) == 39

In [27]:
%%time
assert ok_i_ll_be_smart_to_count_on(demo3_dataset) == 2758514936282235

CPU times: user 181 ms, sys: 189 µs, total: 181 ms
Wall time: 181 ms


In [28]:
%%time
ok_i_ll_be_smart_to_count_on(full_dataset)

CPU times: user 8.13 s, sys: 550 µs, total: 8.13 s
Wall time: 8.13 s


1235164413198198