In [1]:
from typing import NamedTuple
from collections import deque

In [2]:
def parse(lines):
    blocks = []
    for line in lines:
        col, row = line.split(',')
        blocks.append(Point(int(col), int(row)))
    return blocks

class Point(NamedTuple):
    col: int
    row: int

    def __add__(self, other):
        return Point(self.col + other.col, self.row + other.row)
        
    def neighbors(self, h, w):
        for p in (Point(0, 1), Point(0, -1), Point(1, 0), Point(-1, 0)):
            p = p + self
            if 0 <= p.col < h and 0 <= p.row < w:
                yield p

def draw(blocks, h, w, highlight=None):
    for r in range(h):
        for c in range(w):
            if Point(c, r) == highlight:
                print('O', end='')
            elif Point(c, r) in blocks:
                print('#', end='')
            else:
                print('.', end='')
        print('\n', end='')

## Part one

Hmm, just breadth first search…wonder what's coming next?

In [3]:
def bfs(start, end, bytes, h, w):
    
    blocks = set(bytes)
    seen = set([start])
    
    current_time = 0
    d = deque([(start, 0)])

    while d:
        loc, time = d.popleft()
        if loc == end:
            return loc, time
        
        for n in loc.neighbors(h, w):
            if n in blocks or n in seen:
                continue
            d.append((n, time + 1))
            seen.add(n)
    

### Sample data

In [4]:
s = '''5,4
4,2
4,5
3,0
2,1
6,3
2,4
1,5
0,6
3,3
2,6
5,1
1,2
5,5
2,5
6,5
1,4
0,4
6,4
1,1
6,1
1,0
0,5
1,6
2,0'''.split('\n')

SAMPLE_HEIGHT = 7
SAMPLE_WIDTH = 7

In [5]:

all_bytes = parse(s)
some_bytes = all_bytes[:12]

draw(some_bytes, SAMPLE_HEIGHT,SAMPLE_WIDTH)

p, time = bfs(
    Point(0, 0),
    Point(6, 6),
    some_bytes,
    SAMPLE_HEIGHT,
    SAMPLE_WIDTH
)
print(f"Finished at {p} in {time} steps")

...#...
..#..#.
....#..
...#..#
..#..#.
.#..#..
#.#....
Finished at Point(col=6, row=6) in 22 steps


In [6]:
with open('input_files/18.txt') as f:
    raw = f.read().splitlines()

all_bytes = parse(raw)
some_bytes = all_bytes[:1024]

HEIGHT = 71
WIDTH = 71

p, time = bfs(
    Point(0, 0),
    Point(70, 70),
    some_bytes,
    HEIGHT,
    WIDTH
)
print(f"Finished at {p} in {time} steps")

Finished at Point(col=70, row=70) in 302 steps


## Part two

The obvious way of running bfs after each additional dropped byte works, but is **very** slow on the big input.

In [7]:
all_bytes = parse(s)

for i in range(1, len(all_bytes)):
    r = bfs(
        Point(0, 0),
        Point(6, 6),
        all_bytes[:i],
        SAMPLE_HEIGHT,
        SAMPLE_WIDTH
    )
    
    if r is None:
        print(f"finished at step {i}. Last byte: {all_bytes[i-1]}")
        break

finished at step 21. Last byte: Point(col=6, row=1)


In [8]:
all_bytes = parse(raw)

# This takes about 22 seconds!!
for i in range(1, len(all_bytes)):
    r = bfs(
        Point(0, 0),
        Point(70, 70),
        all_bytes[:i],
        HEIGHT,
        WIDTH
    )
    
    if r is None:
        print(f"finished at step {i}. Last byte: {all_bytes[i-1]}")
        break


finished at step 2856. Last byte: Point(col=24, row=32)


## A quick fix:
Since we know the bytes will fall in an order like:

```
[unblocked, unblocked, unblocked, unblocked, BLOCKED, BLOCKED, BLOCKED]
```

We can just look for the first BLOCKED using binary search. BFS above returns None when there's no path. Do jsut look for the first time that happens and avoid searching the whole loop.

In [9]:
from bisect import bisect_left

all_bytes = parse(raw)

HEIGHT = 71
WIDTH = 71


first_blocked_index = bisect_left(
    range(len(all_bytes)),
    True,
    key=lambda idx: bfs(Point(0, 0), Point(70, 70), all_bytes[:idx], HEIGHT, WIDTH) == None)

all_bytes[first_blocked_index - 1]


Point(col=24, row=32)

### Over-engineered, but curious how it is to solve this with disjoint sets


This problem looks like a graph connectivity problem. At a certain point it will take only a single cut to partition the graph, which would make areas unreachable from others. Maybe we can start with all the bytes dropped with the knowledge that there is no path between the start and the end. Now make sets of all connected components.

#### Two partitions 

```
++++++#  
+++++#o  
++++#oo  
+++#ooo  
++#oooo  
+#ooooo  
#oooooo  
```

Starting with the above we can tell in constant time if removing a # connects the sets (thereby allowing a path) by checking if the neighbors of the # are in both sets. The sample problem only has two disconnected components when all the blocks are down. But the real problems has many. So the strategy here is:

