Skip to content

Commit

Permalink
Find slots for each shape rotation before filling.
Browse files Browse the repository at this point in the history
Use numpy to test all slots at once.
Part of #56.
  • Loading branch information
donkirkby committed Feb 29, 2024
1 parent 311da2e commit 484d130
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 3 deletions.
4 changes: 3 additions & 1 deletion .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

62 changes: 60 additions & 2 deletions four_letter_blocks/block_packer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from random import shuffle

import numpy as np
from scipy.ndimage import label # type: ignore

from four_letter_blocks.block import shape_rotations, normalize_coordinates, Block
from four_letter_blocks.square import Square
Expand Down Expand Up @@ -73,6 +74,35 @@ def rotated_positions(self):
rotated_positions[rotated_shape].append((x, y))
return rotated_positions

def find_slots(self) -> dict[str, np.ndarray]:
if self.state is None:
raise RuntimeError('Cannot find slots with invalid state.')

all_masks = build_masks(self.width, self.height)
slots = {}
padded = np.pad(self.state.astype(bool), (0, 3), constant_values=1)
for shape, masks in all_masks.items():
collisions = np.logical_and(masks, padded)
colliding_positions = np.any(collisions, axis=(2, 3))
open_slots = np.logical_not(colliding_positions)

gaps = np.logical_not(np.logical_or(masks, padded))
structure = np.zeros((3, 3, 3, 3), bool)
structure[1, 1, :, :] = [[0, 1, 0],
[1, 1, 1],
[0, 1, 0]]
gap_groups, group_count = label(gaps, structure=structure)
bin_counts = np.bincount(gap_groups.flatten())
uneven_groups, = np.nonzero(bin_counts % 4)
if uneven_groups[0] == 0:
uneven_groups = uneven_groups[1:]
is_uneven = np.isin(gap_groups, uneven_groups)
has_even = np.logical_not(np.any(is_uneven, axis=(2, 3)))

usable_slots = np.logical_and(open_slots, has_even)
slots[shape] = usable_slots
return slots

def display(self, state: np.ndarray | None = None) -> str:
if state is None:
state = self.state
Expand Down Expand Up @@ -102,6 +132,7 @@ def sort_blocks(self):
if old_block <= 1:
# gap or space
continue
# noinspection PyUnresolvedReferences
block_spaces = (self.state == old_block).astype(np.uint8)
state += next_block * block_spaces
next_block += 1
Expand Down Expand Up @@ -165,8 +196,10 @@ def fill(self, shape_counts: typing.Counter[str]) -> bool:
# No empty spaces left, fail.
self.state = None
return False
target_row = empty[0][0]
target_col = empty[1][0]
# noinspection PyTypeChecker
target_row: int = empty[0][0]
# noinspection PyTypeChecker
target_col: int = empty[1][0]
next_block = np.amax(start_state) + 1
if next_block == self.GAP:
next_block += 1
Expand Down Expand Up @@ -278,6 +311,7 @@ def random_fill(self, shape_counts: typing.Counter[str]):
empty = np.argwhere(self.state == 0)
np.random.shuffle(empty)
used_blocks = np.unique(self.state)
block: int
for i, block in enumerate(used_blocks[:-1]):
if block >= self.GAP and used_blocks[i+1] != block+1:
next_block = block + 1
Expand Down Expand Up @@ -317,3 +351,27 @@ def shape_coordinates() -> typing.Dict[str, typing.List[np.ndarray]]:
grid[y, x] = 1
coordinate_lists[name].append(grid)
return dict(coordinate_lists)


@cache
def build_masks(width: int, height: int) -> dict[str, np.ndarray]:
all_coordinates = shape_coordinates()
all_masks = {}
for shape, coordinate_list in all_coordinates.items():
for rotation, start_mask in enumerate(coordinate_list):
if len(coordinate_list) == 1:
name = shape
else:
name = f'{shape}{rotation}'
masks = np.zeros((height, width, height+3, width+3),
dtype=bool)
mask_height, mask_width = start_mask.shape
for row in range(height):
for col in range(width):
masks[row,
col,
row:row+mask_height,
col:col+mask_width] = start_mask
all_masks[name] = masks

return all_masks
34 changes: 34 additions & 0 deletions tests/test_block_packer.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,3 +356,37 @@ def test_fill_fail():
is_filled = packer.fill(shape_counts)

assert not is_filled


def test_find_slots():
packer = BlockPacker(start_text=dedent("""\
#..#.
.....
..#..
.....
.#..#"""))
# Not at (1, 3) or (2, 0), because they cut off something.
expected_o_slots = np.array(object=[[0, 1, 0, 0, 0],
[1, 0, 0, 0, 0],
[0, 0, 0, 1, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 0, 0]],
dtype=bool)

o_slots = packer.find_slots()['O']

assert np.array_equal(o_slots, expected_o_slots)


def test_find_slots_after_fail():
packer = BlockPacker(start_text=dedent("""\
#..#.
.....
..#..
.....
.#..#"""))
packer.fill(Counter({'O': 20})) # fails

with pytest.raises(RuntimeError,
match='Cannot find slots with invalid state.'):
packer.find_slots()

0 comments on commit 484d130

Please sign in to comment.