In [None]:
import sys
sys.path.append("..")

import random

import time
import math
import random
from io import BytesIO
from pathlib import Path
from typing import Optional, Callable, List, Tuple, Iterable, Generator

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
#from torchvision.transforms import v2
from torchvision.utils import make_grid

import PIL.Image
import PIL.ImageDraw
import plotly
import plotly.express as px
plotly.io.templates.default = "plotly_dark"
import numpy as np
import pandas as pd

import clip

from src.datasets import *
from src.util import *
from src.util.image import * 
from src.algo import *
from src.algo.wangtiles import *
from src.datasets.generative import *
from src.models.cnn import *
from src.models.transform import *
from src.util.embedding import *
from src.models.loss import *
from src.functional import *

In [None]:
from scripts import datasets
ds = datasets.RpgTileIterableDataset((1, 32, 32))
ds = IterableShuffle(ds, 1000)
VF.to_pil_image(next(iter(ds)))#.shape
samples = next(iter(DataLoader(ds, batch_size=32)))

VF.to_pil_image(make_grid(samples))

In [None]:
class Debug(nn.Module):
    def __init__(self, name: str = "Debug"):
        super().__init__()
        self.name = name
        
    def forward(self, x):
        if isinstance(x, torch.Tensor):
            print(f"{self.name}: {x.shape}")
        else:
            print(f"{self.name}: {type(x).__name__}")
        return x

In [None]:
ds = ImageFolderIterableDataset("/home/bergi/Pictures/__diverse/")
ds = TransformIterableDataset(
    ds, transforms=[
         
         #VF.resized_crop
         #Debug(">>>"),
         lambda x: set_image_channels(x, 3),
         #lambda x: VF.resize(x, (max(256, x.shape[-2] // 2), max(256, x.shape[-1] // 2)), antialias=True),
         lambda x: image_resize_crop(x, (256, 256)),
         VT.CenterCrop((256, 256)),
         #Debug("  >"),
         
    ]
)
images = [img for img, _ in zip(ds, range(64))]

VF.to_pil_image(make_grid_labeled(images))

In [None]:
class PixelModel(nn.Module):
    def __init__(self, image: torch.Tensor):
        super().__init__()
        self.shape = image.shape
        self.code = nn.Parameter(image.clone())
    
    def forward(self):
        return self.code

    def reset(self):
        with torch.no_grad():
            self.code[:] = torch.randn_like(self.code) * .1 + .3

PixelModel(images[0])().shape

In [None]:
def train_histogram(
    image: torch.Tensor,
    target_histogram: torch.Tensor,
    iters: int = 10,
    batch_size: int = 1,
    lr: float = 1.,
    bins: int = 100,
    sigma: float = 100.,
    device="auto",
    hrange=(0, 1),
):    
    device = to_torch_device(device)
    print(device)

    target_batch = target_histogram.to(device).unsqueeze(0).expand(batch_size, -1, -1)
    pixel_model = PixelModel(image).to(device)
    
    #optimizer = torch.optim.Adadelta(pixel_model.parameters(), lr=lr * 100.)
    optimizer = torch.optim.Adam(pixel_model.parameters(), lr=lr)

    transforms = VT.Compose([
        VT.RandomErasing(),
        #VT.Pad(30, padding_mode="reflect"),
        #Debug("model"),
        #RandomWangMap((8, 8), probability=1, overlap=0, num_colors=2),
        #Debug("wangmap"),
        #lambda x: x + .1 * torch.randn_like(x),
        #VT.RandomAffine(
        #   degrees=35.,
            #scale=(1., 3.),
            #scale=(.3, 1.),
            #translate=(0, 4. / 64.),
        #),
        #VT.RandomCrop((224, 224)),
    ])
        
    # train

    pixel_history = []
    try:
        for it in tqdm(range(iters)):
            # display(VF.to_pil_image(pixel_model()))
            
            pixels = pixel_model()
            pixel_batch = []
            
            for i in range(batch_size):
                pixel_batch.append(transforms(pixels).unsqueeze(0))
            pixel_batch = torch.concat(pixel_batch)

            B, C = pixel_batch.shape[:2]

            hist_batch = soft_histogram(
                rgb_to_hsv(pixel_batch).view(B * C, -1),
                bins, *hrange, sigma,
            ).view(B, C, bins)

            #loss = F.kl_div(hist_batch, target_batch, reduction="batchmean")
            loss = F.mse_loss(hist_batch, target_batch)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if it % max(2, iters / 100) == 0:
                pixel_history.append(pixels.cpu().detach())

            if it == 0:
                display(px.line(pd.DataFrame({
                    "target": target_histogram.cpu().view(-1),
                    "image": hist_batch.detach().cpu().mean(dim=0).view(-1),
                })))                
            if len(pixel_history) >= 10:
                print(float(loss))
                #display(VF.to_pil_image(pixels.detach().to("cpu")))
                display(VF.to_pil_image(make_grid(pixel_history[:10], nrow=len(pixel_history)).clamp(0, 1)))
                #display(VF.to_pil_image(RandomWangMap((8, 40), overlap=5)(pixel_history[-1])))
                pixel_history.clear()
                display(px.line(pd.DataFrame({
                    "target": target_histogram.cpu().view(-1),
                    "image": hist_batch.detach().cpu().mean(dim=0).view(-1),
                })))
    except KeyboardInterrupt:
        pass

    if pixel_history:
        display(VF.to_pil_image(make_grid(pixel_history[:10], nrow=len(pixel_history)).clamp(0, 1)))
                
    #display(VF.to_pil_image(pixel_model()))

