In [95]:
#@title Interactive WFC
class Move:
  DOWN = (1, 0)
  UP = (-1, 0)
  RIGHT = (0, 1)
  LEFT = (0, -1)
  CCW = [DOWN, RIGHT, UP, LEFT]
  CW = [DOWN, LEFT, UP, RIGHT]

from matplotlib import pyplot as plt

class Tile:
  def __init__(self, shape, data):
    self.shape = shape
    self.data = data
    self.number = self._get_number()
  
  def get_blank(shape, color=None):
    if color is None:
      color = [0 for k in range(shape[2])] # TODO channel
    return Tile(shape, [[color for j in range(shape[1])] for i in range(shape[0])])

  def _get_number(self):
    number = 0
    base = 256
    cycle = 1000000007
    for row in self.data:
      for pixel in row:
        for channel in pixel:
          number *= base
          number += channel
          number %= cycle
    return number
  
  def from_image(image, tile_size, x, y):
    data = []
    for i in range(tile_size[0]):
      data.append([])
      for j in range(tile_size[1]):
        data[-1].append([])
        for channel in image[i+x][j+y]:
          data[i][j].append(channel)
    shape = (tile_size[0], tile_size[1], image.shape[2])
    return Tile(shape, data)

  def display(self, ax=plt):
    ax.imshow(self.data, aspect=1)
    ax.axis('off')
    if ax == plt:
      plt.show()

class TiledImage:
  def __init__(self, tiles):
    self.tile_shape = tiles[0][0].shape
    self.tiles = tiles
    self.size = (len(self.tiles), len(self.tiles[0]))
    self.number_to_tile = {}
    self.data = [[None for j in range(self.size[1] * self.tile_shape[1])] for i in range(self.size[0] * self.tile_shape[0])]
    for i, row in enumerate(tiles):
      for j, tile in enumerate(row):
        for x in range(self.tile_shape[0]):
          for y in range(self.tile_shape[1]):
            self.data[i * self.tile_shape[0] + x][j * self.tile_shape[1] + y] = self.tiles[i][j].data[x][y]
        if tile.number not in self.number_to_tile:
          self.number_to_tile[tile.number] = tile
  
  def from_image(tile_size, image):
    tiles = []
    for row_index in range(0, image.shape[0], tile_size[0]):
      if row_index + tile_size[0] >= image.shape[0]:
        break
      tiles.append([])
      for col_index in range(0, image.shape[1], tile_size[1]):
        if col_index + tile_size[1] >= image.shape[1]:
          break
        tiles[-1].append(Tile.from_image(image, tile_size, row_index, col_index))
    tile_shape = (tile_size[0], tile_size[1], image.shape[2])
    return TiledImage(tiles)

  def display(self, title=None):
    figsize = (self.size[1] * self.tile_shape[1] * 0.4, self.size[0] * self.tile_shape[0] * 0.4)
    fig, axs = plt.subplots(self.size[0], self.size[1], squeeze=False, figsize=figsize)
    if title is not None:
      fig.suptitle(title)
    for x in range(self.size[0]):
      for y in range(self.size[1]):
        self.tiles[x][y].display(axs[x, y])
    plt.subplots_adjust(wspace=0.1, hspace=0.1, )
    plt.show()

  def from_generated(self, generated_size, generated, blank_color=None):
    tiles = []
    for i in range(generated_size[0]):
      tiles.append([])
      for j in range(generated_size[1]):
        tiles[-1].append(self.get_tile_from_number(generated[i][j], blank_color))
    return TiledImage(tiles)

  def get_tile(self, x, y):
    if x < 0 or y < 0:
      return None
    if x >= self.size[0] or y >= self.size[1]:
      return None
    return self.tiles[x][y]

  def get_tile_from_number(self, tile_number, blank_color=None):
    return self.number_to_tile.get(tile_number, Tile.get_blank(self.tile_shape))

  def display_as_one(self):
    plt.imshow(self.data, aspect=1)
    plt.axis('off')
    plt.show()

class TileDistribution:
  def __init__(self, tile_shape):
    self.tile_shape = tile_shape
    self.tile_frequency = {}
    self.nei_frequency = [{} for _ in Move.CCW]

  def train(self, tiled_image):
    if tiled_image.tile_shape != self.tile_shape:
      raise "Incompatible Tile Shapes!"
    self._train_tile_frequency(tiled_image)
    self._train_tile_nei_frequency(tiled_image)

  def _train_tile_frequency(self, tiled_image):
    for i in range(tiled_image.size[0]):
      for j in range(tiled_image.size[1]):
        number = tiled_image.get_tile(i, j).number
        if number in self.tile_frequency:
          self.tile_frequency[number] += 1
        else:
          self.tile_frequency[number] = 1
    self.frequency_sorted = sorted(list(self.tile_frequency.items()), key=lambda keyvalue: -keyvalue[1])

  def _train_tile_nei_frequency(self, tiled_image):
    for x in range(tiled_image.size[0]):
      for y in range(tiled_image.size[1]):
        tile = tiled_image.get_tile(x, y)
        for i in range(len(Move.CCW)):
          dx, dy = Move.CCW[i]
          nei = tiled_image.get_tile(x + dx, y + dy)
          if nei is not None:
            key = (tile.number, nei.number)
            if key in self.nei_frequency[i]:
              self.nei_frequency[i][key] += 1
            else:
              self.nei_frequency[i][key] = 1

