In [None]:
import sys
sys.path.append("..")

import random

import time
import math
import random
from io import BytesIO
from pathlib import Path
from typing import Optional, Callable, List, Tuple, Iterable, Generator

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid

import PIL.Image
import PIL.ImageDraw
import plotly
import plotly.express as px
plotly.io.templates.default = "plotly_dark"
import numpy as np
import pandas as pd

from src.datasets import *
from src.util.image import * 
from src.util import ImageFilter
from src.algo import Space2d

In [None]:
class IFS_Torch:
    
    def __init__(self, seed: Optional[int] = None, num_parameters: int = 5):
        self.rng = torch.Generator().manual_seed(seed if seed is not None else int(time.time() * 1000))
        self.parameters = torch.rand((num_parameters, 6), generator=self.rng) * 2. - 1.
        self.probabilities = torch.rand((num_parameters, ), generator=self.rng)
        
    def iter_coordinates(self, num_iterations: int) -> Generator[Tuple[float, float], None, None]:
        x, y = 0., 0.
        for iteration in range(num_iterations):
            param_index = None
            while param_index is None:
                idx = torch.randint(0, self.parameters.shape[0], (1,), generator=self.rng).item()
                if torch.rand(1, generator=self.rng).item() < self.probabilities[idx]:
                    param_index = idx
            
            a, b, c, d, e, f = self.parameters[param_index]
            
            x, y = (
                x * a + y * b + e,
                x * c + y * d + f
            )
            
            yield x, y

    def render_coordinates(self, shape: Tuple[int, int], num_iterations: int, padding: int = 2) -> torch.Tensor:
        coords = torch.Tensor(list(ifs.iter_coordinates(num_iterations)))
        min_x, max_x = coords[:, 0].min(), coords[:, 0].max()
        min_y, max_y = coords[:, 1].min(), coords[:, 1].max()
        coords[:, 0] = (coords[:, 0] - min_x) / (max_x - min_x) * (shape[-1] - padding * 2)
        coords[:, 1] = (coords[:, 1] - min_y) / (max_y - min_y) * (shape[-2] - padding * 2)
        return coords.to(torch.int16)
    
    def render_image_tensor(self, shape: Tuple[int, int], num_iterations: int, padding: int = 2) -> torch.Tensor:
        coords = self.render_coordinates(shape, num_iterations, padding)
        image = torch.zeros((1, *shape))
        for x, y in coords:
            image[0, y, x] = 1
            
        return image.clamp(0, 1)
        
ifs = IFS_Torch()
start_time = time.time()
img = ifs.render_image_tensor((256, 256), 100_000)
seconds = time.time() - start_time
print(seconds)
VF.to_pil_image(img)
#coords = np.array(list(ifs.iter_coordinates(100)))
#px.scatter(x=coords[:, 0], y=coords[:, 1], height=400, width=400)

In [None]:
def plot_samples(
        iterable, 
        total: int = 32, 
        nrow: int = 8, 
        return_image: bool = False, 
        show_compression_ratio: bool = False,
        label: Optional[Callable] = None,
):
    samples = []
    labels = []
    f = ImageFilter()
    try:
        for entry in tqdm(iterable, total=total):
            image = entry
            if isinstance(entry, (list, tuple)):
                image = entry[0]
            if image.ndim == 4:
                image = image.squeeze(0)
            samples.append(image)
            if show_compression_ratio:
                labels.append(round(f.calc_compression_ratio(image), 3))
            elif label is not None:
                labels.append(label(entry))
                
            if len(samples) >= total:
                break
    except KeyboardInterrupt:
        pass
    
    if labels:
        image = VF.to_pil_image(make_grid_labeled(samples, nrow=nrow, labels=labels))
    else:
        image = VF.to_pil_image(make_grid(samples, nrow=nrow))
    if return_image:
        return image
    display(image)

# IFS class

