In [None]:
import sys
sys.path.append("..")

import random
import math
from io import BytesIO
from pathlib import Path
from typing import Optional, Callable, List, Tuple, Iterable, Generator

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid

from src.datasets import *
from src.util.image import * 
from src.util import ImageFilter
from src.algo import Space2d

In [None]:
def plot_samples(
        iterable, 
        total: int = 32, 
        nrow: int = 8, 
        return_image: bool = False, 
        show_compression_ratio: bool = False,
        label: Optional[Callable] = None,
):
    samples = []
    labels = []
    f = ImageFilter()
    try:
        for image in tqdm(iterable, total=total):
            samples.append(image)
            if show_compression_ratio:
                labels.append(round(f.calc_compression_ratio(image), 3))
            elif label is not None:
                labels.append(label(image))
                
            if len(samples) >= total:
                break
    except KeyboardInterrupt:
        pass
    
    if labels:
        image = VF.to_pil_image(make_grid_labeled(samples, nrow=nrow, labels=labels))
    else:
        image = VF.to_pil_image(make_grid(samples, nrow=nrow))
    if return_image:
        return image
    display(image)

In [None]:
def menger_sponge_2d(
        space: Space2d,
        iterations: int = 3,
        radius: float = .25,
        repeat_size: Union[float, Iterable[float]] = 1.,
        rotate_z_deg: float = 45.,
        scale_factor: float = 2.,
        offset: Iterable[float] = (0., 0.),
        shape: str = "torus",  # "square", "circle", "stripe"
        aa: int = 1,
) -> torch.Tensor:
    offset = torch.Tensor(list(offset)).reshape(2, 1, 1).to(space.dtype)
    if not isinstance(repeat_size, (int, float)):
        repeat_size = torch.Tensor(list(repeat_size)).reshape(2, 1, 1).to(space.dtype)
    rotate_z = rotate_z_deg * 3.14159265 / 180.
    
    def _render(coords: torch.Tensor):
        dist_accum = torch.empty(1, *coords.shape[-2:], dtype=space.dtype).fill_(100000.)

        for iteration in range(iterations):
            l_coords = (coords + repeat_size * .5) % repeat_size - repeat_size * .5
            
            if shape in ("circle", "torus"):
                dist = torch.sqrt(torch.sum(torch.square(l_coords), dim=0, keepdim=True))
                dist1, _ = torch.min(dist - radius, dim=0, keepdim=True)
                dist = dist1
                if shape == "torus":
                    dist2, _ = torch.min(dist - radius * .5, dim=0, keepdim=True)
                    dist = torch.maximum(-dist1, dist2)
                    
            elif shape in ("square", "stripe"):
                if shape == "stripe": 
                    dist = torch.abs(l_coords[0]).unsqueeze(0) 
                else:
                    dist = torch.abs(l_coords) 
                dist, _ = torch.max(dist - radius, dim=0, keepdim=True)
            
            if iteration == 0:
                dist_accum = torch.minimum(dist_accum, -dist)
            else:
                dist_accum = torch.maximum(dist_accum, -dist)
        
            si = math.sin(rotate_z)
            co = math.cos(rotate_z)
            coords = torch.cat([
                (co * coords[1] + si * coords[0]).unsqueeze(0),
                (co * coords[0] - si * coords[1]).unsqueeze(0)
            ])
            coords = coords * scale_factor
            coords = coords + offset
            
        #output = -dist_accum    
        output = 1. - dist_accum * 100.
        return torch.clamp(output, 0, 1)

    if aa <= 1:
        return _render(space.space())
    else:
        return space.reduce_aa_output(aa, _render(space.aa_space(aa)))

space = Space2d((2, 64, 64), offset=torch.Tensor([0,0]), scale=1.)
img = menger_sponge_2d(
    space,
    #iterations=35,
    aa=4,
    #rotate_z_deg=22.5,
    #offset=(.5, .2),
)
img
VF.to_pil_image(VF.resize(img, (img.shape[-2]*3, img.shape[-1]*3), VF.InterpolationMode.NEAREST))