import random
import numpy as np
import os

class WeightingOptions:
  NONE = 'No Weights' # based on constraints, not frequencies
  FREQUENCY_WEIGHTED = 'Tile Frequency' # based on single tile frequencies
  CONDITIONAL_WEIGHTED = 'Conditional Probability' # based on single tile frequency + pair frequency of collapsed nighbors
  NEI_WEIGHTED = 'Neighbors Frequency' # based on all frequencies (collapsed or not)
  CHAIN_WEIGHTED = 4 # based on everything
  CONTEXT_WEIGHTED = 5 # based on down and left
  CONTEXT_DIST = 6

class UpdatingOptions:
  NEIGHBOR = 'Neighbor'
  CHAIN = 'Chain'

class EntropyOptions:
  NUMBER_OF_OPTIONS = 'Number of Options'
  SHANNON = 'Shannon Entropy'
  UP_LEFT = 'Up Left to Bottom Right'

class WFC:
  def __init__(self, dist, updating_option, entropy_option, weighting_option):
    self.dist = dist
    self.tile_shape = self.dist.tile_shape
    self.updating_option = updating_option
    self.entropy_option = entropy_option
    self.weighting_option = weighting_option
  
  def _get_updated_possibilities(self, possibilities, collapsed_value, move_number):
    if len(possibilities) == 1:
      return possibilities
    return [possibility for possibility in possibilities
            if (collapsed_value, possibility) in self.dist.nei_frequency[move_number]]

  def _is_in_bounds(self, x, y, start, end):
    if x < start[0] or y < start[1]:
      return False
    if x >= end[0] or y >= end[1]:
      return False
    return True
  
  def _get_neighbor_position(self, x, y, start, size, move_index):
    delta_x, delta_y = Move.CCW[move_index]
    n_x = delta_x + x
    n_y = delta_y + y
    end = (start[0] + size[0], start[1] + size[1])
    if self._is_in_bounds(n_x, n_y, start, end):
      return (n_x, n_y)
    else:
      return (None, None)

  def _get_options(self, tile_limit):
    options = [tile_number for tile_number, freq in self.dist.frequency_sorted]
    if tile_limit is not None:
      options = options[:tile_limit]
    return options

  def _get_entropy(self, options, position, size):
    if self.entropy_option == EntropyOptions.SHANNON:
      weights = [self.dist.tile_frequency[tile_number] for tile_number in options]
      return np.log(sum(weights)) - (sum(weights * np.log(weights)) / sum(weights))
    elif self.entropy_option == EntropyOptions.NUMBER_OF_OPTIONS:
      return len(options)
    elif self.entropy_option == EntropyOptions.UP_LEFT:
      return - position[1] * size[1] - position[0]
    else:
      raise "entropy option not implemented!"

  def _get_position_to_collapse(self, start, size, supermap):
    end = (start[0] + size[0], start[1] + size[1])
    min_entropy = None
    min_entropy_position = (None, None)
    for i in range(start[0], end[0]):
      for j in range(start[1], end[1]):
        if len(supermap[i][j]) <= 1:
          continue
        entropy = self._get_entropy(supermap[i][j], (i, j), size)
        if min_entropy is None or entropy < min_entropy:
          min_entropy = entropy
          min_entropy_position = (i, j)
    return min_entropy_position

  def _get_options_probabilities(self, supermap, x, y, map_size):
    if self.weighting_option == WeightingOptions.FREQUENCY_WEIGHTED:
      weights = [self.dist.tile_frequency[tile_number] for tile_number in supermap[x][y]]
      total = sum(weights)
      return [weight/total for weight in weights]
    elif self.weighting_option == WeightingOptions.NONE:
      return [1/len(supermap[x][y]) for _ in supermap[x][y]]
    elif self.weighting_option == WeightingOptions.NEI_WEIGHTED:
      weights = [0 for _ in supermap[x][y]]
      for index, tile_number in enumerate(supermap[x][y]):
        for i in range(len(Move.CCW)):
          n_x, n_y = self._get_neighbor_position(x, y, (0, 0), map_size, i)
          if n_x is not None and n_y is not None and len(supermap[n_x][n_y]) == 1:
            n_tile_number = supermap[n_x][n_y][0]
            pair_key = (tile_number, n_tile_number)
            if pair_key in self.dist.nei_frequency[i]:
              weights[index] += self.dist.nei_frequency[i][pair_key]# / self.dist.pair_total[tile_number]
      if sum(weights) == 0:
        weights = [self.dist.tile_frequency[tile_number] for tile_number in supermap[x][y]]
      total = sum(weights)
      return [weight/total for weight in weights]
    raise "weighting option not implemented!"

  def _collapse(self, supermap, x, y, map_size):
    probabilities = self._get_options_probabilities(supermap, x, y, map_size)
    supermap[x][y] = [np.random.choice(supermap[x][y], p=probabilities)]

  def _update_supermap(self, changed_x, changed_y, start, size, supermap):
    changed_queue = [(changed_x, changed_y)]
    while len(changed_queue) > 0:
      x, y = changed_queue[0]
      changed_queue = changed_queue[1:]
      for i in range(len(Move.CCW)):
        n_x, n_y = self._get_neighbor_position(x, y, start, size, i)
        if n_x is not None and n_y is not None and len(supermap[x][y]) == 1:
          new_possibilities = self._get_updated_possibilities(supermap[n_x][n_y], supermap[x][y][0], i)
          if len(new_possibilities) != len(supermap[n_x][n_y]):
            supermap[n_x][n_y] = new_possibilities
            if self.updating_option == UpdatingOptions.NEIGHBOR:
              pass
            elif self.updating_option == UpdatingOptions.CHAIN:
              changed_queue.append((n_x, n_y))
  
  def _generate_block(self, start, size, supermap, map_size):
    while True:
      x, y = self._get_position_to_collapse(start, size, supermap)
      if x is None or y is None:
        break
      self._collapse(supermap, x, y, map_size)
      self._update_supermap(x, y, start, size, supermap)

  def generate(self, size, block_size=None, tile_limit=None, seed=0):
    if block_size is None:
      block_size = size
    np.random.seed(seed)
    options = self._get_options(tile_limit)
    supermap = [[[tile_number for tile_number in options] for _ in range(size[1])] for _ in range(size[0])]
    for i in range(size[0] - block_size[0] + 1):
      for j in range(size[1] - block_size[1] + 1):
        for x in range(i, i + block_size[0]):
          for y in range(j, j + block_size[1]):
            supermap[x][y] = [tile_number for tile_number in options]
        for x in range(i, i + block_size[0]):
          if j > 0:
            self._update_supermap(x, j - 1, (0, 0), size, supermap)
          if j < size[1] - 1:
            self._update_supermap(x, j + 1, (0, 0), size, supermap)
        for y in range(j, j + block_size[1]):
          if i > 0:
            self._update_supermap(i - 1, y, (0, 0), size, supermap)
          if i < size[0] - 1:
            self._update_supermap(i + 1, y, (0, 0), size, supermap)
        self._generate_block((i, j), block_size, supermap, size)
    map = [[possibilities[0]if len(possibilities) > 0 else 0 for possibilities in row] for row in supermap]
    return map

