In [None]:
import sys
sys.path.append("..")

import random
import math
import time
import json
from io import BytesIO
from pathlib import Path
from collections import OrderedDict
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Union

import PIL.Image
import PIL.ImageDraw
import plotly
import plotly.express as px
plotly.io.templates.default = "plotly_dark"

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid
from IPython.display import display

from src.datasets import *
from src.util.image import *
from src.util import *
from src.algo import *
from src.models.decoder import *
from src.models.transform import *
from src.models.util import *
from experiments import datasets

def resize(img, scale: float, mode: VF.InterpolationMode = VF.InterpolationMode.NEAREST):
    return VF.resize(img, [max(1, int(s * scale)) for s in img.shape[-2:]], mode, antialias=False)

In [None]:
image = VF.to_tensor(PIL.Image.open(
    #"/home/bergi/Pictures/diffusion/cthulhu-06.jpeg"
    #"/home/bergi/Pictures/__diverse/HourSlack-EarsBleed-2-72SM.jpg"
    #"/home/bergi/Pictures/__diverse/2000_subgenius_bobco_primer.jpg"
    "/home/bergi/Pictures/__diverse/longtime.png"
).convert("L"))
VF.to_pil_image(image)


In [None]:
class ImageMatchingPersuit:

    def __init__(
            self,
            atom_shape: Tuple[int, int, int],
            size: int,
    ):
        self.atom_shape = atom_shape
        self.atoms = torch.rand(size, *atom_shape) * 0.001
        self.atom_counts = torch.zeros(size, dtype=torch.int64)
        
    def get_random_image_patches(self, image: torch.Tensor, count: int) -> torch.Tensor:
        patches = []
        positions = []
        for i in range(count):
            y = random.randrange(image.shape[-2] - self.atom_shape[-2])
            x = random.randrange(image.shape[-1] - self.atom_shape[-1])
            patches.append(image[None, :, y: y + self.atom_shape[-2], x: x + self.atom_shape[-1]])
            positions.append((y, x))
        
        return torch.concat(patches), positions

    def get_atom_matches(self, patches: torch.Tensor):
        """Returns [P, A]"""
        dots = patches.flatten(1) @ self.atoms.flatten(1).T
        return dots

    def subtract_random_atoms(self, image: torch.Tensor, count: int, weight: float = 1.):
        image = image + 0
        positions = []
        atom_ids = []
        for i in range(count):
            pos = (
                random.randrange(image.shape[-2] - self.atom_shape[-2]),
                random.randrange(image.shape[-1] - self.atom_shape[-1]),
            )
            atom_idx = random.randrange(self.atoms.shape[0])
            
            atom = self.atoms[atom_idx]
            image[:, pos[-2]: pos[-2] + self.atom_shape[-2], pos[-1]: pos[-1] + self.atom_shape[-1]] -= atom * weight

            positions.append(pos)
            atom_ids.append(atom_idx)
            
        return image, positions, atom_ids

    def random_trial_learning(
            self,
            image: torch.Tensor, 
            iterations: int = 100, 
            patches_per_trial: int = 10,
            learnrate=0.0001,
    ):
        trials = []
        for trial in range(iterations):
            sub_image, positions, atom_ids = self.subtract_random_atoms(image, patches_per_trial)

            loss = sub_image.mean()
            trials.append((loss, sub_image, positions, atom_ids))

        trials.sort(key=lambda t: t[0])
        
        for pos, atom_id in zip(trials[0][2], trials[0][3]):
            patch = image[:, pos[0]: pos[0] + self.atom_shape[-2], pos[1]: pos[1] + self.atom_shape[-1]]
            self.atoms[atom_id] += learnrate * (patch - self.atoms[atom_id])

        return trials[0]
        
    def digest_image(
            self,
            image: torch.Tensor,
            iterations: int = 100,
            patches_per_iter: int = 1000,
            learnrate=0.001,
    ):
        image = image - image.mean()  # copy image
        
        for iteration in tqdm(range(iterations)):

            patches, positions = self.get_random_image_patches(image, patches_per_iter)
            patches = patches - patches.flatten(1).mean(1)[:, None, None, None]
            matches = self.get_atom_matches(patches)
            #print(matches.min(), matches.max())
            matches /= matches.abs().max()
            #matches -= .1 * (1. - self.atom_counts / self.atom_counts.max())
            #print(patches.mean(dim=-3, keepdim=True))
            best_indices = matches.argsort(descending=True)

            for patch, pos, best_atom_ids in zip(patches, positions, best_indices):
                #patch = image[:, pos[0]: pos[0] + self.atom_shape[-2], pos[1]: pos[1] + self.atom_shape[-1]]
                #best_id = random.randrange(max(1, random.randrange(best_atom_ids.shape[0] // 2)))
                best_id = random.randrange(best_atom_ids.shape[0] // 2)
                atom_id = best_atom_ids[best_id]
                #atom_id = random.choice(best_atom_ids[:best_atom_ids.shape[0] // 2])
                self.atoms[atom_id] += learnrate * ((patch - patch.mean()) - self.atoms[atom_id])
                self.atom_counts[atom_id] += 1
                #atom_idx = best_indices[patch_idx][0]
            #    atom_idx = random.randrange(self.atoms.shape[0])
                
            #image = trials[0][1].clamp(0, 1)  # sub_image    

        return image
            

    def display_atoms(self, scale: float = 1., normalize: bool = True):
        nrows = int(math.ceil(math.sqrt(self.atoms.shape[0])))
        grid = make_grid(self.atoms, nrow=nrows, normalize=normalize)
        grid = resize(grid, scale)
        display(VF.to_pil_image(grid))
        
mp = ImageMatchingPersuit((1, 16, 16), 8*8)
mp.random_trial_learning(image)
dimage = mp.digest_image(image)
mp.display_atoms(scale=2)
#print(mp.atom_counts)
display(px.bar(mp.atom_counts))
#display(VF.to_pil_image(dimage))
#patches, posi = mp.get_random_image_patches(image, 16*10)
#VF.to_pil_image(make_grid(patches, nrow=16))
#mp.match_atoms(patches).shape

In [None]:
try:
    dimage = mp.digest_image(image, iterations=1000)
except KeyboardInterrupt:
    pass
mp.display_atoms(scale=2)
print(mp.atom_counts)
#display(VF.to_pil_image(dimage))
display(px.histogram(mp.atom_counts))
display(px.bar(mp.atom_counts))


In [None]:
px.imshow(mp.atoms.flatten(1) @ mp.atoms.flatten(1).T)

In [None]:
60_000 * 20