In [None]:
import sys
sys.path.append("..")

import random

import time
import math
import random
from io import BytesIO
from pathlib import Path
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Union, Set

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid

import PIL.Image
import PIL.ImageDraw
import plotly
import plotly.express as px
plotly.io.templates.default = "plotly_dark"
import numpy as np
import pandas as pd

from src.datasets import *
from src.util import *
from src.util.image import * 
from src.algo import *
from src.algo.wangtiles import *
from src.datasets.generative import *
from src.models.cnn import *
from src.models.transform import *
from src.util.embedding import *
from src.models.clip import ClipSingleton

In [None]:
wang_template_2c = VF.to_tensor(PIL.Image.open("/home/bergi/prog/python/thegame/thegame/assets/cr31/wang2c.png").convert("RGB"))
wang_template_2e = VF.to_tensor(PIL.Image.open("/home/bergi/prog/python/thegame/thegame/assets/cr31/wang2e.png").convert("RGB"))
wang_template_3e = VF.to_tensor(PIL.Image.open("/home/bergi/prog/python/thegame/thegame/assets/cr31/wang3e.png").convert("RGB"))
#wang_template = VF.to_tensor(PIL.Image.open("/home/bergi/prog/python/thegame/thegame/assets/cr31/path.png"))[:3] * 85
#wang_template = VF.to_tensor(PIL.Image.open("/home/bergi/prog/python/thegame/thegame/assets/w2e_beach.png"))[:3]
display(VF.to_pil_image(wang_template_2e))
display(VF.to_pil_image(wang_template_2c))
display(VF.to_pil_image(wang_template_3e))


In [None]:
def smoothstep(a, b, x):
    x = ((x - a) / (b - a)).clamp(0, 1)
    return x * x * (3. - 2. * x)
    
def render_wang_tile_LOCAL(
    assignments: Tuple[int, int, int, int, int, int, int, int],
    shape: Tuple[int, int],
    padding: float = 0.,
    fade: float = 1.,
    colors: Optional[Tuple[Tuple[int, int, int], ...]] = None,
) -> torch.Tensor:
    if colors is None:
        colors = ((0, 0, 0), (1, 1, 1), (1, 0, 0), (0, 1, 0), (0, 0, 1))
    colors = torch.Tensor(colors)
    
    image = torch.zeros(3, *shape[-2:])
    space = Space2d((2, *shape[-2:])).space()
    space_r = Space2d((2, *shape[-2:]), rotate_2d=torch.pi / 4.).space()
    
    for idx, coords in enumerate([
            space[1] + 1.,
            1. - space_r[1],
            1. - space[0],
            1. - space_r[0],
            1. - space[1],
            space_r[1] + 1.,
            space[0] + 1.,
            space_r[0] + 1.,
    ]):
        if assignments[idx] >= 0:
            c = (coords - padding)
            c = smoothstep(fade, 0., c)
            image = torch.max(image, c.expand(3, -1, -1) * colors[assignments[idx]].reshape(3, 1, 1))

    return image.clamp(0, 1) 

img = render_wang_tile(
    #[1, -1, 1, -1, 1, -1, 1, -1], 
    [1, 0, 2, 0, 3, 0, 4, 0], 
    #[0, 1, 0, 2, 0, 3, 0, 4], 
    #[3, 1, 4, 2, 1, 3, 2, 4], 
    (3, 64, 64),
    fade=1., padding=0.,
)
VF.to_pil_image(img)

In [None]:
from src.algo.wangtiles import *

wang2e = WangTiles2E()
print(wang2e.tiles[0].matching_indices(2))
wang2c = WangTiles2C()
print(wang2c.tiles[0].matching_indices(1))
wang3e = WangTiles3E()
print(wang3e.tiles[0].matching_indices(2))


In [None]:
wt2e = WangTemplate2E(wang_template_2e)
print(wt2e)
wt2c = WangTemplate2C(wang_template_2c)
print(wt2e)
wt3e = WangTemplate3E(wang_template_3e)
print(wt3e)

In [None]:
wang2e.num_colors

In [None]:
def create_template_LOCAL(
        self, 
        tile_shape: Tuple[int, int],
        padding: float = 0.,
        fade: float = 1.0,
):
    tile_images = []
    for i, tile in enumerate(self.tiles):
        tile_images.append(
            render_wang_tile(
                assignments=tile.colors,
                shape=tile_shape,
                padding=padding,
                fade=fade,
            )
        )
    nrow = int(math.sqrt(len(self.tiles)))
    image = make_grid(
        tile_images, 
        padding=0, 
        nrow=nrow,
    )
    indices = torch.linspace(0, nrow * nrow - 1, nrow * nrow).to(torch.int64).view(nrow, nrow)
    indices[indices >= len(self.tiles)] = -1
    return WangTemplate(indices, image)
    
