In [None]:
import sys
sys.path.append("..")

import random
import math
import itertools
from copy import deepcopy
from io import BytesIO
from pathlib import Path
from collections import OrderedDict
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Union, Dict

import PIL.Image
import PIL.ImageDraw
import plotly
import plotly.express as px
import plotly.graph_objects as go
plotly.io.templates.default = "plotly_dark"
import numpy as np
import pandas as pd
pd.options.plotting.backend = "plotly"

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset, RandomSampler
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid
from IPython.display import display

from src.datasets import *
from src.util.image import *
from src.util import to_torch_device
from src.models.cnn import *
from src.models.encoder import EncoderConv2d

In [None]:
class EncoderConv2dLOCAL(nn.Module):

    def __init__(
            self,
            shape: Tuple[int, int, int],
            kernel_size: int = 5,
            channels: Iterable[int] = (16, 32),
            code_size: int = 1024,
            act_fn: Optional[nn.Module] = nn.ReLU(),
    ):
        super().__init__()
        self.shape = tuple(shape)
        self.channels = tuple(channels)
        self.kernel_size = int(kernel_size)
        self.code_size = int(code_size)
        # self.act_fn = act_fn
        
        channels = [self.shape[0], *self.channels]
        self.convolution = Conv2dBlock(
            channels=channels,
            kernel_size=self.kernel_size,
            act_fn=act_fn,
        )
        encoded_shape = self.convolution.get_output_shape(shape)
        self.linear = nn.Linear(math.prod(encoded_shape), self.code_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear(self.convolution(x).flatten(1))

    def get_extra_state(self):
        return {
            "shape": self.shape,
            "kernel_size": self.kernel_size,
            "channels": self.channels,
            "code_size": self.code_size,
            "act_fn": self.convolution._act_fn,
        }
    
    def set_extra_state(self, state):
        pass
    
    @classmethod
    def from_torch(cls, f):
        if isinstance(f, (dict, OrderedDict)):
            data = f
        else:
            data = torch.load(f)
        
        extra = data["_extra_state"]
        model = cls(
            shape=extra["shape"],
            kernel_size=extra["kernel_size"],
            channels=extra["channels"],
            code_size=extra["code_size"],
            act_fn=extra["act_fn"],
        )
        model.load_state_dict(data)
        return model
    
enc = EncoderConv2d((1, 32, 32), kernel_size=7)
enc(torch.ones(1, 1, 32, 32))
dict(enc.state_dict())
EncoderConv2d.from_torch(enc.state_dict())

In [None]:
from scripts.train_autoencoder_vae import VariationalAutoencoderConv
model = VariationalAutoencoderConv((1, 32, 32), channels=[16, 24, 32], kernel_size=5, latent_dims=128)
data = torch.load("../checkpoints/vae-all1/snapshot.pt")
print("{:,} steps".format(data["num_input_steps"]))
model.load_state_dict(data["state_dict"])
model = model.encoder
model

In [None]:
enc = EncoderConv2d(
    shape=(1, 32, 32),
    channels=[16, 24, 32],
    kernel_size=5,
    code_size=128,
)
if 1:
    with torch.no_grad():
        for i in range(3):
            enc.convolution.layers[i * 2].weight[:] = model.encoder[0].layers[i * 2].weight
            enc.convolution.layers[i * 2].bias[:] = model.encoder[0].layers[i * 2].bias
        enc.linear.weight[:] = model.linear_mu.weight
        enc.linear.bias[:] = model.linear_mu.bias
enc#.state_dict().keys()

In [None]:
with torch.no_grad():
    fig = px.line(torch.concat([
        model.forward(torch.ones(1, *enc.shape), random=False),
        enc(torch.ones(1, *enc.shape)),
    ]).T, title="check that model-copying worked")
fig

### save encoder

In [None]:
!ls -l ../models/encoder2d/

In [None]:
torch.save(enc.state_dict(), "../models/encoder2d/conv-1x32x32-128-all1.pt")

# load encoder

In [None]:
encoder = EncoderConv2d.from_torch("../models/encoderconv/encoder-1x32x32-128-photo-5.pt", device="cpu")
encoder.device

In [None]:
ds = make_image_patch_dataset(
    encoder.shape, "~/Pictures/photos", recursive=True, 
    interleave_images=4, patch_shuffle=100_000,
    scales=[1., 1./6.],
)
ds = ImageFilterIterableDataset(ds, filter=ImageFilter(min_mean=.1))
ds = DissimilarImageIterableDataset(ds, max_similarity=.9, max_age=100_000, verbose=True)
patches = next(iter(DataLoader(ds, batch_size=32*32)))
VF.to_pil_image(make_grid_labeled(patches[:32*32], nrow=32))

In [None]:
with torch.no_grad():
    features = torch.round(encoder(patches.to(encoder.device)).cpu(), decimals=5)
    features /= features.norm(dim=1, keepdim=True)

#df = pd.DataFrame(model(patches).detach().numpy())
px.line(pd.DataFrame(features[:50]).T.copy())

In [None]:
sim = features[:100] @ features[:100].T
px.imshow(sim, height=1000)

In [None]:
labels = (sim * (1. - torch.diag(torch.Tensor([1] * sim.shape[0])))).argmax(dim=1)
values, _ = (sim * (1. - torch.diag(torch.Tensor([1] * sim.shape[0])))).max(dim=1)
grid = []
grid_labels = []
for i, (label, value) in enumerate(zip(labels, values)):
    if value > .5:
        grid.append(patches[i])
        grid.append(patches[label])
        grid_labels.append("")
        grid_labels.append(f"{float(value):.3f}")
VF.to_pil_image(make_grid_labeled(grid, labels=grid_labels))

In [None]:
big_sim = features @ features.T
values, labels = big_sim.sort(dim=1, descending=True)
grid = []
grid_labels = []
for i, (label_row, value_row) in enumerate(zip(labels[:50], values[:50])):
    for l, v in zip(label_row[:30], value_row[:30]):
        grid.append(patches[l])
        grid_labels.append(f"{float(v):.3f}")
VF.to_pil_image(make_grid_labeled(grid, nrow=30, labels=grid_labels))
VF.to_pil_image(make_grid(grid, nrow=30))

In [None]:
flat_features = patches.flatten(1)
flat_features = flat_features / flat_features.norm(dim=1, keepdim=True)
big_sim = flat_features @ flat_features.T
values, labels = big_sim.sort(dim=1, descending=True)
grid = []
grid_labels = []
for i, (label_row, value_row) in enumerate(zip(labels[:50], values[:50])):
    for l, v in zip(label_row[:30], value_row[:30]):
        grid.append(patches[l])
        grid_labels.append(f"{float(v):.3f}")
#VF.to_pil_image(make_grid_labeled(grid, nrow=30, labels=grid_labels))
VF.to_pil_image(make_grid(grid, nrow=30))

In [None]:
F.hinge_embedding_loss?

# feature vis

In [None]:
class RandomRotation(nn.Module):
    def __init__(self, degree: float = 10., random_center: float = 1.):
        super().__init__()
        self.degree = degree
        self.random_center = random_center
        
    def forward(self, x):
        degree = (torch.rand(1).item() * 2. - 1.) * self.degree
        center = (torch.rand(2) - .5) * self.random_center + .5
        center = [
            max(0, min(x.shape[-2] - 1, int(center[0] * x.shape[-2]))),
            max(0, min(x.shape[-1] - 1, int(center[1] * x.shape[-1])))
        ]
        return VF.rotate(x, angle=degree, center=center)
        
def feature_visualization(
    encoder: EncoderConv2d,
    target: torch.Tensor,
    shape: Optional[Tuple[int, int]] = None,
    std: float = .1,
    mean: float = .5,
    num_iter: int = 10,
    batch_size: int = 5,
    lr: float = 1.,
):  
    target = target.unsqueeze(0).expand(batch_size, -1).to(encoder.device)
    pixel_shape = encoder.shape
    if shape:
        pixel_shape = (encoder.shape[0], *shape)

    run_again = True
    while run_again:
        run_again = False
        
        pixels = nn.Parameter(torch.rand(pixel_shape).to(encoder.device) * std + mean)

        optimizer = torch.optim.Adadelta([pixels], lr=lr * 5.)
        optimizer = torch.optim.Adamax([pixels], lr=lr * .04)

        augmentations = [
            #VT.Pad(2, padding_mode="reflect"),
            #VT.RandomAffine(15, (.3, .3), scale=(.9, 1.1)),
            #RandomRotation(4, 1),
            VT.RandomPerspective(1, .6),
        ]
        if shape:
            augmentations.append(VT.RandomCrop(encoder.shape[-2:]))

        for itr in range(num_iter):

            with torch.no_grad():
                mix = .35
                pixels[:] = pixels * (1.-mix) + mix * VF.gaussian_blur(pixels, 5, 5)

            pixel_batch = []
            for batch_idx in range(batch_size):
                aug_pixels = pixels
                for aug in augmentations:
                    aug_pixels = aug(aug_pixels)

                pixel_batch.append(aug_pixels.unsqueeze(0))

            pixel_batch = torch.concat(pixel_batch)

            output = encoder(pixel_batch)
            if torch.any(torch.isnan(output)):
                run_again = True
                print("NaN eNcOuNtErEd")
                break

            #loss = F.l1_loss(target, output)
            loss = F.mse_loss(target, output)
            #loss = -F.cosine_similarity(target, output, dim=1).mean()
            #loss = F.soft_margin_loss(output, target)

            encoder.zero_grad()
            loss.backward()
            optimizer.step()
        
    return pixels.detach().clamp(0, 1).cpu()

img = feature_visualization(encoder, features[12])
img = feature_visualization(encoder, features[12], shape=(40, 40), num_iter=35, lr=0.8, batch_size=1)
    
VF.to_pil_image(VF.resize(img, [s * 4 for s in img.shape[-2:]], interpolation=VF.InterpolationMode.NEAREST))

In [None]:
img.shape

In [None]:
images = []
for i in tqdm(range(4*4)):
    target = features[i] 
    #target = torch.randn_like(target) * target.std() + target.mean()
    images.append(VF.resize(patches[i], (64, 64)))
    images.append(feature_visualization(encoder, target, shape=(64, 64)))
    
img = make_grid(images, nrow=4)
VF.to_pil_image(VF.resize(img, [s * 4 for s in img.shape[-2:]], interpolation=VF.InterpolationMode.NEAREST))

In [None]:
def get_window(shape: Tuple[int, int]):
    return (
          torch.hamming_window(shape[-1], periodic=True).unsqueeze(0).expand(shape[-2], -1)
        * torch.hamming_window(shape[-2], periodic=True).unsqueeze(0).expand(shape[-1], -1).T
    )
#px.imshow(get_window((10, 15)))

def reconstruct_image(
        encoder: EncoderConv2d, 
        original: torch.Tensor, 
        sub_sample: float = 1, 
        noise: float = 0.,
        patch_shape: Optional[Tuple[int, int]] = None,
        num_iter: int = 10,
        lr: float = 1.,
):
    _patch_shape = encoder.shape[-2:]
    _scale = [1, 1]
    if patch_shape:
        _patch_shape = (
            int(original.shape[-2] / encoder.shape[-2] * patch_shape[-2]),
            int(original.shape[-1] / encoder.shape[-1] * patch_shape[-1]),
        )
        _scale = [
            1. / encoder.shape[-2] * patch_shape[-2],
            1. / encoder.shape[-1] * patch_shape[-1],
        ]
    recon = torch.zeros(encoder.shape[0], *_patch_shape)
    recon_sum = torch.zeros(encoder.shape[0], *_patch_shape)
    window = get_window(patch_shape or encoder.shape[-2:])

    try:
        patches = []
        positions = []
        for patch, pos in iter_image_patches(
            original, shape=encoder.shape[-2:],
            stride=(int(s / sub_sample) for s in encoder.shape[-2:]),
            with_pos=True,
        ):
            pos = [int(p) for p in pos]
            if patch.shape[0] != encoder.shape[0]:
                for chan in range(patch.shape[0]):
                    patches.append(patch[chan].unsqueeze(0).unsqueeze(0))
                    positions.append([chan] + pos)
            else:
                patches.append(patch.unsqueeze(0))
                positions.append([slice(0, patch.shape[0])] + pos)
                
        with torch.no_grad():
            features = encoder(torch.concat(patches).to(encoder.device))
            if noise:
                features = features + noise * torch.randn_like(features)
        
        for feature, pos in tqdm(zip(features, positions), total=len(positions)):
            chan, pos = pos[0], pos[1:]
            patch_recon = feature_visualization(encoder, feature, shape=patch_shape, lr=lr, num_iter=num_iter)
            s1 = chan
            s2 = slice(int(pos[0] * _scale[0]), int(pos[0] * _scale[0]) + patch_recon.shape[-2])
            s3 = slice(int(pos[1] * _scale[1]), int(pos[1] * _scale[1]) + patch_recon.shape[-1])
            recon[s1, s2, s3] = recon[s1, s2, s3] + patch_recon * window
            recon_sum[s1, s2, s3] = recon_sum[s1, s2, s3] + window
            #recon[chan, pos[0]: pos[0] + patch.shape[-2], pos[1]: pos[1] + patch.shape[-1]] = patch_recon 
    
    except KeyboardInterrupt:
        pass
    
    mask = recon_sum > 0
    recon[mask] = recon[mask] / recon_sum[mask]
    return recon

original = PIL.Image.open(
    "/home/bergi/Pictures/csv-turing.png"
    #"/home/bergi/Pictures/__diverse/28580_1.jpg"
    #"/home/bergi/Pictures/__diverse/merkel_sarkozy_g8_regierungOnline_Kuehler_CMS_small_620.jpeg"
   # "/home/bergi/Pictures/__diverse/honecker.jpg"
    #"/home/bergi/Pictures/__diverse/plakat01.jpg"
    #"/home/bergi/Pictures/DWlZbQ5WsAQEzHT.jpg"
    #"/home/bergi/Pictures/there_is_no_threat.jpeg"
    #"/home/bergi/Pictures/diffusion/cthulhu-09.jpeg"
)
original = VF.to_tensor(original)
original = set_image_channels(original, 1)
original = VF.resize(original, [s // 8 for s in original.shape[-2:]])
display(VF.to_pil_image(original))

img = reconstruct_image(encoder, original, sub_sample=5, noise=.000, patch_shape=(64, 64), num_iter=10, lr=10.)
VF.to_pil_image(img)