In [None]:
import sys
sys.path.append("..")

import random
import math
import itertools
from io import BytesIO
from pathlib import Path
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Union, Dict

import PIL.Image
import PIL.ImageDraw
import plotly
import plotly.express as px
import plotly.graph_objects as go
plotly.io.templates.default = "plotly_dark"
import numpy as np
import pandas as pd
pd.options.plotting.backend = "plotly"

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset, RandomSampler
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid
from IPython.display import display

from src.datasets import *
from src.algo import GreedyLibrary
from src.util.image import *

In [None]:
class LocalGreedyLibrary:
    
    def __init__(self, n_entries: int, shape: Iterable[int], mean: float = 0., std: float = 0.01):
        self.shape = tuple(shape)
        self.n_entries = n_entries
        self.entries = mean + std * torch.randn(n_entries, *self.shape, generator=torch.Generator().manual_seed(23))
        self.hits = [0] * n_entries
    
    @property
    def max_hits(self) -> int:
        return max(*self.hits) if self.hits else 0
    
    def copy(self) -> "CreedyLibrary":
        d = self.__class__(0, self.shape)
        d.n_entries = self.n_entries
        d.entries = self.entries[:]
        d.hits = self.hits.copy()
        return d
    
    def top_entry_index(self) -> Optional[int]:
        top_idx, top_hits = None, None
        for i, hits in enumerate(self.hits):
            if top_idx is None or hits > top_hits:
                top_idx, top_hits = i, hits
        return top_idx
    
    def entry_ranks(self, reverse: bool = False) -> List[int]:
        """
        Returns a list of ranks for each entry, 
        where rank means the index sorted by number of hits.
        """
        entry_ids = list(range(self.n_entries))
        entry_ids.sort(key=lambda i: self.hits[i], reverse=reverse)
        return [entry_ids.index(i) for i in range(self.n_entries)]
    
    def entry_hits(self, reverse: bool = False) -> Dict[int, int]:
        """
        Returns a dict of `entry-index` -> `number-of-hits`.
        
        Sorted by number of hits.
        """
        entry_ids = list(range(self.n_entries))
        entry_ids.sort(key=lambda i: self.hits[i], reverse=reverse)
        return {
            i: self.hits[i]
            for i in entry_ids
        }
    
    def sorted_entry_indices(
            self, 
            by: str = "hits", 
            reverse=True,
    ) -> List[int]:
        if not self.n_entries:
            return []
        entry_ids = list(range(self.n_entries))
        if self.n_entries < 2:
            return entry_ids
        
        if by == "hits":
            entry_ids.sort(key=lambda i: self.hits[i], reverse=reverse)
        
        elif by == "tsne":
            from sklearn.manifold import TSNE
            tsne = TSNE(1, perplexity=min(30, self.n_entries - 1))
            positions = tsne.fit_transform(self.entries.reshape(self.entries.shape[0], -1).numpy())
            entry_ids.sort(key=lambda i: positions[i], reverse=reverse)
        
        else:
            raise ValueError(f"Unsupported sort by '{by}'")

        return entry_ids

    def fit(self, batch: torch.Tensor, lr: float = 1., skip_top_entries: Union[bool, int] = False):
        best_entry_ids = self.best_entries_for(batch, skip_top_entries=skip_top_entries)
        #print(best_entry_ids.tolist())
        for i in range(batch.shape[0]):
            entry_id = best_entry_ids[i]
            weight = 1. / (1 + self.hits[entry_id])
            self.entries[entry_id] += lr * weight * (batch[i] - self.entries[entry_id])
            self.hits[entry_id] += 1
           
    def best_entries_for(self, batch: torch.Tensor, skip_top_entries: Union[bool, int] = False) -> torch.Tensor:
        assert batch.ndim == len(self.shape) + 1, f"Got {batch.shape}"
        assert batch.shape[1:] == self.shape, f"Got {batch.shape}"
        ones = tuple(1 for _ in self.shape)
        repeated_entries = self.entries.repeat(batch.shape[0], *ones)
        repeated_batch = batch.repeat(1, self.n_entries, *ones[1:]).reshape(-1, *self.shape)
        dist = (repeated_entries - repeated_batch).abs()
        while dist.ndim > 1:
            dist = dist.sum(1)
        dist = dist.reshape(batch.shape[0], -1)
        if not skip_top_entries:
            return torch.argmin(dist, 1)
        
        skip_top_entries = int(skip_top_entries)
        sorted_indices = torch.argsort(dist, 1)
        entry_ranks = self.entry_ranks(reverse=True)
        best_entries = []
        for indices in sorted_indices:
            idx = 0
            while idx + 1 < len(indices) and entry_ranks[indices[idx]] < skip_top_entries:
                idx += 1
            #print(idx)
            best_entries.append(indices[idx])
        return torch.Tensor(best_entries).to(torch.int32)
    
    def plot_entries(self, min_size: int = 300, with_hits: bool = True, sort_by: Optional[str] = None):
        if len(self.shape) == 1:
            entries = self.entries.reshape(-1, 1, 1, *self.shape)
        elif len(self.shape) == 2:
            entries = self.entries.reshape(-1, 1, *self.shape)
        elif len(self.shape) == 3:
            entries = self.entries
        else:
            raise RuntimeError(f"Can't plot entries with shape {self.shape} (ndim>3)")
        if entries.shape[0]:
            
            e_min, e_max = entries.min(), entries.max()
            if e_min != e_max:
                entries = (entries - e_min) / (e_max - e_min)
            
            if with_hits:
                max_hits = max(1, self.max_hits)
                entry_list = []
                for entry, hits in zip(entries, self.hits):
                    if entry.shape[0] == 1:
                        entry = entry.repeat(3, *(1 for _ in entry.shape[1:]))
                    elif entry.shape[0] == 3:
                        pass
                    else:
                        raise ValueError(f"Can't plot entries with {entry.shape[0]} channels")
                    
                    background = torch.Tensor([0, hits / max_hits, 0])
                    background = background.reshape(3, *((1,) * (len(entry.shape) - 1)))
                    background = background.repeat(1, *(s + 2 for s in entry.shape[1:]))
                    background[:, 1:-1, 1:-1] = entry
                    entry_list.append(background)#VF.pad(entry, 1, (0, hits / max_hits, 0)))
                entries = entry_list
            
            if sort_by:
                if not isinstance(entries, list):
                    entries = list(entries)
                entry_ids = self.sorted_entry_indices(by=sort_by, reverse=True)
                entries = [entries[i] for i in entry_ids]
            
            grid = make_grid(entries, nrow=max(1, int(np.sqrt(self.n_entries))), normalize=False)
            if grid.shape[-1] < min_size:
                grid = VF.resize(
                    grid, (
                        int(grid.shape[-2] * min_size / grid.shape[-1]),
                        min_size,
                    ), 
                    VF.InterpolationMode.NEAREST
                )
        else:
            grid = torch.zeros(1, min_size, min_size)
        return VF.to_pil_image(grid)
    
    def drop_unused(self):
        self.drop_entries(hits_lt=1)
        
    def drop_entries(self, hits_lt: Optional[int] = None):
        drop_idx = set()
        if hits_lt is not None:
            for i, hits in enumerate(self.hits):
                if hits <= hits_lt:
                    drop_idx.add(i)
        
        if drop_idx:
            entries = []
            hits = []
            for i, (entry, h) in enumerate(zip(self.entries, self.hits)):
                if i not in drop_idx:
                    entries.append(entry.unsqueeze(0))
                    hits.append(h)
            self.entries = torch.concat(entries) if entries else torch.Tensor()
            self.hits = hits
            self.n_entries = len(self.hits)
            