wt2e_2 = wang2e.create_template((32, 32))
print(wt2e_2.indices)
VF.to_pil_image(wt2e_2.image)

In [None]:
from functools import partial
def get_image_window(
        shape: Tuple[int, int], 
        window_function: Callable = partial(torch.hamming_window, periodic=True),
):
    return (
        window_function(shape[-1], periodic=True).unsqueeze(0).expand(shape[-2], -1)
        * window_function(shape[-2], periodic=True).unsqueeze(0).expand(shape[-1], -1).T
    )

def render_wang_map_LOCAL(
        wang_template: WangTemplate,
        tile_indices: torch.Tensor, 
        overlap: Union[int, Tuple[int, int]] = 0,
):
    if isinstance(overlap, int):
        overlap = [overlap, overlap]
    
    tile_shape = wang_template.tile_shape

    if overlap[-2] > tile_shape[-2] or overlap[-1] > tile_shape[-1]:
        raise ValueError(
            f"`overlap` exceeds tile size, got {overlap}, tile shape is {tile_shape}"
        )

    image = torch.zeros(
        wang_template.shape[-3], 
        tile_indices.shape[-2] * (tile_shape[-2] - overlap[-2]) + overlap[-2], 
        tile_indices.shape[-1] * (tile_shape[-1] - overlap[-1]) + overlap[-1],
    )
    if overlap != (0, 0):
        accum = torch.zeros_like(image)
        window = get_image_window(tile_shape[-2:])
        
    for y, row in enumerate(tile_indices):
        for x, tile_idx in enumerate(row):
            if tile_idx < 0:
                continue
                
            template_patch = wang_template.tile(tile_idx)
            
            if overlap == (0, 0):
                image[
                    :, 
                    y * tile_shape[-2]: (y + 1) * tile_shape[-2], 
                    x * tile_shape[-1]: (x + 1) * tile_shape[-1],
                ] = template_patch 
            else:
                sy = slice(y * (tile_shape[-2] - overlap[-2]), (y + 1) * (tile_shape[-2] - overlap[-2]) + overlap[-2]) 
                sx = slice(x * (tile_shape[-1] - overlap[-1]), (x + 1) * (tile_shape[-1] - overlap[-1]) + overlap[-1])
                image[:, sy, sx] = image[:, sy, sx] + window * template_patch
                accum[:, sy, sx] = accum[:, sy, sx] + window
                
    if overlap != (0, 0):
        mask = accum > 0
        image[mask] = image[mask] / accum[mask]

    return image

display(VF.to_pil_image(render_wang_map(
    wt2e,
    wt2e.indices.view(4, 4), 
    overlap=0,
)))
display(VF.to_pil_image(render_wang_map(
    wt2c,
    wt2c.indices.view(4, 4), 
    overlap=0,
)))
display(VF.to_pil_image(render_wang_map(
    wt3e,
    wt3e.indices.view(9, 9), 
    overlap=0,
)))

print(wang2e[0].colors)
print(wang2c[0].colors)
print(wang3e[0].colors)

