In [None]:
import sys
sys.path.append("..")

import random

import time
import math
import random
from io import BytesIO
from pathlib import Path
from typing import Optional, Callable, List, Tuple, Iterable, Generator

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid

import PIL.Image
import PIL.ImageDraw
import plotly
import plotly.express as px
plotly.io.templates.default = "plotly_dark"
import numpy as np
import pandas as pd

from src.datasets import *
from src.util.image import * 
from src.util import ImageFilter
from src.algo import Space2d, IFS
from src.datasets.generative import *

In [None]:
def plot_samples(
        iterable, 
        total: int = 32, 
        nrow: int = 8, 
        return_image: bool = False, 
        show_compression_ratio: bool = False,
        label: Optional[Callable] = None,
):
    samples = []
    labels = []
    f = ImageFilter()
    try:
        for entry in tqdm(iterable, total=total):
            image = entry
            if isinstance(entry, (list, tuple)):
                image = entry[0]
            if image.ndim == 4:
                image = image.squeeze(0)
            samples.append(image)
            if show_compression_ratio:
                labels.append(round(f.calc_compression_ratio(image), 3))
            elif label is not None:
                labels.append(label(entry))
                
            if len(samples) >= total:
                break
    except KeyboardInterrupt:
        pass
    
    if labels:
        image = VF.to_pil_image(make_grid_labeled(samples, nrow=nrow, labels=labels))
    else:
        image = VF.to_pil_image(make_grid(samples, nrow=nrow))
    if return_image:
        return image
    display(image)

# base dataset

In [None]:
dataset = IFSDataset(
    shape=(128, 128), num_iterations=2000, alpha=.5,
    num_parameters=3,
    patch_size=3,
    num_classes=1000,
    start_seed=0, 
)
plot_samples(dataset, label=lambda l: l[1])

# CLIP guidance

In [None]:
import clip as cliplib
CODE_SIZE = 512
class ToRGB(nn.Module):
    def forward(self, x):
        return x.repeat(1, 3, 1, 1)
class ToDevice(nn.Module):
    def forward(self, x):
        return x.half().cuda()
class FromDevice(nn.Module):
    def forward(self, x):
        return x.cpu().to(torch.float32)
clip, preproc = cliplib.load("ViT-B/32")
encoder = nn.Sequential(
    VT.Resize((224, 224), VF.InterpolationMode.BICUBIC),
    ToRGB(),
    preproc.transforms[-1],
    ToDevice(),
    clip.visual,
    FromDevice(),
)

In [None]:
@torch.no_grad()
def encode_texts(texts: List[str]) -> torch.Tensor:
    features = clip.encode_text(cliplib.tokenize(texts).cuda()).cpu().float()
    return features / features.norm(dim=-1, keepdim=True)
@torch.no_grad()
def encode_images(images: torch.Tensor) -> torch.Tensor:
    features = encoder(images)
    return features / features.norm(dim=-1, keepdim=True)
        

In [None]:
def get_dataset_features(ds: IFSClassIterableDataset) -> torch.Tensor:
    features_list = []
    seed_list = []
    try:
        for images, seeds in tqdm(DataLoader(ds, batch_size=4)):
            features_list.append(encode_images(images))
            for s in seeds.tolist():
                seed_list.append(s)
    except KeyboardInterrupt:
        pass
    return torch.cat(features_list), seed_list

ds_features, ifs_seeds = get_dataset_features(dataset)
ds_features.shape

In [None]:
with torch.no_grad():
    target_features = clip.encode_text(cliplib.tokenize([
        #"very detailed structures",
        "clustered",
        #"leave", #"bird"
    ]).cuda()).cpu().float()
    
    target_features /= target_features.norm(dim=-1, keepdim=True)
    
    images = []
    labels = []
    for seed in tqdm(range(1000)):
        ifs = IFS(seed=seed)
        image = torch.Tensor(ifs.render_image((128, 128), 1000, alpha=.5))
        features = encoder(image.unsqueeze(0))
        features /= features.norm(dim=-1, keepdim=True)
        dots = features @ target_features.T
        if torch.any(dots > .22):
            #image = torch.Tensor(ifs.render_image((128, 128), 10000, alpha=.2))
            #print(dots)
            #display(VF.to_pil_image(image))
            images.append(image)
            labels.append(" ".join(str(round(float(d), 3)) for d in dots[0]))
        if len(images) >= 8:
            display(VF.to_pil_image(make_grid_labeled(images, labels=labels)))
            images.clear()
            labels.clear()

# dataset generation

In [None]:
128*128*16*1000 / 1024 / 1024

In [None]:
VF.to_pil_image(torch.Tensor(IFS(seed=955).render_image((512, 512), 400_000, alpha=0.1)))