d = LocalGreedyLibrary(4, (3,)).copy()
for i in range(10):
    d.fit(torch.rand(1, *d.shape), skip_top_entries=2)
#d.fit(torch.Tensor([[-1, 1, 0], [0, 1, 0]]), skip_top_entries=1)
#d.fit(torch.Tensor([[1, 0, 0], [0, 1, 0]]), skip_top_entries=1)
#d.fit(torch.Tensor([[1, 0, 0], [0, 1, 0]]), skip_top_entries=1)
#d = CreedyDictionary(10, (3, 2))
#d.fit(torch.rand(5, 3, 2))
print("hits   ", d.hits)
print("hitmap ", d.entry_hits(reverse=True))
print("ranks  ", d.entry_ranks(reverse=True))
print("s hits ", d.sorted_entry_indices(reverse=True))
print("s tsne ", d.sorted_entry_indices(by="tsne", reverse=True))
display(d.plot_entries(sort_by="hits"))
d.drop_entries(hits_lt=1)
d.plot_entries()

In [None]:
d = GreedyLibrary(100, (7, 7))
for i in tqdm(range(1_00)):
    d.fit(torch.randn(1000, *d.shape), skip_top_entries=0, metric="corr")
d.drop_unused()
display(d.plot_entries(sort_by="hits"))
print(d.hits)

