In [None]:
import sys
sys.path.append("..")

import random
import math
import itertools
from copy import deepcopy
from io import BytesIO
from pathlib import Path
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Union, Dict

import PIL.Image
import PIL.ImageDraw
import plotly
import plotly.express as px
import plotly.graph_objects as go
plotly.io.templates.default = "plotly_dark"
import numpy as np
import pandas as pd
pd.options.plotting.backend = "plotly"

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset, RandomSampler
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid
from IPython.display import display

from src.datasets import *
from src.algo import GreedyLibrary
from src.util.image import *
from src.util import to_torch_device
from src.patchdb import PatchDB
from src.models.encoder import EncoderConv2d

In [None]:
def plot_samples(ds, count=30*30, nrow=30, batch_size=100):
    cur_count = 0
    batches = []
    try:
        for batch in DataLoader(ds, batch_size=batch_size):
            if isinstance(batch, (list, tuple)):
                batch = batch[0]
            batches.append(batch)
            cur_count += batch.shape[0]
            if cur_count >= count:
                break
    except KeyboardInterrupt:
        pass
    if not batches:
        return "nothin'"
    batch = torch.concat(batches)[:count]
        
    return VF.to_pil_image(make_grid(batch, nrow=nrow))
    

In [None]:
SHAPE = [1, 32, 32]

def _stride(shape: Tuple[int, int]):
    # print(shape)
    size = min(shape)
    if size <= 512:
        return 5
    
    return SHAPE[-2:]

ds_crop = make_image_patch_dataset(
    path="~/Pictures/photos", recursive=True,
    shape=SHAPE,
    scales=[1./12., 1./6, 1./3, 1.],
    stride=_stride,
    interleave_images=4,
    #image_shuffle=5,
    #transforms=[lambda x: VF.resize(x, tuple(s // 6 for s in x.shape[-2:]))], stride=5,
    with_pos=True,
    with_filename=True,
)
#ds_crop = IterableImageFilterDataset(ds_crop, ImageFilter(min_std=0.03)) 

plot_samples(ds_crop, count=30*60)