import ipywidgets as widgets

!wget -q https://raw.githubusercontent.com/mxgmn/WaveFunctionCollapse/master/samples/Flowers.png


import imageio

image_file=widgets.FileUpload(accept='.png', multiple=False)
display(image_file)

@widgets.interact_manual(
    tile_width=widgets.IntSlider(2, 1, 10), tile_height=widgets.IntSlider(2, 1, 10),
    width=widgets.IntSlider(5, 1, 20), height=widgets.IntSlider(5, 1, 20),
    updating=[UpdatingOptions.CHAIN, UpdatingOptions.NEIGHBOR],
    entropy=[EntropyOptions.NUMBER_OF_OPTIONS, EntropyOptions.SHANNON, EntropyOptions.UP_LEFT],
    weighting=[WeightingOptions.NONE, WeightingOptions.FREQUENCY_WEIGHTED, WeightingOptions.NEI_WEIGHTED],
    show_grid=False,
    seed=widgets.IntSlider(612903, 0, 1000000)
    )
def wfc(tile_width, tile_height, width, height, updating, entropy, weighting, show_grid, seed):
  try:
    im = imageio.imread(list(image_file.value.values())[0]['content'])
  except Exception:
    print('using default image...')
    im = imageio.imread('Flowers.png')
  tiled_image = TiledImage.from_image((tile_width, tile_height), im)
  dist = TileDistribution(tiled_image.tile_shape)
  dist.train(tiled_image)
  wfc = WFC(dist, updating, entropy, weighting)
  size = (width, height)
  generated = tiled_image.from_generated(size, wfc.generate(size, seed=seed))
  if show_grid:
    tiled_image.display('original image')
    generated.display('generated image')
  else:
    plt.title('original image')
    tiled_image.display_as_one()
    plt.title('generated image')
    generated.display_as_one()
  image_file._counter=0

FileUpload(value={}, accept='.png', description='Upload')

interactive(children=(IntSlider(value=2, description='tile_width', max=10, min=1), IntSlider(value=2, descript…