In [None]:
import sys
sys.path.append("..")

import random
import math
from io import BytesIO
from pathlib import Path
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Union

import PIL.Image
import PIL.ImageDraw

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid
from IPython.display import display

from src.datasets import *
from src.util.image import *
from src.algo import *

## p = abs(p) / dot(p, p) - v

In [None]:
img = kali_2d(
    Space2d(
        shape=(3, 128, 128), 
        offset=torch.Tensor([0.1,0,0]), 
        scale=.01,
    ), 
    param=torch.Tensor([.75, .75, .75]),
    iterations=21,
    out_weights=torch.rand((3, 3)),
    accumulate="min",
    aa=10,
)
#img = VF.resize(img, [512, 512], interpolation=VF.InterpolationMode.BICUBIC)
img = VF.resize(img, [512, 512], interpolation=VF.InterpolationMode.NEAREST)
VF.to_pil_image(img)

In [None]:
def plot_samples(
        iterable, 
        total: int = 32, 
        nrow: int = 8, 
        return_image: bool = False, 
        show_compression_ratio: bool = False,
        label: Union[None, bool, Callable] = None,
):
    samples = []
    labels = []
    f = ImageFilter()
    try:
        for i, image in tqdm(enumerate(iterable), total=total):
            samples.append(image)
            if show_compression_ratio:
                labels.append(round(f.calc_compression_ratio(image), 3))
            elif label is True:
                labels.append(i)
            elif callable(label):
                labels.append(label(image))
                
            if len(samples) >= total:
                break
    except KeyboardInterrupt:
        pass
    
    if labels:
        image = VF.to_pil_image(make_grid_labeled(samples, nrow=nrow, labels=labels))
    else:
        image = VF.to_pil_image(make_grid(samples, nrow=nrow))
    if return_image:
        return image
    display(image)

# randomly generate and select

In [None]:
dataset = Kali2dFilteredIterableDataset(
    shape=(3, 128, 128), aa=2, size=8*8, 
    accumulation_modes=["min", "max"],
    min_iterations=17,
    min_scale=0.01, max_scale=2.,
    min_offset=-2., max_offset=2.,
    filter_shape=(64, 64),
    seed=349833493,
    filter=ImageFilter(
        #min_mean=.2,
        max_mean=.3,
        #min_std=.4,
        #max_std=.3,
        #min_compression_ratio=.5,
        #max_compression_ratio=.9,
        #min_scaled_compression_ratio=.7,
        #scaled_compression_shape=(16, 16),
        min_blurred_compression_ratio=0.32,
        #blurred_compression_sigma=10.,
        #blurred_compression_kernel_size=[21, 21],
        #compression_format="png",
    ),
    with_parameters=True,
)
images_and_params = list(tqdm(dataset))
plot_samples([i[0] for i in images_and_params], label=True)

In [None]:
def plot_xth(x, shape=None, aa=None, pil=True, **kwargs):
    params = {
        **images_and_params[x][1],
    }
    print(params)
    if shape:
        params["shape"] = shape
    if aa:
        params["aa"] = aa
    params.update(kwargs)
    image = Kali2dDataset.render(params)
    if pil:
        return VF.to_pil_image(image)
    
img = plot_xth(
    9, (3, 1024, 1024), 4, 
    #iterations=40,
)
img#.save("../db/images/kali/kali01.png")

# dataset #2

In [None]:
dataset = Kali2dFilteredIterableDataset(
    shape=(3, 128, 128), aa=2, size=8*16, 
    min_iterations=8,
    max_iterations=37,
    #min_scale=.05, max_scale=.1,
    min_offset=0, max_offset=1.,
    filter_shape=(16, 16),
    seed=998,
    filter=ImageFilter(
        #min_mean=.15,
        max_mean=.4,
        #min_std=.2,
        #max_std=.3,
        min_compression_ratio=.7,
        #max_compression_ratio=.9,
        min_blurred_compression_ratio=.4,
        #max_blurred_compression_ratio=.32,
        #blurred_compression_sigma=10.,
        blurred_compression_kernel_size=[15, 15],
    ),
    with_parameters=True,
)
images_and_params = list(tqdm(dataset))
plot_samples([i[0] for i in images_and_params], label=True, total=len(images_and_params))

