In [None]:
import sys
sys.path.append("..")

import random

import time
import math
import random
from io import BytesIO
from pathlib import Path
from typing import Optional, Callable, List, Tuple, Iterable, Generator

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid

import PIL.Image
import PIL.ImageDraw
import plotly
import plotly.express as px
plotly.io.templates.default = "plotly_dark"
import numpy as np
import pandas as pd

from src.datasets import *
from src.util import *
from src.util.image import * 
from src.algo import Space2d, IFS
from src.datasets.generative import *
from src.models.cnn import *
from src.util.embedding import *

In [None]:
def plot_samples(
        iterable, 
        total: int = 32, 
        nrow: int = 8, 
        return_image: bool = False, 
        show_compression_ratio: bool = False,
        label: Optional[Callable] = None,
):
    samples = []
    labels = []
    f = ImageFilter()
    try:
        for idx, entry in enumerate(tqdm(iterable, total=total)):
            image = entry
            if isinstance(entry, (list, tuple)):
                image = entry[0]
            if image.ndim == 4:
                image = image.squeeze(0)
            samples.append(image)
            if show_compression_ratio:
                labels.append(round(f.calc_compression_ratio(image), 3))
            elif label is not None:
                labels.append(label(entry) if callable(label) else idx)
                
            if len(samples) >= total:
                break
    except KeyboardInterrupt:
        pass
    
    if labels:
        image = VF.to_pil_image(make_grid_labeled(samples, nrow=nrow, labels=labels))
    else:
        image = VF.to_pil_image(make_grid(samples, nrow=nrow))
    if return_image:
        return image
    display(image)

# load model

In [None]:
from scripts.train_autoencoder import DalleAutoencoder

In [None]:
SHAPE = (1, 64, 64)
CODE_SIZE = 128
model = DalleAutoencoder(SHAPE, vocab_size=CODE_SIZE, n_hid=64, group_count=1, n_blk_per_group=1, act_fn=nn.GELU)
model.load_state_dict(torch.load("../checkpoints/ae-d3/best.pt")["state_dict"])

## plot random samples

In [None]:
VF.to_pil_image(make_grid(model.decoder(
    torch.randn(8*8, CODE_SIZE) * 1.5 
).clamp(0, 1)))

## transition

In [None]:
features = torch.zeros(8 * 8, CODE_SIZE)
for i in range(8):
    f1, f2 = torch.randn(2, CODE_SIZE) * 1.5

    for j in range(8):
        t = j / 7.
        features[i * 8 + j] = f1 * (1. - t) + t * f2
        
VF.to_pil_image(make_grid(model.decoder(features).clamp(0, 1)))

# load some image patches

In [None]:
samples = torch.load("../datasets/kali-uint8-64x64.pt")[:5000]
#samples = torch.load("../datasets/fonts-regular-32x32.pt")[:1000]; samples = VF.resize(samples, (64, 64), antialias=True)
#samples = torch.load("../datasets/diverse-64x64-aug4.pt")[:1000]

samples = (samples.to(torch.float32) / 255.).mean(1, keepdim=True)
samples.shape

In [None]:
plot_samples(samples, label=True, total=64)

# get embeddings

In [None]:
with torch.no_grad():
    all_features = batch_call(model.encoder, samples, verbose=True)
    all_features_norm = normalize_embedding(all_features)
    
features_mean = all_features.mean()
features_std = all_features.std()
features_mean0 = all_features.mean(0)
features_std0 = all_features.std(0)
print(f"embeddings mean {features_mean} std {features_std}")

In [None]:
display(px.line(all_features[:10].detach().T, title="sample embeddings"))
display(px.line(pd.DataFrame({
    "mean": features_mean0,
    "std": features_std0,
})))

# random samples again

In [None]:
VF.to_pil_image(make_grid(model.decoder(
    torch.randn(8*8, CODE_SIZE) * features_std0 + features_mean0
).clamp(0, 1)))

# morph 

In [None]:
@torch.no_grad()
def morph_images(idx1, idx2, noise: float = 0.):
    images = samples[[idx1, idx2]]
    f1, f2 = model.encoder(images)
    
    features = torch.zeros(8, 128)
    for j in range(features.shape[0]):
        t = j / (features.shape[0] - 1)
        f = f1 * (1.-t) + t * f2
        f = f + torch.randn_like(f) * noise
        features[j] = f
        
    display(VF.to_pil_image(make_grid(model.decoder(features).clamp(0, 1))))
    
