In [None]:
import sys
sys.path.append("..")

import random

import time
import math
import random
from io import BytesIO
from pathlib import Path
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Union, Set

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid

import PIL.Image
import PIL.ImageDraw
import plotly
import plotly.express as px
plotly.io.templates.default = "plotly_dark"
import numpy as np
import pandas as pd

from src.datasets import *
from src.util import *
from src.util.image import * 
from src.algo import *
from src.datasets.generative import *
from src.models.cnn import *
from src.models.transform import *
from src.util.embedding import *
from src.models.clip import ClipSingleton

In [None]:
wang_template = VF.to_tensor(PIL.Image.open("/home/bergi/prog/python/thegame/thegame/assets/cr31/wang2e.png"))[:3] * 51
#wang_template = VF.to_tensor(PIL.Image.open("/home/bergi/prog/python/thegame/thegame/assets/cr31/path.png"))[:3] * 85
#wang_template = VF.to_tensor(PIL.Image.open("/home/bergi/prog/python/thegame/thegame/assets/w2e_beach.png"))[:3]
print(wang_template.shape, wang_template.max())
VF.to_pil_image(wang_template)

In [None]:
VF.to_pil_image(RandomWangMap((4, 32))(wang_template))

In [None]:
WANG_TEMPLATE_INDICES = [
    4, 6, 14, 12,
    5, 7, 15, 13,
    1, 3, 11, 9,
    0, 2, 10, 8,
]
WANG_INDEX_TO_TEMPLATE_INDEX = {
    tile_idx: template_idx 
    for template_idx, tile_idx in enumerate(WANG_TEMPLATE_INDICES)
}

In [None]:
WANG_TILES = [
    "", "t", "r", "tr", 
    "b", "tb", "rb", "trb",
    "l", "tl", "lr", "ltr", 
    "lb", "tlb", "lbr", "tlbr", 
]

for w in WANG_TILES:
    v = [
        1 if "t" in w else 0,
        1 if "r" in w else 0,
        1 if "b" in w else 0,
        1 if "l" in w else 0,
    ]
    print(f"{v},")

In [None]:
# Top, Right, Bottom, Left
WANG_TILE_OFFSETS = [
    [-1, 0], [0, 1], [1, 0], [0, -1],
]
WANG_TILES = [
    [0, 0, 0, 0],
    [1, 0, 0, 0],
    [0, 1, 0, 0],
    [1, 1, 0, 0],
    [0, 0, 1, 0],
    [1, 0, 1, 0],
    [0, 1, 1, 0],
    [1, 1, 1, 0],
    [0, 0, 0, 1],
    [1, 0, 0, 1],
    [0, 1, 0, 1],
    [1, 1, 0, 1],
    [0, 0, 1, 1],
    [1, 0, 1, 1],
    [0, 1, 1, 1],
    [1, 1, 1, 1],
]

In [None]:
MATCHING_WANG_TILES = []
for idx1, tile1 in enumerate(WANG_TILES):
    matches = {"t": set(), "r": set(), "b": set(), "l": set(), "tl": set()}
    for idx2, tile2 in enumerate(WANG_TILES):
        if tile1[0] == tile2[2]:
            matches["t"].add(idx2)
        if tile1[1] == tile2[3]:
            matches["r"].add(idx2)
        if tile1[2] == tile2[0]:
            matches["b"].add(idx2)
        if tile1[3] == tile2[1]:
            matches["l"].add(idx2)            
        if tile1[3] == tile2[1]:
            matches["l"].add(idx2)            

        if tile1[0] == tile2[2] and tile1[3] == tile2[1]:
            matches["tl"].add(idx2)            
                    
    MATCHING_WANG_TILES.append(matches)
        
MATCHING_WANG_TILES

