In [None]:
import sys
sys.path.append("..")

import random

import time
import math
import random
from io import BytesIO
from pathlib import Path
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Optional, Union

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid

import PIL.Image
import PIL.ImageDraw
import plotly
import plotly.express as px
plotly.io.templates.default = "plotly_dark"
import numpy as np
import pandas as pd

from src.datasets import *
from src.util import *
from src.util.image import * 
from src.algo import Space2d, IFS
from src.datasets import *
from src.models.cnn import *
from src.util.embedding import *
from src.models.clip import ClipSingleton

In [None]:
3**8

In [None]:
def plot_samples(
        iterable, 
        total: int = 32, 
        nrow: int = 8, 
        return_image: bool = False, 
        show_compression_ratio: bool = False,
        label: Optional[Callable] = None,
):
    samples = []
    labels = []
    f = ImageFilter()
    try:
        for idx, entry in enumerate(tqdm(iterable, total=total)):
            image = entry
            if isinstance(entry, (list, tuple)):
                image = entry[0]
            if image.ndim == 4:
                image = image.squeeze(0)
            samples.append(image)
            if show_compression_ratio:
                labels.append(round(f.calc_compression_ratio(image), 3))
            elif label is not None:
                labels.append(label(entry) if callable(label) else idx)
                
            if len(samples) >= total:
                break
    except KeyboardInterrupt:
        pass
    
    if labels:
        image = VF.to_pil_image(make_grid_labeled(samples, nrow=nrow, labels=labels))
    else:
        image = VF.to_pil_image(make_grid(samples, nrow=nrow, pad_value=1))
    if return_image:
        return image
    display(image)

In [None]:
class RpgTileIterableDataset(IterableDataset):

    def __init__(
            self, 
            shape: Tuple[int, int, int] = (3, 32, 32),
            directory: str = "~/prog/data/game-art/",
            include: Optional[str] = None,
            exclude: Optional[str] = None,
            even_files: Optional[bool] = None,
            interleave: bool = True,
    ):
        self.shape = shape
        self.directory = directory
        self.interleave = interleave
        self.tilesets = [
            dict(name="Castle2.png", shape=(16, 16)),
            dict(name="overworld_tileset_grass.png", shape=(16, 16)),
            dict(name="apocalypse.png", shape=(16, 16)),
            dict(name="PathAndObjects.png", shape=(32, 32)),
            dict(name="mininicular.png", shape=(8, 8)),
            dict(name="items.png", shape=(16, 16)),
            dict(name="roguelikeitems.png", shape=(16, 16), limit_count=181),
            dict(name="tileset_1bit.png", shape=(16, 16)),
            dict(name="MeteorRepository1Icons_fixed.png", shape=(16, 16), offset=(8, 0), stride=(17, 17)),
            dict(name="DENZI_CC0_32x32_tileset.png", shape=(32, 32)),
            dict(name="goodly-2x.png", shape=(32, 32)),
            dict(name="Fruit.png", shape=(16, 16)),
            dict(name="roguelikecreatures.png", shape=(16, 16)),
            dict(name="metroid-like.png", shape=(16, 16), limit=(128, 1000)),
            dict(name="tilesheet_complete.png", shape=(64, 64)),
            dict(name="tiles-map.png", shape=(16, 16)),
            dict(name="base_out_atlas.png", shape=(32, 32)),
            dict(name="build_atlas.png", shape=(32, 32)),
            dict(name="obj_misk_atlas.png", shape=(32, 32)),
            dict(name="Tile-set - Toen's Medieval Strategy (16x16) - v.1.0.png", shape=(16, 16), limit_count=306),
        ]
        if even_files is True:
            self.tilesets = self.tilesets[::2]
        elif even_files is False:
            self.tilesets = self.tilesets[1::2]

        if include is not None:
            self.tilesets = list(filter(
                lambda t: fnmatch.fnmatch(t["name"], include),
                self.tilesets
            ))
        if exclude is not None:
            self.tilesets = list(filter(
                lambda t: not fnmatch.fnmatch(t["name"], exclude),
                self.tilesets
            ))

    def __iter__(self):
        if not self.interleave:
            for params in self.tilesets:
                yield from self._iter_tiles(**params)
        else:
            iterables = [
                self._iter_tiles(**params)
                for params in self.tilesets
            ]
            while iterables:
                next_iterables = []
                for it in iterables:
                    try:
                        yield next(it)
                        next_iterables.append(it)
                    except StopIteration:
                        pass
                iterables = next_iterables

    def _iter_tiles(
            self, name: str,
            shape: Tuple[int, int],
            offset: Tuple[int, int] = None,
            stride: Optional[Tuple[int, int]] = None,
            limit: Optional[Tuple[int, int]] = None,
            limit_count: Optional[int] = None,
            remove_transparent: bool = True,
    ):
        image = VF.to_tensor(PIL.Image.open(
            (Path(self.directory) / name).expanduser()
        ))

        if image.shape[0] != self.shape[0]:
            if image.shape[0] == 4 and remove_transparent:
                image = image[:3] * image[3].unsqueeze(0)
            image = set_image_channels(image[:3], self.shape[0])

        if limit:
            image = image[..., :limit[0], :limit[1]]
        if offset:
            image = image[..., offset[0]:, offset[1]:]

        count = 0
        for patch in iter_image_patches(image, shape, stride=stride):
            if patch.std(1).mean() > 0.:
                #print(patch.std(1).mean())
                patch = VF.resize(patch, self.shape[-2:], VF.InterpolationMode.NEAREST, antialias=False)
                if limit_count is None or count < limit_count:
                    yield patch
                    count += 1
                else:
                    break
                
