In [None]:
import sys
sys.path.append("..")

import random
import math
from io import BytesIO
from pathlib import Path
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Union

import PIL.Image
import PIL.ImageDraw
import plotly
import plotly.express as px
plotly.io.templates.default = "plotly_dark"

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset, RandomSampler
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid
from IPython.display import display

from src.datasets import *
from src.util.image import *
from src.util import *
from src.algo import *
from src.models.cnn import *

In [None]:
#SHAPE = (3, 128, 128)
#dataset = TensorDataset(torch.load(f"../datasets/kali-uint8-{SHAPE[-2]}x{SHAPE[-1]}.pt"))

if 0:
    SHAPE = (3, 64, 64)
    dataset = TensorDataset(torch.load(f"../datasets/pattern-{1}x{SHAPE[-2]}x{SHAPE[-1]}-uint.pt"))
    dataset = TransformDataset(dataset, dtype=torch.float, multiply=1./255., transforms=[lambda i: i.repeat(3, 1, 1)])

if 1:
    SHAPE = (3, 64, 64)
    dataset = TensorDataset(torch.load(f"../datasets/kali-uint8-{SHAPE[-2]}x{SHAPE[-1]}.pt"))
    dataset = TransformDataset(dataset, dtype=torch.float, multiply=1./255.)#, transforms=[lambda i: i.repeat(3, 1, 1)])

assert SHAPE == dataset[0][0].shape

In [None]:
if 0:
    CODE_SIZE = 512
    from scripts.train_from_dataset import EncoderMLP
    model = EncoderMLP(SHAPE, channels=[CODE_SIZE])
    model.load_state_dict(torch.load("../checkpoints/clip2/best.pt")["state_dict"])

if 1:
    CODE_SIZE = 512
    from scripts.train_from_dataset import EncoderMLP
    model = EncoderMLP(SHAPE, channels=[CODE_SIZE * 4, CODE_SIZE], hidden_act=nn.GELU())
    model.load_state_dict(torch.load("../checkpoints/clip3/best.pt")["state_dict"])
    
if 0:
    CODE_SIZE = 512
    from scripts.train_from_dataset import EncoderTrans
    model = EncoderTrans(SHAPE, code_size=CODE_SIZE)
    model.load_state_dict(torch.load("../checkpoints/clip4-tr/best.pt")["state_dict"])

model

In [None]:
class Pixels(nn.Module):
    def __init__(self, shape: Tuple[int, int, int]):
        super().__init__()
        self.shape = shape
        self.pixels = nn.Parameter(
            torch.rand(self.shape) * .1 + .3
        )

In [None]:
def feature_vis(
    pixels: nn.Module,
    model: nn.Module,
    target_feature: torch.Tensor,
    batch_size: int = 50,
    num_steps: int = 4000,
    learnrate: float = 0.03,
    random_offset: int = 4,
    random_rotation: float = 3,
    show: bool = False,
):
    optimizer = torch.optim.AdamW(pixels.parameters(), lr=learnrate)
    display_images = []
    for batch_idx in range(0, num_steps, batch_size):        
        pixel_batch = []
        for image_idx in range(batch_size):
            if batch_idx + image_idx >= num_steps:
                break
            image = pixels.pixels
            shape = image.shape
            
            #image = VF.resize(image, [SHAPE[-2] + 8, SHAPE[-1] + 8], VF.InterpolationMode.BICUBIC)
            if random_offset:
                image = VF.pad(image, random_offset, padding_mode="edge")
                image = VT.RandomCrop(shape[-2:])(image)
                
            if 1:
                center_x = int(shape[-1] * (torch.rand(1).item()))
                center_y = int(shape[-2] * (torch.rand(1).item()))
                image = VT.RandomRotation(
                    random_rotation, center=[center_x, center_y],
                    interpolation=VF.InterpolationMode.BILINEAR,
                )(image)
            
            pixel_batch.append(image.unsqueeze(0))
        if not pixel_batch:
            break
        pixel_batch = torch.cat(pixel_batch)
        
        features = model(pixel_batch)
        target_features = target_feature.unsqueeze(0).repeat(features.shape[0], 1)
        loss = F.l1_loss(features, target_features)
        
        pixels.zero_grad()
        model.zero_grad()

        loss.backward(retain_graph=True)
        optimizer.step()
        
        with torch.no_grad():
            if show:
                display_images.append(pixels.pixels.clamp(0, 1))
                if len(display_images) >= 8:
                    print(float(loss))
                    display(VF.to_pil_image(make_grid(display_images)))
                    display_images.clear()

    if len(display_images):
        display(VF.to_pil_image(make_grid(display_images)))

the_image = dataset[8][0]
display(VF.to_pil_image(the_image))
pixels = Pixels(SHAPE)
feature_vis(
    pixels, model, show=True, 
    #random_offset=20,
    target_feature=model(the_image.unsqueeze(0)).squeeze(0),
)

In [None]:
def feature_vis_many(
    model: nn.Module, 
    indices: List[int], 
    num_vis: int = 1,
    **kwargs,
):
    images = []
    for index in tqdm(indices):
        the_image = dataset[index][0]
        images.append(the_image)
        for i in range(num_vis):
            pixels = Pixels(SHAPE)
            feature_vis(
                pixels, model,
                target_feature=model(the_image.unsqueeze(0)).squeeze(0),
                **kwargs,
            )
            images.append(pixels.pixels.clamp(0, 1))
    
    grid_images = []
    for j in range(num_vis + 1):
        for i in range(len(indices)):
            grid_images.append(images[i * (num_vis + 1) + j])
            
    display(VF.to_pil_image(make_grid(grid_images, nrow=len(indices))))


feature_vis_many(model, list(range(15)))

In [None]:
feature_vis_many(model, list(range(15)))