In [None]:
import sys
sys.path.append("..")

import random
import math
from io import BytesIO
from pathlib import Path
from typing import Optional, Callable, List, Tuple, Iterable, Generator

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid

from src.datasets import *
from src.util.image import * 
from src.algo import Space2d

## p = abs(p) / dot(p, p) - v

In [None]:
def kali2d(
        space: Space2d,
        param: torch.Tensor,
        iterations: int = 11,
        out_weights: Optional[torch.Tensor] = None,
        accumulate: str = "submin",  # none, mean, max, min, submin, alternate
        exponent: float = 0,
        sin_freq: float = 0,
        aa: int = 0,
) -> torch.Tensor:
    param = param.reshape(-1, 1, 1)

    def _render(space: torch.Tensor) -> torch.Tensor:
        if accumulate == "none":
            pass
        elif accumulate in ("min", "submin"):
            accum = torch.ones_like(space) * iterations
        else:
            accum = torch.zeros_like(space)

        for iteration in range(iterations):
            dot_prod = torch.sum(space * space, dim=0, keepdim=True) + 0.000001
            space = torch.abs(space) / dot_prod
            
            a_space = space
            if exponent:
                a_space = torch.exp(-a_space * exponent)
                
            if sin_freq:
                a_space = torch.sin(a_space * sin_freq)

            if accumulate == "mean":
                accum = accum + a_space

            elif accumulate == "max":
                accum = torch.max(a_space, accum)

            elif accumulate == "min":
                accum = torch.min(a_space, accum)
            
            elif accumulate == "submin":
                accum = accum - torch.min(accum, a_space)

            elif accumulate == "alternate":
                accum = accum + (a_space if iteration % 2 == 0 else -a_space)
                        
            if iteration < iterations - 1:
                space = space - param

        if accumulate == "none":
            output = a_space 
        elif accumulate == "min":
            output = accum * iterations
        else:
            output = accum / iterations
            if accumulate == "alternate":
                output = output * 2

        return output

    if aa and aa > 1:
        s = space.shape
        aa_space = space.space().repeat(1, aa, aa)
        for x in range(aa):
            for y in range(aa):
                if x or y:
                    aa_space[1, y*s[-2]:(y+1)*s[-2], x*s[-1]:(x+1)*s[-1]] += (y / aa / s[-2]) * space.scale
                    aa_space[0, y*s[-2]:(y+1)*s[-2], x*s[-1]:(x+1)*s[-1]] += (x / aa / s[-1]) * space.scale

        output = _render(aa_space)
        for x in range(0, aa):
            for y in range(0, aa):
                if x or y:
                    output[:, :s[-2], :s[-1]] = output[:, :s[-2], :s[-1]] + output[:, y*s[-2]:(y+1)*s[-2], x*s[-1]:(x+1)*s[-1]]
        output = output[:, :s[-2], :s[-1]] / (aa * aa)
       
    else:
        output = _render(space.space())

    if out_weights is not None:
        a = output.permute(1, 2, 0).reshape(-1, 3)
        output = torch.matmul(a, out_weights).reshape((output.shape[1], output.shape[2], output.shape[0])).permute(2, 0, 1)

    return torch.clamp(output, 0, 1)



space = Space2d((3, 300, 300), offset=torch.Tensor([0,0,0]), scale=1.)
#VF.to_pil_image(
img = kali2d(space, param=torch.Tensor([.5, .5, .5]))#, accumulate="x", aa=2)#out_weights=torch.randn((3, 3)) / 2.)
VF.to_pil_image(img[:3])

In [None]:
img = kali2d(
    Space2d(
        shape=(3, 128, 128), 
        offset=torch.Tensor([0.7,0,0]), 
        scale=.01,
    ), 
    param=torch.Tensor([.75, .75, .75]),
    iterations=21,
    #out_weights=torch.rand((3, 3)),
    accumulate="x",
    aa=10,
)
#img = VF.resize(img, [512, 512], interpolation=VF.InterpolationMode.BICUBIC)
img = VF.resize(img, [1024, 1024], interpolation=VF.InterpolationMode.NEAREST)
VF.to_pil_image(img)