In [None]:
def wang_map_scanline_stochastic(
        wang_tiles: WangTiles,
        shape: Tuple[int, int],
        include: Optional[Iterable[int]] = None,
        exclude: Optional[Iterable[int]] = None,
) -> torch.Tensor:
    if include is not None:
        include = set(include)
    if exclude is not None:
        exclude = set(exclude)
        
    possible_tiles_set = include.copy() if include is not None else set(range(len(wang_tiles.tiles)))
    if exclude is not None:
        possible_tiles_set -= exclude
    
    tiles = [
        [-1] * shape[-1]
        for _ in range(shape[-2])
    ]
    for y in range(shape[-2]):
        for x in range(shape[-1]):
            
            tile_idx = None
            possible_tiles = list(possible_tiles_set)
            random.shuffle(possible_tiles)
            for tile_idx in possible_tiles:
                
                if "edge" in wang_tiles.mode:
                    if x >= 1 and tiles[y][x - 1] >= 0:
                        if not wang_tiles[tile_idx].matches_left(wang_tiles[tiles[y][x - 1]]):
                            continue
                    if y >= 1 and tiles[y - 1][x] >= 0:
                        if not wang_tiles[tile_idx].matches_top(wang_tiles[tiles[y - 1][x]]):
                            continue
                
                if "corner" in wang_tiles.mode:
                    if x >= 1 and tiles[y][x - 1] >= 0:
                        if wang_tiles[tile_idx].colors[WangTiles.TopLeft] != wang_tiles[tiles[y][x - 1]].colors[WangTiles.TopRight]:
                            continue
                        if wang_tiles[tile_idx].colors[WangTiles.BottomLeft] != wang_tiles[tiles[y][x - 1]].colors[WangTiles.BottomRight]:
                            continue

                    if y >= 1 and tiles[y - 1][x] >= 0:
                        if wang_tiles[tile_idx].colors[WangTiles.TopLeft] != wang_tiles[tiles[y - 1][x]].colors[WangTiles.BottomLeft]:
                            continue
                        if wang_tiles[tile_idx].colors[WangTiles.TopRight] != wang_tiles[tiles[y - 1][x]].colors[WangTiles.BottomRight]:
                            continue
                    
                    if x >= 1 and y >= 1 and tiles[y - 1][x - 1] >= 0:
                        if not wang_tiles[tile_idx].matches_top_left(wang_tiles[tiles[y - 1][x - 1]]):
                            continue
                    
                tiles[y][x] = tile_idx or 0
                break
                
    return torch.Tensor(tiles).to(torch.int64)#.clamp_min(0)

display(VF.to_pil_image(render_wang_map(
    #wt2e, 
    wang2e.create_template((32, 32), padding=0, fade=.2),
    wang_map_scanline_stochastic(
        wang2e,
        (10, 24), 
        #exclude=(1, 2, 3, 5, 10, 11),
        #exclude=(15,),
        #include=(0, 1, 4, 2, 8, 15),
    ),
)))
display(VF.to_pil_image(render_wang_map(
    #wt2e, 
    wang3e.create_template((32, 32), padding=0, fade=.2),
    wang_map_scanline_stochastic(
        wang3e,
        (10, 24), 
    ),
)))
display(VF.to_pil_image(render_wang_map(
    #wt2c,
    wang2c.create_template((32, 32), padding=.333),
    wang_map_scanline_stochastic(
        wang2c,
        (10, 24), 
    ),
)))


In [None]:
def wang_map_fill_stochastic(
        wang_tiles: WangTiles,
        tiles: torch.Tensor,
        include: Optional[Iterable[int]] = None,
        exclude: Optional[Iterable[int]] = None,
) -> None:
    if include is not None:
        include = set(include)
    if exclude is not None:
        exclude = set(exclude)
        
    possible_tiles_set = include.copy() if include is not None else set(range(len(wang_tiles.tiles)))
    if exclude is not None:
        possible_tiles_set -= exclude
    
    empty_positions = []
    for y, row in enumerate(tiles):
        for x, value in enumerate(row):
            if value < 0:
                empty_positions.append((y, x))
    
    random.shuffle(empty_positions)
    
    for y, x in empty_positions:
            
        possible_tiles = list(possible_tiles_set)
        random.shuffle(possible_tiles)
        for tile_idx in possible_tiles:
            
            if "edge" in wang_tiles.mode:
                if x >= 1 and tiles[y][x - 1] >= 0:
                    if not wang_tiles[tile_idx].matches_left(wang_tiles[tiles[y][x - 1]]):
                        continue
                if x + 1 < tiles.shape[-1] and tiles[y][x + 1] >= 0:
                    if not wang_tiles[tile_idx].matches_right(wang_tiles[tiles[y][x + 1]]):
                        continue                
                if y >= 1 and tiles[y - 1][x] >= 0:
                    if not wang_tiles[tile_idx].matches_top(wang_tiles[tiles[y - 1][x]]):
                        continue
                if y + 1 < tiles.shape[-2] and tiles[y + 1][x] >= 0:
                    if not wang_tiles[tile_idx].matches_bottom(wang_tiles[tiles[y + 1][x]]):
                        continue
                        
            if "corner" in wang_tiles.mode:
                if x >= 1 and tiles[y][x - 1] >= 0:
                    if wang_tiles[tile_idx].colors[WangTiles.TopLeft] != wang_tiles[tiles[y][x - 1]].colors[WangTiles.TopRight]:
                        continue
                    if wang_tiles[tile_idx].colors[WangTiles.BottomLeft] != wang_tiles[tiles[y][x - 1]].colors[WangTiles.BottomRight]:
                        continue
                if x + 1 < tiles.shape[-1] and tiles[y][x + 1] >= 0:
                    if wang_tiles[tile_idx].colors[WangTiles.TopRight] != wang_tiles[tiles[y][x + 1]].colors[WangTiles.TopLeft]:
                        continue
                    if wang_tiles[tile_idx].colors[WangTiles.BottomRight] != wang_tiles[tiles[y][x + 1]].colors[WangTiles.BottomLeft]:
                        continue
                
                if y >= 1 and tiles[y - 1][x] >= 0:
                    if wang_tiles[tile_idx].colors[WangTiles.TopLeft] != wang_tiles[tiles[y - 1][x]].colors[WangTiles.BottomLeft]:
                        continue
                    if wang_tiles[tile_idx].colors[WangTiles.TopRight] != wang_tiles[tiles[y - 1][x]].colors[WangTiles.BottomRight]:
                        continue
                if y + 1 < tiles.shape[-2] and tiles[y + 1][x] >= 0:
                    if wang_tiles[tile_idx].colors[WangTiles.BottomLeft] != wang_tiles[tiles[y + 1][x]].colors[WangTiles.TopLeft]:
                        continue
                    if wang_tiles[tile_idx].colors[WangTiles.BottomRight] != wang_tiles[tiles[y + 1][x]].colors[WangTiles.TopRight]:
                        continue
                
            tiles[y][x] = tile_idx
            break
                
