In [None]:
import sys
sys.path.append("..")

import random

import time
import math
import random
from io import BytesIO
from pathlib import Path
from typing import Optional, Callable, List, Tuple, Iterable, Generator

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
#from torchvision.transforms import v2
from torchvision.utils import make_grid

import PIL.Image
import PIL.ImageDraw
import plotly
import plotly.express as px
plotly.io.templates.default = "plotly_dark"
import numpy as np
import pandas as pd

import clip

from src.datasets import *
from src.util import *
from src.util.image import * 
from src.algo import *
from src.algo.wangtiles import *
from src.datasets.generative import *
from src.models.cnn import *
from src.models.transform import *
from src.util.embedding import *
from src.models.clip import ClipSingleton

In [None]:
from scripts import datasets
ds = datasets.RpgTileIterableDataset((1, 32, 32))
ds = IterableShuffle(ds, 1000)
VF.to_pil_image(next(iter(ds)))#.shape
samples = next(iter(DataLoader(ds, batch_size=32)))

VF.to_pil_image(make_grid(samples))

In [None]:
class PixelModel(nn.Module):
    def __init__(self, shape: Tuple[int, int, int] = (3, 224, 224)):
        super().__init__()
        self.shape = shape
        self.code = nn.Parameter(torch.randn(self.shape) * .1 + .3)
    
    def forward(self):
        return self.code.clamp(0, 1)

    def reset(self):
        with torch.no_grad():
            self.code[:] = torch.randn_like(self.code) * .1 + .3
    
class DecoderModel(nn.Module):
    def __init__(self, decoder: nn.Module, code_size: int, std: float = 1.):
        super().__init__()
        self.decoder = decoder
        self.code_size = code_size
        self.std = std
        self.code = nn.Parameter(torch.randn(1, code_size) * std)
    
    def forward(self):
        return self.decoder(self.code).squeeze(0).clamp(0, 1)

    def reset(self):
        with torch.no_grad():
            self.code[:] = torch.randn_like(self.code) * self.std
            
class DecoderModelRGB(nn.Module):
    def __init__(self, decoder: nn.Module, code_size: int, std: float = 1.):
        super().__init__()
        self.decoder = decoder
        self.code_size = code_size
        self.std = nn.Parameter(std, requires_grad=False) if isinstance(std, torch.Tensor) else std
        self.code = nn.Parameter(torch.randn(3, code_size) * std)
        
    def forward(self):
        rgb = self.decoder(self.code).squeeze(1)
        return rgb.clamp(0, 1)

    def reset(self):
        with torch.no_grad():
            self.code[:] = torch.randn_like(self.code) * self.std

#print(DecoderModelRGB(dalle_decoder, 128)().shape)

class DecoderModelHSV(DecoderModelRGB):
    
    def forward(self):
        hsv = self.decoder(self.code).squeeze(1)
        hsv[0] = hsv[0] * 3.
        return hsv_to_rgb(hsv).clamp(0, 1)

class DecoderModelHxW(nn.Module):
    def __init__(self, decoder: nn.Module, code_size: int, shape: Tuple[int, int], std: float = 1.):
        super().__init__()
        self.decoder = decoder
        self.code_size = code_size
        self.shape = shape
        self.std = std
        self.code = nn.Parameter(torch.randn(math.prod(shape), code_size) * std)
    
    def forward(self):
        images = self.decoder(self.code).clamp(0, 1)
        #x = make_grid(x, nrow=self.shape[1], padding=0)[:1]
        shape = images.shape
        output = torch.zeros_like(images).view(shape[-3], shape[-2] * self.shape[-2], shape[-1] * self.shape[-1])
        for y in range(self.shape[-2]):
            for x in range(self.shape[-1]):
                output[:, y * shape[-2]: (y + 1) * shape[-2], x * shape[-1]: (x + 1) * shape[-1]] \
                    = images[y * self.shape[-1] + x]
        
        return output
        
    def reset(self):
        with torch.no_grad():
            self.code[:] = torch.randn_like(self.code) * self.std
    
    #def set_pixels(pixels: torch.Tensor, decoder_shape: Tuple[int, int, int] = (1, 64, 64)):
        #pixels = VF.resize(pixels, (self.shape[-2] * decoder_shape[-2
                                                                   
