In [None]:
import sys
sys.path.append("..")

import random
import math
import itertools
from copy import deepcopy
from io import BytesIO
from pathlib import Path
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Union, Dict

import PIL.Image
import PIL.ImageDraw
import plotly
import plotly.express as px
import plotly.graph_objects as go
plotly.io.templates.default = "plotly_dark"
import numpy as np
import pandas as pd
pd.options.plotting.backend = "plotly"

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset, RandomSampler
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid
from IPython.display import display

from src.datasets import *
from src.algo import GreedyLibrary
from src.util.image import *
from src.util import to_torch_device
from src.patchdb import PatchDB, PatchDBIndex
from src.models.encoder import *

In [None]:
!ls -l ../models/encoder2d

In [None]:
#encoder = EncoderConv2d.from_torch("../models/encoder2d/encoder-1x32x32-128-photo-5.pt", device="cpu")
encoder = EncoderConv2d.from_torch("../models/encoder2d/conv-1x32x32-128-all1.pt", device="cpu")
#encoder = BoltzmanEncoder2d.from_torch("../models/encoder2d/boltzman-1x32x32-128-photo-300M.pt", device="cpu")
encoder.device

In [None]:
db = PatchDB(
    #"../db/photos-1x32x32-2.patchdb", 
    #"../db/photos-bigpatch-1x32x32.patchdb",
    #"../db/kali-1x32x32.patchdb",
    #"../db/diverse-1x32x32-3b.patchdb",
    #"../db/diverse-1x32x32-3b-rbm-300M.patchdb",
    #"../db/kali-1x64x64-rbm300M.patchdb",
    #"../db/bob-1x32x32-convall1.patchdb",
    #"../db/fjord-1x32x32-convall1.patchdb",
    #"../db/serifs-1x32x32-convall1.patchdb",
    #"../db/hyperplane-1x32x32-convall1.patchdb",
    #"../db/hyperplane-inv-1x32x32-convall1.patchdb",
    "../db/topping-1x32x32-convall1.patchdb",
    #"../db/cells-1x32x32-convall1.patchdb",
    #"../db/sand-1x32x32-convall1.patchdb",
    
    encoder=encoder,
    verbose=True, limit=1_000_000, 
)
index = db.index()
print(f"{index.size} patches, {len(index.filenames())} images")

# random sample

In [None]:
VF.to_pil_image(make_grid(
    [index.patches[random.randrange(index.size)].patch for _ in range(30*30)], 
    nrow=30,
))

# samples per image

In [None]:
samples_per_image = {}
for patch in index.patches:
    if patch.filename not in samples_per_image:
        samples_per_image[patch.filename] = []
    samples_per_image[patch.filename].append(patch)

for filename in sorted(samples_per_image):
    print(filename)
    patches = samples_per_image[filename]
    display(VF.to_pil_image(make_grid(
        [patches[random.randrange(len(patches))].patch for _ in range(30)], 
    nrow=30,
    )))

# random similars

In [None]:
def plot_random_similars(index, encoder, count=30, count_sim=50):
    patch_ids = list(range(index.size))
    random.shuffle(patch_ids)
    patch_ids = patch_ids[:count]
    grid = []
    sim_patches = []
    for patch_id in tqdm(patch_ids):
        patch = index.patches[patch_id].patch
        # grid.append(patch) # will be the first anyway
        sim_patches.append(index.similar_patches(patch, count=count_sim)[0])
    
    for y in range(count_sim):
        for sim_patch in sim_patches:
            patch = sim_patch[y].patch
            grid.append(patch)
            
    return VF.to_pil_image(make_grid(grid, nrow=count))

img = plot_random_similars(index, encoder, count_sim=30)
img#.save("/home/bergi/Pictures/random-similars-patchdb-photos6M.png")

In [None]:
def plot_similars(index, patch_or_id, count=30, nrow=30):
    if isinstance(patch_or_id, int):
        patch = index.patches[patch_or_id].patch
    else:
        patch = patch_or_id
    
    sim_patches = index.similar_patches(patch, count=count)[0]
    
    grid = []
    for sim_patch in sim_patches:
        patch = sim_patch.patch
        grid.append(patch)
            
    return VF.to_pil_image(make_grid(grid, nrow=nrow))