In [None]:
!ls -l ../datasets

In [None]:
ds_data = torch.load(f"../datasets/kali-uint8-{128}x{128}.pt")#[:30000]
#ds_data = torch.load(f"../datasets/ifs-1x{128}x{128}-uint8-1000x32.pt")#[:30000]
#ds_data = torch.load(f"../datasets/photos-64x64-bcr03.pt")#[:30000]

In [None]:
ds = TransformDataset(
    TensorDataset(ds_data),
    dtype=torch.float, multiply=1. / 255.,
    transforms=[
        #VT.CenterCrop(64),
        #VT.RandomCrop(SHAPE[-2:]),
        VT.Grayscale(),
    ],
    num_repeat=1,
)
for batch, in DataLoader(ds, batch_size=10000):
    ds_mean, ds_std = batch.mean(), batch.std()
    break
print(f"mean {ds_mean}, std {ds_std}")
VF.to_pil_image(ds[0][0])

In [None]:
d = GreedyLibrary(100, (1, 11, 11), std=.001*ds_std, mean=0.)
#d.entries = d.entries.cuda()

ds_small = TransformDataset(ds, transforms=[VT.RandomCrop(d.shape[-2:])], num_repeat=50)
#for batch, in DataLoader(ds_small, batch_size=d.n_entries):
#    d.entries = batch * torch.rand_like(batch)
#    break
try:
    count, last_count = 0, 0
    for batch, in tqdm(DataLoader(ds_small, batch_size=100, shuffle=True)):
        d.fit(batch, lr=1., skip_top_entries=0, zero_mean=True, metric="corr")
        count += batch.shape[0]
        if count > last_count + 5000:
            last_count = count
            top_idx = d.top_entry_index() 
            #print(f"{top_idx}: {d.hits[top_idx]}")
            #break
        #    display(d.plot_entries(min_size=100))
        
except KeyboardInterrupt:
    pass
d.drop_entries(hits_lt=10, inplace=True).sort_entries(inplace=True, reverse=True)
display(d.plot_entries(min_size=600, sort_by="hits"))
print(sorted(d.hits, reverse=True))
display(d.plot_entries(min_size=600, sort_by="tsne"))

In [None]:
d.entries.min(), d.entries.max()

In [None]:
VF.to_pil_image(make_grid(F.max_pool2d(d.convolve((ds[0][0] - ds[0][0].mean()).to("cuda")), 1).unsqueeze(1)).clamp(0, 1))

In [None]:
torch.save(d.entries, "../datasets/creedylib-1x11x11-signed-photos.pt")

In [None]:
VF.to_pil_image(signed_to_image(make_grid(torch.load("../datasets/creedylib-1x11x11-signed-kali.pt"), nrow=30)))

In [None]:
weights = torch.load("../datasets/creedylib-1x11x11-signed-kali.pt")
weights.shape

In [None]:
image = ds[0][0]
VF.to_pil_image(image)

In [None]:
#c = F.conv2d(image, weights, stride=4)
#c = F.conv2d(image - image.mean(), weights - weights.mean(), stride=4)
c = F.conv2d(image - image.mean(), d.entries, stride=4)
c.shape#.sum(dim=1).sum(dim=1)
display(VF.to_pil_image(make_grid(d.entries, nrow=16, normalize=True)))
img = make_grid(c.unsqueeze(1), nrow=16, normalize=False)
#img = signed_to_image(img)
display(VF.to_pil_image(img.clamp(0, 1)))