1. Use BFS to find all the disjoint sets with all the bytes dropped, keeping special track of the component with the start and end
2. Look at each byte in reverse order
3. Find its neighbors and all the sets containing those neighbors.
4. If the sets include the sets with the start and stop point, we're done, removing this block allows a path
5. Otherwise merge these partitions into a larger partition (paying attention to keep track of the start/stop paritions)

This will eventually get to a point where removing the a block connects the start and stop. This was the last byte to drop blocking the path.

Lists of sets may not be the best data structure. Perhaps the [Disjoint-set data structure](https://en.wikipedia.org/wiki/Disjoint-set_data_structure) would be even better.


### Partition all points in the graph that are not in the set of bytes

In [10]:
def find_points(start, all_points, h, w):
    '''
    Find all points reachable from the start point. This will 
    return a single connected component of the graph.
    '''
    
    seen = set([start])
    d = deque([start])

    while d:
        loc = d.popleft()
        for n in loc.neighbors(h, w):
            if n in seen or n not in all_points :
                continue
            d.append(n)
            seen.add(n)
            
    return seen
    
def get_all_points(h, w, blocks):
    '''
    Helper to just get every grid point that is
    not one of the falling bytes.
    '''
    all_points = set()
    # every cooridinate in the space as a set of Point
    for row in range(h):
        for col in range(w):
            all_points.add(Point(col, row))

    # remove the fallen bytes
    return  all_points - set(blocks)

def make_partitions(points, h, w):
    '''
    Find all the connected components in the points.
    '''
    # Make start and end partitions manually to make it easier to track them
    start_partition = find_points(Point(0, 0), all_points, h, w)
    end_partition = find_points(Point(h-1, w-1), all_points, h, w)

    # Be careful to make these easy to find
    # We will always keep them in indices 0 and 1
    partitions = [start_partition, end_partition]
    
    rest = all_points - start_partition.union(end_partition)
    
    while rest:
        next_point = rest.pop()
        next_partition = find_points(next_point, rest, h, w)
        partitions.append(next_partition)
        rest = rest - next_partition
    
    return partitions

HEIGHT = 71
WIDTH = 71

all_bytes = parse(raw)

all_points = get_all_points(HEIGHT, WIDTH, all_bytes)
partitions = make_partitions(all_points, HEIGHT, WIDTH)

print("Start:", partitions[0])
print("End:", partitions[1])

# These are disjoint, the goal is to connect them:
print("distjoint?", partitions[0].isdisjoint(partitions[1]))

Start: {Point(col=4, row=0), Point(col=0, row=0), Point(col=2, row=0), Point(col=3, row=0), Point(col=5, row=0), Point(col=1, row=0)}
End: {Point(col=70, row=70), Point(col=69, row=70), Point(col=70, row=69)}
distjoint? True


In [11]:
def merge_sets(partitions, blocks, h, w):
    '''
    Add points from blocks one by one until the first two partitions 
    are connected.
    '''
    # assuming start and end sets are in indices 0 and 1
    block_set = set(all_bytes)
    for b in reversed(blocks):
        block_set.remove(b)
        neighbors = set(n for n in b.neighbors(h, w) if n not in block_set)
         
        # indices of all partitions that overlap these neighbors
        intersections = [i for i, s in enumerate(partitions) if s.intersection(neighbors)]
        
        if len(intersections) == 0:
            # block is isolated in a larger group of blocks
            partitions.append(set([b]))
        elif 0 in intersections and 1 in intersections:
            # we're done!
            return b
        else:
            # merge newly connected components
            s, *rest = intersections
            partitions[s] = set.union(*[partitions[i] for i in intersections])
            for j in rest:
                partitions.pop(j)
            partitions[s].add(b)

merge_sets(partitions, all_bytes, HEIGHT, WIDTH)

Point(col=24, row=32)

In [12]:
draw(all_bytes, HEIGHT, WIDTH)

......##....#.#.##.#..####...#...######..####..##..###.###.......#..##.
######.#######.#.#################################.#.#############.####
.##..#.#.#.#..##...#.#.#...#...#.##.#.#.#...##.###.#.#.#...####..#.#.#.
####.#####.#.#.#####.###.#.###.###################.###.#.###.#.########
#.##.##..#####.#.#.##..#.#...##..##.##########.##.#.##.#.#..#.###.##.##
##.###.###.#################.#########.#.###############.###########.#.
..#####..#.###..###..#.#.#...###...#.########..#.....###.#.#.#.#.#..#.#
.###########.#.#######.###.#########.###########.#####################.
..#..#.....##..#..##.###.##..#..##.#.#...#.#...#..##....#####.#.....###
######.#.#####.#####.#######.#.#.###########.#####.###########.#####.#.
####.###.#...#.##.##.#..#..#.######..#.#####.....#####..####.#.#...###.
.###.#####.#.###############.#.#########.#.#.###.#####.#####.#########.
.#.##.####.#.#.###.#.#...#.####.#..##.##...##..#.#...###.######..#.#.#.
##.###.###.#######.#.#####.###################.#.###.#.#########