In [None]:
class Kali2dDataset(Dataset):
    accumulation_choices = ["none", "mean", "min", "max"]
    def __init__(
        self,
        shape: Tuple[int, int, int],
        size: int = 1000,
        seed: int = 23,
        min_scale: float = 0.,
        max_scale: float = 2.,
        min_offset: float = -2.,
        max_offset: float = 2.,
        min_iterations: int = 1,
        max_iterations: int = 37,
        accumulation_modes: Optional[Iterable[str]] = None,
        dtype: torch.dtype = torch.float,
        aa: int = 0,
    ):
        super().__init__()
        self.shape = shape
        self._size = size
        self.seed = seed
        self.min_scale = min_scale
        self.max_scale = max_scale
        self.min_offset = min_offset
        self.max_offset = max_offset
        self.min_iterations = min_iterations
        self.max_iterations = max_iterations
        self.accumulation_modes = self.accumulation_choices if accumulation_modes is None else list(accumulation_modes)
        self.dtype = dtype
        self.aa = aa
        
    def __len__(self) -> int:
        return self._size
    
    def __getitem__(self, idx) -> torch.Tensor:
        rng = torch.Generator().manual_seed(idx ^ self.seed)
        
        space = Space2d(
            shape=self.shape,
            offset=torch.rand(self.shape[0], dtype=self.dtype, generator=rng) * (self.min_offset - self.max_offset) + self.min_offset,
            scale=torch.pow(torch.rand(1, dtype=self.dtype, generator=rng)[0], 3.) * (self.max_scale - self.min_scale) + self.min_scale,
            dtype=self.dtype,
        )
        
        param=torch.rand(self.shape[0], dtype=self.dtype, generator=rng) * 1.2
        accumulate = self.accumulation_modes[torch.randint(0, len(self.accumulation_modes), (1,), generator=rng)[0]]
        iterations=max(self.min_iterations, min(self.max_iterations, 
            int(torch.randint(self.min_iterations, self.max_iterations, (1,), generator=rng)[0]) + int(1. / space.scale)
        ))
        out_weights = (
            torch.rand((self.shape[0], self.shape[0]), dtype=self.dtype, generator=rng) / math.sqrt(self.shape[0])
            + torch.randn((self.shape[0], self.shape[0]), dtype=self.dtype, generator=rng) * .2
        )
        return kali2d(
            space=space,
            param=param,
            iterations=iterations,
            accumulate=accumulate,
            out_weights=out_weights,
            aa=self.aa,
        )
        
dataset = Kali2dDataset(
    (3, 64, 64), aa=2, size=1_000_000, 
    accumulation_modes=["min", "max"],
    min_iterations=17,
    min_scale=0.01, max_scale=2.,
    min_offset=-2., max_offset=2.
)

VF.to_pil_image(make_grid(
    [dataset[i] for i in range(16*16)],
    nrow=16
))

In [None]:
class IterableImageFilterDataset(IterableDataset):
    
    def __init__(
        self, 
        dataset: Union[IterableDataset, Dataset],
        min_mean: float = 0.,
        max_mean: float = 0.,
        min_std: float = 0.,
        max_std: float = 0.,
        min_compression_ratio: float = 0.,
        max_compression_ratio: float = 0.,
        min_scaled_compression_ratio: float = 0.,
        max_scaled_compression_ratio: float = 0.,
        scaled_compression_shape: Iterable[int] = (16, 16),
        min_blurred_compression_ratio: float = 0.,
        max_blurred_compression_ratio: float = 0.,
        blurred_compression_kernel_size: Iterable[int] = (11, 11),
        blurred_compression_sigma: float = 10.,
        compression_format: str = "png",
    ):
        self.dataset = dataset
        self.min_mean = min_mean
        self.max_mean = max_mean
        self.min_std = min_std
        self.max_std = max_std
        self.min_compression_ratio = min_compression_ratio
        self.max_compression_ratio = max_compression_ratio
        self.min_scaled_compression_ratio = min_scaled_compression_ratio
        self.max_scaled_compression_ratio = max_scaled_compression_ratio
        self.scaled_compression_shape = list(scaled_compression_shape)
        self.min_blurred_compression_ratio = min_blurred_compression_ratio
        self.max_blurred_compression_ratio = max_blurred_compression_ratio
        self.blurred_compression_kernel_size = blurred_compression_kernel_size
        self.blurred_compression_sigma = blurred_compression_sigma
        self.compression_format = compression_format
        
    def __iter__(self) -> Generator[torch.Tensor, None, None]:
        for image in self._iter_dataset():

            if self.min_mean or self.max_mean:
                mean = image.mean()
                if self.min_mean and mean < self.min_mean:
                    continue
                if self.max_mean and mean > self.max_mean:
                    continue

            if self.min_std or self.max_std:
                std = image.std()
                if self.min_std and std < self.min_std:
                    continue
                if self.max_std and std > self.max_std:
                    continue
                    
            if self.min_compression_ratio or self.max_compression_ratio:
                ratio = self.calc_compression_ratio(image)
                if self.min_compression_ratio and ratio < self.min_compression_ratio:
                    continue
                if self.max_compression_ratio and ratio > self.max_compression_ratio:
                    continue

            if self.min_scaled_compression_ratio or self.max_scaled_compression_ratio:
                ratio = self.calc_compression_ratio(
                    VF.resize(image, self.scaled_compression_shape, interpolation=VF.InterpolationMode.BICUBIC)
                )
                if self.min_scaled_compression_ratio and ratio < self.min_scaled_compression_ratio:
                    continue
                if self.max_scaled_compression_ratio and ratio > self.max_scaled_compression_ratio:
                    continue

            if self.min_blurred_compression_ratio or self.max_blurred_compression_ratio:
                ratio = self.calc_compression_ratio(
                    VF.gaussian_blur(image, self.blurred_compression_kernel_size, self.blurred_compression_sigma)
                )
                if self.min_blurred_compression_ratio and ratio < self.min_blurred_compression_ratio:
                    continue
                if self.max_blurred_compression_ratio and ratio > self.max_blurred_compression_ratio:
                    continue
                #print(ratio)
                
            yield image

    def _iter_dataset(self) -> Generator[torch.Tensor, None, None]:
        if isinstance(self.dataset, Dataset):
            for i in range(len(self.dataset)):
                yield self.dataset[i]
        else:
            yield from self.dataset
    
    def calc_compression_ratio(self, image: torch.Tensor) -> float:
        img = VF.to_pil_image(image)
        fp = BytesIO()
        img.save(fp, self.compression_format)
        memory_size = math.prod(image.shape)
        compress_size = fp.tell()
        return compress_size / memory_size

            