In [None]:
img = plot_xth(
    34, (3, 1024, 1024), 4, 
    #iterations=25,
    #scale=1.,
)
img

In [None]:
img.save("../db/images/kali/kali17.png")

In [None]:
!ls -l ../db/images/kali/

# dataset #3

In [None]:
ds_iter_3 = Kali2dFilteredIterableDataset(
    SHAPE, aa=4, size=1_000_000_000, 
    #accumulation_modes=["min", "max"],
    #min_iterations=17,
    #min_scale=0.01, max_scale=2.,
    #min_offset=-2., max_offset=2.,
    filter_shape=(12, 12),
    seed=777+SEED,
    filter=ImageFilter(
        min_mean=.05,
        max_mean=.4,
        #min_std=.4,
        #max_std=.3,
        min_compression_ratio=.9,
        #max_compression_ratio=.9,
        #min_scaled_compression_ratio=.7,
        #scaled_compression_shape=(16, 16),
        #min_blurred_compression_ratio=.5,
        #min_blurred_compression_ratio=.32,
        #blurred_compression_sigma=10.,
        #blurred_compression_kernel_size=[21, 21],
    )
)
plot_samples(ds_iter_3, total=16*4)

# dataset #4

In [None]:
ds_iter_4 = Kali2dFilteredIterableDataset(
    SHAPE, aa=10, size=1_000_000_000, 
    accumulation_modes=["max"],
    min_iterations=21,
    min_scale=.5, max_scale=1,
    min_offset=-0, max_offset=0,
    filter_shape=(16, 16),
    seed=45878+SEED,
    filter=ImageFilter(
        min_mean=.05,
        max_mean=.4,
        #min_std=.4,
        #max_std=.3,
        #min_compression_ratio=.9,
        #max_compression_ratio=.9,
        #min_scaled_compression_ratio=.7,
        #scaled_compression_shape=(16, 16),
        min_blurred_compression_ratio=.45,
        #max_blurred_compression_ratio=.32,
        #blurred_compression_sigma=10.,
        #blurred_compression_kernel_size=[21, 21],
    )
)
plot_samples(ds_iter_4, total=16*4)

# dataset #5 !!

In [None]:
ds_iter_5 = Kali2dFilteredIterableDataset(
    SHAPE, aa=8, size=1_000_000_000, 
    accumulation_modes=["min"],
    min_iterations=21,
    min_scale=.1, max_scale=.1,
    min_offset=-.2, max_offset=.2,
    filter_shape=(16, 16),
    seed=339595+SEED,
    filter=ImageFilter(
        min_mean=.05,
        max_mean=.4,
        #min_std=.4,
        #max_std=.3,
        #min_compression_ratio=.9,
        #max_compression_ratio=.9,
        #min_scaled_compression_ratio=.7,
        #scaled_compression_shape=(16, 16),
        min_blurred_compression_ratio=.5,
        #max_blurred_compression_ratio=.32,
        #blurred_compression_sigma=10.,
        #blurred_compression_kernel_size=[21, 21],
    )
)
plot_samples(ds_iter_5, total=16*4)

# datatset #6

In [None]:
ds_iter_6 = Kali2dFilteredIterableDataset(
    SHAPE, aa=4, size=1_000_000_000, 
    accumulation_modes=["min"],
    min_iterations=8,
    max_iterations=19,
    min_scale=.1, max_scale=.1,
    min_offset=--.2, max_offset=.2,
    filter_shape=(16, 16),
    seed=77+SEED,
    filter=ImageFilter(
        min_mean=.15,
        max_mean=.5,
        min_std=.2,
        max_std=.3,
        #min_compression_ratio=.9,
        #max_compression_ratio=.9,
        #min_scaled_compression_ratio=.7,
        #scaled_compression_shape=(16, 16),
        min_blurred_compression_ratio=.5,
        #max_blurred_compression_ratio=.32,
        #blurred_compression_sigma=10.,
        #blurred_compression_kernel_size=[21, 21],
    )
)
plot_samples(ds_iter_6, total=16*4)

# dataset 7