In [None]:
class DissimilarImageIterableDatasetLOCAL(IterableDataset):

    def __init__(
            self,
            dataset: Union[IterableDataset, Dataset],
            max_similarity: float = .9,
            max_age: Optional[int] = None,
            encoder: Union[str, torch.nn.Module, Callable[[torch.Tensor], torch.Tensor]] = "flatten",
            batch_size: int = 10,
            verbose: bool = False,
    ):
        self.dataset = dataset
        self.max_similarity = float(max_similarity)
        self.max_age = max_age
        self.encoder = encoder
        self.batch_size = int(batch_size)
        self.verbose = bool(verbose)
        self.features: Optional[torch.Tensor] = None
        self._ages: Optional[List[int]] = None
        self._age = 0

    def __iter__(self) -> Generator[Union[torch.Tensor, Tuple[torch.Tensor, ...]], None, None]:
        self.features = None
        self._ages = None
        self._age = 0
        self._num_passed = 0
        image_batch = []
        tuple_batch = []

        def _process(data):

            is_tuple = isinstance(data, (tuple, list))
            if is_tuple:
                image_batch.append(data[0])
                tuple_batch.append(data[1:])
            else:
                image_batch.append(data)
                tuple_batch.append(None)

            if len(image_batch) >= self.batch_size:
                yield from self._process_batch(image_batch, tuple_batch)
                image_batch.clear()
                tuple_batch.clear()

            self._age += 1

            self._drop_old_features()

        if not self.verbose:
            for data in self.dataset:
                yield from _process(data)

        else:
            from tqdm import tqdm

            try:
                total = len(self.dataset)
            except:
                total = None

            with tqdm(total=total) as progress:
                for data in self.dataset:
                    yield from _process(data)

                    progress.desc = (
                        f"filtering unsimilar images (features={self.features.shape[0] if self.features is not None else 0}"
                        f", passed={self._num_passed})"
                    )
                    progress.update(1)

        if image_batch:
            yield from self._process_batch(image_batch, tuple_batch)

    def _process_batch(self, image_batch, tuple_batch):
        image_batch = torch.concat([i.unsqueeze(0) for i in image_batch])
        feature_batch = self._encode(image_batch)

        # store first image feature
        if self.features is None:
            if tuple_batch[0]:
                yield image_batch[0], *tuple_batch[0]
            else:
                yield image_batch[0]

            self._num_passed += 1
            self.features = feature_batch[0].unsqueeze(0)
            self._ages = [self._age]
            image_batch = image_batch[1:]
            tuple_batch = tuple_batch[1:]
            feature_batch = feature_batch[1:]

        similarities = self._highest_similarities(feature_batch)

        features_to_add = None
        for image, tup, feature, similarity in zip(image_batch, tuple_batch, feature_batch, similarities):

            if features_to_add is not None:
                # get highest similarity with stored features and new features from batch
                similarity2 = self._highest_similarities(feature.unsqueeze(0), features_to_add)
                similarity = torch.max(similarity, similarity2)

            if similarity <= self.max_similarity:
                if tup:
                    yield image, *tup
                else:
                    yield image

                self._num_passed += 1
                self._ages.append(self._age)
                if features_to_add is None:
                    features_to_add = feature.unsqueeze(0)
                else:
                    features_to_add = torch.concat([features_to_add, feature.unsqueeze(0)])

        if features_to_add is not None:
            self.features = torch.concat([self.features, features_to_add])

    def _drop_old_features(self):
        if self.max_age is not None and self._ages:
            idx = None
            for i, age in enumerate(self._ages):
                if self._age - age > self.max_age:
                    idx = i + 1
                else:
                    break

            if idx is not None and idx < len(self._ages):
                self.features = self.features[idx:]
                self._ages = self._ages[idx:]

    def _highest_similarities(self, feature_batch: torch.Tensor, features: Optional[torch.Tensor] = None) -> torch.Tensor:
        sim = feature_batch @ (features if features is not None else self.features).T
        return sim.max(dim=1)[0]

    def _encode(self, image_batch: torch.Tensor) -> torch.Tensor:
        if isinstance(self.encoder, str):

            if self.encoder == "flatten":
                feature_batch = image_batch.flatten(1)

            elif self.encoder.startswith("clip"):
                from src.models.clip import ClipSingleton
                feature_batch = ClipSingleton.encode_image(image_batch)

            elif self.encoder.startswith("encoderconv:"):
                from src.models.encoder import EncoderConv2d
                if not hasattr(self, "_encoderconv"):
                    self._encoderconv = EncoderConv2d.from_torch(self.encoder.split(":", 1)[-1])
                feature_batch = self._encoderconv.encode_image(image_batch)

            else:
                raise ValueError(f"Unsupported encoder '{self.encoder}', expected 'flatten', 'clip'")

        elif callable(self.encoder):
            feature_batch = self.encoder(image_batch)
        else:
            raise ValueError(f"Unsupported encoder type {type(self.encoder).__name__}, expected str or callable")

        # feature_batch = feature_batch - feature_batch.mean(dim=1, keepdim=True)
        return feature_batch / torch.norm(feature_batch, dim=1, keepdim=True)
 
    
ds_unique = DissimilarImageIterableDataset(ds_crop, max_similarity=0.999, max_age=200_000, verbose=True)
plot_samples(ds_unique, count=1000)

# build PatchDB

In [None]:
encoder = EncoderConv2d.from_torch("../models/encoderconv/encoder-1x32x32-128-photo-5.pt", device="cpu")
encoder.device

In [None]:
db = PatchDB("../db/photos-1x32x32.patchdb")

db.clear()

count = 0
last_count = 0
try:
    with db:
        for patches, positions, filenames in DataLoader(ds_unique, batch_size=64):
            embeddings = encoder(patches)
            embeddings = embeddings / embeddings.norm(dim=1, keepdim=True)
            # embeddings = torch.round(embeddings, decimals=5)
            for embedding, pos, filename in zip(embeddings, positions, filenames):
                rect = pos.tolist() + list(patches[0].shape[-2:])
                db.add_patch(filename, rect, embedding)
                count += 1

            if count - last_count > 50_000:
                last_count = count
                db.flush()
                print(f"{db.size_bytes():,} bytes")
except KeyboardInterrupt:
    pass

In [None]:
import gzip
with gzip.open(db.filename, "rt") as fp:
    print(fp.readline())

In [None]:
math.prod(embedding.shape) * 4

In [None]:
import base64
a = embedding.detach().numpy()
print(len(base64.b64encode(a.data)))
len(json.dumps(a.tolist()))

In [None]:
ds_unique = ds_crop#DissimilarImageIterableDataset(ds_crop, max_similarity=.99, max_age=200_000, verbose=False)
ds_unique = DissimilarImageIterableDataset(ds_unique, max_similarity=.4, max_age=200_000, verbose=True, 
                                           encoder="encoderconv:../models/encoderconv/encoder-1x32x32-128-small-photo-3.pt")
plot_samples(ds_unique, count=2000)

In [None]:
32*32