ds_iter = IterableImageFilterDataset(
    dataset,
    #min_mean=.2,
    max_mean=.3,
    #min_std=.4,
    #max_std=.3,
    #min_compression_ratio=.5,
    max_compression_ratio=.9,
    #min_scaled_compression_ratio=.7,
    #scaled_compression_shape=(16, 16),
    min_blurred_compression_ratio=.3,
    #min_blurred_compression_ratio=.32,
    #blurred_compression_sigma=10.,
    #blurred_compression_kernel_size=[21, 21],
)
samples = []
total = 16 * 16
for image in tqdm(ds_iter, total=total):
    samples.append(image)
    if len(samples) >= total:
        break

VF.to_pil_image(make_grid(samples, nrow=16))

In [None]:
VF.to_pil_image(VF.gaussian_blur(dataset[55], [11, 11], 10))

In [None]:
def check_image(image) -> bool:
    if torch.all(image.reshape(3, -1).std(1) < 0.1):
        return False
    return True

def iter_images():
    cropper = VT.RandomCrop(target_shape[-2:])
    def _iter_crops(image, min_size: int):
        count = 2 + max(1, min(40, (min_size - 400) // 200))
        min_scale = max(.05, 1. - min_size / 400)
        #print(min_size, min_scale)
        num_yielded = 0
        num_tried = 0
        while num_yielded < count and num_tried < count * 5:
            img = image
            #image = VT.RandomAffine(degrees=30, scale=[2, 2])(image)
            scale = min_scale + math.pow(random.random() * (1. - min_scale), 10.)
            #if scale < random.random():
            #    img = VT.RandomPerspective(distortion_scale=.7)(img)
            #    crop_x = max(target_shape[-1], img.shape[-1] // 5)
            #    crop_y = max(target_shape[-2], img.shape[-2] // 5)
            #    img = VF.crop(img, crop_y // 2, crop_x // 2, img.shape[-2] - crop_y, img.shape[-1] - crop_x)
            img = VF.resize(img, [
                max(target_shape[-2], int(image.shape[-2] * scale)), 
                max(target_shape[-1], int(image.shape[-1] * scale)),
            ])
            if random.random() < .5:
                center = center=[random.randrange(target_shape[-2]), random.randrange(target_shape[-2])]
                img = VT.RandomRotation(30, center=center)(img)
            img = cropper(img)
            
            num_tried += 1
            if check_image(img):
                yield img
                num_yielded += 1
            
    for base_image in base_images:
        image = set_image_channels(base_image, target_shape[0])
        #yield image
        yield image_resize_crop(image, target_shape[-2:])
        
        min_size = min(*image.shape[-2:])
        if min_size >= 200:
            yield from _iter_crops(image, min_size)
        
samples = []
total = 16*16
for i in tqdm(iter_images(), total=total):
    i = i.clamp(0, 1)
    if check_image(i):
        samples.append(i)
    if len(samples) >= total:
        break
VF.to_pil_image(make_grid(samples, nrow=16))

In [None]:
samples[15].view(3, -1).std(1)

In [None]:
def store_dataset(
        images: Iterable,
        dtype=torch.float32,
        #image_folder="~/Pictures/__diverse/",
        output_filename="../datasets/photos-32x32-std01.pt",
        max_megabyte=1_000,
):
    tensor_batch = []
    tensor_size = 0
    last_print_size = 0
    for image in tqdm(images):
        if len(image.shape) < 4:
            image = image.unsqueeze(0)
        tensor_batch.append(image.clamp(0, 1))
        tensor_size += math.prod(image.shape) * 4

        if tensor_size - last_print_size > 1024 * 1024 * 50:
            last_print_size = tensor_size

            print(f"size: {tensor_size:,}")

        if tensor_size >= max_megabyte * 1024 * 1024:
            break

    tensor_batch = torch.cat(tensor_batch)
    torch.save(tensor_batch, output_filename)

store_dataset(iter_images())

In [None]:
#ds = TensorDataset(torch.load("../datasets/diverse-32x32-std01.pt"))
ds = TensorDataset(torch.load("../datasets/photos-32x32-std01.pt"))
dl = DataLoader(ds, shuffle=True, batch_size=24*24)
for batch in dl:
    img = VF.to_pil_image(make_grid(batch[0], nrow=24))
    break
img

In [None]:
t = torch.rand(16, 8)
torch.concat([t, t]).shape