In [None]:
import sys
sys.path.append("..")

import random
import math
from io import BytesIO
from pathlib import Path
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Union

import PIL.Image
import PIL.ImageDraw
import plotly
import plotly.express as px
plotly.io.templates.default = "plotly_dark"

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset, RandomSampler
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid
from IPython.display import display

from src.datasets import *
from src.util.image import *
from src.util import *
from src.algo import *
from src.models.cnn import *

In [None]:
SHAPE = (1, 32, 32)
CODE_SIZE = 64
dataset = TensorDataset(torch.load(f"../datasets/fonts-regular-{SHAPE[-2]}x{SHAPE[-1]}.pt"))
#dataset = TransformDataset(dataset, dtype=torch.float, multiply=255.)
assert SHAPE == dataset[0][0].shape

In [None]:
images = torch.cat([dataset[i][0].unsqueeze(0) for i in RandomSampler(dataset, num_samples=16)])
VF.to_pil_image(make_grid(images, nrow=16))

# dataset -> features

In [None]:
def encode_dataset(dataset):
    feature_list = []
    for image_batch, in tqdm(DataLoader(dataset, batch_size=50, shuffle=True)):
        feature_list.append(image_batch.view(-1, math.prod(SHAPE)))
        #if len(feature_list) >= 300:
        #    break
    return torch.cat(feature_list)


with torch.no_grad():
    features = encode_dataset(dataset)
features.shape

# PCA of images

In [None]:
from sklearn.decomposition import PCA
import numpy as np

pca = PCA(CODE_SIZE)
pca.fit(features)
#pca.components_
pca_weight = torch.Tensor(pca.components_)
pca_features = features @ pca_weight.T
VF.to_pil_image(pca_weight)

In [None]:
px.line(pca_features.std(0))

# generate from PCA features

In [None]:
def plot_pca_samples(features, nrow=16, comps=5):
    repros = features @ pca_weight
    repros = repros.clip(0, 1).view(-1, *SHAPE)
    display(VF.to_pil_image(make_grid(repros, nrow=nrow)))
    
plot_pca_samples(pca_features[:32])

# random features

In [None]:
def plot_random_pca(num=16*4, nrow=16):
    features = torch.randn(num, CODE_SIZE) * pca_features.std().unsqueeze(0) + pca_features.mean(0).unsqueeze(0)
    repros = features @ pca_weight
    repros = repros.clip(0, 1).view(-1, *SHAPE)
    display(VF.to_pil_image(make_grid(repros, nrow=nrow)))
    
plot_random_pca()

# manipulate PCA features

In [None]:
pca_features.min(), pca_features.mean(), pca_features.max()

In [None]:
def manipulate_pca_sample(pca_sample, nrow=16, comps=5):
    modified_samples = pca_sample.view(1, -1).repeat(nrow * comps, 1)
    for i, sample in enumerate(modified_samples):
        t = ((i % nrow) / 15.) * 2. - 1.
        idx = i // nrow
        sample[idx] = t * 10.
    repros = modified_samples @ pca_weight
    repros = repros.clip(0, 1).view(-1, *SHAPE)
    display(VF.to_pil_image(make_grid(repros, nrow=nrow)))
    
manipulate_pca_sample(pca_features[2])

# blend

In [None]:
def blend_feature_plot(sample1, samples2, nrow=16):
    samples = []
    for sample2 in samples2:
        for i in range(nrow):
            t = i / (nrow - 1)
            samples.append( (sample1 * (1. - t) + t * sample2).unsqueeze(0) )

    repros = torch.cat(samples) @ pca_weight
    repros = repros.clip(0, 1).view(-1, *SHAPE)
    display(VF.to_pil_image(make_grid(repros, nrow=nrow)))
    
blend_feature_plot(pca_features[13], pca_features[20:30])

In [None]:
def plot_samples(
        iterable, 
        total: int = 16, 
        nrow: int = 16, 
        return_image: bool = False, 
):
    samples = []
    f = ImageFilter()
    try:
        for image in tqdm(iterable, total=total):
            samples.append(image)
                
            if len(samples) >= total:
                break
    except KeyboardInterrupt:
        pass
    
    image = VF.to_pil_image(make_grid(samples, nrow=nrow))
    if return_image:
        return image
    display(image)
    
plot_samples(
    ()
)