In [None]:
class WangTiles:
    
    Top = 0
    TopRight = 1
    Right = 2
    BottomRight = 3
    Bottom = 4
    BottomLeft = 5
    Left = 6
    TopLeft = 7
    
    map_offsets = [
        [-1,  0], 
        [-1,  1],
        [ 0,  1], 
        [ 1,  1],
        [ 1,  0], 
        [ 1, -1],
        [ 0, -1],
        [-1, -1],
    ]
        
    class Tile:
        
        def __init__(self, parent: "WangTiles", index: int, colors: List[int]):
            self.parent = parent
            self.index = index
            self.colors = colors
            self._matching_indices = {}
            
        def __repr__(self):
            return f"Tile({self.index}, {self.colors})"
        
        @property
        def top(self): return self.colors[WangTiles.Top]
        @property
        def top_right(self): return self.colors[WangTiles.TopRight]
        @property
        def right(self): return self.colors[WangTiles.Right]
        @property
        def bottom_right(self): return self.colors[WangTiles.BottomRight]
        @property
        def bottom(self): return self.colors[WangTiles.Bottom]
        @property
        def bottom_left(self): return self.colors[WangTiles.BottomLeft]
        @property
        def left(self): return self.colors[WangTiles.Left]
        @property
        def top_left(self): return self.colors[WangTiles.TopLeft]
        
        @property
        def matching_indices_top(self) -> Set[int]:
            return self.matching_indices(WangTiles.Top)
        @property
        def matching_indices_top_right(self) -> Set[int]:
            return self.matching_indices(WangTiles.TopRight)
        @property
        def matching_indices_right(self) -> Set[int]:
            return self.matching_indices(WangTiles.Right)
        @property
        def matching_indices_bottom_right(self) -> Set[int]:
            return self.matching_indices(WangTiles.BottomRight)
        @property
        def matching_indices_bottom(self) -> Set[int]:
            return self.matching_indices(WangTiles.Bottom)
        @property
        def matching_indices_bottom_left(self) -> Set[int]:
            return self.matching_indices(WangTiles.BottomLeft)
        @property
        def matching_indices_left(self) -> Set[int]:
            return self.matching_indices(WangTiles.Left)
        @property
        def matching_indices_top_left(self) -> Set[int]:
            return self.matching_indices(WangTiles.TopLeft)

        def matches(self, tile: "Tile", direction: int) -> bool:
            return self.colors[direction] == tile.colors[(direction + 4) % 8]

        def matching_indices(self, direction: int) -> Set[int]:
            return self._matching_indices[direction]

    def __init__(
            self,
            colors: Iterable[Iterable[int]],
            mode: str = "edge",  
    ):
        if mode in ("e", "edge"):
            self.mode = "edge"
        elif mode in ("c", "corner"):
            self.mode = "corner"
        elif mode in ("ec", "ce", "edgecorner", "corneredge"):
            self.mode == "edgecorner"
        else:
            raise ValueError(f"`mode` must be one of e, edge, c, corner, ec, ce, edgecorner or corneredge, got '{mode}'")
        
        self.mode = mode
        self.tiles = []
        
        expected_length = 8 if mode == "edgecorner" else 4
        for idx, row in enumerate(colors):
            row = list(row)
            if len(row) != expected_length:
                raise ValueError(f"Item #{idx} in `colors` has length {len(row)}, expected {expected_length}")
            if mode == "edgecorner":
                colors = row
            elif mode == "edge":
                colors = [0] * 8
                for i, c in enumerate(row):
                    colors[i * 2] = c
            else:  # "corner":
                colors = [0] * 8
                for i, c in enumerate(row):
                    colors[i * 2 + 1] = c
        
            self.tiles.append(self.Tile(
                parent=self,
                index=idx,
                colors=colors,
            ))
            
        for tile1 in self.tiles:
            for direction in range(8):
                tile1._matching_indices[direction] = set()
                for tile2 in self.tiles:
                    if tile1.matches(tile2, direction):
                        tile1._matching_indices[direction].add(tile2.index)
            
