In [None]:
import sys
sys.path.append("..")

import random
import math
from io import BytesIO
from pathlib import Path
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Union

import PIL.Image
import PIL.ImageDraw

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid
from IPython.display import display

from src.datasets import *
from src.util.image import *
from src.util import ImageFilter
from src.algo import *

In [None]:
@torch.no_grad()
def make_grid_labeled(
    tensor: Union[torch.Tensor, List[torch.Tensor]],
    labels: Union[bool, Iterable[str]] = True,
    nrow: int = 8,
    padding: int = 2,
    normalize: bool = False,
    value_range: Optional[Tuple[int, int]] = None,
    scale_each: bool = False,
    pad_value: float = 0.0,
    return_pil: bool = False,
    **kwargs,
) -> torch.Tensor:
    grid = make_grid(
        tensor=tensor, nrow=nrow, padding=padding, value_range=value_range, 
        scale_each=scale_each, pad_value=pad_value, 
        **kwargs,
    )
    
    if labels:
        if isinstance(tensor, (list, tuple)):
            num_images = len(tensor)
            shape = tensor[0].shape
        else:
            assert tensor.ndim == 4, f"make_grid_labeled() only supports [N, C, H, W] shape, got '{tensor.shape}'"
            num_images = tensor.shape[0]
            shape = tensor.shape[1:]

        if labels is True:
            labels = [str(i) for i in range(num_images)]
        else:
            labels = [str(i) for i in labels]
        
        grid_pil = VF.to_pil_image(grid)
        draw = PIL.ImageDraw.ImageDraw(grid_pil)
        
        for idx, label in enumerate(labels):
            x = padding + ((idx % nrow) * (shape[-1] + padding))
            y = padding + ((idx // nrow) * (shape[-2] + padding))
            draw.text((x-1, y), label, fill=(0, 0, 0))
            draw.text((x+1, y), label, fill=(0, 0, 0))
            draw.text((x, y-1), label, fill=(0, 0, 0))
            draw.text((x, y+1), label, fill=(0, 0, 0))
            draw.text((x, y), label, fill=(256, 256, 256))
        
        if return_pil:
            return grid_pil
        
        grid = VF.to_tensor(grid_pil)
        
    return grid


In [None]:
def plot_samples(
        iterable, 
        total: int = 32, 
        nrow: int = 8, 
        return_image: bool = False, 
        show_compression_ratio: bool = False,
        label: Optional[Callable] = None,
):
    samples = []
    labels = []
    f = ImageFilter()
    try:
        for image in tqdm(iterable, total=total):
            samples.append(image)
            if show_compression_ratio:
                labels.append(round(f.calc_compression_ratio(image), 3))
            elif label is not None:
                labels.append(label(image))
                
            if len(samples) >= total:
                break
    except KeyboardInterrupt:
        pass
    
    if labels:
        image = VF.to_pil_image(make_grid_labeled(samples, nrow=nrow, labels=labels))
    else:
        image = VF.to_pil_image(make_grid(samples, nrow=nrow))
    if return_image:
        return image
    display(image)

In [None]:
class PatternDataset(Dataset):
    
    shape_types = ("square", "circle")
    fill_types = ("border", "inside")

    def __init__(
            self,
            shape: Tuple[int, int, int],
            size: int = 1_000_000,
            seed: int = 23,
            min_scale: float = 0.5,
            max_scale: float = 2.,
            min_offset: float = -2.,
            max_offset: float = 2.,
            dtype: torch.dtype = torch.float,
            shape_types: Optional[Iterable[str]] = shape_types,
            fill_types: Optional[Iterable[str]] = fill_types,
            aa: int = 0,
    ):
        assert shape[0] in (1, 3), f"Expecting 1 or 3 color channels, got {shape[0]}"
        
        super().__init__()
        self.shape = shape
        self._size = size
        self.seed = seed
        self.min_scale = min_scale
        self.max_scale = max_scale
        self.min_offset = min_offset
        self.max_offset = max_offset
        self.dtype = dtype
        self.aa = aa
        self.shape_types = self.__class__.shape_types if shape_types is None else list(shape_types)
        self.fill_types = self.__class__.fill_types if fill_types is None else list(fill_types)

    def __len__(self) -> int:
        return self._size
        
    def __getitem__(self, idx) -> torch.Tensor:
        rng = torch.Generator().manual_seed(idx ^ self.seed)

        amplitude = torch.sqrt(torch.rand(1, dtype=self.dtype, generator=rng)).item() * 50. 
        radius = torch.rand(1, dtype=self.dtype, generator=rng).item() * .5 
        shape_type = self.shape_types[torch.randint(0, len(self.shape_types), (1,), generator=rng).item()]
        fill_type = self.fill_types[torch.randint(0, len(self.fill_types), (1,), generator=rng).item()]
        offset = torch.rand(2, dtype=self.dtype, generator=rng) * (self.min_offset - self.max_offset) + self.min_offset
        scale = torch.pow(torch.rand(1, dtype=self.dtype, generator=rng)[0], 3.) * (self.max_scale - self.min_scale) + self.min_scale
        rotate_2d = torch.rand(1, dtype=self.dtype, generator=rng) * 6.28
        
        aa = self.aa
        if fill_type != "inside":
            aa = 0

        shape = [2, *self.shape[1:]]
        if aa > 1:
            shape[1] *= aa
            shape[2] *= aa
            
        space = Space2d(
            shape=shape,
            offset=offset,
            scale=scale,
            rotate_2d=rotate_2d,
            dtype=self.dtype,
        ).space()
                            
        # repeat
        image = (space + .5) % 1. - .5
        
        if shape_type == "square":
            image = torch.abs(image) - radius
            if fill_type == "border":
                image, _ = torch.min(image, dim=0, keepdim=True)
            elif fill_type == "inside":
                image, _ = torch.min((image < 0.).to(image.dtype), dim=0, keepdim=True)
            
        elif shape_type == "circle":
            image = torch.sqrt(torch.sum(torch.square(image), dim=0, keepdim=True))
            if fill_type == "border":
                image, _ = torch.min(torch.abs(image - radius), dim=0, keepdim=True)
            elif fill_type == "inside":
                image, _ = torch.min(((image - radius) <= 0.).to(image.dtype), dim=0, keepdim=True)
        
        image = (image * amplitude).clamp(0, 1)
        
        if aa > 1:
            image = VF.resize(image, (shape[1] // aa, shape[2] // aa), VF.InterpolationMode.BICUBIC)
        
        # colors
        rgb = torch.rand(3, 1, 1, dtype=self.dtype, generator=rng) * .7 + .3
        offset = .5#torch.rand(1, dtype=self.dtype, generator=rng).item() * .7
        if self.shape[0] == 3:
            image = (image.repeat(3, 1, 1) + offset) * rgb #- offset
            
        #if torch.rand(1, generator=rng).item() < .5:
        #    image = 1. - image
        return image.clamp(0, 1)
        
        
dataset = PatternDataset(
    (1, 64, 64), aa=2, size=1_000_000, 
    #shape_types=["circle"],
    #min_scale=0.01, max_scale=2.,
    #min_offset=-2., max_offset=2.
)

plot_samples(dataset, show_compression_ratio=True)

In [None]:
SHAPE = (1, 64, 64)
SEED = 5432

# dataset #1

In [None]:
ds_iter_1 = IterableImageFilterDataset(
    PatternDataset(
        SHAPE, aa=2, size=1_000_000_000, 
        seed=SEED,
        min_scale=1., max_scale=2.,
        #min_offset=-2., max_offset=2.,
        shape_types=["square"],
        fill_types=["border"],
    ),
    filter=ImageFilter(
        min_compression_ratio=.08,
        #max_compression_ratio=.9,
        #min_scaled_compression_ratio=.7,
        #scaled_compression_shape=(64, 64),
        #min_blurred_compression_ratio=0.32,
        #blurred_compression_sigma=10.,
        #blurred_compression_kernel_size=[21, 21],
        #compression_format="png",
    )
)
plot_samples(ds_iter_1, show_compression_ratio=True)

# dataset #2

In [None]:
ds_iter_2 = IterableImageFilterDataset(
    PatternDataset(
        SHAPE, aa=2, size=1_000_000_000, 
        seed=SEED + 121,
        min_scale=1., max_scale=2.,
        #min_offset=-2., max_offset=2.,
        shape_types=["square"],
        fill_types=["inside"],
    ),
    filter=ImageFilter(
        min_compression_ratio=.08,
        #max_compression_ratio=.9,
        #min_scaled_compression_ratio=.7,
        #scaled_compression_shape=(16, 16),
        #min_blurred_compression_ratio=0.32,
        #blurred_compression_sigma=10.,
        #blurred_compression_kernel_size=[21, 21],
        #compression_format="png",
    )
)
plot_samples(ds_iter_2, show_compression_ratio=True)

# dataset #3

In [None]:
ds_iter_3 = IterableImageFilterDataset(
    PatternDataset(
        SHAPE, aa=2, size=1_000_000_000, 
        seed=SEED,
        min_scale=1., max_scale=2.,
        #min_offset=-2., max_offset=2.,
        shape_types=["circle"],
        fill_types=["border"],
    ),
    filter=ImageFilter(
        min_compression_ratio=.1,
        #max_compression_ratio=.9,
        #min_scaled_compression_ratio=.7,
        #scaled_compression_shape=(16, 16),
        #min_blurred_compression_ratio=0.32,
        #blurred_compression_sigma=10.,
        #blurred_compression_kernel_size=[21, 21],
        #compression_format="png",
    )
)
plot_samples(ds_iter_3, show_compression_ratio=True)

# dataset #4

In [None]:
ds_iter_4 = IterableImageFilterDataset(
    PatternDataset(
        SHAPE, aa=2, size=1_000_000_000, 
        seed=SEED,
        min_scale=.5, max_scale=2.,
        #min_offset=-2., max_offset=2.,
        shape_types=["circle"],
        fill_types=["inside"],
    ),
    filter=ImageFilter(
        min_compression_ratio=.07,
        #max_compression_ratio=.9,
        #min_scaled_compression_ratio=.7,
        #scaled_compression_shape=(16, 16),
        #min_blurred_compression_ratio=0.32,
        #blurred_compression_sigma=10.,
        #blurred_compression_kernel_size=[21, 21],
        #compression_format="png",
    )
)
plot_samples(ds_iter_4, show_compression_ratio=True)

In [None]:
from src.datasets.interleave import InterleaveIterableDataset
interleaved_dataset = InterleaveIterableDataset(
    datasets=[
        ds_iter_1, ds_iter_2, ds_iter_3, ds_iter_4, 
    ],
    counts=[1, 1, 1, 1],
    shuffle_datasets=True,
)
plot_samples(interleaved_dataset, total=16*16, nrow=16)

## store 64x64 samples image

In [None]:
img = plot_samples(interleaved_dataset, total=64*64, nrow=64, return_image=True)
img.save("/home/bergi/Pictures/pattern-dataset.png")
img.size

In [None]:
dataset_name = f"../datasets/pattern-{SHAPE[-3]}x{SHAPE[-2]}x{SHAPE[-1]}-uint.pt"
dataset_name

In [None]:
def store_dataset(
        images: Iterable,
        output_filename,
        max_megabyte=512,
):
    tensor_batch = []
    tensor_size = 0
    last_print_size = 0
    try:
        for image in tqdm(images):

            image = (image.clamp(0, 1) * 255).to(torch.uint8)

            if len(image.shape) < 4:
                image = image.unsqueeze(0)
            tensor_batch.append(image)
            tensor_size += math.prod(image.shape)

            if tensor_size - last_print_size > 1024 * 1024 * 100:
                last_print_size = tensor_size

                print(f"size: {tensor_size:,}")

            if tensor_size >= max_megabyte * 1024 * 1024:
                break
    
    except KeyboardInterrupt:
        pass
    
    tensor_batch = torch.cat(tensor_batch)
    torch.save(tensor_batch, output_filename)

store_dataset(interleaved_dataset, dataset_name)

In [None]:
ds = TensorDataset(torch.load(dataset_name))

In [None]:
dl = DataLoader(ds, shuffle=True, batch_size=8*8)
for batch in dl:
    batch = batch[0]
    img = VF.to_pil_image(make_grid(batch, nrow=8))
    break
img

In [None]:
display(VF.to_pil_image(Kali2dDataset((3, 16, 16))[111]))
display(VF.to_pil_image(Kali2dDataset((3, 32, 32))[111]))
display(VF.to_pil_image(Kali2dDataset((3, 64, 64))[111]))
display(VF.to_pil_image(Kali2dDataset((3, 256, 256), aa=10)[111]))

In [None]:
img = Kali2dDataset((3, 128, 128))[221]
VF.to_pil_image(img)

In [None]:
from sklearn.decomposition import PCA
def pca_error(img: torch.Tensor, n_components: int = 1) -> float:
    h = img.shape[-2]
    pca = PCA(n_components=n_components)
    data = img.permute(1, 2, 0).reshape(h, -1)
    pca.fit(data)
    f = pca.transform(data)
    r_data = torch.Tensor(pca.inverse_transform(f))
    r_img = r_data.reshape(h, img.shape[-1], 3).permute(2, 0, 1)
    #return ((img - r_img).abs().sum() / math.prod(img.shape))
    #d = (img-r_img).abs()
    #return VF.to_pil_image(d / d.max())
    return VF.to_pil_image(r_img)
    
pca_error(img)