img = plot_similars(index, 23, count=50*50, nrow=50)
img#.save("/home/bergi/Pictures/random-similars-patchdb-photos6M.png")

# image reconstruction

In [None]:
def get_window_2d(shape: Tuple[int, int]):
    return (
          torch.hamming_window(shape[-1], periodic=True).unsqueeze(0).expand(shape[-2], -1)
        * torch.hamming_window(shape[-2], periodic=True).unsqueeze(0).expand(shape[-1], -1).T
    )

def reconstruct_image(
        original: torch.Tensor,
        index: PatchDBIndex,
        sample_patch_shape: Optional[Tuple[int, int]] = None,
        sub_sample: float = 1, 
        multi_sample: int = 1,  # only meaningful with random_first_patches>1
        noise: float = 0.,
        random_first_patches: int = 1,
        padding: Union[None, int, Iterable[int]] = None,
        fill: int = 0,
):
    assert index.db.encoder, "Must be defined"
    encoder = index.db.encoder
    
    if sample_patch_shape is None:
        sample_patch_shape = encoder.shape[-2:]
        _scale = (1., 1.)
    else:
        _scale = (
            sample_patch_shape[-2] / encoder.shape[-2], 
            sample_patch_shape[-1] / encoder.shape[-1],
        )
    
    #output_shape = (
    #    int(original.shape[-2] / encoder.shape[-2]) * encoder.shape[-2],
    #    int(original.shape[-1] / encoder.shape[-1]) * encoder.shape[-1],
    #)
    
    if isinstance(padding, int):
        padding = [padding] * 4
    
    recon = torch.zeros(original.shape)
    recon_sum = torch.zeros(original.shape)
    if padding:
        recon = F.pad(recon, padding, value=fill)
        recon_sum = F.pad(recon_sum, padding, value=fill)
    window = get_window_2d(sample_patch_shape or encoder.shape[-2:])
    
    try:
        patches = []
        positions = []
        for patch, pos in iter_image_patches(
            original, 
            shape=sample_patch_shape,
            stride=(int(s / sub_sample) for s in sample_patch_shape),
            with_pos=True,
            padding=padding,
            fill=fill,
        ):
            pos = [int(p) for p in pos]
            if patch.shape[0] != encoder.shape[0]:
                for chan in range(patch.shape[0]):
                    patches.append(patch[chan].unsqueeze(0).unsqueeze(0))
                    positions.append([chan] + pos)
            else:
                patches.append(patch.unsqueeze(0))
                positions.append([slice(0, patch.shape[0])] + pos)
                
        similar_patches = []
        with torch.no_grad():
            with tqdm(desc="encoding/searching patches", total=len(patches)) as progress:
                for patch_batch, in DataLoader(TensorDataset(torch.concat(patches)), batch_size=512):
                    feature_batch = encoder.encode_image(patch_batch)
                    
                    if noise:
                        feature_batch = feature_batch + noise * torch.randn_like(feature_batch)
            
                    similar_patches += index.similar_patches(feature_batch, count=random_first_patches)
                    
                    progress.update(len(feature_batch))            
                
        for pos, sim_patches in tqdm(zip(positions, similar_patches), total=len(positions)):
            chan, pos = pos[0], pos[1:]
            
            for ms_idx in range(multi_sample):
                patch_recon = sim_patches[random.randrange(len(sim_patches))].patch
                s1 = chan
                s2 = slice(int(pos[0] * _scale[0]), int(pos[0] * _scale[0]) + patch_recon.shape[-2])
                s3 = slice(int(pos[1] * _scale[1]), int(pos[1] * _scale[1]) + patch_recon.shape[-1])
                recon[s1, s2, s3] = recon[s1, s2, s3] + patch_recon * window
                recon_sum[s1, s2, s3] = recon_sum[s1, s2, s3] + window
            
    except KeyboardInterrupt:
        pass
    
    mask = recon_sum > 0
    recon[mask] = recon[mask] / recon_sum[mask]
    return recon