In [None]:
class IFS:
    max_coordinate = 1e10
    
    def __init__(
            self, 
            seed: Optional[int] = None, 
            num_parameters: int = 2,
            parameters: Optional[np.ndarray] = None,
            probabilities: Optional[np.ndarray] = None,
    ):
        self.rng = np.random.Generator(np.random.MT19937(
            seed if seed is not None else random.randint(0, int(1e10))
        ))
        self.rng.bytes(100)
        self.parameters = self.rng.uniform(-1., 1., (num_parameters, 6))
        self.probabilities = self.rng.uniform(0., 1., (num_parameters, ))
        if parameters is not None:
            self.parameters = parameters
        if probabilities is not None:
            self.probabilities = probabilities
        
    def iter_coordinates(self, num_iterations: int) -> Generator[Tuple[float, float], None, None]:
        x, y = 0., 0.
        for iteration in range(num_iterations):
            param_index = None
            while param_index is None:
                idx = self.rng.integers(0, self.parameters.shape[0])
                if self.rng.uniform(0., 1.) < self.probabilities[idx]:
                    param_index = idx
            
            a, b, c, d, e, f = self.parameters[param_index]
            
            x, y = (
                x * a + y * b + e,
                x * c + y * d + f
            )
            if np.abs(x) > self.max_coordinate or np.abs(y) > self.max_coordinate:
                #print(f"early stop at iteration {iteration}")
                break
                
            if not (np.isnan(x) or np.isnan(y) or np.isinf(x) or np.isinf(y)):
                yield x, y
            else:
                #print(f"early stop at iteration {iteration}")
                break

    def render_coordinates(self, shape: Tuple[int, int], num_iterations: int, padding: int = 2) -> np.ndarray:
        coords = np.array(list(self.iter_coordinates(num_iterations)))
        min_x, max_x = coords[:, 0].min(), coords[:, 0].max()
        min_y, max_y = coords[:, 1].min(), coords[:, 1].max()
        if max_x != min_x:
            coords[:, 0] = (coords[:, 0] - min_x) / (max_x - min_x) * (shape[-1] - padding * 2) + padding
        if max_y != min_y:
            coords[:, 1] = (coords[:, 1] - min_y) / (max_y - min_y) * (shape[-2] - padding * 2) + padding
        return coords.astype(np.uint16)
    
    def render_image(
            self, 
            shape: Tuple[int, int], 
            num_iterations: int, 
            padding: int = 2,
            alpha: float = 0.1,
            patch_size: int = 1,
    ) -> np.ndarray:
        
        extra_padding = 0
        if patch_size > 1:
            extra_padding = patch_size
            shape = (shape[-2] + extra_padding * 2, shape[-1] + extra_padding * 2)
            
        coords = self.render_coordinates(shape, num_iterations, padding + extra_padding)
        image = np.zeros((1, *shape))
        
        if patch_size <= 1:
            for x, y in coords:
                image[0, y, x] += alpha
                
        else:
            half_patch_size = patch_size // 2
            patch = np.hamming(patch_size).repeat(patch_size).reshape(patch_size, patch_size)
            patch *= patch.T * alpha
            
            for x, y in coords:
                x -= half_patch_size
                y -= half_patch_size
                image[0, y:y + patch_size, x:x + patch_size] += patch
            
            
            image = image[:, extra_padding:-extra_padding, extra_padding:-extra_padding]
        
        return image.clip(0, 1)
    
images = []
for i in tqdm(range(8)):
    ifs = IFS(seed=i+0)
    img = ifs.render_image((128, 128), 10_000)
    images.append(torch.Tensor(img))

grid = VF.to_pil_image(make_grid_labeled(images, nrow=8, labels=True))
grid

# num-parameters variation