#VF.to_pil_image(DecoderModelHxW(dalle_decoder, 128, (4, 4), .2)())

class DecoderModelRGBHxW(nn.Module):
    def __init__(self, decoder: nn.Module, code_size: int, shape: Tuple[int, int], std: float = 1.):
        super().__init__()
        self.decoder = decoder
        self.code_size = code_size
        self.shape = shape
        self.std = std
        self.code = nn.Parameter(torch.randn(math.prod(shape) * 3, code_size) * std)
    
    def forward(self):
        images = self.decoder(self.code).clamp(0, 1)
        images = images.view(images.shape[0] // 3, 3, *images.shape[-2:])
        shape = images.shape
        output = torch.zeros_like(images).view(shape[-3], shape[-2] * self.shape[-2], shape[-1] * self.shape[-1])
        for y in range(self.shape[-2]):
            for x in range(self.shape[-1]):
                output[:, y * shape[-2]: (y + 1) * shape[-2], x * shape[-1]: (x + 1) * shape[-1]] \
                    = images[y * self.shape[-1] + x]
        
        return output
        
    def reset(self):
        with torch.no_grad():
            self.code[:] = torch.randn_like(self.code) * self.std

#VF.to_pil_image(DecoderModelRGBHxW(dalle_decoder, 128, (4, 4), .2)())            

In [None]:
from scripts.train_autoencoder import DalleAutoencoder, DalleManifoldAutoencoder

if 0: # gray
    model = DalleAutoencoder((1, 64, 64), vocab_size=128, n_hid=64, group_count=1, n_blk_per_group=1, act_fn=nn.GELU)
    model.load_state_dict(torch.load("../checkpoints/ae-d3/best.pt")["state_dict"])
elif 0: # gray
    model = DalleAutoencoder((1, 32, 32), vocab_size=128, n_hid=96, group_count=4, n_blk_per_group=2, act_fn=nn.GELU, space_to_depth=True)
    model.load_state_dict(torch.load("../checkpoints/ae-d11-32/best.pt")["state_dict"])
elif 1: # gray
    model = DalleManifoldAutoencoder((1, 32, 32), vocab_size=128, n_hid=64, n_blk_per_group=1, act_fn=nn.GELU, space_to_depth=True, decoder_n_blk=8, decoder_n_layer=2, decoder_n_hid=128)
    model.load_state_dict(torch.load("../checkpoints/ae-manifold9d-8l/best.pt")["state_dict"])
elif 1: # gray
    model = DalleManifoldAutoencoder((1, 32, 32), vocab_size=128, n_hid=64, n_blk_per_group=1, act_fn=nn.GELU, space_to_depth=True, decoder_n_blk=8, decoder_n_layer=2, decoder_n_hid=128)
    #model = DalleManifoldAutoencoder((1, 32, 32), vocab_size=128, n_hid=64, n_blk_per_group=1, act_fn=nn.GELU, space_to_depth=True, decoder_n_blk=8, decoder_n_layer=2, decoder_n_hid=300)
    model.load_state_dict(torch.load("../checkpoints/ae-manifold-8/snapshot.pt")["state_dict"])
elif 0: # color
    model = DalleAutoencoder((3, 64, 64), vocab_size=128, n_hid=64, group_count=1, n_blk_per_group=1, act_fn=nn.GELU)
    model.load_state_dict(torch.load("../checkpoints/ae-d4-color/best.pt")["state_dict"])
else:
    model = DalleAutoencoder((1, 64, 64), vocab_size=128, n_hid=64, group_count=1, n_blk_per_group=1, act_fn=nn.GELU, space_to_depth=True)
    model.load_state_dict(torch.load("../checkpoints/ae-d8-sobel-spd/best.pt")["state_dict"])
    
for param in model.parameters():
    param.requires_grad = False
dalle_decoder = model.decoder
print(f"{num_module_parameters(dalle_decoder):,}")

In [None]:
with torch.no_grad():
    print(DecoderModel(dalle_decoder, 128)().shape)
    print(DecoderModelHxW(dalle_decoder, 128, (4, 4))().shape)

In [None]:
with torch.no_grad():
    display(VF.to_pil_image(make_grid([DecoderModelRGB(dalle_decoder, 128, std=.5)() for _ in range(16)])))

In [None]:
code_mean = None

@torch.no_grad()
def reproduce(samples):
    global code_mean
    codes = model.encoder(samples)
    code_mean = codes.mean(0)
    print("std:", codes.std(1).mean())
    repros = dalle_decoder(codes)
    display(VF.to_pil_image(
        make_grid([
            make_grid(samples),
            make_grid(repros)
        ])
    ))
    display(px.line(codes.T))
    display(px.line(code_mean))
    try:
        repros_big = dalle_decoder(codes, (96, 96))
    except:
        return
    display(VF.to_pil_image(
        make_grid(repros_big)
    ))
reproduce(samples.mean(-3, keepdim=True))

In [None]:
with torch.no_grad():
    display(VF.to_pil_image(make_grid([DecoderModelRGB(dalle_decoder, 128, std=code_mean)() for _ in range(16)])))

In [None]:
class Debug(nn.Module):
    def __init__(self, name: str = "Debug"):
        super().__init__()
        self.name = name
        
    def forward(self, x):
        if isinstance(x, torch.Tensor):
            print(f"{self.name}: {x.shape}")
        else:
            print(f"{self.name}: {type(x).__name__}")
        return x

In [None]:
def visualize(
    pixel_model: nn.Module,
    text: str,
    # text_neg: Optional[str] = None,
    iters: int = 100,
    batch_size: int = 1,
    lr: float = 10.,
    device="auto",
):    
    device = to_torch_device(device)
    print(device)
    pixel_model = pixel_model.to(device)
    
    #optimizer = torch.optim.Adadelta(pixel_model.parameters(), lr=lr * 100.)
    optimizer = torch.optim.Adam(pixel_model.parameters(), lr=lr * .1)

    transforms = VT.Compose([
        #VT.RandomErasing(),
        #VT.Pad(30, padding_mode="reflect"),
        #Debug("model"),
        RandomWangMap((8, 8), probability=1, overlap=0, num_colors=2),
        #Debug("wangmap"),
        VT.RandomAffine(
            degrees=35.,
            #scale=(1., 3.),
            #scale=(.3, 1.),
            #translate=(0, 4. / 64.),
        ),
        VT.RandomCrop((224, 224)),
    ])
        
    with torch.no_grad():
        target_embeddings = ClipSingleton.encode_text(text).to(device)
        target_embeddings = target_embeddings / target_embeddings.norm(dim=-1, keepdim=True)

    target_dots = torch.ones(batch_size, 1).half().to(device)

    def _pixels_for_clip(pixels):
        pixels = pixels.to(device).half()
        if pixels.shape[-3] != 3:
            pixels = set_image_channels(pixels, 3)
        #if pixels.shape[-2:] != (224, 224):
        #    pixels = VF.resize(pixels, (224, 224), VT.InterpolationMode.BILINEAR, antialias=True)
        return pixels
    
    def _clip_size(pixels):
        if pixels.shape[-2:] != (224, 224):
            pixels = VF.resize(pixels, (224, 224), VT.InterpolationMode.BILINEAR, antialias=True)
        return pixels
    
    # find best start candidate
    
    with torch.no_grad():
        pixel_batch = []
        code_batch = []
        for i in tqdm(range(8), desc="find best seed"):
            pixel_model.reset()
            pixels = _clip_size(transforms(_pixels_for_clip(pixel_model())))
            pixel_batch.append(pixels.unsqueeze(0))
            code_batch.append(pixel_model.code)
         
        image_embeddings = ClipSingleton.encode_image(torch.concat(pixel_batch), requires_grad=True)
        image_embeddings = image_embeddings / image_embeddings.norm(dim=-1, keepdim=True)

        dots = image_embeddings @ target_embeddings.T
        dot_ids = dots.flatten(0).argsort()
        
        pixel_model.code[:] = code_batch[dot_ids[-1]]
        
        # show example transforms
        display(VF.to_pil_image(make_grid([
            _clip_size(transforms(_pixels_for_clip(pixel_model().cpu())))
            for _ in range(8)
        ])))
    
    torch.cuda.empty_cache()

    # train
    
    pixel_history = []
    try:
        for it in tqdm(range(iters)):
            pixel_batch = []
            pixels = _pixels_for_clip(pixel_model())
            
            for i in range(batch_size):
                pixel_batch.append(_clip_size(transforms(pixels)).unsqueeze(0))
            pixel_batch = torch.concat(pixel_batch)
            
            image_embeddings = ClipSingleton.encode_image(pixel_batch, requires_grad=True)
            image_embeddings = image_embeddings / image_embeddings.norm(dim=-1, keepdim=True)

            dots = image_embeddings @ target_embeddings.T

            loss = F.l1_loss(dots, target_dots)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if it % 2 == 0:
                pixel_history.append(pixels.cpu().detach())

            if len(pixel_history) >= 10:
                print(float(loss))
                #display(VF.to_pil_image(pixels.detach().to("cpu")))
                display(VF.to_pil_image(make_grid(pixel_history, nrow=len(pixel_history))))
                display(VF.to_pil_image(RandomWangMap((8, 40), overlap=5)(pixel_history[-1])))
                pixel_history.clear()
    except KeyboardInterrupt:
        pass
    result = pixel_model().detach().to("cpu")
    #filename = "".join(
    return result
        

# dalle_decoder.default_shape = (128, 128)
VF.to_pil_image(visualize(
    #PixelModel((3, 224, 224)),
    #PixelModel((3, 32*9, 32*9)),
    #DecoderModel(dalle_decoder, 128, std=.5),
    #DecoderModelRGB(dalle_decoder, 128, std=.2),
    DecoderModelRGBHxW(dalle_decoder, 128, (4, 4), std=.5),
    #"a house in the woods",
    #"deep in the woods",
    #"blood splattered wall",
    #"cracked stone surface",
    #"close-up of a smiling face",
    #"yellow square on blue background",
    #"lava river between black rock",
    #"the sphere of planet earth in front of a black background",
    #"checkerboard texture",
    #"moss and stone",
    #"stars in the sky",
    #"deep space photography",
    #"evil spiderweb",
    #"weapons of mass destruction",
    #"organic cell structures",
    #"small dots on black background",
    #"Bob Dobbs",
    #"underwater cobble texture",
    #"top-down view of a river between meadows",
    #"pile of guts",
    #"stone texture",
    #"top-down view of a city maze",
    #"knotted ropes",
    #"an unfriendly stone texture",
    #"ocean waves",
    #"typewriter printing",
    #"cables and wires",
    #"80's car racing game, bird's eye view",
    "role playing game outdoor map",
    #"cracks in the marble floor",
    #"mountain tops raising from the mist",
    #"penrose tiling",
    #"2d side-scroller",
    #"love & peace pattern",
    #"friendly pattern",
    batch_size=1,
    lr=1.,
    iters=1000,
))

In [None]:
with torch.no_grad():
    img = DecoderModel(dalle_decoder, 128, std=.3)()
    #img = DecoderModelHxW(dalle_decoder, 128, (4, 4), std=.3).cuda()()
    print(img.shape)
    img = RandomWangMap((5, 5), probability=1, overlap=0, num_colors=2)(img)
    print(img.shape)

In [None]:
def visualize_wang(
    pixel_model: nn.Module,
    text: str,
    iters: int = 100,
    batch_size: int = 1,
    lr: float = 10.,
    device="auto",
):    
    device = to_torch_device(device)
    
    #optimizer = torch.optim.Adadelta(pixel_model.parameters(), lr=lr * 100.)
    optimizer = torch.optim.Adam(pixel_model.parameters(), lr=lr * .1)

    
    transforms = VT.Compose([
        #VT.RandomErasing(),
        #VT.Pad(30, padding_mode="reflect"),
        RandomWangMap((30, 30), probability=1),
        VT.RandomAffine(
            degrees=35.,
            scale=(.3, 1.),
            #translate=(0, 4. / 64.),
        ),
        VT.RandomCrop((224, 224)),
    ])
    
    with torch.no_grad():
        target_embeddings = ClipSingleton.encode_text(text).to(device)
        target_embeddings = target_embeddings / target_embeddings.norm(dim=-1, keepdim=True)

    target_dots = torch.ones(batch_size, 1).half().to(device)

    def _pixels_for_clip(pixels):
        pixels = pixels.to(device).half()
        if pixels.shape[-3] != 3:
            pixels = set_image_channels(pixels, 3)
        #if pixels.shape[-2:] != (224, 224):
        #    pixels = VF.resize(pixels, (224, 224), VT.InterpolationMode.BILINEAR, antialias=True)
        return pixels
    
    def _clip_size(pixels):
        if pixels.shape[-2:] != (224, 224):
            pixels = VF.resize(pixels, (224, 224), VT.InterpolationMode.BILINEAR, antialias=True)
        return pixels
    
    # find best start candidate
    
    with torch.no_grad():
        wang_template = VF.to_tensor(PIL.Image.open("/home/bergi/prog/python/thegame/thegame/assets/cr31/path.png"))[:3] * 150
        wang_template = VF.resize(wang_template, (32, 32), VF.InterpolationMode.BILINEAR, antialias=True)
        code = model.encoder(wang_template.unsqueeze(0)).squeeze(0)
        pixel_model.code[:] = code
    
    display(VF.to_pil_image(pixel_model()))
    
    # train
    
    pixel_history = []
    try:
        for it in tqdm(range(iters)):

            pixel_batch = []
            pixels = _pixels_for_clip(pixel_model())
            
            for i in range(batch_size):
                pixel_batch.append(_clip_size(transforms(pixels)).unsqueeze(0))
            pixel_batch = torch.concat(pixel_batch)

            image_embeddings = ClipSingleton.encode_image(pixel_batch, requires_grad=True)
            image_embeddings = image_embeddings / image_embeddings.norm(dim=-1, keepdim=True)

            dots = image_embeddings @ target_embeddings.T

            loss = F.l1_loss(dots, target_dots)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if it % 2 == 0:
                pixel_history.append(pixels.to("cpu").detach())

            if len(pixel_history) >= 10:
                print(float(loss))
                #display(VF.to_pil_image(pixels.detach().to("cpu")))
                display(VF.to_pil_image(make_grid(pixel_history, nrow=len(pixel_history))))
                display(VF.to_pil_image(RandomWangMap((8, 40))(pixel_history[-1])))
                pixel_history.clear()
    except KeyboardInterrupt:
        pass
    result = pixel_model().detach().to("cpu")
    #filename = "".join(
    return result
        

VF.to_pil_image(visualize_wang(
    #PixelModel((3, 224, 224)).to("cuda"),
    DecoderModel(dalle_decoder, 128, std=.1),
    #DecoderModelHxW(dalle_decoder, 128, (4, 4), std=0.5),
    #"a house in the woods",
    #"deep in the woods",
    #"blood splattered wall",
    #"cracked stone surface",
    #"close-up of a smiling face",
    #"yellow square on blue background",
    #"lava river between black rock",
    #"the sphere of planet earth in front of a black background",
    #"checkerboard texture",
    #"moss and stone",
    #"stars in the sky",
    #"deep space photography",
    #"evil spiderweb",
    #"weapons of mass destruction",
    #"organic structures",
    #"Bob Dobbs",
    "underwater cobble texture",
    #"top-down view of a river between meadows",
    #"pile of guts",
    #"stone texture",
    #"top-down view of a city maze",
    #"knotted ropes",
    #"an unfriendly stone texture",
    #"love & peace pattern",
    #"friendly pattern",
    batch_size=2,
    lr=1.5,
    iters=1000,
))

In [None]:
VT.Pad?

In [None]:
cli