tiles = -torch.ones(24, 24).to(torch.int64)
tiles[2:5, 2:8] = 0
tiles[:, 11] = 15
tiles[6, :] = 15
if 1:
    wang_map_fill_stochastic(wang2e, tiles)
    display(VF.to_pil_image(render_wang_map(wt2e, tiles)))
else:
    wang_map_fill_stochastic(wang2c, tiles)
    display(VF.to_pil_image(render_wang_map(wt2c, tiles)))

In [None]:
wang3e = WangTiles(get_wang_tile_colors(3), "edge")
wt3e = wang3e.create_template((32, 32))
VF.to_pil_image(wt3e.image)

In [None]:
c = get_wang_tile_colors(3)
c[8], c[80]
#VF.to_pil_image(wt3e.tile(80))

In [None]:
def get_wang_tile_colors_LOCAL(num_colors: int, num_places: int = 4):
    def _shift(idx, i):
        for _ in range(i):
            idx //= num_colors
        return idx
    ret = []
    for idx in range(num_colors ** num_places):
        ret.append(
            [_shift(idx, i) % num_colors for i in range(num_places)]
        )
    return ret

get_wang_tile_colors(3, 2)

In [None]:
wangx = WangTiles(get_wang_tile_colors(5, 4), mode="c")
wangxt = wangx.create_template((32, 32), padding=.3, fade=.8)
display(VF.to_pil_image(render_wang_map(
    wangxt,
    wang_map_scanline_stochastic(
        wangx,
        (36, 36), 
        #exclude=(1, 2, 3, 5, 10, 11),
        #exclude=(15,),
        #include=(0, 1, 4, 2, 8, 15),
    ),
)))
VF.to_pil_image(wangxt.image)

In [None]:
class RandomWangMap(nn.Module):
    """
    Treats input as wang tile template and renders a random wang map
    """
    def __init__(
            self,
            map_size: Tuple[int, int],
            num_colors: int = 2,
            mode: str = "edge",
            overlap: Union[int, Tuple[int, int]] = 0,
            probability: float = 1.,
    ):
        super().__init__()
        self.map_size = map_size
        self.overlap = overlap
        self.probability = probability
        self.wangtiles = WangTiles(get_wang_tile_colors(num_colors), mode=mode)
        self.template: WangTemplate = None
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:            
        def _render(template):
            if self.probability < 1.:
                if random.random() > self.probability:
                    return template
            
            if self.template is None or self.template.image.shape != template:
                self.template = self.wangtiles.create_template(template.shape)
            self.template.image = template
            
            r = wang_map_stochastic_scanline(self.wangtiles, self.map_size)
            return render_wang_map(
                self.template,
                wang_map_stochastic_scanline(self.wangtiles, self.map_size),
                overlap=self.overlap,
            ).to(x)

        if x.ndim == 3:
            return _render(x)

        elif x.ndim == 4:
            return torch.concat([
                _render(i).unsqueeze(0)
                for i in x
            ])
        else:
            raise ValueError(f"Expected input to have 3 or 4 dimensions, got {x.shape}")

rmap = RandomWangMap((5, 5), num_colors=2, mode="c")
#display(rmap.template.image)
x = WangTiles(get_wang_tile_colors(2), "c").create_template((32, 32)).image
VF.to_pil_image(rmap(x))