In [None]:
nrow = 32
images = []
for i in tqdm(range(nrow * 4)):
    ifs = IFS(num_parameters=2 + i // nrow)
    img = ifs.render_image((128, 128), 10_000)
    images.append(torch.Tensor(img))
#print(seconds)
grid = VF.to_pil_image(make_grid(images, nrow=nrow))
grid

# class variations

In [None]:
images = []
for i in tqdm(range(8*4)):
    if i % 8 == 0:
        ifs = IFS()
    if i % 8 == 7:
        images.append(
            signed_to_image(images[-7] - images[-1]).clamp(0, 1)
        )
    else:
        img = ifs.render_image((128, 128), 10_000)
        images.append(torch.Tensor(img).repeat(3, 1, 1))

grid = VF.to_pil_image(make_grid(images, nrow=8))
grid

# parameter permutations

In [None]:
images = []
for i in tqdm(range(8*8)):
    ifs = IFS(seed=i // 8)
    ifs.parameters += 0.03 * np.random.uniform(-1., 1., ifs.parameters.shape)
    if i % 8 == 7:
        images.append(
            signed_to_image(images[-7] - images[-1]).clamp(0, 1)
        )
    else:
        img = ifs.render_image((128, 128), 10_000)
        images.append(torch.Tensor(img).repeat(3, 1, 1))

grid = VF.to_pil_image(make_grid(images, nrow=8))
grid

# iteration depth

In [None]:
images = []
ITERATIONS = [100, 1000, 10_000, 50_000, 100_000, 200_000, 500_000, 1_000_000]
for i in tqdm(range(8*4)):
    ifs = IFS(seed=i // 8)
    img = ifs.render_image((128, 128), ITERATIONS[i % 8])
    images.append(torch.Tensor(img))

grid = VF.to_pil_image(make_grid(images, nrow=8))
grid

# CLIP guidance

In [None]:
import clip as cliplib
CODE_SIZE = 512
class ToRGB(nn.Module):
    def forward(self, x):
        return x.repeat(1, 3, 1, 1)
class ToDevice(nn.Module):
    def forward(self, x):
        return x.half().cuda()
class FromDevice(nn.Module):
    def forward(self, x):
        return x.cpu().to(torch.float32)
clip, preproc = cliplib.load("ViT-B/32")
encoder = nn.Sequential(
    VT.Resize((224, 224), VF.InterpolationMode.BICUBIC),
    ToRGB(),
    preproc.transforms[-1],
    ToDevice(),
    clip.visual,
    FromDevice(),
)

In [None]:
with torch.no_grad():
    target_features = clip.encode_text(cliplib.tokenize([
        "drawing of a building", #"leave", #"bird"
    ]).cuda()).cpu().float()
    
    target_features /= target_features.norm(dim=-1, keepdim=True)
    
    for seed in tqdm(range(1000)):
        ifs = IFS(seed=seed)
        image = torch.Tensor(ifs.render_image((32, 32), 1000, alpha=.5))
        features = encoder(image.unsqueeze(0))
        features /= features.norm(dim=-1, keepdim=True)
        dots = features @ target_features.T
        if torch.any(dots > .23):
            image = torch.Tensor(ifs.render_image((128, 128), 10000, alpha=.2))
            print(dots)
            display(VF.to_pil_image(image))
            

# dataset generation

In [None]:
class IFSClassIterableDataset(IterableDataset):
    def __init__(
        self,
        shape: Tuple[int, int],
        num_classes: int,
        num_instances_per_class: int = 1,
        num_iterations: int = 10_000,
        seed: Optional[int] = None,
        alpha: float = 0.1,
        patch_size: int = 1,
        parameter_variation: float = 0.03,
        parameter_variation_max: Optional[float] = None,
        alpha_variation: float = 0.05,
        patch_size_variations: Optional[Iterable[int]] = None,
        num_iterations_variation: int = 0,
        image_filter: Optional[ImageFilter] = None,
        filter_num_iterations: Optional[int] = None,
        filter_shape: Optional[Tuple[int, int]] = None,
        filter_alpha: Optional[float] = None,
    ):
        self.shape = shape
        self.num_classes = num_classes
        self.num_instances_per_class = num_instances_per_class
        self.num_iterations = num_iterations
        self.seed = seed if seed is not None else random.randint(0, 1e10)
        self.alpha = alpha
        self.patch_size = patch_size
        self.parameter_variation = parameter_variation
        self.parameter_variation_max = parameter_variation_max
        self.alpha_variation = alpha_variation
        self.patch_size_variations = list(patch_size_variations) if patch_size_variations is not None else None
        self.num_iterations_variation = num_iterations_variation
        
        self.image_filter = image_filter
        self.filter_num_iterations = filter_num_iterations
        self.filter_shape = filter_shape
        self.filter_alpha = filter_alpha
        
        self.rng = np.random.Generator(np.random.MT19937(
            seed if seed is not None else random.randint(0, int(1e10))
        ))
        self.rng.bytes(42)
    
    def __len__(self) -> int:
        return self.num_classes * self.num_instances_per_class
    
    def _iter_class_seeds(self) -> Generator[int, None, None]:
        class_index = 0
        class_count = 0
        while class_count < self.num_classes:
            seed = class_index ^ self.seed
            
            ifs = IFS(seed=seed)
            class_index += 1
            
            if self.image_filter is not None:
            
                image = torch.Tensor(ifs.render_image(
                    shape=self.filter_shape or self.shape,
                    num_iterations=self.filter_num_iterations or self.num_iterations,
                    alpha=self.filter_alpha or self.alpha,
                    patch_size=self.patch_size,
                ))
                if not self.image_filter(image):
                    continue
            
            yield seed
            class_count += 1
        
    def __iter__(self) -> Generator[Tuple[torch.Tensor, int], None, None]:
        for class_index, seed in enumerate(self._iter_class_seeds()):
            
            instance_count = 0
            base_mean = None
            while instance_count < self.num_instances_per_class:
                ifs = IFS(seed=seed)
                
                alpha = self.alpha
                patch_size = self.patch_size
                num_iterations = self.num_iterations
                
                if instance_count > 0:
                    t = (instance_count + 1) / self.num_instances_per_class
                    
                    amt = self.parameter_variation
                    if self.parameter_variation_max is not None:
                        amt = amt * (1. - t) + t * self.parameter_variation_max
                        
                    ifs.parameters += amt* self.rng.uniform(-1., 1., ifs.parameters.shape)
                    alpha = max(.001, alpha + self.alpha_variation * self.rng.uniform(-1., 1.))
                    if self.patch_size_variations is not None:
                        patch_size = self.patch_size_variations[self.rng.integers(len(self.patch_size_variations))]
                    if self.num_iterations_variation:
                        num_iterations += self.rng.integers(self.num_iterations_variation)
                    
                image = torch.Tensor(ifs.render_image(
                    shape=self.shape,
                    num_iterations=num_iterations,
                    alpha=alpha,
                    patch_size=patch_size,
                ))
                
                if base_mean is None:
                    base_mean = image.mean()
                else:
                    mean = image.mean()
                    if mean < base_mean / 1.5:
                        continue
                    
                yield image, seed
                instance_count += 1
                
ds = IFSClassIterableDataset(
    num_classes=32, num_instances_per_class=32, seed=int(1e6),
    shape=(128, 128), num_iterations=10_000, alpha=.15,
    #shape=(32, 32), num_iterations=1_000, alpha=1,
    parameter_variation=0.05,
    parameter_variation_max=0.09,
    alpha_variation=0.12,
    patch_size_variations=[1, 1, 1, 3, 3, 5],
    num_iterations_variation=10_000,
    image_filter=ImageFilter(
        min_mean=0.2,
        max_mean=0.27,
        #min_blurred_compression_ratio=.6,
    ),
    filter_shape=(32, 32),
    filter_num_iterations=1000,
    filter_alpha=1.
)
plot_samples(
    ds, total=len(ds), nrow=32, 
    label=lambda i: str(i[1]), # show seed
    #label=lambda i: round(float(i[0].mean()), 2),
)

In [None]:
128*128*16*1000 / 1024 / 1024

In [None]:
VF.to_pil_image(torch.Tensor(IFS(seed=955).render_image((512, 512), 400_000, alpha=0.1)))

In [None]:
ds = IFSClassIterableDataset(
    num_classes=1, num_instances_per_class=8, seed=955,
    shape=(128, 128), num_iterations=10_000, alpha=.2,
    #shape=(32, 32), num_iterations=1_000, alpha=1,
    image_filter=ImageFilter(
        min_mean=0.2,
        max_mean=0.27,
        #min_blurred_compression_ratio=.6,
    ),
    filter_shape=(32, 32),
    filter_num_iterations=1000,
    filter_alpha=1.
)
plot_samples(
    ds, total=len(ds), nrow=8, 
    label=lambda i: str(i[1]), # show seed
    #label=lambda i: round(float(i[0].mean()), 2),
)

# store dataset

In [None]:
dataset = IFSClassIterableDataset(
    num_classes=100, num_instances_per_class=8, seed=3746385,
    shape=(128, 128), num_iterations=30_000, alpha=.15,
    #shape=(32, 32), num_iterations=1_000, alpha=1,
    parameter_variation=0.05,
    alpha_variation=0.12,
    patch_size_variations=[1, 1, 1, 3, 3, 5],
    num_iterations_variation=50_000,
    image_filter=ImageFilter(
        min_mean=0.2,
        max_mean=0.27,
        #min_blurred_compression_ratio=.6,
    ),
    filter_shape=(32, 32),
    filter_num_iterations=1000,
    filter_alpha=1.
)
plot_samples(dataset, total=32)

In [None]:
dataset_name = "../datasets/ifs-1x128x128-uint8-100x8"

def store_dataset(
        images: Iterable,
        output_filename,
        max_megabyte=4096,
):
    tensor_batch = []
    label_batch = []
    tensor_size = 0
    last_print_size = 0
    try:
        for image, label in tqdm(images):

            image = (image.clamp(0, 1) * 255).to(torch.uint8)

            if len(image.shape) < 4:
                image = image.unsqueeze(0)
            tensor_batch.append(image)
            label_batch.append(torch.Tensor([label]).to(torch.int64))
            
            tensor_size += math.prod(image.shape)

            if tensor_size - last_print_size > 1024 * 1024 * 100:
                last_print_size = tensor_size

                print(f"size: {tensor_size:,}")

            if tensor_size >= max_megabyte * 1024 * 1024:
                break
    
    except KeyboardInterrupt:
        pass
    
    tensor_batch = torch.cat(tensor_batch)
    torch.save(tensor_batch, f"{output_filename}.pt")
    label_batch = torch.cat(label_batch)
    torch.save(label_batch, f"{output_filename}-labels.pt")

store_dataset(dataset, dataset_name)

In [None]:
dataset_name = "../datasets/ifs-1x128x128-uint8-200x32-seed3482374923"
ds = TensorDataset(
    torch.load(f"{dataset_name}.pt"),
    torch.load(f"{dataset_name}-labels.pt"),
)
print("label:", ds[0][1])
VF.to_pil_image(ds[0][0])

In [None]:
plot_samples(DataLoader(ds, shuffle=False), total=32*200, nrow=32, label=lambda e: int(e[1]))

In [None]:
# STORE BIG ONE

grid = plot_samples(DataLoader(ds, shuffle=True), total=64*64, nrow=64, label=lambda e: int(e[1]), return_image=True)
#grid.save("/home/bergi/Pictures/ifs-database-shuffled.png")

In [None]:
labels = torch.load(f"../datasets/ifs-1x128x128-uint8-200x32-seed3482374923-labels.pt")

In [None]:
set(labels.tolist())