In [None]:
import sys
sys.path.append("..")

import random
import math
from io import BytesIO
from pathlib import Path
from collections import OrderedDict
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Union

import PIL.Image
import PIL.ImageDraw
import plotly
import plotly.express as px
plotly.io.templates.default = "plotly_dark"

import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset, RandomSampler
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid
import torchvision.models as VM 
from IPython.display import display

from src.util.image import *
from src.util import *
from src.util.embedding import *
from src.models.util import *
from src.algo import ca1

def resize(img, scale: float, mode: VF.InterpolationMode = VF.InterpolationMode.NEAREST):
    return VF.resize(img, [max(1, int(s * scale)) for s in img.shape[-2:]], mode, antialias=False)

In [None]:
if 1:
    vgg = VM.vgg19(weights=VM.VGG19_Weights.DEFAULT)
    print(f"params: {num_module_parameters(vgg):,}")
    print(vgg)

In [None]:
VF.to_pil_image(get_model_weight_images(vgg, normalize="each"))

In [None]:
if 0:
    incept = VM.inception_v3(weights=VM.Inception_V3_Weights.DEFAULT)
    print(f"params: {num_module_parameters(incept):,}")
    print(incept)

In [None]:
if 0:
    shufflenet = VM.shufflenet_v2_x2_0(weights=VM.ShuffleNet_V2_X2_0_Weights.DEFAULT)
    print(f"params: {num_module_parameters(shufflenet):,}")    
    print(shufflenet)

# find different example patches

In [None]:
patches = torch.load("../datasets/rpg-3x32x32-uint-test.pt")

In [None]:
from sklearn.decomposition import PCA
pca = PCA(64)
features = pca.fit_transform(patches.flatten(1))

In [None]:
from sklearn.cluster import KMeans
clusterer = KMeans(32, n_init=20)
labels = clusterer.fit_predict(features)
label_to_index = {}
for i, l in enumerate(labels):
    label_to_index.setdefault(l, []).append(i)
hist = sorted(np.histogram(labels, 32, (0, 31))[0])
px.bar(hist)

In [None]:
images = []
image_labels = []
num_samples = hist[0]
for label in range(32):
    for i in range(-num_samples, -1):
        idx = label_to_index[label][i]
        images.append(patches[idx])
        image_labels.append(idx)

display(VF.to_pil_image(resize(make_grid_labeled(images, labels=image_labels, nrow=num_samples), 2)))

In [None]:
SAMPLE_INDICES = [
    27, 4, 2, 7, 67, 153, 272, 187, 527, 124, 75, 33, 542, 35, 224, 344, 1644, 2363, 2172,
]

# similarity by feature

In [None]:
from experiments.datasets import rpg_tile_dataset_3x32x32

def sim_by_feature(
        model,
        #count: int = 1000,
):
    features = batch_call(
        lambda t: normalize_embedding(model(t.float()).flatten(1)), 
        patches, verbose=True)
    
    sim = features @ features.T
    indices = sim.argsort(1, descending=True)
    #return patches, sim, indices
    images = []
    image_labels = []
    for source_idx in SAMPLE_INDICES:
        sim_row = sim[source_idx]
        idx_row = sim_row.argsort(descending=True)
        for i in itertools.chain(range(32), range(-11, -1)):
            idx = idx_row[i]
            images.append(patches[idx])
            image_labels.append(int(sim_row[idx] * 100))
            
    display(VF.to_pil_image(make_grid_labeled(images, labels=image_labels, nrow=42)))


sim_by_feature(
    vgg.features[:6]
)

In [None]:
for i in range(1, 20):
    m = vgg.features[:i]
    print(m)
    sim_by_feature(m)

In [None]:
images = []
image_labels = []
for source_idx in SAMPLE_INDICES:
    sim_row = sim[source_idx]
    idx_row = sim_row.argsort(descending=True)
    for i in itertools.chain(range(32), range(-11, -1)):
        idx = idx_row[i]
        images.append(patches[idx])
        image_labels.append(int(sim_row[idx] * 100))
        
#images = [
#    patches[i] for i in indices[:32, :32].flatten()
#]
display(VF.to_pil_image(make_grid_labeled(images, labels=image_labels, nrow=42)))


# models

In [None]:
for layer in vgg.features:
    if isinstance(layer, nn.Conv2d):
        print(layer)

In [None]:
class VGGFeatures(nn.Module):
    def __init__(self, vgg: nn.Module):
        super().__init__()
        self.layers = vgg.features
        self.features = {}
        self._layer_map = {}
        for layer in self.layers:
            if isinstance(layer, nn.Conv2d):
                self._layer_map[layer] = f"conv_{len(self._layer_map) + 1}"
                self.features[self._layer_map[layer]] = None
                
    def forward(self, image):
        x = image
        for layer in vgg.features:
            x = layer(x)
            if layer in self._layer_map:
                self.features[self._layer_map[layer]] = x
        return x

    def features_concat(self, names: Optional[List[str]], gram: bool = True):
        features = []
        for name, f in self.features.items():
            if f is not None and (names is None or name in names):
                if gram:
                    f = f * f.permute(0, 1, 3, 2)
                features.append(f.flatten(-3))
        return torch.concat(features, dim=-1)
        

In [None]:
class PixelModel(nn.Module):
    def __init__(self, image: torch.Tensor):
        super().__init__()
        self.image = nn.Parameter(image)

    def forward(self):
        return self.image
        