morph_images(0, 59, noise=0.5)

# randomize

In [None]:
@torch.no_grad()
def randomize_image(idx, count=8*8):
    image = samples[idx]
    features = model.encoder(image.unsqueeze(0)).repeat(count, 1)
    
    for i in range(1, count):
        features[i] = features[i - 1]
        for j in range(3):
            features[i][random.randrange(features.shape[-1])] = random.gauss(features_mean, features_std)
        
    display(VF.to_pil_image(make_grid(model.decoder(features).clamp(0, 1))))
    
randomize_image(19)

# similars

In [None]:
with torch.no_grad():
    sim = all_features_norm @ all_features_norm.T

In [None]:
px.imshow(sim[:200, :200], height=1300)

In [None]:
best_ids = sim[:64].argsort(descending=True)[..., :32].flatten(0)
VF.to_pil_image(make_grid([
    samples[i] for i in best_ids
], nrow=32))

# PCA of features

In [None]:
import ipywidgets
from sklearn.decomposition import PCA

pca = PCA(8*8)
pca.fit(all_features)
pca_components = torch.Tensor(pca.components_)
pca_variance = torch.Tensor(pca.explained_variance_)
px.line(pca_components[:10].T)

In [None]:
VF.to_pil_image(make_grid(model.decoder(
    pca_components * 15 #pca_variance.unsqueeze(1)
)))

In [None]:
@torch.no_grad()
def morph_images_pca(images, band: int = 0, count: int = 10):
    features = model.encoder(images)
    
    image_grid = []
    for i in range(count):
        t = i / max(1, count - 1) * 2. - 1.
        fmod = features + pca_components[band] * t * 20#* pca_variance.unsqueeze(1)
        images = model.decoder(fmod).clamp(0, 1)
        image_grid.append(images)
    
    image_grid = torch.concat(image_grid)
    
    display(VF.to_pil_image(make_grid(image_grid, nrow=features.shape[0])))
    
morph_images_pca(samples[:20], band=52)

In [None]:
1_000_000 // (36*36*36)

In [None]:
import ipywidgets

values = [0] * 10
output = ipywidgets.Output()
display(output)

def doit(index, x):
    values[index] = x
    output.clear_output(False)
    with output:
        display(values)#output.append_display_data(values)
    #display(values, output)

    
widgets = [
    ipywidgets.interactive(
        doit, 
        x=ipywidgets.FloatSlider(value=i, min=-10, max=10, step=.1, continuous_update=False), 
        index=ipywidgets.fixed(i),
    ) 
    for i in range(len(values))
]

for w in widgets:
    display(w)
    


In [None]:
b = ipywidgets.Button(description="hello")
b

In [None]:
from functools import partial

def create_param_widgets(params: dict, callback):
    
    param_widgets = []
    
    def _on_change(param_name, event):
        if event["type"] == "change":
            callback({param_name: event["new"]["value"]})
    
    for param_name, param in params.items(): 
        if issubclass(param["type"], int):
            input_widget = ipywidgets.IntSlider()
        elif issubclass(param["type"], float):
            input_widget = ipywidgets.FloatSlider()
        elif issubclass(param["type"], str):
            input_widget = ipywidgets.Text()
        
        input_widget.observe(partial(_on_change, param_name))
        
        param_widgets.append(
            ipywidgets.HBox([ipywidgets.HTML(param_name), input_widget])
        )
    
    return ipywidgets.VBox(param_widgets)


def create_widgets():
    params = {
        "text": {"type": str, "value": "strange beings"}, 
        "number1": {"type": int, "value": 10}, 
        "number2": {"type": float, "value": 23.5},
    }
    output_widget = ipywidgets.Output()
    
    def _on_change(pars):
        for key, value in pars.items():
            params[key]["value"] = value
            
        output_widget.clear_output()
        with output_widget:
            display(params)
    
    param_widgets = create_param_widgets(
        params,
        _on_change,
    )
    return ipywidgets.HBox([
        param_widgets,
        output_widget,
    ])
    

create_widgets()

In [None]:
ipywidgets.Text?
#b.on_click

In [None]:
ipywidgets.Button?