BINS=20
train_histogram(
    images[31],
    soft_histogram(rgb_to_hsv(images[7]), BINS, -1, 2, 10),
    #torch.ones(3, 100) * math.prod(images[0].shape[-2:]) / 100, 
    #batch_size=1,
    bins=BINS,
    lr=.0001,
    iters=10000,
    sigma=10,
    hrange=(-1, 2),
)

In [None]:
def train_histogram_2(
    source_image: torch.Tensor,
    target_image: torch.Tensor,
    iters: int = 10,
    batch_size: int = 1,
    lr: float = 1.,
    bins: int = 100,
    sigma: float = 100.,
    device="auto",
    hrange=(0, 1),
):    
    device = to_torch_device(device)
    print(device)

    transforms = VT.Compose([
        VT.RandomErasing(),
        #VT.Pad(30, padding_mode="reflect"),
        #Debug("model"),
        #RandomWangMap((8, 8), probability=1, overlap=0, num_colors=2),
        #Debug("wangmap"),
        #lambda x: x + .1 * torch.randn_like(x),
        #VT.RandomAffine(
        #   degrees=35.,
            #scale=(1., 3.),
            #scale=(.3, 1.),
            #translate=(0, 4. / 64.),
        #),
        #VT.RandomCrop((224, 224)),
    ])
        
    # train
    pixel_history = []
    for rev_it in range(2):
        if rev_it == 1:
            source_image, target_image = target_image, source_image
        
        target_histogram = soft_histogram(target_image, bins, *hrange, sigma)
        print(target_histogram.shape)
        target_batch = target_histogram.to(device).unsqueeze(0).expand(batch_size, -1, -1)
        pixel_model = PixelModel(source_image).to(device)
        
        #optimizer = torch.optim.Adadelta(pixel_model.parameters(), lr=lr * 100.)
        optimizer = torch.optim.Adam(pixel_model.parameters(), lr=lr)
    
        for it in tqdm(range(iters)):
            # display(VF.to_pil_image(pixel_model()))
            
            pixels = pixel_model()
            pixel_batch = []
            
            for i in range(batch_size):
                pixel_batch.append(transforms(pixels).unsqueeze(0))
            pixel_batch = torch.concat(pixel_batch)

            B, C = pixel_batch.shape[:2]

            hist_batch = soft_histogram(
                rgb_to_hsv(pixel_batch).view(B * C, -1),
                bins, *hrange, sigma,
            ).view(B, C, bins)

            #loss = F.kl_div(hist_batch, target_batch, reduction="batchmean")
            loss = F.mse_loss(hist_batch, target_batch)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if it % max(2, iters / 10) == 0:
                pixel_history.append(pixels.cpu().detach())

            if it == 0:
                display(px.line(pd.DataFrame({
                    "target": target_histogram.cpu().view(-1),
                    "image": hist_batch.detach().cpu().mean(dim=0).view(-1),
                })))                
            
            if False and len(pixel_history) >= 10:
                print(float(loss))
                #display(VF.to_pil_image(pixels.detach().to("cpu")))
                display(VF.to_pil_image(make_grid(pixel_history[:10], nrow=len(pixel_history)).clamp(0, 1)))
                #display(VF.to_pil_image(RandomWangMap((8, 40), overlap=5)(pixel_history[-1])))
                pixel_history.clear()
                display(px.line(pd.DataFrame({
                    "target": target_histogram.cpu().view(-1),
                    "image": hist_batch.detach().cpu().mean(dim=0).view(-1),
                })))

    display(VF.to_pil_image(make_grid(pixel_history, nrow=len(pixel_history) // 2).clamp(0, 1)))
    
    #display(VF.to_pil_image(pixel_model()))

train_histogram_2(
    images[31],
    images[7],
    #batch_size=1,
    bins=20,
    lr=.001,
    iters=5000,
    sigma=10,
    hrange=(-1, 2),
)

In [None]:
def visualize_wang(
    pixel_model: nn.Module,
    text: str,
    iters: int = 100,
    batch_size: int = 1,
    lr: float = 10.,
    device="auto",
):    
    device = to_torch_device(device)
    
    #optimizer = torch.optim.Adadelta(pixel_model.parameters(), lr=lr * 100.)
    optimizer = torch.optim.Adam(pixel_model.parameters(), lr=lr * .1)

    
    transforms = VT.Compose([
        #VT.RandomErasing(),
        #VT.Pad(30, padding_mode="reflect"),
        RandomWangMap((30, 30), probability=1),
        VT.RandomAffine(
            degrees=35.,
            scale=(.3, 1.),
            #translate=(0, 4. / 64.),
        ),
        VT.RandomCrop((224, 224)),
    ])
    
    with torch.no_grad():
        target_embeddings = ClipSingleton.encode_text(text).to(device)
        target_embeddings = target_embeddings / target_embeddings.norm(dim=-1, keepdim=True)

    target_dots = torch.ones(batch_size, 1).half().to(device)

    def _pixels_for_clip(pixels):
        pixels = pixels.to(device).half()
        if pixels.shape[-3] != 3:
            pixels = set_image_channels(pixels, 3)
        #if pixels.shape[-2:] != (224, 224):
        #    pixels = VF.resize(pixels, (224, 224), VT.InterpolationMode.BILINEAR, antialias=True)
        return pixels
    
    def _clip_size(pixels):
        if pixels.shape[-2:] != (224, 224):
            pixels = VF.resize(pixels, (224, 224), VT.InterpolationMode.BILINEAR, antialias=True)
        return pixels
    
    # find best start candidate
    
    with torch.no_grad():
        wang_template = VF.to_tensor(PIL.Image.open("/home/bergi/prog/python/thegame/thegame/assets/cr31/path.png"))[:3] * 150
        wang_template = VF.resize(wang_template, (32, 32), VF.InterpolationMode.BILINEAR, antialias=True)
        code = model.encoder(wang_template.unsqueeze(0)).squeeze(0)
        pixel_model.code[:] = code
    
    display(VF.to_pil_image(pixel_model()))
    
    # train
    
    pixel_history = []
    try:
        for it in tqdm(range(iters)):

            pixel_batch = []
            pixels = _pixels_for_clip(pixel_model())
            
            for i in range(batch_size):
                pixel_batch.append(_clip_size(transforms(pixels)).unsqueeze(0))
            pixel_batch = torch.concat(pixel_batch)

            image_embeddings = ClipSingleton.encode_image(pixel_batch, requires_grad=True)
            image_embeddings = image_embeddings / image_embeddings.norm(dim=-1, keepdim=True)

            dots = image_embeddings @ target_embeddings.T

            loss = F.l1_loss(dots, target_dots)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if it % 2 == 0:
                pixel_history.append(pixels.to("cpu").detach())

            if len(pixel_history) >= 10:
                print(float(loss))
                #display(VF.to_pil_image(pixels.detach().to("cpu")))
                display(VF.to_pil_image(make_grid(pixel_history, nrow=len(pixel_history))))
                display(VF.to_pil_image(RandomWangMap((8, 40))(pixel_history[-1])))
                pixel_history.clear()
    except KeyboardInterrupt:
        pass
    result = pixel_model().detach().to("cpu")
    #filename = "".join(
    return result
        

VF.to_pil_image(visualize_wang(
    #PixelModel((3, 224, 224)).to("cuda"),
    DecoderModel(dalle_decoder, 128, std=.1),
    #DecoderModelHxW(dalle_decoder, 128, (4, 4), std=0.5),
    #"a house in the woods",
    #"deep in the woods",
    #"blood splattered wall",
    #"cracked stone surface",
    #"close-up of a smiling face",
    #"yellow square on blue background",
    #"lava river between black rock",
    #"the sphere of planet earth in front of a black background",
    #"checkerboard texture",
    #"moss and stone",
    #"stars in the sky",
    #"deep space photography",
    #"evil spiderweb",
    #"weapons of mass destruction",
    #"organic structures",
    #"Bob Dobbs",
    "underwater cobble texture",
    #"top-down view of a river between meadows",
    #"pile of guts",
    #"stone texture",
    #"top-down view of a city maze",
    #"knotted ropes",
    #"an unfriendly stone texture",
    #"love & peace pattern",
    #"friendly pattern",
    batch_size=2,
    lr=1.5,
    iters=1000,
))

In [None]:
VT.Pad?

In [None]:
cli