In [None]:
import sys
sys.path.append("..")

import random

import time
import math
import random
from io import BytesIO
from pathlib import Path
from typing import Optional, Callable, List, Tuple, Iterable, Generator

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
#from torchvision.transforms import v2
from torchvision.utils import make_grid

import PIL.Image
import PIL.ImageDraw
import plotly
import plotly.express as px
plotly.io.templates.default = "plotly_dark"
import numpy as np
import pandas as pd

import clip

from src.datasets import *
from src.util import *
from src.util.image import * 
from src.algo import *
from src.algo.wangtiles import *
from src.datasets.generative import *
from src.models.cnn import *
from src.models.transform import *
from src.util.embedding import *
from src.models.clip import ClipSingleton

In [None]:
image = VF.to_tensor(PIL.Image.open("/home/bergi/Pictures/clipig2/pixelart-cthulhu-02.png"))
#image = VF.to_tensor(PIL.Image.open("/home/bergi/Pictures/clipig2/cobblestone-tile.png"))
VF.to_pil_image(image)

In [None]:
@torch.no_grad()
def clip_segmentation(
    image: torch.Tensor,
    prompt: str,
    patch_size: Tuple[int, int],
    batch_size: int = 32,
):
    def _get_grid(patch_size: Tuple[int, int]):
        grid_shape = (image.shape[-2] // patch_size[-2], image.shape[-1] // patch_size[-1]) 
        num_total = math.prod(grid_shape)
        
        def _iter_batches():
            image_batch = []
            pos_batch = []
            for patch, pos in tqdm(iter_image_patches(image, shape=patch_size, with_pos=True), total=num_total, disable=True):
                if 0:
                    masked_image = torch.zeros_like(image)
                    masked_image[:, pos[0]:pos[0]+patch_size[0], pos[1]:pos[1]+patch_size[1]] = patch
                else:
                    masked_image = image + 0
                    masked_image[:, pos[0]:pos[0]+patch_size[0], pos[1]:pos[1]+patch_size[1]] = 0
                image_batch.append(masked_image.unsqueeze(0))
                pos_batch.append(pos)
                if len(image_batch) == batch_size:
                    yield torch.concat(image_batch), pos_batch
                    image_batch.clear()
                    pos_batch.clear()
            if image_batch:
                yield torch.concat(image_batch), pos_batch
    
        target_features = ClipSingleton.encode_text(prompt, normalize=True)
    
        dot_batches = [] 
        for image_batch, pos_batch in _iter_batches():
            # display(VF.to_pil_image(make_grid(image_batch)))
            # break
            feature_batch = ClipSingleton.encode_image(image_batch, normalize=True)
    
            dots = feature_batch @ target_features.T
            dot_batches.append(dots)
            
        grid = torch.concat(dot_batches).view(grid_shape).cpu()

        gridn = 1. - (grid - grid.min()) / (grid.max() - grid.min())
        return VF.resize(gridn.unsqueeze(0), image.shape[-2:], VF.InterpolationMode.NEAREST, antialias=False)

    grid_sum = torch.zeros_like(image)
    for size in tqdm(range(30, 50)):
        grid_sum += _get_grid((size, size))

    grid_sum /= 20
    return grid_sum
    #display(px.imshow(grid))

grid = clip_segmentation(image, "cobblestone", (50, 50))

In [None]:
#gridn = (grid - grid.min()) / (grid.max() - grid.min())
#print(gridn.shape, image.shape[-2:])
#gridn = VF.resize(gridn.unsqueeze(0), image.shape[-2:], VF.InterpolationMode.NEAREST, antialias=False)
image_g = image + 0
image_g *= grid
VF.to_pil_image(image_g)

In [None]:
@torch.no_grad()
def clip_segmentation_2(
    image: torch.Tensor,
    prompt: str,
    #patch_size: Tuple[int, int],
    trials: int = 100,
    batch_size: int = 32,
):
    rng = random.Random(23)
    def _iter_batches():
        image_batch = [image.unsqueeze(0)]
        pos_batch = [((0, 0), image.shape[-2:])]
        for i in tqdm(range(trials)):
            size = (rng.randrange(image.shape[-2] // 2), rng.randrange(image.shape[-1] // 2))
            size = (rng.randint(20, 50), rng.randint(20, 50))
            pos = (rng.randrange(image.shape[-2] - size[-2]), rng.randrange(image.shape[-1] - size[-1]))
                   
            masked_image = image + 0
            masked_image[:, pos[0]:pos[0]+size[0], pos[1]:pos[1]+size[1]] = 0
            image_batch.append(masked_image.unsqueeze(0))
            pos_batch.append((pos, size))
            if len(image_batch) == batch_size:
                yield torch.concat(image_batch), pos_batch
                image_batch.clear()
                pos_batch.clear()
        if image_batch:
            yield torch.concat(image_batch), pos_batch

    target_features = ClipSingleton.encode_text(prompt, normalize=True)

    grid = torch.zeros_like(image[0])
    grid_count = torch.zeros_like(image[0])
    for image_batch, pos_batch in _iter_batches():
        #display(VF.to_pil_image(make_grid(image_batch)))
        #break
        feature_batch = ClipSingleton.encode_image(image_batch, normalize=True)

        dots = feature_batch @ target_features.T

        for (pos, size), dot in zip(pos_batch, dots):
            grid[pos[0]:pos[0]+size[0], pos[1]:pos[1]+size[1]] += float(dot)
            grid_count[pos[0]:pos[0]+size[0], pos[1]:pos[1]+size[1]] += 1

    mask = grid_count > 0
    grid[mask] = grid[mask] / grid_count[mask]
    #return grid
    #print(grid.min(), grid.max(), grid)
    return 1. - (grid - grid.min()) / (grid.max() - grid.min())

grid = clip_segmentation_2(image, "repetitive", 1000)
image_g = image + 0
image_g *= grid
VF.to_pil_image(make_grid([image, image_g, grid.unsqueeze(0).repeat(image.shape[0], 1, 1)]))

In [None]:
image_g = image + 0
image_g *= grid
VF.to_pil_image(make_grid([image, image_g, grid.unsqueeze(0).repeat(image.shape[0], 1, 1)]))

In [None]:
display(px.imshow(grid))

In [None]:
image_g = image + 0
image_g *= grid
VF.to_pil_image(image_g)