In [None]:
ds_iter_7 = Kali2dFilteredIterableDataset(
    SHAPE, aa=4, size=1_000_000_000, 
    accumulation_modes=["none"],
    min_iterations=8,
    max_iterations=19,
    min_scale=.1, max_scale=.1,
    min_offset=--.2, max_offset=.2,
    filter_shape=(32, 32),
    seed=77733+SEED,
    filter=ImageFilter(
        #min_mean=.15,
        max_mean=.5,
        #min_std=.2,
        max_std=.3,
        #min_compression_ratio=.9,
        max_compression_ratio=.95,
        #min_scaled_compression_ratio=.7,
        #scaled_compression_shape=(16, 16),
        min_blurred_compression_ratio=.3,
        #max_blurred_compression_ratio=.32,
        #blurred_compression_sigma=10.,
        #blurred_compression_kernel_size=[21, 21],
    )
)
plot_samples(ds_iter_7, total=16*4)

# dataset 8 !!

In [None]:
ds_iter_8 = Kali2dFilteredIterableDataset(
    SHAPE, aa=4, size=1_000_000_000, 
    #accumulation_modes=["mean"],
    min_iterations=8,
    max_iterations=19,
    min_scale=.05, max_scale=.1,
    min_offset=1, max_offset=1.,
    filter_shape=(16, 16),
    seed=7445977+SEED,
    filter=ImageFilter(
        #min_mean=.15,
        max_mean=.4,
        #min_std=.2,
        #max_std=.3,
        min_compression_ratio=.7,
        #max_compression_ratio=.9,
        min_blurred_compression_ratio=.4,
        #max_blurred_compression_ratio=.32,
        #blurred_compression_sigma=10.,
        blurred_compression_kernel_size=[15, 15],
    )
)
plot_samples(ds_iter_8, total=16*4)

# datatset 9

In [None]:
ds_iter_9 = Kali2dFilteredIterableDataset(
    SHAPE, aa=4, size=1_000_000_000, 
    accumulation_modes=["min", "max"],
    min_iterations=8,
    max_iterations=19,
    min_scale=.05, max_scale=.1,
    min_offset=1, max_offset=1.,
    filter_shape=(16, 16),
    seed=777+SEED,
    filter=ImageFilter(
        #min_mean=.15,
        max_mean=.4,
        #min_std=.2,
        #max_std=.3,
        #min_compression_ratio=.0,
        #max_compression_ratio=.9,
        min_blurred_compression_ratio=.45,
        #max_blurred_compression_ratio=.32,
        #blurred_compression_sigma=10.,
        blurred_compression_kernel_size=[15, 15],
    )
)
plot_samples(ds_iter_9, total=16*4)

# dataset 10

In [None]:
ds_iter_10 = Kali2dFilteredIterableDataset(
    SHAPE, aa=4, size=1_000_000_000, 
    accumulation_modes=["min", "max"],
    min_iterations=21,
    max_iterations=51,
    min_scale=.001, max_scale=.01,
    min_offset=.5, max_offset=2.,
    filter_shape=(16, 16),
    seed=7696333+SEED,
    filter=ImageFilter(
        #min_mean=.15,
        max_mean=.4,
        #min_std=.2,
        #max_std=.3,
        min_compression_ratio=.9,
        #max_compression_ratio=.9,
        min_blurred_compression_ratio=.4,
        #max_blurred_compression_ratio=.5,
        blurred_compression_sigma=10.,
        blurred_compression_kernel_size=[15, 15],
        compression_format="png",
    )
)
plot_samples(ds_iter_10)

# dataset 11

In [None]:
ds_iter_11 = Kali2dFilteredIterableDataset(
    SHAPE, aa=8, size=1_000_000_000, 
    accumulation_modes=["min"],
    min_iterations=21,
    min_scale=.5, max_scale=1,
    min_offset=0, max_offset=0,
    filter_shape=(16, 16),
    seed=93230+SEED,
    filter=ImageFilter(
        #min_mean=.05,
        max_mean=.4,
        #min_std=.4,
        #max_std=.3,
        min_compression_ratio=.8,
            #max_compression_ratio=.9,
        #min_scaled_compression_ratio=.7,
        #scaled_compression_shape=(16, 16),
        min_blurred_compression_ratio=.25,
        #max_blurred_compression_ratio=.32,
        blurred_compression_sigma=10.,
        blurred_compression_kernel_size=[15, 15],
    )
)
plot_samples(ds_iter_11)

# dataset 12