In [None]:
class MengerSponge2dDataset(Dataset):
    available_shapes = ("circle", "square", "stripe", "torus")
    
    def __init__(
        self,
        shape: Tuple[int, int, int],
        size: int = 1000,
        seed: int = 23,
        min_scale: float = .1,
        max_scale: float = 1.,
        min_offset: float = 0.,
        max_offset: float = 0.,
        rotation_steps: int = 4,
        min_radius: float = .1,
        max_radius: float = .25,
        min_iterations: int = 2,
        max_iterations: int = 10,
        min_recursive_scale: float = 1.,
        max_recursive_scale: float = 3.,
        recursive_offset_steps: int = 1,
        recursive_rotation_steps: int = 4,
        shapes: Optional[Iterable[str]] = None,#("circle", "square"),
        dtype: torch.dtype = torch.float,
        aa: int = 0,
    ):
        super().__init__()
        self.shape = shape
        self._size = size
        self.seed = seed
        self.min_scale = min_scale
        self.max_scale = max_scale
        self.min_offset = min_offset
        self.max_offset = max_offset
        self.rotation_steps = rotation_steps
        self.min_recursive_scale = min_recursive_scale
        self.max_recursive_scale = max_recursive_scale
        self.recursive_offset_steps = recursive_offset_steps
        self.recursive_rotation_steps = recursive_rotation_steps
        self.min_radius = min_radius
        self.max_radius = max_radius
        self.min_iterations = min_iterations
        self.max_iterations = max_iterations
        self.shapes = self.available_shapes if shapes is None else list(shapes)
        self.dtype = dtype
        self.aa = aa
        
    def __len__(self) -> int:
        return self._size
    
    def __getitem__(self, idx) -> torch.Tensor:
        rng = torch.Generator().manual_seed(idx ^ self.seed)
        def _rand(mi: float, ma: float) -> float:
            return torch.rand(1, generator=rng, dtype=self.dtype) * (ma - mi) + mi
        
        rotation = math.pi * math.floor(_rand(-1., 1.) * self.rotation_steps) / max(1, self.rotation_steps)
        
        space = Space2d(
            shape=(2, *self.shape[-2:]),
            offset=_rand(self.min_offset, self.max_offset),
            scale=_rand(self.min_scale, self.max_scale),
            rotate_2d=rotation,
            dtype=self.dtype,
        )
        
        menger_shape = self.shapes[torch.randint(0, len(self.shapes), (1,), generator=rng).item()]
        iterations = max(self.min_iterations, min(self.max_iterations, 
            int(.5 + _rand(self.min_iterations, self.max_iterations))
        ))
        radius = _rand(self.min_radius, self.max_radius)
        scale = _rand(self.min_recursive_scale, self.max_recursive_scale)
        rotation = 360. * math.floor(_rand(-1., 1.) * self.recursive_rotation_steps) / max(1, self.recursive_rotation_steps)
        offset = torch.rand(2, dtype=self.dtype, generator=rng)
        offset = torch.floor(offset * self.recursive_offset_steps) / max(1, self.recursive_offset_steps)
        return menger_sponge_2d(
            space=space,
            shape=menger_shape,
            iterations=iterations,
            scale_factor=scale,
            radius=radius,
            rotate_z_deg=rotation,
            offset=offset,
            aa=self.aa,
        )
        
dataset = MengerSponge2dDataset(
    (1, 64, 64), aa=4, size=1_000_000, 
    #min_iterations=17,
    #min_scale=0.01, max_scale=2.,
    #min_offset=-2., max_offset=2.
)
plot_samples(
    dataset,
    nrow=8, total=8*8,
    show_compression_ratio=True,
)

