
# Wave function collapse

See [Wave Function Collapse — Tutorial of a Basic Example Implementation in Python](https://medium.com/swlh/wave-function-collapse-tutorial-with-a-basic-exmaple-implementation-in-python-152d83d5cdb1) by Mateusz Bugaj.

In [2]:
from collections import Counter
import matplotlib.pyplot as plt
import numpy as np
import random

The original pattern to start from.

In [3]:
pattern = np.array([
    [255, 255, 255, 255],
    [255,   0,   0,   0],
    [255,   0, 138,   0],
    [255,   0,   0,   0],
], dtype=np.uint8)

It is convenient to represent patterns and subpatterns by numpy arrays, but those are not hashable and hence can not be used as keys in dictionaries or as members of sets.  Hence we define a wrapper class to ensure that the wrapped objects can be hashed.  This is not strictly required, but it will make the implementation much more transparent.

In [4]:
class Subpattern:
    
    def __init__(self, subpattern):
        self._hash = hash(tuple(subpattern.flatten().tolist()))
        self._pattern = subpattern
        
    @property
    def pattern(self):
        return self._pattern.copy()
    
    def __hash__(self):
        return self._hash
    
    def __eq__(self, other):
        if isinstance(other, Subpattern):
            return self._hash ==other._hash
        return NotImplemented   
    
    def __repr__(self):
        return str(self._pattern)

A function to extract all subpatterns from a given pattern, i.e., all $width \times height$ subpatterns, but also their rotations and reflections.  The unique patterns will be returned as keys in a dictionary that stores the probability of the subpattern.

In [5]:
def extract_subpatterns(pattern, width=2, height=2):
    counter = Counter()
    for i in range(pattern.shape[0] - width + 1):
        for j in range(pattern.shape[1] - height + 1):
            subpattern = pattern[i:i + width, j:j + height].copy()
            for k in range(4):
                counter[Subpattern(np.rot90(subpattern, k=k))] += 1
            for axis in range(2):
                counter[Subpattern(np.flip(subpattern, axis=axis))]
    total_count = sum(counter.values())
    subpatterns = {pattern: count/total_count for pattern, count in counter.items()}
    return subpatterns

In [6]:
subpattern_probs = extract_subpatterns(pattern)

For this particular starting pattern, there should be 12 unique subpatterns.

In [7]:
len(subpattern_probs)

12

For testing purposes, also create a list with the subpatterns only.

In [8]:
subpatterns = list(subpattern_probs.keys())

Starting from a given position, we can move up-left, up, up-right, left, right, down-left, down, down-right.  These directions can be represented by a tuple:
  * up-left: `(-1, -1)`
  * up: `(-1, 0)`
  * up-right: `(-1, 1)`
  * left: `(0, -1)`
  * right: `(0, 1)`
  * down-left: `(1, -1)`
  * down: `(1, 0)`
  * down-right: `(1, 1)`
  
The following function computes all directions.  It can optionally take into account that when you're at the top row, you can't go up, or if you're at the bottom-right, you can go neither down nor right.

In [9]:
def compute_directions(is_bottom=False, is_top=False, is_left=False, is_right=False):
    directions = set()
    bottom = 0 if is_bottom else 1
    top = 0 if is_top else -1
    left = 0 if is_left else -1
    right = 0 if is_right else 1
    for row in range(top, bottom + 1):
        for col in range(left, right + 1):
            if row != 0 or col != 0:
                directions.add((row, col))
    return directions

In [10]:
compute_directions(is_left=True)

{(0, 1), (-1, 1), (1, 1), (-1, 0), (1, 0)}

Class to represent the rules of next pattern placement.  For each pattern and each direction, all possible next patterns are stored in a set.  A next pattern can be placed in a given direction when the overlapping pixels match.

In [11]:
class Rules:
    
    def __init__(self, patterns):
        directions = compute_directions()
        self._rules = dict()
        for pattern in patterns:
            self._rules[pattern] = {direction: set() for direction in directions}
        for pattern, directions in self._rules.items():
            for direction in directions:
                for next_pattern in patterns:
                    tile = Rules.get_offset_pattern(pattern, (direction[0], direction[1]))
                    next_tile = Rules.get_offset_pattern(next_pattern, (-direction[0], -direction[1]))
                    if tile == next_tile:
                        self._rules[pattern][direction].add(next_pattern)
    
    def is_possible(self, pattern, direction, next_pattern):
        return next_pattern in self._rules[pattern][direction]
    
    def possibilities(self, pattern, direction):
        return self._rules[pattern][direction]
    
    @staticmethod
    def get_offset_pattern(pattern, offset):
        row, col = offset
        indices = {(i, j) for i in range(2) for j in range(2)}
        offset_indices = {(t[0] + row, t[1] + col) for t in indices}
        return [pattern.pattern[t] for t in sorted(indices & offset_indices)]

Test the static method that computes the overlap when a next pattern has a given direction.

In [12]:
pattern = Subpattern(np.array([[1, 2], [3, 4]]))

In [13]:
pattern

[[1 2]
 [3 4]]

In [14]:
for direction in compute_directions():
    print(f'{direction}: {Rules.get_offset_pattern(pattern, direction)}')

(0, 1): [2, 4]
(-1, -1): [1]
(-1, 1): [2]
(1, 1): [4]
(1, -1): [3]
(-1, 0): [1, 2]
(1, 0): [3, 4]
(0, -1): [1, 3]


In [15]:
rules = Rules(subpatterns)

In [16]:
subpatterns[0]

[[255 255]
 [255   0]]

In [17]:
rules.possibilities(subpatterns[0], (0, 1))

{[[255 255]
 [  0   0]], [[255 255]
 [  0 255]]}

In [18]:
def embed(pattern, offset=(0, 0)):
    embedding = np.empty((pattern.pattern.shape[0] + 2, pattern.pattern.shape[1] + 2))
    embedding.fill(-1)
    row = 1 + offset[0]
    col = 1 + offset[1]
    embedding[row:row + 2, col:col + 2] = pattern.pattern
    return embedding

In [19]:
embed(subpatterns[0])

array([[ -1.,  -1.,  -1.,  -1.],
       [ -1., 255., 255.,  -1.],
       [ -1., 255.,   0.,  -1.],
       [ -1.,  -1.,  -1.,  -1.]])

In [20]:
embed(subpatterns[0], (1, 1))

array([[ -1.,  -1.,  -1.,  -1.],
       [ -1.,  -1.,  -1.,  -1.],
       [ -1.,  -1., 255., 255.],
       [ -1.,  -1., 255.,   0.]])

In [21]:
def are_compatible(pattern, direction, next_pattern):
    embedding = embed(pattern)
    next_embedding = embed(next_pattern, direction)
    mask = np.logical_and(embedding > -0.5, next_embedding > -0.5)
    return (embedding[mask] == next_embedding[mask]).all()

In [22]:
are_compatible(subpatterns[0], (1, 0), subpatterns[0])

False

In [23]:
for pattern in subpatterns:
    for direction in compute_directions():
        for next_pattern in rules.possibilities(pattern, direction):
            if not are_compatible(pattern, direction, next_pattern):
                print(pattern, direction, next_pattern)

In [24]:
class Image:
    
    def __init__(self, width, height, rules, probabilities):
        self._rules = rules
        self._probs = probabilities
        self._img = list()
        for row in range(height):
            self._img.append(list())
            for _ in range(width):
                self._img[row].append(set(probabilities.keys()))
 
    @property
    def width(self):
        return len(self._img[0])

    @property
    def height(self):
        return len(self._img)
    
    def is_collapsed(self):
        all(all(map(lambda x: len(x) == 1, row)) for row in self._img)
        
    def possible_patterns_at(self, row, col):
        return self._img[row][col]
    
    def shannon_entropy_at(self, row, col):
        if len(self.possible_patterns_at(row, col)) > 1:
            entropy = 0.0
            for pattern in self.possible_patterns_at(row, col):
                entropy -= self._probs[pattern]*np.log(self._probs[pattern])
            return entropy + random.uniform(0, 0.1)
        else:
            return np.inf
    
    def collapse_site(self):
        entropy = np.array([[self.shannon_entropy_at(row, col) for col in range(self.width)] for row in range(self.height)])
        row, col = np.unravel_index(np.argmax(entropy), entropy.shape)
        pattern = max(self._probs, key=lambda x: self._probs[x])
        self._img[row][col].clear()
        self._img[row][col].add(pattern)
        return row, col
    
    def get_neighbours(self, row, col):
        directions = compute_directions(
            is_top=row== 0,
            is_bottom=row == self.height - 1,
            is_left=col == 0,
            is_right=col == self.width - 1
        )
        return [(row + d[0], col + d[1]) for d in directions]
 
    def collapse(self):
        while not self.is_collapsed():
            row, col = self.collapse_site()
            print(f'collapsing ({row}, {col})')
            stack = set()
            stack.add((row, col))
            while stack:
                row, col = stack.pop()
                # print(f'  reducing ({row}, {col})')
                directions = compute_directions(
                    is_top=row== 0,
                    is_bottom=row == self.height - 1,
                    is_left=col == 0,
                    is_right=col == self.width - 1
                )
                for d in directions:
                    # print(f'    checking ({row + d[0]}, {col + d[1]})')
                    next_patterns = self._img[row + d[0]][col + d[1]]
                    to_remove = set()
                    for next_pattern in next_patterns:
                        if not any(self._rules.is_possible(pattern, d, next_pattern) for pattern in self._img[row][col]):
                            to_remove.add(next_pattern)
                    if to_remove:
                        self._img[row + d[0]][col + d[1]] -= to_remove
                        stack.add((row + d[0], col + d[1]))
        return self.get_image()
    
    def get_image(self):
        img = np.empty((self.height, self.width))
        for row in range(self.height):
            for col in range(self.width):
                pattern = self._img[row][col].pop()
                img[row, col] = pattern.pattern[0, 0]
        return img

In [25]:
image = Image(5, 6, rules, subpattern_probs)

In [None]:
img = image.collapse()

In [378]:
for x in range(image.width):
    for y in range(image.height):
        print(len(image._img[y][x]))

9
7
3
2
5
6
9
7
3
1
3
5
9
7
3
2
5
6
9
8
5
5
8
8
11
11
10
10
11
11


In [None]:
entropy

In [377]:
image._img

[[{[[  0   0]
    [  0 138]],
   [[  0   0]
    [138   0]],
   [[  0   0]
    [255 255]],
   [[  0 255]
    [  0 255]],
   [[  0 255]
    [255 255]],
   [[255   0]
    [255   0]],
   [[255   0]
    [255 255]],
   [[255 255]
    [  0 255]],
   [[255 255]
    [255   0]]},
  {[[  0   0]
    [  0 138]],
   [[  0   0]
    [138   0]],
   [[  0   0]
    [255 255]],
   [[  0 255]
    [  0 255]],
   [[  0 255]
    [255 255]],
   [[255   0]
    [255   0]],
   [[255   0]
    [255 255]],
   [[255 255]
    [  0 255]],
   [[255 255]
    [255   0]]},
  {[[  0   0]
    [  0 138]],
   [[  0   0]
    [138   0]],
   [[  0   0]
    [255 255]],
   [[  0 255]
    [  0 255]],
   [[  0 255]
    [255 255]],
   [[255   0]
    [255   0]],
   [[255   0]
    [255 255]],
   [[255 255]
    [  0 255]],
   [[255 255]
    [255   0]]},
  {[[  0   0]
    [  0 138]],
   [[  0   0]
    [138   0]],
   [[  0   0]
    [255 255]],
   [[  0 255]
    [  0 255]],
   [[  0 255]
    [255 255]],
   [[255   0]
    [255   0]],
   [[25