ds = RpgTileIterableDataset()
plot_samples(ds, total=500, nrow=24)

In [None]:
len(list(ds))

In [None]:
plot_samples(
    IterableShuffle(ds, 10000),
    total=256, nrow=12,
)

In [None]:
ds_validation = SplitIterableDataset(ds, 10)
plot_samples(
    IterableShuffle(ds_validation, 10000),
    total=256, nrow=12,
)

In [None]:
352 / 16
#print(list(SplitIterableDataset(range(10), 4, True)))
#list(SplitIterableDataset(range(10), 4, False))

In [None]:
360 / 5

In [None]:
class RpgTileIterableBootstrapDataset(IterableDataset):

    def __init__(
            self, 
            shape: Tuple[int, int, int] = (3, 32, 32),
            directory: str = "~/prog/data/game-art/DOWN/",
            include: Optional[str] = None,
            exclude: Optional[str] = None,
            even_files: Optional[bool] = None,
            interleave: bool = False,
    ):
        self.shape = shape
        self.directory = directory
        self.interleave = interleave
        self.tilesets = [
            

            #dict(name="", shape=(16, 16)),
            #dict(name="", shape=(16, 16)),
            #dict(name="", shape=(16, 16)),
            #dict(name="", shape=(16, 16)),
            
            #dict(name="apocalypse.png", shape=(16, 16)),
            #dict(name="PathAndObjects.png", shape=(32, 32)),
            #dict(name="mininicular.png", shape=(8, 8)),
            #dict(name="items.png", shape=(16, 16)),
            #dict(name="roguelikeitems.png", shape=(16, 16), limit_count=181),
            #dict(name="tileset_1bit.png", shape=(16, 16)),
            #dict(name="MeteorRepository1Icons_fixed.png", shape=(16, 16), offset=(8, 0), stride=(17, 17)),
            #dict(name="DENZI_CC0_32x32_tileset.png", shape=(32, 32)),
            #dict(name="goodly-2x.png", shape=(32, 32)),
            #dict(name="Fruit.png", shape=(16, 16)),
            #dict(name="roguelikecreatures.png", shape=(16, 16)),
            #dict(name="metroid-like.png", shape=(16, 16), limit=(128, 1000)),
            #dict(name="tilesheet_complete.png", shape=(64, 64)),
            #dict(name="tiles-map.png", shape=(16, 16)),
            #dict(name="base_out_atlas.png", shape=(32, 32)),
            #dict(name="build_atlas.png", shape=(32, 32)),
            #dict(name="obj_misk_atlas.png", shape=(32, 32)),
            #dict(name="Tile-set - Toen's Medieval Strategy (16x16) - v.1.0.png", shape=(16, 16), limit_count=306),
        ]
        if even_files is True:
            self.tilesets = self.tilesets[::2]
        elif even_files is False:
            self.tilesets = self.tilesets[1::2]

        if include is not None:
            self.tilesets = list(filter(
                lambda t: fnmatch.fnmatch(t["name"], include),
                self.tilesets
            ))
        if exclude is not None:
            self.tilesets = list(filter(
                lambda t: not fnmatch.fnmatch(t["name"], exclude),
                self.tilesets
            ))

    def __iter__(self):
        if not self.interleave:
            for params in self.tilesets:
                yield from self._iter_tiles(**params)
        else:
            iterables = [
                self._iter_tiles(**params)
                for params in self.tilesets
            ]
            while iterables:
                next_iterables = []
                for it in iterables:
                    try:
                        yield next(it)
                        next_iterables.append(it)
                    except StopIteration:
                        pass
                iterables = next_iterables

    def _iter_tiles(
            self, 
            name: str,
            shape: Tuple[int, int],
            offset: Tuple[int, int] = None,
            stride: Optional[Tuple[int, int]] = None,
            limit: Optional[Tuple[int, int]] = None,
            count: Optional[int] = None,
            remove_transparent: bool = True,
            ignore_lines: Iterable[int] = None,
            ignore_tiles: Iterable[Tuple[int, int]] = None,
    ):
        if ignore_lines:
            ignore_lines = set(ignore_lines)
        if ignore_tiles:
            ignore_tiles = set(ignore_tiles)
        
        image = PIL.Image.open(
            (Path(self.directory) / name).expanduser()
        )
        if image.mode == "P":
            image = image.convert("RGBA")
        image = VF.to_tensor(image)

        if image.shape[0] != self.shape[0]:
            if image.shape[0] == 4 and remove_transparent:
                image = image[:3] * image[3].unsqueeze(0)

            image = set_image_channels(image[:3], self.shape[0])

        if limit:
            image = image[..., :limit[0], :limit[1]]
        if offset:
            image = image[..., offset[0]:, offset[1]:]
        
        limit_count = count
        count = 0
        for patch, pos in iter_image_patches(image, shape, stride=stride, with_pos=True):
            pos = tuple(int(p) // s for p, s in zip(pos, shape))
            
            if ignore_lines and pos[0] in ignore_lines:
                continue
            if ignore_tiles and pos in ignore_tiles:
                continue
            
            if patch.std(1).mean() > 0.:            
                #print(patch.std(1).mean())
                patch = VF.resize(patch, self.shape[-2:], VF.InterpolationMode.NEAREST, antialias=False)
                if limit_count is None or count < limit_count:
                    yield patch
                    count += 1
                else:
                    break
                
ds = RpgTileIterableBootstrapDataset((3, 64, 64))
plot_samples(ds, total=5000, nrow=24)

In [None]:
import tempfile
tempfile.gettempdir