In [None]:
class MengerSponge2dFilteredIterableDataset(IterableDataset):
    def __init__(
        self,
        shape: Tuple[int, int, int],
        size: int = 1000,
        seed: int = 23,
        min_scale: float = .1,
        max_scale: float = 1.,
        min_offset: float = 0.,
        max_offset: float = 0.,
        rotation_steps: int = 4,
        min_radius: float = .1,
        max_radius: float = .25,
        min_iterations: int = 2,
        max_iterations: int = 10,
        min_recursive_scale: float = 1.,
        max_recursive_scale: float = 3.,
        min_recursive_offset: float = 0.,
        recursive_offset_steps: int = 1,
        recursive_rotation_steps: int = 4,
        shapes: Optional[Iterable[str]] = None,
        dtype: torch.dtype = torch.float,
        aa: int = 0,
        image_filter: Optional[ImageFilter] = None,
        filter_shape: Optional[Tuple[int, int, int]] = None,
        filter_aa: Optional[int] = None,

    ):
        kwargs = dict(
            shape=shape,
            size=size,
            seed=seed,
            min_scale=min_scale,
            max_scale=max_scale,
            min_offset=min_offset,
            max_offset=max_offset,
            rotation_steps=rotation_steps,
            min_radius=min_radius,
            max_radius=max_radius,
            min_iterations=min_iterations,
            max_iterations=max_iterations,
            min_recursive_scale=min_recursive_scale,
            max_recursive_scale=max_recursive_scale,
            recursive_offset_steps=recursive_offset_steps,
            recursive_rotation_steps=recursive_rotation_steps,
            shapes=shapes,
            dtype=dtype,
            aa=aa,
        )
        self.dataset = MengerSponge2dDataset(**kwargs)
        self.image_filter = image_filter
        self.filter_dataset = None
        if image_filter is not None:
            if (filter_aa is not None and filter_aa != aa) or (filter_shape is not None and filter_shape != shape):
                if filter_aa is not None:
                    kwargs["aa"] = filter_aa
                if filter_shape is not None:
                    kwargs["shape"] = filter_shape
                self.filter_dataset = MengerSponge2dDataset(**kwargs)
            
    def __iter__(self) -> torch.Tensor:
        for i in range(len(self.dataset)):
            if self.image_filter is not None:
                if self.filter_dataset is not None:
                    image = self.filter_dataset[i]
                else:
                    image = self.dataset[i]

                if not self.image_filter(image):
                    continue

                if self.filter_dataset is not None:
                    image = self.dataset[i]

            else:
                image = self.dataset[i]

            yield image


plot_samples(
    MengerSponge2dFilteredIterableDataset(
        shape=(1, 128, 128),
        aa=4,
        max_iterations=8,
        rotation_steps=3,
        recursive_offset_steps=12,
        image_filter=ImageFilter(
            min_blurred_compression_ratio=.5,
            max_compression_ratio=.5,
        ),
        filter_shape=(1, 32, 32),
        #filter_aa=2,
    ),
    show_compression_ratio=True, total=8*8,
)

# dataset #1

In [None]:
SHAPE = (1, 128, 128)
AA = 4
SIZE = 1_000_000_000

In [None]:
ds1 = MengerSponge2dFilteredIterableDataset(
    shape=SHAPE,
    aa=AA,
    size=SIZE,
    min_iterations=5,
    max_iterations=8,
    rotation_steps=3,
    recursive_offset_steps=12,
    image_filter=ImageFilter(
        max_mean=.2,
        min_blurred_compression_ratio=.3,
        #min_compression_ratio=.1,
        #max_compression_ratio=.3,
    ),
    #filter_shape=(1, 32, 32),
    filter_aa=0,
)
plot_samples(ds1, show_compression_ratio=True, total=8*8)

# dataset #2

In [None]:
ds2 = MengerSponge2dFilteredIterableDataset(
    shape=SHAPE,
    aa=AA,
    size=SIZE,
    min_iterations=3,
    max_iterations=5,
    #rotation_steps=3,
    recursive_offset_steps=12,
    recursive_rotation_steps=12,
    image_filter=ImageFilter(
        #min_mean=.2,
        #min_blurred_compression_ratio=.3,
        #min_compression_ratio=.1,
        max_compression_ratio=.2,
    ),
    #filter_shape=(1, 32, 32),
    #filter_aa=0,
)
plot_samples(ds2, show_compression_ratio=True, total=8*8)