img = make_grid(F.max_pool2d(c, (5, 5)).unsqueeze(1), nrow=16).clamp(0, 1)
VF.to_pil_image(VF.resize(img, (img.shape[-2] * 4, img.shape[-1] * 4), VF.InterpolationMode.NEAREST))
#img = (img - img.min()) / (img.max() - img.min())
#VF.to_pil_image(VF.resize(img, (img.shape[-2] * 4, img.shape[-1] * 4), VF.InterpolationMode.NEAREST))

In [None]:
with torch.no_grad():
    d.entries = d.entries.cuda()
    def encode(batch):
        c = d.convolve(batch.cuda())
        c, _ = c.reshape(*c.shape[:2], -1).max(dim=-1)
        return c.cpu()

    features = []
    try:
        for batch, in tqdm(DataLoader(ds, batch_size=100)):
            features.append(encode(batch))
    except KeyboardInterrupt:
        pass
    features = torch.concat(features)
features.shape

In [None]:
from sklearn.manifold import TSNE
tsne = TSNE(1)
positions = tsne.fit_transform(features)
label_ids = list(range(features.shape[0]))
label_ids.sort(key=lambda i: positions[i])
VF.to_pil_image(make_grid([ds[i][0] for i in label_ids[:100]]))

In [None]:
#VF.to_pil_image(make_grid([ds[i][0] for i in label_ids], nrow=int(np.sqrt(len(label_ids))))).save("/home/bergi/Pictures/kali-1x128x128-sorted-by-tsne-of-creedylib.png")

In [None]:
def best_entries_for(
        self,
        batch: torch.Tensor,
        skip_top_entries: Union[bool, int] = False,
        metric: str = "corr",
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Returns the index of the best matching entry for each patch in the batch.

    :param batch: Tensor of N patches of shape matching the library's shape
    :param skip_top_entries: bool or int,
        Do not match the top N entries (1 for True), sorted by number of hits
    :return: tuple of
        - Tensor of int64: entry indices
        - Tensor of float: distances
    """
    assert batch.ndim == len(self.shape) + 1, f"Got {batch.shape}"
    assert batch.shape[1:] == self.shape, f"Got {batch.shape}"
    metric = metric.lower()
    
    ones = tuple(1 for _ in self.shape)

    repeated_entries = self.entries.repeat(batch.shape[0], *ones)
    repeated_batch = batch.to(self.device).repeat(1, self.n_entries, *ones[1:]).reshape(-1, *self.shape)
            
    if metric in ("l1", "mae"):
        dist = (repeated_entries - repeated_batch).abs().flatten(1).sum(1)
    elif metric in ("l2", "mse"):
        dist = (repeated_entries - repeated_batch).pow(2).flatten(1).sum(1).sqrt()
    elif metric.startswith("corr"):
        dist = -(repeated_entries - repeated_batch).flatten(1).sum(1)
    
    dist = dist.reshape(batch.shape[0], -1)

    if not skip_top_entries:
        indices = torch.argmin(dist, 1)
        return (
            indices,
            dist.flatten()[
                indices + torch.linspace(0, indices.shape[0] - 1, indices.shape[0]).to(torch.int64).to(indices.device) * self.n_entries
            ]
        )

    skip_top_entries = int(skip_top_entries)
    sorted_indices = torch.argsort(dist, 1)
    entry_ranks = self.entry_ranks(reverse=True)
    best_entries = []
    for indices in sorted_indices:
        idx = 0
        while idx + 1 < len(indices) and entry_ranks[indices[idx]] < skip_top_entries:
            idx += 1
        best_entries.append(indices[idx])

    indices = torch.Tensor(best_entries).to(torch.int64)
    return (
        indices,
        dist.flatten()[
            indices + torch.linspace(0, indices.shape[0] - 1, indices.shape[0]).to(torch.int64).to(indices.device) * self.n_entries
        ]
    )

lib = GreedyLibrary(10, (1, 2, 3))
lib.fit(torch.randn(100, 1, 2, 3))
best_entries_for(lib, torch.randn(2, 1, 2, 3), metric="l2")
#lib.plot_entries()