In [None]:
ds = IFSClassIterableDataset(
    num_classes=1, num_instances_per_class=8, seed=955,
    shape=(128, 128), num_iterations=10_000, alpha=.2,
    #shape=(32, 32), num_iterations=1_000, alpha=1,
    image_filter=ImageFilter(
        min_mean=0.2,
        max_mean=0.27,
        #min_blurred_compression_ratio=.6,
    ),
    filter_shape=(32, 32),
    filter_num_iterations=1000,
    filter_alpha=1.
)
plot_samples(
    ds, total=len(ds), nrow=8, 
    label=lambda i: str(i[1]), # show seed
    #label=lambda i: round(float(i[0].mean()), 2),
)

# store dataset

In [None]:
dataset = IFSClassIterableDataset(
    num_classes=100, num_instances_per_class=8, seed=3746385,
    shape=(128, 128), num_iterations=30_000, alpha=.15,
    #shape=(32, 32), num_iterations=1_000, alpha=1,
    parameter_variation=0.05,
    alpha_variation=0.12,
    patch_size_variations=[1, 1, 1, 3, 3, 5],
    num_iterations_variation=50_000,
    image_filter=ImageFilter(
        min_mean=0.2,
        max_mean=0.27,
        #min_blurred_compression_ratio=.6,
    ),
    filter_shape=(32, 32),
    filter_num_iterations=1000,
    filter_alpha=1.
)
plot_samples(dataset, total=32)

In [None]:
dataset_name = "../datasets/ifs-1x128x128-uint8-200x3"

def store_dataset(
        images: Iterable,
        output_filename,
        max_megabyte=4096,
):
    tensor_batch = []
    label_batch = []
    tensor_size = 0
    last_print_size = 0
    try:
        for image, label in tqdm(images):

            image = (image.clamp(0, 1) * 255).to(torch.uint8)

            if len(image.shape) < 4:
                image = image.unsqueeze(0)
            tensor_batch.append(image)
            label_batch.append(torch.Tensor([label]).to(torch.int64))
            
            tensor_size += math.prod(image.shape)

            if tensor_size - last_print_size > 1024 * 1024 * 100:
                last_print_size = tensor_size

                print(f"size: {tensor_size:,}")

            if tensor_size >= max_megabyte * 1024 * 1024:
                break
    
    except KeyboardInterrupt:
        pass
    
    tensor_batch = torch.cat(tensor_batch)
    torch.save(tensor_batch, f"{output_filename}.pt")
    label_batch = torch.cat(label_batch)
    torch.save(label_batch, f"{output_filename}-labels.pt")

store_dataset(dataset, dataset_name)

In [None]:
dataset_name = "../datasets/ifs-1x128x128-uint8-200x32-seed3482374923"
ds = TensorDataset(
    torch.load(f"{dataset_name}.pt"),
    torch.load(f"{dataset_name}-labels.pt"),
)
print("label:", ds[0][1])
VF.to_pil_image(ds[0][0])

In [None]:
plot_samples(DataLoader(ds, shuffle=False), total=32*200, nrow=32, label=lambda e: int(e[1]))

In [None]:
# STORE BIG ONE

grid = plot_samples(DataLoader(ds, shuffle=True), total=64*64, nrow=64, label=lambda e: int(e[1]), return_image=True)
#grid.save("/home/bergi/Pictures/ifs-database-shuffled.png")

In [None]:
images = torch.load(f"../datasets/ifs-1x128x128-uint8-1000x32.pt")
images = torch.load(f"../datasets/kali-uint8-128x128.pt")[:10_000]
#labels = torch.load(f"../datasets/ifs-1x128x128-uint8-1000x32-labels.pt")
images.shape

In [None]:
#plot_samples(images, nrow=32, total=32*32)
plot_samples(DataLoader(TensorDataset(images), shuffle=True), nrow=16, total=16*16)

# show PCA weights

In [None]:
from sklearn.decomposition import PCA, KernelPCA

pca = PCA(256)
pca.fit(images.numpy().reshape(images.shape[0], -1))

In [None]:
weights = torch.Tensor(pca.components_)
mi, ma = weights.min(), weights.max()
weights = (weights - mi) / (ma - mi)
VF.to_pil_image(make_grid([
    w.reshape(images[0].shape) for w in weights
]))

In [None]:
weights = torch.Tensor(pca.eigenvectors_).permute(1, 0).reshape(pca.n_components, 100, 100)
VF.to_pil_image(make_grid(
    [w.unsqueeze(0) for w in weights]
))

In [None]:
a = torch.randn(3, 4)
b = torch.rand(3, 4)

In [None]:
F.kl_div(torch.ones(3, 4)+10, torch.ones(3, 4))

In [None]:
torch.log(torch.tensor(.3))