In [None]:
import sys
sys.path.append("..")

import random
import math
import itertools
from copy import deepcopy
from io import BytesIO
from pathlib import Path
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Union, Dict

import PIL.Image
import PIL.ImageDraw
import plotly
import plotly.express as px
import plotly.graph_objects as go
plotly.io.templates.default = "plotly_dark"
import numpy as np
import pandas as pd
pd.options.plotting.backend = "plotly"

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset, RandomSampler
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid
from IPython.display import display

from src.datasets import *
from src.algo import GreedyLibrary
from src.util.image import *
from src.util import to_torch_device
from src.patchdb import PatchDB
from src.models.encoder import *
from scripts import datasets

In [None]:
def plot_samples(ds, count=30*30, nrow=30, batch_size=10, skip: int = 0):
    cur_count = 0
    batches = []
    try:
        for batch in DataLoader(ds, batch_size=batch_size):
            if isinstance(batch, (list, tuple)):
                batch = batch[0]
            if cur_count > skip:
                batches.append(batch)
            
            cur_count += batch.shape[0]
            if cur_count - skip >= count:
                break
                
    except KeyboardInterrupt:
        pass
    if not batches:
        return "nothin'"
    batch = torch.concat(batches)[:count]
        
    return VF.to_pil_image(make_grid(batch, nrow=nrow))
    

In [None]:
#plot_samples(datasets.all_patch_datasets((1, 32, 32)))
#plot_samples(datasets.photo_patch_dataset((1, 32, 32), "/home/bergi/Pictures/photos"))

In [None]:
SHAPE = [1, 32, 32]

def _scales(shape: Tuple[int, int]):
    size = min(shape)
    shape_size = min(SHAPE[-2:])
    scales = []
    for s in [
        #2.,
        1., 
        1./2., 
        1./5, 
        1./10, 1./20., 1./30.
    ]:
        if s * size >= shape_size and s * size < 10_000:
            scales.append(s)
    return scales
        
def _stride(shape: Tuple[int, int]):
    # print(shape)
    size = min(shape)
    shape_size = min(SHAPE[-2:])
    return max(1, min(shape_size, int(size / 10000)))
    #if size <= 512:
    #    return tuple(max(1, s // 5) for s in SHAPE[-2:])
    #return SHAPE[-2:]

ds_crop = make_image_patch_dataset(
    verbose_image=True,
    #path="~/Pictures/photos", 
    #path="../db/images/kali", 
    
    #path="../db/images/textures/topping.png", 
    
    #path="/home/bergi/Pictures/__diverse/bob-dobbs_raster_trans.png", 
    #path="/home/bergi/Pictures/__diverse/100_1600.jpg",
    #path="/home/bergi/Pictures/__diverse/100_1600.jpg",
    #path="/home/bergi/Pictures/__diverse/Pollock1.jpg",
    #path="/home/bergi/Pictures/hyperbolic_helicopter.jpg",
    #path="/home/bergi/Pictures/diffusion/cells-07.jpeg",
    path="/home/bergi/prog/python/github/magic-pen/results/pattern/3/organic-structures-fantasies-of-friendship-be-0000.jpg",
    
    recursive=True,
    shape=SHAPE,
    #max_images=1,
    max_image_bytes=1024 * 1024 * 1024 * 1,
    #scales=[1./12., 1./6, 1./3, 1.],
    #scales=[1./70., 1./40, 1./20, 1./10, 1./5, 1./3],
    scales=_scales,
    stride=5,#_stride,
    #interleave_images=4,
    #image_shuffle=5,
    #transforms=[lambda x: VF.resize(x, tuple(s // 6 for s in x.shape[-2:]))], stride=5,
    with_pos=True,
    with_filename=True,
    with_scale=True,
)
if 1:
    ds_unique = ImageFilterIterableDataset(ds_crop, ImageFilter(
        #min_std=0.03
        #max_std=.3,
        #min_compression_ratio=.8,
        #max_compression_ratio=.95,
    )) 
    ds_unique = DissimilarImageIterableDataset(
        ds_unique, verbose=True,
        max_similarity=.99, max_age=10_000, 
    )
    #ds_unique = DissimilarImageIterableDataset(
    #    ds_unique, verbose=True,
    #    max_similarity=0.6, max_age=10_000, 
    #    encoder="encoderconv:../models/encoderconv/encoder-1x32x32-128-photo-5.pt",
    #)

plot_samples(IterableShuffle(ds_unique, max_shuffle=0_000), count=30*30, skip=0)

# build PatchDB

In [None]:
!ls ../models/encoder2d/

In [None]:
encoder = EncoderConv2d.from_torch("../models/encoder2d/conv-1x32x32-128-all1.pt", device="cpu")
#encoder = BoltzmanEncoder2d.from_torch("../models/encoder2d/boltzman-1x32x32-128-photo-300M.pt", device="cpu")
encoder.device

In [None]:
!ls ../db/*.patchdb

In [None]:
db = PatchDB("../db/sand-1x32x32-convall1.patchdb", writeable=True)

db.clear()

count = 0
last_count = 0
try:
    with torch.no_grad():
        with db:
            for patches, positions, scales, filenames in DataLoader(ds_unique, batch_size=64):
                embeddings = encoder(patches)
                embeddings = embeddings / embeddings.norm(dim=1, keepdim=True)
                # embeddings = torch.round(embeddings, decimals=5)
                #display(VF.to_pil_image(make_grid(patches)))
                for embedding, pos, scale, filename in zip(embeddings, positions, scales, filenames):
                    rect = pos.tolist() + list(patches[0].shape[-2:])
                    rect = [int(r / scale) for r in rect]
                    db.add_patch(filename, rect, embedding)
                    count += 1

                if count - last_count > 50_000:
                    last_count = count
                    db.flush()
                    print(f"{db.size_bytes():,} bytes")
except KeyboardInterrupt:
    pass

f"{db.size_bytes():,} bytes"

In [None]:
import gzip
with gzip.open(db.filename, "rt") as fp:
    print(fp.readline())

In [None]:
math.prod(embedding.shape) * 4

In [None]:
import base64
a = embedding.detach().numpy()
print(len(base64.b64encode(a.data)))
len(json.dumps(a.tolist()))

In [None]:
ds_unique = ds_crop#DissimilarImageIterableDataset(ds_crop, max_similarity=.99, max_age=200_000, verbose=False)
ds_unique = DissimilarImageIterableDataset(ds_unique, max_similarity=.4, max_age=200_000, verbose=True, 
                                           encoder="encoderconv:../models/encoderconv/encoder-1x32x32-128-small-photo-3.pt")
plot_samples(ds_unique, count=2000)

In [None]:
32*32