original = PIL.Image.open(
    #"/home/bergi/Pictures/csv-turing.png"
    #"/home/bergi/Pictures/__diverse/28580_1.jpg"
    #"/home/bergi/Pictures/__diverse/merkel_sarkozy_g8_regierungOnline_Kuehler_CMS_small_620.jpeg"
    #"/home/bergi/Pictures/__diverse/honecker.jpg"
    #"/home/bergi/Pictures/__diverse/plakat01.jpg"
    #"/home/bergi/Pictures/DWlZbQ5WsAQEzHT.jpg"
    #"/home/bergi/Pictures/there_is_no_threat.jpeg"
    #"/home/bergi/Pictures/diffusion/cthulhu-09.jpeg"
    #"/home/bergi/Pictures/__diverse/HourSlack-EarsBleed-2-72SM.jpg"
    #"/home/bergi/Pictures/__diverse/parental_advisory.jpg"
    #"/home/bergi/Pictures/__diverse/BP_Logo_01.png"
    #"/home/bergi/Pictures/__diverse/994066096.jpg"
    #"/home/bergi/Pictures/__diverse/EuropeAfricaNato.jpg"
    #"/home/bergi/Pictures/__diverse/nobluesky2.jpg"
    #"/home/bergi/Pictures/__diverse/Zeisswerk_Jena_um_1910.jpg"
    #"/home/bergi/Pictures/__diverse/1907_Panic.png"
    #"/home/bergi/Pictures/__diverse/schildkroe.bmp"
    #"/home/bergi/Pictures/__diverse/nk_hammer+brush+sickle.jpg"
    #"/home/bergi/Pictures/__diverse/Klaus_Naumann.jpg"
    #"/home/bergi/Pictures/__diverse/cheney2.jpg"
    #"/home/bergi/Pictures/__diverse/1233217389_800.jpg"
    "/home/bergi/Pictures/IMG-20210108-WA0001.jpg"
)
original = VF.to_tensor(original)
original = set_image_channels(original, 1)
original = VF.resize(original, [int(s / 1.5) for s in original.shape[-2:]])

img = reconstruct_image(
    original, index, sub_sample=2, 
    #noise=.0005, 
    random_first_patches=100,
    multi_sample=1,
    #sample_patch_shape=(),
    #padding=5, fill=1,
)
display(VF.to_pil_image(img))
display(VF.to_pil_image(original))

In [None]:
VF.to_pil_image(1. - img)

# reconstruct at random positions