class WangTiles2E(WangTiles):
    """All wang tiles with 2 colors (0 or 1)"""
    def __init__(self):
        super().__init__(
            colors=[
                [0, 0, 0, 0],
                [1, 0, 0, 0],
                [0, 1, 0, 0],
                [1, 1, 0, 0],
                [0, 0, 1, 0],
                [1, 0, 1, 0],
                [0, 1, 1, 0],
                [1, 1, 1, 0],
                [0, 0, 0, 1],
                [1, 0, 0, 1],
                [0, 1, 0, 1],
                [1, 1, 0, 1],
                [0, 0, 1, 1],
                [1, 0, 1, 1],
                [0, 1, 1, 1],
                [1, 1, 1, 1],
            ],
            mode="edge",
        )
        
wang2e = WangTiles2E()
wang2e.tiles[0].matching_indices(2)

In [None]:
class WangTemplate:
    def __init__(
            self,
            indices: Iterable[Iterable[int]],
            image: torch.Tensor,
    ):
        self.indices = torch.Tensor(indices).to(torch.int64)
        self.image = image
        self._index_to_pos = {}
        for y, row in enumerate(self.indices):
            for x, idx in enumerate(row):
                self._index_to_pos[idx] = (int(y), int(x))
    
    def __repr__(self):
        return f"WangTemplate(shape={tuple(self.shape)}, tile_shape={self.tile_shape})"
    
    @property
    def shape(self) -> Tuple[int, int, int]:
        return self.image.shape
    
    @property
    def tile_shape(self) -> Tuple[int, int, int]:
        return (
            self.image.shape[-3],
            self.image.shape[-2] // self.indices.shape[-2],
            self.image.shape[-1] // self.indices.shape[-1],
        )
    
    def tile(
            self, 
            index: Optional[int] = None,
            position: Optional[Tuple[int, int]] = None,
    ):
        if index is None and position is None:
            raise ValueError(f"Expected `index` or `position`, got none")
        if index is not None and position is not None:
            raise ValueError(f"Expected `index` or `position`, got both")
        
        if position is not None:
            pos = position
        else:
            pos = self._index_to_pos[index]
            #index = self.indices[postion[0]][position[1]]
        
        shape = self.tile_shape
        return self.image[
            ...,
            pos[0] * shape[-2]: (pos[0] + 1) * shape[-2], 
            pos[1] * shape[-1]: (pos[1] + 1) * shape[-1],
        ]
    
class WangTemplate2E(WangTemplate):
    def __init__(
            self,
            image: torch.Tensor,
    ):
        super().__init__(
            indices=[
                [4, 6, 14, 12],
                [5, 7, 15, 13],
                [1, 3, 11, 9],
                [0, 2, 10, 8],
            ],
            image=image,
        )
        
wt = WangTemplate2E(wang_template)
wt

In [None]:
from functools import partial
def get_image_window(
        shape: Tuple[int, int], 
        window_function: Callable = partial(torch.hamming_window, periodic=True),
):
    return (
        window_function(shape[-1], periodic=True).unsqueeze(0).expand(shape[-2], -1)
        * window_function(shape[-2], periodic=True).unsqueeze(0).expand(shape[-1], -1).T
    )