In [None]:
ds_iter_12 = Kali2dFilteredIterableDataset(
    SHAPE, aa=4, size=1_000_000_000, 
    #accumulation_modes=["mean"],
    min_iterations=31,
    min_scale=.5, max_scale=.6,
    min_offset=0, max_offset=0,
    filter_shape=(16, 16),
    seed=93237880+SEED,
    filter=ImageFilter(
        #min_mean=.05,
        max_mean=.4,
        #min_std=.4,
        max_std=.3,
        min_compression_ratio=.8,
            #max_compression_ratio=.95,
        #min_scaled_compression_ratio=.7,
        #scaled_compression_shape=(16, 16),
        min_blurred_compression_ratio=.25,
        #max_blurred_compression_ratio=.32,
        blurred_compression_sigma=10.,
        blurred_compression_kernel_size=[15, 15],
    )
)
plot_samples(ds_iter_12)

In [None]:
from src.datasets.interleave import InterleaveIterableDataset
interleaved_dataset = InterleaveIterableDataset(
    datasets=[
        ds_iter_1, ds_iter_2, ds_iter_3, ds_iter_4, ds_iter_5, ds_iter_6, ds_iter_7, ds_iter_8, ds_iter_9, ds_iter_10, ds_iter_11
    ],
    counts=[1, 1, 1, 1, 2, 1, 1, 4, 1, 1, 1],
    shuffle_datasets=True,
)
plot_samples(interleaved_dataset, total=16*4, nrow=16)

## store 64x64 samples image

In [None]:
img = plot_samples(interleaved_dataset, total=64*64, nrow=64, return_image=True)
img.save("/home/bergi/Pictures/kali-special.png")
img.size

In [None]:
img = PIL.Image.open("/home/bergi/Pictures/kali-interleaved.png")
t = VF.pil_to_tensor(img)
t.shape[2]/64#reshape(3, 68*68, -1).shape 

In [None]:
def store_dataset(
        images: Iterable,
        output_filename=f"../datasets/kali-uint8-{SHAPE[-2]}x{SHAPE[-1]}.pt",
        max_megabyte=2_048,
):
    tensor_batch = []
    tensor_size = 0
    last_print_size = 0
    try:
        for image in tqdm(images):

            image = (image.clamp(0, 1) * 255).to(torch.uint8)

            if len(image.shape) < 4:
                image = image.unsqueeze(0)
            tensor_batch.append(image)
            tensor_size += math.prod(image.shape)

            if tensor_size - last_print_size > 1024 * 1024 * 100:
                last_print_size = tensor_size

                print(f"size: {tensor_size:,}")

            if tensor_size >= max_megabyte * 1024 * 1024:
                break
    
    except KeyboardInterrupt:
        pass
    
    tensor_batch = torch.cat(tensor_batch)
    torch.save(tensor_batch, output_filename)

store_dataset(interleaved_dataset)

In [None]:
ds = TensorDataset(torch.load(f"../datasets/kali-uint8-{SHAPE[-2]}x{SHAPE[-1]}.pt"))

In [None]:
dl = DataLoader(ds, shuffle=True, batch_size=8*8)
for batch in dl:
    batch = batch[0]
    img = VF.to_pil_image(make_grid(batch, nrow=8))
    break
img

In [None]:
display(VF.to_pil_image(Kali2dDataset((3, 16, 16))[111]))
display(VF.to_pil_image(Kali2dDataset((3, 32, 32))[111]))
display(VF.to_pil_image(Kali2dDataset((3, 64, 64))[111]))
display(VF.to_pil_image(Kali2dDataset((3, 256, 256), aa=10)[111]))

In [None]:
img = Kali2dDataset((3, 128, 128))[221]
VF.to_pil_image(img)

In [None]:
from sklearn.decomposition import PCA
def pca_error(img: torch.Tensor, n_components: int = 1) -> float:
    h = img.shape[-2]
    pca = PCA(n_components=n_components)
    data = img.permute(1, 2, 0).reshape(h, -1)
    pca.fit(data)
    f = pca.transform(data)
    r_data = torch.Tensor(pca.inverse_transform(f))
    r_img = r_data.reshape(h, img.shape[-1], 3).permute(2, 0, 1)
    #return ((img - r_img).abs().sum() / math.prod(img.shape))
    #d = (img-r_img).abs()
    #return VF.to_pil_image(d / d.max())
    return VF.to_pil_image(r_img)
    
pca_error(img)