# find unsimilar ones

In [None]:
class ToRGB(nn.Module):
    def forward(self, x):
        return x.repeat(1, 3, 1, 1)

if 0:
    CODE_SIZE = 512
    from scripts.train_from_dataset import EncoderMLP, EncoderTrans

    #model = EncoderMLP((3, *SHAPE[-2:]), channels=[CODE_SIZE])
    #model.load_state_dict(torch.load("../checkpoints/clip2/best.pt")["state_dict"])
    model = EncoderTrans((3, 64, 64), code_size=CODE_SIZE)
    model.load_state_dict(torch.load("../checkpoints/clip5-tr/best.pt")["state_dict"])
    model = nn.Sequential(
        VT.Resize((64, 64), VF.InterpolationMode.BICUBIC),
        ToRGB(),
        model
    )
if 1:
    CODE_SIZE = 512
    import clip
    class ToDevice(nn.Module):
        def forward(self, x):
            return x.half().cuda()
    class FromDevice(nn.Module):
        def forward(self, x):
            return x.cpu().float()
    model, preproc = clip.load("ViT-B/32")
    model = nn.Sequential(
        VT.Resize((224, 224), VF.InterpolationMode.BICUBIC),
        ToRGB(),
        preproc.transforms[-1],
        ToDevice(),
        model.visual,
        FromDevice(),
    )
    

In [None]:
def iter_unsimilar_images(ds, max_dot: float = 0.9, total: int = 10):
    def _iter_image_and_feature():
        for image_batch in DataLoader(ds, batch_size=20):
            features = model(image_batch)
            features = features / features.norm(dim=-1, keepdim=True)
            for image, feature in zip(image_batch, features):
                yield image, feature
                
    count = 0
    count_tried = 0
    last_count_tried = 0
    all_features = None
    
    for image, feature in _iter_image_and_feature():
        count_tried += 1
        
        if all_features is None:
            all_features = feature.unsqueeze(0)
        else:
            dots = feature @ all_features.T
            
            if count_tried - last_count_tried > 1000:
                last_count_tried = count_tried
                print(f"found {count} in {count_tried}, current dots: {dots.tolist()}")
        
            if torch.any(dots >= max_dot):
                continue
                
            all_features = torch.cat([all_features, feature.unsqueeze(0)])

        yield image
        count += 1
        if count >= total:
            break

plot_samples(iter_unsimilar_images(ds1, .95, 32))

In [None]:
plot_samples(iter_unsimilar_images(ds2, .95, 128), total=128)

In [None]:
def store_dataset(
        images: Iterable,
        dtype=torch.float32,
        #image_folder="~/Pictures/__diverse/",
        output_filename="../datasets/photos-32x32-std01.pt",
        max_megabyte=1_000,
):
    tensor_batch = []
    tensor_size = 0
    last_print_size = 0
    for image in tqdm(images):
        if len(image.shape) < 4:
            image = image.unsqueeze(0)
        tensor_batch.append(image.clamp(0, 1))
        tensor_size += math.prod(image.shape) * 4

        if tensor_size - last_print_size > 1024 * 1024 * 50:
            last_print_size = tensor_size

            print(f"size: {tensor_size:,}")

        if tensor_size >= max_megabyte * 1024 * 1024:
            break

    tensor_batch = torch.cat(tensor_batch)
    torch.save(tensor_batch, output_filename)

store_dataset(iter_images())

In [None]:
#ds = TensorDataset(torch.load("../datasets/diverse-32x32-std01.pt"))
ds = TensorDataset(torch.load("../datasets/photos-32x32-std01.pt"))
dl = DataLoader(ds, shuffle=True, batch_size=24*24)
for batch in dl:
    img = VF.to_pil_image(make_grid(batch[0], nrow=24))
    break
img

In [None]:
t = torch.rand(16, 8)
torch.concat([t, t]).shape