In [None]:
import sys
sys.path.append("..")

import random
import math
from io import BytesIO
from pathlib import Path
from collections import OrderedDict
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Union

import PIL.Image
import PIL.ImageDraw
import plotly
import plotly.express as px
plotly.io.templates.default = "plotly_dark"

import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset, RandomSampler
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid
import torchvision.models as VM 
from IPython.display import display

from src.util.image import *
from src.util import *
from src.models.util import *
from src.algo import ca1
from experiments.datasets import rpg_tile_dataset_3x32x32

def resize(img, scale: float, mode: VF.InterpolationMode = VF.InterpolationMode.NEAREST):
    return VF.resize(img, [max(1, int(s * scale)) for s in img.shape[-2:]], mode, antialias=False)

In [None]:
3*64*64 // 100

In [None]:
class CombineImagePatchDataset(IterableDataset):

    def __init__(
            self, 
            ds: Union[Dataset, IterableDataset],
            shape: Tuple[int, int],
            num_combine: int = 4,
            black_is_alpha: bool = True,
    ):
        self.ds = ds
        self.shape = shape
        self.num_combine = num_combine
        self.black_is_alpha = black_is_alpha
        self.sub_transforms = VT.Compose([
            VT.RandomVerticalFlip(.3),
            VT.RandomHorizontalFlip(.3),
        ])

    def __iter__(self):
        patches = []
        for patch in self.ds:
            if isinstance(patch, (list, tuple)):
                patch = patch[0]
            patches.append(patch)

            if len(patches) >= self.num_combine:
                yield from self._iter_patch_combinations(patches)
                patches.clear()

        if len(patches) > 1:
            yield from self._iter_patch_combinations(patches)

    def _iter_patch_combinations(self, patches):
        C, H, W = patches[0].shape[-3:]
        #image = torch.zeros(C, *self.shape)
        for background_idx in range(len(patches)):
            mode = VF.InterpolationMode.NEAREST if torch.randint(0, 2, (1,)).item() else VF.InterpolationMode.BICUBIC
            image = VF.resize(patches[background_idx], self.shape, interpolation=mode, antialias=False)
            for idx in torch.randperm(len(patches)):
                if idx != background_idx:
                    x = torch.randint(0, image.shape[-1] - W, (1,)).item()
                    y = torch.randint(0, image.shape[-2] - H, (1,)).item()
                    patch = self.sub_transforms(patches[idx])
                    if self.black_is_alpha: 
                        mask = (patch.sum(dim=0) > 0).float().unsqueeze(0).expand(C, -1, -1)
                        patch = patch + (1. - mask) * image[:, y: y+H, x: x+W]
                    image[:, y: y+H, x: x+W] = patch
                    
            yield image

ds = CombineImagePatchDataset(rpg_tile_dataset_3x32x32(), shape=(64, 64))
patches = [p for i, p in zip(range(128), ds)]
VF.to_pil_image(resize(make_grid(patches), 1)) 

In [None]:
torch.randint?