In [None]:
import sys
sys.path.append("..")

import random
import math
import itertools
from io import BytesIO
from pathlib import Path
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Union

import PIL.Image
import PIL.ImageDraw
import plotly
import plotly.express as px
plotly.io.templates.default = "plotly_dark"
import numpy as np
import pandas as pd
import clip

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset, RandomSampler
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid
from IPython.display import display

from src.datasets import *
from src.util.image import *
from src.util import *
from src.algo import *
from src.models.cnn import *

In [None]:
from scripts.train_contrastive import EncoderConv, EncoderTrans

SHAPE = (3, 64, 64)
if 0:
    dataset = TensorDataset(torch.load(f"../datasets/pattern-1x{SHAPE[-2]}x{SHAPE[-1]}-uint.pt"))
    dataset = TransformDataset(dataset, dtype=torch.float, multiply=1./255., transforms=[lambda i: i.repeat(3, 1, 1)])
elif 1:
    dataset = TensorDataset(torch.load(f"../datasets/kali-uint8-{SHAPE[-2]}x{SHAPE[-1]}.pt"))
    dataset = TransformDataset(dataset, dtype=torch.float, multiply=1./255.)
else:
    dataset = TensorDataset(torch.load(f"../datasets/photos-{SHAPE[-2]}x{SHAPE[-1]}-bcr03.pt"))

assert SHAPE == dataset[0][0].shape
print(f"{len(dataset):,}")
VF.to_pil_image(dataset[0][0])

In [None]:
model, preproc = clip.load(
    "RN50"
    #"ViT-B/32"
)    

In [None]:
#preproc

# dataset -> features

In [None]:
with torch.no_grad():
    f = model.visual(torch.zeros(3, 3, 224, 224).half().cuda())
f.min(), f.max(), f

In [None]:
def encode_dataset(dataset, max_count=10_000):
    feature_list = []
    count = 0
    try:
        for image_batch in tqdm(DataLoader(dataset, batch_size=3)):
            image_batch = image_batch[0]
            image_batch = VF.resize(image_batch, (224, 224))
            features = model.visual(image_batch.half().to("cuda"))
            #features = features / torch.norm(features, dim=-1, keepdim=True)
            feature_list.append(features)
            count += features.shape[0]
            if count >= max_count:
                break
    except KeyboardInterrupt:
        pass
    return torch.cat(feature_list)


with torch.no_grad():
    features = encode_dataset(dataset).cpu().float()

features_n = features / features.norm(dim=-1, keepdim=True)
print("shape:", features.shape)
print("min/max:", features.min(), features.max(), features.mean())
VF.to_pil_image(features[:100])

In [None]:
px.line(features_n[:10].T)

In [None]:
VF.to_pil_image(make_grid([dataset[i][0] for i in range(10)], nrow=10))

In [None]:
px.line(features.std(dim=0))

In [None]:
px.scatter(
    x=features_n[:1000, 0] * torch.linspace(0.5, 1, 1000), 
    y=features_n[:1000, 1] * torch.linspace(0.5, 1, 1000), 
    width=400, height=400, 
)

# sort features

In [None]:
from sklearn.manifold import TSNE
reduction = TSNE(1, verbose=1)
positions = torch.Tensor(reduction.fit_transform(features_n)).reshape(-1)
positions

In [None]:
_, indices = torch.sort(positions)
images = [
    VF.resize(dataset[i][0], (32, 32), VF.InterpolationMode.NEAREST)
    for i in itertools.chain(indices[:500], indices[-500:])
]
VF.to_pil_image(make_grid(images, nrow=20))

# save full sorted image

In [None]:
images = [
    dataset[i][0]
    #VF.resize(dataset[i][0], (32, 32), VF.InterpolationMode.NEAREST)
    for i in indices
]
big_image = VF.to_pil_image(make_grid(images, nrow=64))

In [None]:
big_image.save(Path("~/Pictures/kali-of-tsne1d-of-clip-rn50.png").expanduser())

# plot similars

In [None]:
def get_similar_indices(feat, count: int = 10):
    #feat = feat / feat.norm(dim=-1, keepdim=True)
    dot = feat @ features_n.T
    _, indices = torch.sort(dot, descending=True)
    return indices[:, :count]

def plot_similar(indices: Iterable[int], count: int = 10):
    indices = list(indices)
    sim_indices = get_similar_indices(torch.cat([
        features_n[i].unsqueeze(0) for i in indices
    ]), count=count)
    images = [dataset[i][0] for i in sim_indices.T.reshape(-1)] 
    display(VF.to_pil_image(make_grid(images, nrow=len(indices))))
    
#get_similar_indices(features[0:2])
plot_similar(list(range(2000, 2020)), 50)

In [None]:
f = torch.load("../datasets/kali-uint8-64x64-CLIP.pt")

In [None]:
px.line(pd.DataFrame({
    "min": f.min(dim=0)[0],
    "max": f.max(dim=0)[0],
}))
#px.bar(f.max(dim=0)[0])