In [None]:
@torch.no_grad()
def reconstruct_image_random(
        original: torch.Tensor,
        index: PatchDBIndex,
        num_patches: int = 1000,
        noise: float = 0.,
        random_first_patches: int = 1,
        batch_size: int = 200,
        mix: float = .5,
):
    assert index.db.encoder, "Must be defined"
    encoder = index.db.encoder
          
    recon = torch.zeros(original.shape)
    recon_sum = torch.zeros(original.shape)
    window = get_window_2d(encoder.shape[-2:])
    
    def iter_patch_batches():
        patch_batch = []
        pos_batch = []
        for i in tqdm(range(num_patches)):
            s = random.randint(encoder.shape[-2] // 2, original.shape[-2] // 3)
            rect = [s, s]
            pos = [
                random.randint(0, recon.shape[-2] - rect[-2]),
                random.randint(0, recon.shape[-1] - rect[-1])
            ]
            patch = VF.resize(VF.crop(original, *pos, *rect), encoder.shape[-2:])
            
            patch_batch.append(patch.unsqueeze(0))
            pos_batch.append(pos + rect)
            
            if len(patch_batch) >= batch_size:
                yield torch.concat(patch_batch), pos_batch
                patch_batch.clear()
                pos_batch.clear()
                
        if len(patch_batch):
            yield torch.concat(patch_batch), pos_batch
    
    try:
        for patch_batch, pos_batch in iter_patch_batches():
            feature_batch = encoder.encode_image(patch_batch)
            sim_patches_batch = index.similar_patches(feature_batch, random_first_patches)
            
            for pos, sim_patches in zip(pos_batch, sim_patches_batch):
            
                patch_recon = sim_patches[random.randrange(len(sim_patches))].patch
                patch_recon = VF.resize(patch_recon, pos[-2:])
                
                s1 = 0
                s2 = slice(pos[0], pos[0] + pos[2])
                s3 = slice(pos[1], pos[1] + pos[3])
                
                wmix = mix# * get_window_2d(patch_recon.shape[-2:])
                recon[s1, s2, s3] = recon[s1, s2, s3] * (1. - wmix) + wmix * patch_recon 
                #recon_sum[s1, s2, s3] = recon_sum[s1, s2, s3] + wmix
            
    except KeyboardInterrupt:
        pass
    
    mask = recon_sum > 0
    #recon[mask] = recon[mask] / recon_sum[mask]
    return recon

original = PIL.Image.open(
    "/home/bergi/Pictures/csv-turing.png"
    #"/home/bergi/Pictures/__diverse/28580_1.jpg"
    #"/home/bergi/Pictures/__diverse/merkel_sarkozy_g8_regierungOnline_Kuehler_CMS_small_620.jpeg"
    #"/home/bergi/Pictures/__diverse/honecker.jpg"
    #"/home/bergi/Pictures/__diverse/plakat01.jpg"
    #"/home/bergi/Pictures/DWlZbQ5WsAQEzHT.jpg"
    #"/home/bergi/Pictures/there_is_no_threat.jpeg"
    #"/home/bergi/Pictures/diffusion/cthulhu-09.jpeg"
    #"/home/bergi/Pictures/__diverse/HourSlack-EarsBleed-2-72SM.jpg"
    #"/home/bergi/Pictures/__diverse/parental_advisory.jpg"
    #"/home/bergi/Pictures/__diverse/BP_Logo_01.png"
    #"/home/bergi/Pictures/__diverse/994066096.jpg"
    #"/home/bergi/Pictures/__diverse/EuropeAfricaNato.jpg"
    #"/home/bergi/Pictures/__diverse/nobluesky2.jpg"
    #"/home/bergi/Pictures/__diverse/Zeisswerk_Jena_um_1910.jpg"
    #"/home/bergi/Pictures/__diverse/1907_Panic.png"
    #"/home/bergi/Pictures/__diverse/schildkroe.bmp"
    #"/home/bergi/Pictures/__diverse/nk_hammer+brush+sickle.jpg"
    #"/home/bergi/Pictures/__diverse/Klaus_Naumann.jpg"
    #"/home/bergi/Pictures/__diverse/cheney2.jpg"
    #"/home/bergi/Pictures/__diverse/1233217389_800.jpg"
)
original = VF.to_tensor(original)
original = set_image_channels(original, 1)
original = VF.resize(original, [s * 1 for s in original.shape[-2:]])

img = reconstruct_image_random(
    original, index, 
    num_patches=1000,
    #noise=.002, 
    mix=.5,
    #random_first_patches=100,
)
display(VF.to_pil_image(img))
display(VF.to_pil_image(original))

# extend by patches

In [None]:
def extend_patch_image(image, index, iters=1, mix=1., num_rand=1):
    if iters > 1:
        image = extend_patch_image(image, index, iters-1, mix=mix)
        print(iters, end=", ")
    p_shape = index.db.patch_shape
    padding = (p_shape[-2] // 4, p_shape[-1] // 3)
    new_image = VF.pad(image, padding, padding_mode="reflect")
    def iter_slices():
        for x in range(0, new_image.shape[-1], p_shape[-1]):
            if x + p_shape[-1] <= new_image.shape[-1]:
                yield slice(0, p_shape[-2]), slice(x, x + p_shape[-1])
                yield slice(-p_shape[-2], None), slice(x, x + p_shape[-1])
        for y in range(0, new_image.shape[-2], p_shape[-2]):
            if y + p_shape[-2] <= new_image.shape[-2]:
                yield slice(y, y + p_shape[-2]), slice(0, p_shape[-1])
                yield slice(y, y + p_shape[-2]), slice(-p_shape[-1], None)
    
    source_patches = []
    for s1, s2 in iter_slices():
        source_patch = new_image[..., s1, s2]
        source_patches.append(source_patch.unsqueeze(0))
    
    sim_patches = index.similar_patches(torch.concat(source_patches), num_rand)
    for (s1, s2), sim_patch in zip(iter_slices(), sim_patches):#, desc=f"iteration {iters}"):
        sim_patch = sim_patch[random.randrange(len(sim_patch))]
        new_image[..., s1, s2] = new_image[..., s1, s2] * (1.-mix) + mix * sim_patch.patch
    #display(VF.to_pil_image(new_image))
    return new_image

ex = extend_patch_image(VF.resize(index.patches[57].patch, (64, 64)), index, 35, mix=.1, num_rand=1000)
VF.to_pil_image(ex)