def render_wang_map(
        wang_template: torch.Tensor, 
        tile_indices: torch.Tensor,
        overlap: Union[int, Tuple[int, int]] = 0,
):
    if isinstance(overlap, int):
        overlap = [overlap, overlap]
            
    for s in wang_template.shape[-2:]:
        if not s % 4 == 0:
            raise ValueError(f"`wang_template` size must be divisible by 4, got {wang_template.shape}")
    
    tile_size_y = wang_template.shape[-2] // 4
    tile_size_x = wang_template.shape[-1] // 4

    if overlap[-2] > tile_size_y or overlap[-1] > tile_size_x:
        raise ValueError(
            f"`overlap` exceeds tile size, got {overlap}, tile size is {[tile_size_y, tile_size_x]}"
        )

    image = torch.zeros(
        wang_template.shape[-3], 
        tile_indices.shape[-2] * (tile_size_y - overlap[-2]) + overlap[-2], 
        tile_indices.shape[-1] * (tile_size_x - overlap[-1]) + overlap[-1],
    )
    if overlap != (0, 0):
        accum = torch.zeros_like(image)
        window = get_image_window((tile_size_y, tile_size_x))
        
    for y, row in enumerate(tile_indices):
        for x, tile_idx in enumerate(row):
            if tile_idx < 0:
                continue
                
            t_idx = WANG_INDEX_TO_TEMPLATE_INDEX[int(tile_idx)]
            tx = t_idx % 4
            ty = t_idx // 4
            template_patch = wang_template[
                :, 
                ty * tile_size_y: (ty + 1) * tile_size_y, 
                tx * tile_size_x: (tx + 1) * tile_size_x,
            ]
            if overlap == (0, 0):
                image[
                    :, 
                    y * tile_size_y: (y + 1) * tile_size_y, 
                    x * tile_size_x: (x + 1) * tile_size_x,
                ] = template_patch 
            else:
                sy = slice(y * (tile_size_y - overlap[-2]), (y + 1) * (tile_size_y - overlap[-2]) + overlap[-2]) 
                sx = slice(x * (tile_size_x - overlap[-1]), (x + 1) * (tile_size_x - overlap[-1]) + overlap[-1])
                image[:, sy, sx] = image[:, sy, sx] + window * template_patch
                accum[:, sy, sx] = accum[:, sy, sx] + window
                
    if overlap != (0, 0):
        mask = accum > 0
        image[mask] = image[mask] / accum[mask]
        print(image.min(), image.max(), accum.min(), accum.max())
    else:
        print(image.min(), image.max())
    return image

VF.to_pil_image(render_wang_map(wang_template, torch.Tensor(WANG_TEMPLATE_INDICES).to(torch.int64).view(4, 4), overlap=4))
#VF.to_pil_image(render_wang_map(torch.Tensor([[0, 1], [2, 3]]).to(torch.int64)))

In [None]:
def random_wang_map(
        shape: Tuple[int, int],
        include: Optional[Iterable[int]] = None,
        exclude: Optional[Iterable[int]] = None,
):
    if include is not None:
        include = set(include)
    if exclude is not None:
        exclude = set(exclude)
        
    possible_tiles_set = include.copy() if include is not None else set(range(len(WANG_TILES)))
    if exclude is not None:
        possible_tiles_set -= exclude
    
    tiles = [
        [-1] * shape[-1]
        for _ in range(shape[-2])
    ]
    for y in range(shape[-2]):
        for x in range(shape[-1]):
            
            tile_idx = None
            possible_tiles = list(possible_tiles_set)
            random.shuffle(possible_tiles)
            for tile_idx in possible_tiles:
                
                if x >= 1 and tiles[y][x - 1] >= 0:
                    if tiles[y][x - 1] not in MATCHING_WANG_TILES[tile_idx]["l"]:
                        continue
                if y >= 1 and tiles[y - 1][x] >= 0:
                    if tiles[y - 1][x] not in MATCHING_WANG_TILES[tile_idx]["t"]:
                        continue

                #if x >= 1 and tiles[y][x - 1] >= 0 and y >= 1 and tiles[y - 1][x] >= 0:
                #    if tiles[y - 1][x - 1] not in MATCHING_WANG_TILES[tile_idx]["tl"]:
                #        continue
                        
                tiles[y][x] = tile_idx or 0
                break
                
    return torch.Tensor(tiles).to(torch.int64)#.clamp_min(0)

VF.to_pil_image(render_wang_map(
    wang_template, 
    random_wang_map(
        (10, 24), 
        #exclude=(1, 2, 3, 5, 10, 11),
        #exclude=(15,),
        #include=(0, 1, 4, 2, 8, 15),
    ),
))