# Overview

This method uses k-NN regression with distances of CLIP embeddings.  
***Note: This method does not use generated images, only prompts.***
![](https://www.googleapis.com/download/storage/v1/b/kaggle-forum-message-attachments/o/inbox%2F8163878%2Fbf35d4beb7867bc163f28abb647128ae%2FSDIP_method.PNG?generation=1682513923845611&alt=media)

In [1]:
# 앙상블 Ratio
ratio_ViT_224 = 0.5 # LB 0.55567 (opencli-l18 (900k))
CLIP_KNN_g_ratio = 0.4 # 0.55246
CLIP_KNN_ratio = 0.4 # 0.55313
ratio_ViT_384 = 0.2 #  LB 0.52823
ratio_interrogator = 0.2

# openclip224-l18 models

In [2]:
import numpy as np
import pickle
import pandas as pd
from tqdm import tqdm
import torch
from sklearn.model_selection import train_test_split
from glob import glob
from torch import nn
import torch.optim as optim
from torch.utils.data import DataLoader
from PIL import Image
from pathlib import Path
from transformers import AutoModel, AutoProcessor
from torchvision.transforms.functional import resize
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
from PIL import Image
import torch.nn.functional as F

clip_processor = AutoProcessor.from_pretrained("/kaggle/input/0511-clip224-l18-models/clip-vit-large-patch14_processor")
BATCHSIZE=128
SAVE_OPT_CKP = True
SAVE_MODEL_CKP = True
UNFREEZE_START = 18 # set it to lower number when significantly more samples are included.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn

run_name = f'clip224-l18'

In [3]:
torch.__version__

'1.13.0'

In [4]:
class CFG:
    model_path = '/kaggle/input/0515-clip224-l18-1/clip224-l18_trn07134_val07109.pt'    
    model_name= 'clip-vit-large-patch14'
    input_size = 224
    batch_size = 128

In [5]:
class InferenceIMGDataset:    
    def __init__(self, image_paths, clip_processor=clip_processor):
        self.images = image_paths        
        self.input_processor = clip_processor
        
    def __len__(self):
        return len(self.images)            

    def __getitem__(self, item):
        image = Image.open(self.images[item])
        image = self.input_processor(images=image)
        return image
    
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        clip = AutoModel.from_pretrained("/kaggle/input/0511-clip224-l18-models/clip-vit-large-patch14_model")
        self.vision = clip.vision_model
        self.fc = nn.Linear(1024, 384)

    def forward(self, x):
        out = self.vision(x)['pooler_output']
        return self.fc(out)

In [6]:
# optimizer load할 필요는 없음.
def infer(
    images,
    model_path,    
    input_size,
    batch_size
):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    dataset = InferenceIMGDataset(images, clip_processor=clip_processor)
    dataloader = DataLoader(
        dataset=dataset,
        shuffle=False,
        batch_size=batch_size,
        pin_memory=True,
        num_workers=2,
        drop_last=False
    )

    model = Net()
    state_dict = torch.load(model_path)
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()    
    
    preds = []
    for batch_images in tqdm(dataloader, leave=False):
        batch_images = batch_images['pixel_values'][0].to(device)

        with torch.no_grad():
            output = model(batch_images)
            output = F.normalize(output, p=2, dim=1)  # Normalize output
            
            preds.append(output.cpu().numpy())
                
    return np.vstack(preds)

In [7]:
images = list(Path('/kaggle/input/stable-diffusion-image-to-prompts/images').glob('*.png'))
imgIds = [i.stem for i in images]
EMBEDDING_LENGTH = 384
imgId_eId = [
    '_'.join(map(str, i)) for i in zip(
        np.repeat(imgIds, EMBEDDING_LENGTH),
        np.tile(range(EMBEDDING_LENGTH), len(imgIds)))]

embeddings4 = infer(images, CFG.model_path, CFG.input_size, CFG.batch_size) # shape: (7,384)
embeddings4 = embeddings4.flatten()

                                             

# Library

In [8]:
wheels_path = "/kaggle/input/clip-interrogator-wheels-x"
clip_interrogator_whl_path = f"{wheels_path}/clip_interrogator-0.4.3-py3-none-any.whl"
!pip install --no-index --find-links $wheels_path $clip_interrogator_whl_path -q

[0m

In [9]:
import os, glob, gc
import random
import numpy as np
import pandas as pd
import pickle
from pathlib import Path
from PIL import Image
from tqdm.notebook import tqdm
from sklearn.preprocessing import normalize
import torch
from torch import nn
from torchvision import transforms
import open_clip
import warnings
warnings.filterwarnings('ignore')
from matplotlib import pyplot as plt

# Config ViT-G-14

In [10]:
class CFG:
    clip_model_path = "/kaggle/input/0503-vit-g-14-models/ViT-g-14.pt"
    clip_preproc_pkl = "/kaggle/input/0503-vit-g-14-models/preprocess.pkl"
    input_size = 224
    batch_size = 64
    seed = 42
    knn_topk = 100
    knn_interval = 2500
    knn_dim = 6

In [11]:
def seed_everything(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    if torch.cuda.is_available(): 
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True

seed_everything(CFG.seed)

In [12]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#### <a id="top"></a>
# <div style="box-shadow: rgb(60, 121, 245) 0px 0px 0px 3px inset, rgb(255, 255, 255) 10px -10px 0px -3px, rgb(31, 193, 27) 10px -10px, rgb(255, 255, 255) 20px -20px 0px -3px, rgb(255, 217, 19) 20px -20px, rgb(255, 255, 255) 30px -30px 0px -3px, rgb(255, 156, 85) 30px -30px, rgb(255, 255, 255) 40px -40px 0px -3px, rgb(255, 85, 85) 40px -40px; padding:20px; margin-right: 40px; font-size:30px; font-family: consolas; text-align:center; display:fill; border-radius:15px; color:rgb(60, 121, 245);"><b> ref_embeddings </b></div>


In [13]:
FILES_OBJECTIVE_EMB = []
FILES_CLIP_EMB = []

In [14]:
### DiffusionDB-2M 
# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/pub-dffusion2m-prompt-15m/all-MiniLM-L6-v2/*.npy"))

FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/pub-dffusion2m-prompt-15m/vitg14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

16 16


In [15]:
### sd2-81k
# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/0502-df-sd2-81k/all-MiniLM-L6-v2/*.npy"))

FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/0502-df-sd2-81k/vitg14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

17 17


In [16]:
### /kaggle/input/0502-gpt-generated-prompts-30k
# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/0502-gpt-generated-prompts-30k/all-MiniLM-L6-v2/*.npy"))

FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/0502-gpt-generated-prompts-30k/vitg14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

18 18


In [17]:
### /kaggle/input/0502-hardcoded-42k
# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/0502-hardcoded-42k/all-MiniLM-L6-v2/*.npy"))

FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/0502-hardcoded-42k/vitg14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

19 19


In [18]:
### /kaggle/input/0502-prompt-900k
# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/0502-prompt-900k/all-MiniLM-L6-v2/*.npy"))

FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/0502-prompt-900k/vitg14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

21 21


In [19]:
### /kaggle/input/0502-prompt-chatgpt-25k
# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/0502-prompt-chatgpt-25k/all-MiniLM-L6-v2/*.npy"))

FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/0502-prompt-chatgpt-25k/vitg14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

22 22


In [20]:
### /kaggle/input/0502-prompt-magic-1m
# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/0502-prompt-magic-1m/all-MiniLM-L6-v2/*.npy"))

FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/0502-prompt-magic-1m/vitg14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

33 33


In [21]:
### /kaggle/input/0502-prompt-sd2-30k
# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/0502-prompt-sd2-30k/all-MiniLM-L6-v2/*.npy"))

FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/0502-prompt-sd2-30k/vitg14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

34 34


In [22]:
### /kaggle/input/0502-prompt-sd2-30k
# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/0502-laion2b-1/all-MiniLM-L6-v2/*.npy"))

FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/0502-laion2b-1/vitg14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

80 80


In [23]:
### /kaggle/input/0502-prompt-sd2-30k
# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/0502-laion2b-2/all-MiniLM-L6-v2/*.npy"))

FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/0502-laion2b-2/vitg14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

126 126


In [24]:
### /kaggle/input/0502-prompt-sd2-30k
# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/0502-laion2b-3/all-MiniLM-L6-v2/*.npy"))

FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/0502-laion2b-3/vitg14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

172 172


In [25]:
### /kaggle/input/0502-prompt-sd2-30k
# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/0502-laion2b-4/all-MiniLM-L6-v2/*.npy"))

FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/0502-laion2b-4/vitg14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

218 218


In [26]:
### /kaggle/input/0502-prompt-sd2-30k
# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/0502-laion2b-5/all-MiniLM-L6-v2/*.npy"))

FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/0502-laion2b-5/vitg14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

264 264


In [27]:
### /kaggle/input/0502-prompt-sd2-30k
# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/0502-laion2b-6/all-MiniLM-L6-v2/*.npy"))

FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/0502-laion2b-6/vitg14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

310 310


In [28]:
### /kaggle/input/0502-prompt-sd2-30k
# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/0502-laion2b-7/all-MiniLM-L6-v2/*.npy"))

FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/0502-laion2b-7/vitg14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

356 356


In [29]:
### /kaggle/input/0502-prompt-sd2-30k
# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/0502-laion2b-8/all-MiniLM-L6-v2/*.npy"))

FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/0502-laion2b-8/vitg14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

402 402


In [30]:
### /kaggle/input/0502-prompt-sd2-30k
# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/0502-laion2b-9/all-MiniLM-L6-v2/*.npy"))

FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/0502-laion2b-9/vitg14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

493 493


In [31]:
### /kaggle/input/0502-prompt-sd2-30k
# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/0502-laion2b-10/all-MiniLM-L6-v2/*.npy"))

FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/0502-laion2b-10/vitg14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

539 539


In [32]:
### /kaggle/input/0502-prompt-sd2-30k
# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/gpt-prompts-50k/all-MiniLM-L6-v2/*.npy"))

FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/gpt-prompts-50k/vitg14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

540 540


In [33]:
### /kaggle/input/0502-prompt-sd2-30k
# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/0502-openprompts-1400k/all-MiniLM-L6-v2/*.npy"))

FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/0502-openprompts-1400k/vitg14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

555 555


#### <a id="top"></a>
# <div style="box-shadow: rgb(60, 121, 245) 0px 0px 0px 3px inset, rgb(255, 255, 255) 10px -10px 0px -3px, rgb(31, 193, 27) 10px -10px, rgb(255, 255, 255) 20px -20px 0px -3px, rgb(255, 217, 19) 20px -20px, rgb(255, 255, 255) 30px -30px 0px -3px, rgb(255, 156, 85) 30px -30px, rgb(255, 255, 255) 40px -40px 0px -3px, rgb(255, 85, 85) 40px -40px; padding:20px; margin-right: 40px; font-size:30px; font-family: consolas; text-align:center; display:fill; border-radius:15px; color:rgb(60, 121, 245);"><b> Generate CLIP vision embeddings </b></div>

In [34]:
test_images = list(Path('/kaggle/input/stable-diffusion-image-to-prompts/images').glob('*.png'))
print(len(test_images))

7


In [35]:
# load CLIP vision model (Laion/ViTH14)
clip_model = torch.jit.load(CFG.clip_model_path).cuda()
clip_model = clip_model.cuda().eval().half();

In [36]:
# load transforms from pickele
with open(CFG.clip_preproc_pkl, "rb") as fp:
    saved_preprocess = pickle.load(fp)
saved_preprocess

Compose(
    Resize(size=224, interpolation=bicubic, max_size=None, antialias=None)
    CenterCrop(size=(224, 224))
    <function _convert_to_rgb at 0x781b0c64e0e0>
    ToTensor()
    Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
)

In [37]:
# Inference with CLIP vision encoder
test_vision_embeddings = []
for test_image in tqdm(test_images):
    image = Image.open( test_image )
    prep = saved_preprocess(image).unsqueeze(0).to(device)
    embedding = clip_model(prep.half())
    test_vision_embeddings.append( embedding.detach().cpu().numpy() )
test_vision_embeddings = np.concatenate(test_vision_embeddings).astype(np.float16)
print(test_vision_embeddings.shape)

  0%|          | 0/7 [00:00<?, ?it/s]

(7, 1024)


In [38]:
del clip_model, prep, image, embedding
gc.collect()

2798

In [39]:
# Reduce the number of files as debugging.
if test_vision_embeddings.shape[0] == 7:   # 7 public test images
    FILES_OBJECTIVE_EMB = FILES_OBJECTIVE_EMB[:5]
    FILES_CLIP_EMB = FILES_CLIP_EMB[:5]
    print(f"!!! Debug mode !!!")
    print(f"len(FILES_OBJECTIVE_EMB)={len(FILES_OBJECTIVE_EMB)}")
    print(f"len(FILES_CLIP_EMB)={len(FILES_CLIP_EMB)}")

!!! Debug mode !!!
len(FILES_OBJECTIVE_EMB)=5
len(FILES_CLIP_EMB)=5


#### <a id="top"></a>
# <div style="box-shadow: rgb(60, 121, 245) 0px 0px 0px 3px inset, rgb(255, 255, 255) 10px -10px 0px -3px, rgb(31, 193, 27) 10px -10px, rgb(255, 255, 255) 20px -20px 0px -3px, rgb(255, 217, 19) 20px -20px, rgb(255, 255, 255) 30px -30px 0px -3px, rgb(255, 156, 85) 30px -30px, rgb(255, 255, 255) 40px -40px 0px -3px, rgb(255, 85, 85) 40px -40px; padding:20px; margin-right: 40px; font-size:30px; font-family: consolas; text-align:center; display:fill; border-radius:15px; color:rgb(60, 121, 245);"><b> k-NN Regression (CUDA) </b></div>


In [40]:
def predict_local_knn(
    ref_x_embeddings, test_x_embeddings, 
    n_neighbors=CFG.knn_topk,
    interval=CFG.knn_interval,
    distance_dim=CFG.knn_dim,
    coef=1.0, # a coef to prevent from overflow
):
    # convert to tensor and move to GPU
    ref_x_embeddings = torch.from_numpy(ref_x_embeddings).half().to('cuda')
    ref_x_embeddings /= ref_x_embeddings.norm(dim=-1, keepdim=True)
    
    n_iter = (test_x_embeddings.shape[0] + interval - 1) // interval
    
    dist_topk_store = []
    idxs_topk_store = []
    weights_store = []
    delta = 0.0001
    for i in range(n_iter):
        with torch.no_grad():
            batch_test_embeddings = torch.from_numpy(
                test_x_embeddings[i*interval:(i+1)*interval, :].copy()
            ).half().to('cuda')
            batch_test_embeddings /= batch_test_embeddings.norm(dim=-1, keepdim=True)
            
            # calc distance matrix
            dists = 1 - torch.mm(batch_test_embeddings, ref_x_embeddings.T) # dists.shape=[N_test, N_ref]
            
            # get topk indecies and distance
            dist_topk, idxs_topk = torch.topk(dists, n_neighbors, largest=False, dim=-1)
            dist_topk = dist_topk.double()
            
            # calc weights from distance
            weights = 1/(dist_topk**distance_dim+delta)*coef
            weights[ dist_topk < 0 ] = delta
            
            dist_topk_store.append( dist_topk.cpu().numpy().copy() )
            idxs_topk_store.append( idxs_topk.cpu().numpy().copy() )
            weights_store.append( weights.cpu().numpy().copy() )
            
            del dists, weights, dist_topk, idxs_topk, batch_test_embeddings
            torch.cuda.empty_cache()

    del ref_x_embeddings
    torch.cuda.empty_cache()
    return np.concatenate(dist_topk_store), np.concatenate(idxs_topk_store), np.concatenate(weights_store)

for i_file, file_clip_emb in enumerate(tqdm(FILES_CLIP_EMB)):
    # Local k-NN (for each CLIP embeddings file VS CLIP vision embeddings of test images)
    ref_clip_embeddings = np.load(file_clip_emb).astype(np.float16)
    local_dists, local_emb_indecies, local_weights = predict_local_knn(
        ref_clip_embeddings, test_vision_embeddings,
        n_neighbors=CFG.knn_topk, interval=CFG.knn_interval, distance_dim=CFG.knn_dim,
        coef=0.001
    )
    local_files = np.zeros(local_dists.shape, dtype=np.int32) + i_file
    
    # merge local k-NN into global k-NN
    if i_file == 0:
        global_files = local_files
        global_dists = local_dists
        global_emb_indecies = local_emb_indecies
        global_weights = local_weights
    else:
        global_files = np.concatenate([global_files, local_files], axis=-1)
        global_dists = np.concatenate([global_dists, local_dists], axis=-1)
        global_emb_indecies = np.concatenate([global_emb_indecies, local_emb_indecies], axis=-1)
        global_weights = np.concatenate([global_weights, local_weights], axis=-1)

        unsorted_min_indices = np.argpartition(global_dists, CFG.knn_topk, axis=1)[:, :CFG.knn_topk]

        global_files = np.vstack( [ global_files[i, unsorted_min_indices[i,:]] for i in range(unsorted_min_indices.shape[0]) ])
        global_dists = np.vstack( [ global_dists[i, unsorted_min_indices[i,:]] for i in range(unsorted_min_indices.shape[0]) ])
        global_emb_indecies = np.vstack( [ global_emb_indecies[i, unsorted_min_indices[i,:]] for i in range(unsorted_min_indices.shape[0]) ])
        global_weights = np.vstack( [ global_weights[i, unsorted_min_indices[i,:]] for i in range(unsorted_min_indices.shape[0]) ])
    
    gc.collect()

  0%|          | 0/5 [00:00<?, ?it/s]

In [41]:
# for i_file, file_clip_emb in enumerate(tqdm(FILES_CLIP_EMB)):
#     # Local k-NN (for each CLIP embeddings file VS CLIP vision embeddings of test images)
#     ref_clip_embeddings = np.load(file_clip_emb).astype(np.float16)
#     with torch.no_grad():
#         local_dists, local_emb_indecies, local_weights = predict_local_knn(
#             ref_clip_embeddings, test_vision_embeddings,
#             n_neighbors=CFG.knn_topk, interval=CFG.knn_interval, distance_dim=CFG.knn_dim,
#             coef=0.001
#         )
#     local_files = np.zeros(local_dists.shape, dtype=np.int32) + i_file
    
#     # merge local k-NN into global k-NN
#     if i_file == 0:
#         global_files = local_files
#         global_dists = local_dists
#         global_emb_indecies = local_emb_indecies
#         global_weights = local_weights
#     else:
#         global_files = np.concatenate([global_files, local_files], axis=-1)
#         global_dists = np.concatenate([global_dists, local_dists], axis=-1)
#         global_emb_indecies = np.concatenate([global_emb_indecies, local_emb_indecies], axis=-1)
#         global_weights = np.concatenate([global_weights, local_weights], axis=-1)

#         unsorted_min_indices = np.argpartition(global_dists, CFG.knn_topk, axis=1)[:, :CFG.knn_topk]

#         global_files = np.vstack( [ global_files[i, unsorted_min_indices[i,:]] for i in range(unsorted_min_indices.shape[0]) ])
#         global_dists = np.vstack( [ global_dists[i, unsorted_min_indices[i,:]] for i in range(unsorted_min_indices.shape[0]) ])
#         global_emb_indecies = np.vstack( [ global_emb_indecies[i, unsorted_min_indices[i,:]] for i in range(unsorted_min_indices.shape[0]) ])
#         global_weights = np.vstack( [ global_weights[i, unsorted_min_indices[i,:]] for i in range(unsorted_min_indices.shape[0]) ])
    
#     gc.collect()

In [42]:
df_knn = pd.DataFrame()
df_knn["file"] = global_files.flatten()
df_knn["file"] = df_knn['file'].apply(lambda x: FILES_OBJECTIVE_EMB[x])
df_knn["dist"] = global_dists.flatten()
df_knn["emb_index"] = global_emb_indecies.flatten()
df_knn["test_index"] = np.array([ [val]*CFG.knn_topk for val in range(test_vision_embeddings.shape[0])]).flatten()
df_knn["weight"] = global_weights.flatten()
df_knn

Unnamed: 0,file,dist,emb_index,test_index,weight
0,/kaggle/input/pub-dffusion2m-prompt-15m/all-Mi...,0.727539,16638,0,0.006739
1,/kaggle/input/pub-dffusion2m-prompt-15m/all-Mi...,0.750977,63326,0,0.005572
2,/kaggle/input/pub-dffusion2m-prompt-15m/all-Mi...,0.725586,53214,0,0.006848
3,/kaggle/input/pub-dffusion2m-prompt-15m/all-Mi...,0.728516,87680,0,0.006685
4,/kaggle/input/pub-dffusion2m-prompt-15m/all-Mi...,0.740234,91462,0,0.006075
...,...,...,...,...,...
695,/kaggle/input/pub-dffusion2m-prompt-15m/all-Mi...,0.696289,75445,6,0.008768
696,/kaggle/input/pub-dffusion2m-prompt-15m/all-Mi...,0.696289,20565,6,0.008768
697,/kaggle/input/pub-dffusion2m-prompt-15m/all-Mi...,0.696289,13350,6,0.008768
698,/kaggle/input/pub-dffusion2m-prompt-15m/all-Mi...,0.697266,19061,6,0.008694


In [43]:
gc.collect()

39

In [44]:
import dask.dataframe as dd
from dask.diagnostics import ProgressBar

# Convert the pandas DataFrame to a Dask DataFrame
ddf = dd.from_pandas(df_knn, npartitions=10)

# k-NN regression
test_prompt_embeddings = np.zeros((test_vision_embeddings.shape[0], 384))

def add_weighted_embeddings(group):
    objective_emb_file = group.name
    ref_objective_embeddings = np.load(objective_emb_file).astype(np.float16)
    for i, r in group.iterrows():
        test_prompt_embeddings[int(r.test_index), :] += r.weight * ref_objective_embeddings[int(r.emb_index), :]
    return group

with ProgressBar():
    result = ddf.groupby("file").apply(add_weighted_embeddings, meta=ddf).compute(scheduler='threads')

[########################################] | 100% Completed |  9.5s


In [45]:
# # k-NN regression
# test_prompt_embeddings = np.zeros( (test_vision_embeddings.shape[0], 384))
# for (objective_emb_file, gdf) in tqdm(df_knn.groupby("file")):
#     ref_objective_embeddings = np.load(objective_emb_file).astype(np.float16) 
#     for _, r in gdf.iterrows():
#         test_prompt_embeddings[int(r.test_index), :] += r.weight * ref_objective_embeddings[int(r.emb_index), :]

In [46]:
# L2 norm
BS=1000
num = test_prompt_embeddings.shape[0] // BS
if test_prompt_embeddings.shape[0] % BS != 0:
    num+=1
for i in range(num):
    embeddings = test_prompt_embeddings[i*BS:(i+1)*BS, :]
    embeddings = embeddings / ( np.abs(embeddings).max(axis=-1, keepdims=True) + 0.0000001)
    embeddings = normalize( embeddings )
    test_prompt_embeddings[i*BS:(i+1)*BS, :] = embeddings
gc.collect()

289

In [47]:
test_prompt_embeddings_g_14 = test_prompt_embeddings.flatten()

# Config (ViT-h-14)

In [48]:
class CFG:
    clip_model_path = "/kaggle/input/laion-vit-h-14-model/ViT-H-14_laion2b_s32b_b79k.pt"
    clip_preproc_pkl = "/kaggle/input/laion-vit-h-14-model/preprocess.pkl"
    input_size = 224
    batch_size = 64
    seed = 42
    knn_topk = 100
    knn_interval = 2500
    knn_dim = 6

In [49]:
def seed_everything(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    if torch.cuda.is_available(): 
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True

seed_everything(CFG.seed)

In [50]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#### <a id="top"></a>
# <div style="box-shadow: rgb(60, 121, 245) 0px 0px 0px 3px inset, rgb(255, 255, 255) 10px -10px 0px -3px, rgb(31, 193, 27) 10px -10px, rgb(255, 255, 255) 20px -20px 0px -3px, rgb(255, 217, 19) 20px -20px, rgb(255, 255, 255) 30px -30px 0px -3px, rgb(255, 156, 85) 30px -30px, rgb(255, 255, 255) 40px -40px 0px -3px, rgb(255, 85, 85) 40px -40px; padding:20px; margin-right: 40px; font-size:30px; font-family: consolas; text-align:center; display:fill; border-radius:15px; color:rgb(60, 121, 245);"><b> ref_embeddings </b></div>


In [51]:
FILES_OBJECTIVE_EMB = []
FILES_CLIP_EMB = []

In [52]:
### DiffusionDB-14M (https://huggingface.co/datasets/poloclub/diffusiondb)
# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-diffusiondb-14m-part1/all_minilm_l6_v2/*.npy"))
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-diffusiondb-14m-part2/all_minilm_l6_v2/*.npy"))
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-diffusiondb-14m-part3/all_minilm_l6_v2/*.npy"))

# CLIP text tembeddings
FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-diffusiondb-14m-part1/vith14/*.npy"))
FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-diffusiondb-14m-part2/vith14/*.npy"))
FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-diffusiondb-14m-part3/vith14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

140 140


In [53]:
# MSCOCO 2017(train data) (https://cocodataset.org/#download)
# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-mscoco/all_minilm_l6_v2/*.npy"))

# CLIP text tembeddings
FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-mscoco/vith14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

146 146


In [54]:
# dataset80k (https://www.kaggle.com/competitions/stable-diffusion-image-to-prompts/discussion/390674)
# Objective embeddings
FILES_OBJECTIVE_EMB.append("/kaggle/input/pub-embeddings-dataset80k/all_minilm_l6_v2/prompt_embeddings_allminilm_001.npy")

# CLIP text tembeddings
FILES_CLIP_EMB.append("/kaggle/input/pub-embeddings-dataset80k/vith14/prompt_embedding_vith14_001.npy")
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

147 147


In [55]:
# Dataset30k (https://www.kaggle.com/competitions/stable-diffusion-image-to-prompts/discussion/391500)
# Objective embeddings
FILES_OBJECTIVE_EMB.append("/kaggle/input/pub-embeddings-dataset30k/all_minilm_l6_v2/prompt_embeddings_allminilm_001.npy")

# CLIP text tembeddings
FILES_CLIP_EMB.append("/kaggle/input/pub-embeddings-dataset30k/vith14/prompt_embedding_vith14_001.npy")
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

148 148


In [56]:
# Dataset900k (https://www.kaggle.com/competitions/stable-diffusion-image-to-prompts/discussion/399699)
# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-dataset900k/all_minilm_l6_v2/*.npy"))

# CLIP text tembeddings
FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-dataset900k/vith14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

158 158


In [57]:
# Conceptual Captions
# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-conceptual-captions/all_minilm_l6_v2/*.npy"))

# CLIP text tembeddings
FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-conceptual-captions/vith14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

192 192


In [58]:
# SD2GPT2 (https://www.kaggle.com/datasets/xiaozhouwang/sd2gpt2)
# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-sd2gpt2/all_minilm_l6_v2/*.npy"))

# CLIP text tembeddings
FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-sd2gpt2/vith14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

193 193


In [59]:
# SD2Hardcode (https://www.kaggle.com/datasets/xiaozhouwang/sd2hardcode)
# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-sd2hardcode/all_minilm_l6_v2/*.npy"))

# CLIP text tembeddings
FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-sd2hardcode/vith14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

194 194


In [60]:
# ChatGPT (https://www.kaggle.com/competitions/stable-diffusion-image-to-prompts/discussion/402146)
# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-chatgpt/all_minilm_l6_v2/*.npy"))

# CLIP text tembeddings
FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-chatgpt/vith14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))


195 195


In [61]:
# Laion2B-en (https://huggingface.co/datasets/laion/laion2B-en)
# (part0000-part0050 of 2000)

# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0000-0004/all_minilm_l6_v2/*.npy"))
FILES_OBJECTIVE_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0005-0009/all_minilm_l6_v2/*.npy"))
FILES_OBJECTIVE_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0010-0014/all_minilm_l6_v2/*.npy"))
FILES_OBJECTIVE_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0015-0019/all_minilm_l6_v2/*.npy"))
FILES_OBJECTIVE_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0020-0024/all_minilm_l6_v2/*.npy"))
FILES_OBJECTIVE_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0025-0029/all_minilm_l6_v2/*.npy"))
FILES_OBJECTIVE_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0030-0034/all_minilm_l6_v2/*.npy"))
FILES_OBJECTIVE_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0035-0039/all_minilm_l6_v2/*.npy"))
FILES_OBJECTIVE_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0040-0044/all_minilm_l6_v2/*.npy"))
FILES_OBJECTIVE_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0045-0049/all_minilm_l6_v2/*.npy"))

# CLIP text tembeddings
FILES_CLIP_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0000-0004/vith14/*.npy"))
FILES_CLIP_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0005-0009/vith14/*.npy"))
FILES_CLIP_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0010-0014/vith14/*.npy"))
FILES_CLIP_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0015-0019/vith14/*.npy"))
FILES_CLIP_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0020-0024/vith14/*.npy"))
FILES_CLIP_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0025-0029/vith14/*.npy"))
FILES_CLIP_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0030-0034/vith14/*.npy"))
FILES_CLIP_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0035-0039/vith14/*.npy"))
FILES_CLIP_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0040-0044/vith14/*.npy"))
FILES_CLIP_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0045-0049/vith14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

665 665


#### <a id="top"></a>
# <div style="box-shadow: rgb(60, 121, 245) 0px 0px 0px 3px inset, rgb(255, 255, 255) 10px -10px 0px -3px, rgb(31, 193, 27) 10px -10px, rgb(255, 255, 255) 20px -20px 0px -3px, rgb(255, 217, 19) 20px -20px, rgb(255, 255, 255) 30px -30px 0px -3px, rgb(255, 156, 85) 30px -30px, rgb(255, 255, 255) 40px -40px 0px -3px, rgb(255, 85, 85) 40px -40px; padding:20px; margin-right: 40px; font-size:30px; font-family: consolas; text-align:center; display:fill; border-radius:15px; color:rgb(60, 121, 245);"><b> Generate CLIP vision embeddings </b></div>

In [62]:
test_images = list(Path('/kaggle/input/stable-diffusion-image-to-prompts/images').glob('*.png'))
print(len(test_images))

7


In [63]:
# load CLIP vision model (Laion/ViTH14)
clip_model = torch.jit.load(CFG.clip_model_path).cuda()
clip_model = clip_model.cuda().eval().half();

In [64]:
# load transforms from pickele
with open(CFG.clip_preproc_pkl, "rb") as fp:
    saved_preprocess = pickle.load(fp)
saved_preprocess

Compose(
    Resize(size=224, interpolation=bicubic, max_size=None, antialias=None)
    CenterCrop(size=(224, 224))
    <function _convert_to_rgb at 0x781b0c64e0e0>
    ToTensor()
    Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
)

In [65]:
# Inference with CLIP vision encoder
test_vision_embeddings = []
for test_image in tqdm(test_images):
    image = Image.open( test_image )
    prep = saved_preprocess(image).unsqueeze(0).to(device)
    embedding = clip_model(prep.half())
    test_vision_embeddings.append( embedding.detach().cpu().numpy() )
test_vision_embeddings = np.concatenate(test_vision_embeddings).astype(np.float16)
print(test_vision_embeddings.shape)

  0%|          | 0/7 [00:00<?, ?it/s]

(7, 1024)


In [66]:
del clip_model, prep, image, embedding
gc.collect()

135

In [67]:
# Reduce the number of files as debugging.
if test_vision_embeddings.shape[0] == 7:   # 7 public test images
    FILES_OBJECTIVE_EMB = FILES_OBJECTIVE_EMB[:5]
    FILES_CLIP_EMB = FILES_CLIP_EMB[:5]
    print(f"!!! Debug mode !!!")
    print(f"len(FILES_OBJECTIVE_EMB)={len(FILES_OBJECTIVE_EMB)}")
    print(f"len(FILES_CLIP_EMB)={len(FILES_CLIP_EMB)}")

!!! Debug mode !!!
len(FILES_OBJECTIVE_EMB)=5
len(FILES_CLIP_EMB)=5


#### <a id="top"></a>
# <div style="box-shadow: rgb(60, 121, 245) 0px 0px 0px 3px inset, rgb(255, 255, 255) 10px -10px 0px -3px, rgb(31, 193, 27) 10px -10px, rgb(255, 255, 255) 20px -20px 0px -3px, rgb(255, 217, 19) 20px -20px, rgb(255, 255, 255) 30px -30px 0px -3px, rgb(255, 156, 85) 30px -30px, rgb(255, 255, 255) 40px -40px 0px -3px, rgb(255, 85, 85) 40px -40px; padding:20px; margin-right: 40px; font-size:30px; font-family: consolas; text-align:center; display:fill; border-radius:15px; color:rgb(60, 121, 245);"><b> k-NN Regression (CUDA) </b></div>


In [68]:
def predict_local_knn(
    ref_x_embeddings, test_x_embeddings, 
    n_neighbors=CFG.knn_topk,
    interval=CFG.knn_interval,
    distance_dim=CFG.knn_dim,
    coef=1.0, # a coef to prevent from overflow
):
    # convert to tensor and move to GPU
    ref_x_embeddings = torch.from_numpy(ref_x_embeddings).half().to('cuda')
    ref_x_embeddings /= ref_x_embeddings.norm(dim=-1, keepdim=True)
    
    n_iter = (test_x_embeddings.shape[0] + interval - 1) // interval
    
    dist_topk_store = []
    idxs_topk_store = []
    weights_store = []
    delta = 0.0001
    for i in range(n_iter):
        with torch.no_grad():
            batch_test_embeddings = torch.from_numpy(
                test_x_embeddings[i*interval:(i+1)*interval, :].copy()
            ).half().to('cuda')
            batch_test_embeddings /= batch_test_embeddings.norm(dim=-1, keepdim=True)
            
            # calc distance matrix
            dists = 1 - torch.mm(batch_test_embeddings, ref_x_embeddings.T) # dists.shape=[N_test, N_ref]
            
            # get topk indecies and distance
            dist_topk, idxs_topk = torch.topk(dists, n_neighbors, largest=False, dim=-1)
            dist_topk = dist_topk.double()
            
            # calc weights from distance
            weights = 1/(dist_topk**distance_dim+delta)*coef
            weights[ dist_topk < 0 ] = delta
            
            dist_topk_store.append( dist_topk.cpu().numpy().copy() )
            idxs_topk_store.append( idxs_topk.cpu().numpy().copy() )
            weights_store.append( weights.cpu().numpy().copy() )
            
            del dists, weights, dist_topk, idxs_topk, batch_test_embeddings
            torch.cuda.empty_cache()

    del ref_x_embeddings
    torch.cuda.empty_cache()
    return np.concatenate(dist_topk_store), np.concatenate(idxs_topk_store), np.concatenate(weights_store)

for i_file, file_clip_emb in enumerate(tqdm(FILES_CLIP_EMB)):
    # Local k-NN (for each CLIP embeddings file VS CLIP vision embeddings of test images)
    ref_clip_embeddings = np.load(file_clip_emb).astype(np.float16)
    local_dists, local_emb_indecies, local_weights = predict_local_knn(
        ref_clip_embeddings, test_vision_embeddings,
        n_neighbors=CFG.knn_topk, interval=CFG.knn_interval, distance_dim=CFG.knn_dim,
        coef=0.001
    )
    local_files = np.zeros(local_dists.shape, dtype=np.int32) + i_file
    
    # merge local k-NN into global k-NN
    if i_file == 0:
        global_files = local_files
        global_dists = local_dists
        global_emb_indecies = local_emb_indecies
        global_weights = local_weights
    else:
        global_files = np.concatenate([global_files, local_files], axis=-1)
        global_dists = np.concatenate([global_dists, local_dists], axis=-1)
        global_emb_indecies = np.concatenate([global_emb_indecies, local_emb_indecies], axis=-1)
        global_weights = np.concatenate([global_weights, local_weights], axis=-1)

        unsorted_min_indices = np.argpartition(global_dists, CFG.knn_topk, axis=1)[:, :CFG.knn_topk]

        global_files = np.vstack( [ global_files[i, unsorted_min_indices[i,:]] for i in range(unsorted_min_indices.shape[0]) ])
        global_dists = np.vstack( [ global_dists[i, unsorted_min_indices[i,:]] for i in range(unsorted_min_indices.shape[0]) ])
        global_emb_indecies = np.vstack( [ global_emb_indecies[i, unsorted_min_indices[i,:]] for i in range(unsorted_min_indices.shape[0]) ])
        global_weights = np.vstack( [ global_weights[i, unsorted_min_indices[i,:]] for i in range(unsorted_min_indices.shape[0]) ])
    
    gc.collect()

  0%|          | 0/5 [00:00<?, ?it/s]

In [69]:
# for i_file, file_clip_emb in enumerate(tqdm(FILES_CLIP_EMB)):
#     # Local k-NN (for each CLIP embeddings file VS CLIP vision embeddings of test images)
#     ref_clip_embeddings = np.load(file_clip_emb).astype(np.float16)
#     with torch.no_grad():
#         local_dists, local_emb_indecies, local_weights = predict_local_knn(
#             ref_clip_embeddings, test_vision_embeddings,
#             n_neighbors=CFG.knn_topk, interval=CFG.knn_interval, distance_dim=CFG.knn_dim,
#             coef=0.001
#         )
#     local_files = np.zeros(local_dists.shape, dtype=np.int32) + i_file
    
#     # merge local k-NN into global k-NN
#     if i_file == 0:
#         global_files = local_files
#         global_dists = local_dists
#         global_emb_indecies = local_emb_indecies
#         global_weights = local_weights
#     else:
#         global_files = np.concatenate([global_files, local_files], axis=-1)
#         global_dists = np.concatenate([global_dists, local_dists], axis=-1)
#         global_emb_indecies = np.concatenate([global_emb_indecies, local_emb_indecies], axis=-1)
#         global_weights = np.concatenate([global_weights, local_weights], axis=-1)

#         unsorted_min_indices = np.argpartition(global_dists, CFG.knn_topk, axis=1)[:, :CFG.knn_topk]

#         global_files = np.vstack( [ global_files[i, unsorted_min_indices[i,:]] for i in range(unsorted_min_indices.shape[0]) ])
#         global_dists = np.vstack( [ global_dists[i, unsorted_min_indices[i,:]] for i in range(unsorted_min_indices.shape[0]) ])
#         global_emb_indecies = np.vstack( [ global_emb_indecies[i, unsorted_min_indices[i,:]] for i in range(unsorted_min_indices.shape[0]) ])
#         global_weights = np.vstack( [ global_weights[i, unsorted_min_indices[i,:]] for i in range(unsorted_min_indices.shape[0]) ])
    
#     gc.collect()

In [70]:
df_knn = pd.DataFrame()
df_knn["file"] = global_files.flatten()
df_knn["file"] = df_knn['file'].apply(lambda x: FILES_OBJECTIVE_EMB[x])
df_knn["dist"] = global_dists.flatten()
df_knn["emb_index"] = global_emb_indecies.flatten()
df_knn["test_index"] = np.array([ [val]*CFG.knn_topk for val in range(test_vision_embeddings.shape[0])]).flatten()
df_knn["weight"] = global_weights.flatten()
df_knn

Unnamed: 0,file,dist,emb_index,test_index,weight
0,/kaggle/input/pub-embeddings-diffusiondb-14m-p...,0.716797,35943,0,0.007367
1,/kaggle/input/pub-embeddings-diffusiondb-14m-p...,0.726562,60517,0,0.006793
2,/kaggle/input/pub-embeddings-diffusiondb-14m-p...,0.709961,98421,0,0.007803
3,/kaggle/input/pub-embeddings-diffusiondb-14m-p...,0.709961,98420,0,0.007803
4,/kaggle/input/pub-embeddings-diffusiondb-14m-p...,0.709961,98419,0,0.007803
...,...,...,...,...,...
695,/kaggle/input/pub-embeddings-diffusiondb-14m-p...,0.644531,22745,6,0.013929
696,/kaggle/input/pub-embeddings-diffusiondb-14m-p...,0.670898,23578,6,0.010954
697,/kaggle/input/pub-embeddings-diffusiondb-14m-p...,0.670898,23582,6,0.010954
698,/kaggle/input/pub-embeddings-diffusiondb-14m-p...,0.670898,23575,6,0.010954


In [71]:
gc.collect()

39

In [72]:
import dask.dataframe as dd
from dask.diagnostics import ProgressBar

# Convert the pandas DataFrame to a Dask DataFrame
ddf = dd.from_pandas(df_knn, npartitions=10)

# k-NN regression
test_prompt_embeddings = np.zeros((test_vision_embeddings.shape[0], 384))

def add_weighted_embeddings(group):
    objective_emb_file = group.name
    ref_objective_embeddings = np.load(objective_emb_file).astype(np.float16)
    for i, r in group.iterrows():
        test_prompt_embeddings[int(r.test_index), :] += r.weight * ref_objective_embeddings[int(r.emb_index), :]
    return group

with ProgressBar():
    result = ddf.groupby("file").apply(add_weighted_embeddings, meta=ddf).compute(scheduler='threads')

[########################################] | 100% Completed |  2.8s


In [73]:
# # k-NN regression
# test_prompt_embeddings = np.zeros( (test_vision_embeddings.shape[0], 384))
# for (objective_emb_file, gdf) in tqdm(df_knn.groupby("file")):
#     ref_objective_embeddings = np.load(objective_emb_file).astype(np.float16) 
#     for _, r in gdf.iterrows():
#         test_prompt_embeddings[int(r.test_index), :] += r.weight * ref_objective_embeddings[int(r.emb_index), :]

In [74]:
# L2 norm
BS=1000
num = test_prompt_embeddings.shape[0] // BS
if test_prompt_embeddings.shape[0] % BS != 0:
    num+=1
for i in range(num):
    embeddings = test_prompt_embeddings[i*BS:(i+1)*BS, :]
    embeddings = embeddings / ( np.abs(embeddings).max(axis=-1, keepdims=True) + 0.0000001)
    embeddings = normalize( embeddings )
    test_prompt_embeddings[i*BS:(i+1)*BS, :] = embeddings
    
gc.collect()

314

In [75]:
# KNN Regression Result
test_prompt_embeddings = test_prompt_embeddings.flatten()

# CLIP + blip? (CLIP inter)

In [76]:
wheels_path = "/kaggle/input/clip-interrogator-wheels-x"
clip_interrogator_whl_path = f"{wheels_path}/clip_interrogator-0.4.3-py3-none-any.whl"

In [77]:
!pip install --no-index --find-links $wheels_path $clip_interrogator_whl_path -q

[0m

In [78]:
import inspect
import importlib

from blip.models import blip
from clip_interrogator import clip_interrogator

In [79]:
# replace tokenizer path to prevent downloading
blip_path = inspect.getfile(blip)

fin = open(blip_path, "rt")
data = fin.read()
data = data.replace(
    "BertTokenizer.from_pretrained('bert-base-uncased')", 
    "BertTokenizer.from_pretrained('/kaggle/input/clip-interrogator-models-x/bert-base-uncased')"
)
fin.close()

fin = open(blip_path, "wt")
fin.write(data)
fin.close()

# reload module
importlib.reload(blip)

<module 'blip.models.blip' from '/opt/conda/lib/python3.7/site-packages/blip/models/blip.py'>

In [80]:
# fix clip_interrogator bug
clip_interrogator_path = inspect.getfile(clip_interrogator.Interrogator)

fin = open(clip_interrogator_path, "rt")
data = fin.read()
data = data.replace(
    'open_clip.get_tokenizer(clip_model_name)', 
    'open_clip.get_tokenizer(config.clip_model_name.split("/", 2)[0])'
)
fin.close()

fin = open(clip_interrogator_path, "wt")
fin.write(data)
fin.close()

# reload module
importlib.reload(clip_interrogator)

<module 'clip_interrogator.clip_interrogator' from '/opt/conda/lib/python3.7/site-packages/clip_interrogator/clip_interrogator.py'>

In [81]:
import os
import sys
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt 

import numpy as np
import pandas as pd
import torch
import open_clip

sys.path.append('../input/sentence-transformers-222/sentence-transformers')
from sentence_transformers import SentenceTransformer, models

comp_path = Path('/kaggle/input/stable-diffusion-image-to-prompts/')

In [82]:
# CLIP + BLIP (clip-interrogator)
class CFG:
    device = "cuda"
    seed = 42
    embedding_length = 384
    sentence_model_path = "/kaggle/input/sentence-transformers-222/all-MiniLM-L6-v2"
    blip_model_path = "/kaggle/input/clip-interrogator-models-x/model_large_caption.pth"
    ci_clip_model_name = "ViT-H-14/laion2b_s32b_b79k"
    clip_model_name = "ViT-H-14"
    clip_model_path = "/kaggle/input/clip-interrogator-models-x/CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin"
    cache_path = "/kaggle/input/clip-interrogator-models-x"

In [83]:
df_submission = pd.read_csv(comp_path / 'sample_submission.csv', index_col='imgId_eId')
df_submission.head()

Unnamed: 0_level_0,val
imgId_eId,Unnamed: 1_level_1
20057f34d_0,0.018848
20057f34d_1,0.03019
20057f34d_2,0.072792
20057f34d_3,-0.000673
20057f34d_4,0.016774


In [84]:
images = os.listdir(comp_path / 'images')
imgIds = [i.split('.')[0] for i in images]

eIds = list(range(CFG.embedding_length))

imgId_eId = [
    '_'.join(map(str, i)) for i in zip(
        np.repeat(imgIds, CFG.embedding_length),
        np.tile(range(CFG.embedding_length), len(imgIds))
    )
]

assert sorted(imgId_eId) == sorted(df_submission.index)

In [85]:
st_model = SentenceTransformer(CFG.sentence_model_path)

In [86]:
model_config = clip_interrogator.Config(clip_model_name=CFG.ci_clip_model_name)
model_config.cache_path = CFG.cache_path

In [87]:
configs_path = os.path.join(os.path.dirname(os.path.dirname(blip_path)), 'configs')
med_config = os.path.join(configs_path, 'med_config.json')
blip_model = blip.blip_decoder(
    pretrained=CFG.blip_model_path,
    image_size=model_config.blip_image_eval_size, 
    vit=model_config.blip_model_type, 
    med_config=med_config
)
blip_model.eval()
blip_model = blip_model.to(model_config.device)
model_config.blip_model = blip_model

load checkpoint from /kaggle/input/clip-interrogator-models-x/model_large_caption.pth


In [88]:
clip_model = open_clip.create_model(CFG.clip_model_name, precision='fp16' if model_config.device == 'cuda' else 'fp32')
open_clip.load_checkpoint(clip_model, CFG.clip_model_path)
clip_model.to(model_config.device).eval()
model_config.clip_model = clip_model

In [89]:
clip_preprocess = open_clip.image_transform(
    clip_model.visual.image_size,
    is_train = False,
    mean = getattr(clip_model.visual, 'image_mean', None),
    std = getattr(clip_model.visual, 'image_std', None),
)
model_config.clip_preprocess = clip_preprocess

In [90]:
ci = clip_interrogator.Interrogator(model_config)

Loaded CLIP model and data in 2.34 seconds.


In [91]:
cos = torch.nn.CosineSimilarity(dim=1)

mediums_features_array = torch.stack([torch.from_numpy(t) for t in ci.mediums.embeds]).to(ci.device)
movements_features_array = torch.stack([torch.from_numpy(t) for t in ci.movements.embeds]).to(ci.device)
flavors_features_array = torch.stack([torch.from_numpy(t) for t in ci.flavors.embeds]).to(ci.device)

In [92]:
def interrogate(image: Image) -> str:
    caption = ci.generate_caption(image)
    image_features = ci.image_to_features(image)
    
    medium = [ci.mediums.labels[i] for i in cos(image_features, mediums_features_array).topk(1).indices][0]
    movement = [ci.movements.labels[i] for i in cos(image_features, movements_features_array).topk(1).indices][0]
    flaves = ", ".join([ci.flavors.labels[i] for i in cos(image_features, flavors_features_array).topk(3).indices])

    if caption.startswith(medium):
        prompt = f"{caption}, {movement},  {flaves}"
    else:
        prompt = f"{caption}, {medium}, {movement}, {flaves}"

    return clip_interrogator._truncate_to_fit(prompt, ci.tokenize)

In [93]:
prompts = []

images_path = "../input/stable-diffusion-image-to-prompts/images/"
for image_name in images:
    img = Image.open(images_path + image_name).convert("RGB")

    generated = interrogate(img)
    
    prompts.append(generated)

In [94]:
def add_text_limiters(text: str) -> str:
    return " ".join([
        word + "\n" if i % 15 == 0 else word 
        for i, word in enumerate(text.split(" "), start=1)
    ])

def plot_image(image: np.ndarray, original_prompt: str, generated_prompt: str) -> None:
    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    plt.annotate(
        "Original prompt:\n" + add_text_limiters(original_prompt) + "\n\nGenerated prompt:\n" + add_text_limiters(generated_prompt), 
        xy=(1.05, 0.5), xycoords='axes fraction', ha='left', va='center', 
        fontsize=16, rotation=0, color="#104a6e"
    )

In [95]:
# clip-interrogator
prompt_embeddings = st_model.encode(prompts).flatten()

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

# ViT

In [96]:
import numpy as np
import pandas as pd
from pathlib import Path
from PIL import Image
from tqdm.notebook import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import timm
from sklearn.preprocessing import normalize
import torch.nn.functional as F

In [97]:
class CFG:
    model_path = '/kaggle/input/0506-cv-large-384-06687/vit_large_patch16_384_0.7348_0.6687_epoch4.pth' #  LB 0.52823
    model_name = 'vit_large_patch16_384'
    input_size = 384
    batch_size = 64

In [98]:
class DiffusionTestDataset(Dataset):
    def __init__(self, images, transform):
        self.images = images
        self.transform = transform
    
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(self.images[idx])
        image = self.transform(image)
        return image

In [99]:
def predict(
    images,
    model_path,
    model_name,
    input_size,
    batch_size
):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    transform = transforms.Compose([
        transforms.Resize(input_size),
        transforms.RandomHorizontalFlip(p=0.5),
        #transforms.RandomRotation(degrees=10),

        #transforms.RandomVerticalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    dataset = DiffusionTestDataset(images, transform)
    dataloader = DataLoader(
        dataset=dataset,
        shuffle=False,
        batch_size=batch_size,
        pin_memory=True,
        num_workers=2,
        drop_last=False
    )

    model = timm.create_model(
        model_name,
        pretrained=False,
        num_classes=384
    )
    state_dict = torch.load(model_path)
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()
    
    tta_preds = None
    for _ in range(3):
        preds = []
        for X in tqdm(dataloader, leave=False):
            X = X.to(device)

            with torch.no_grad():
                X_out = model(X)
                X_out = F.normalize(X_out, p=2, dim=1)  # Normalize X_out
                
                preds.append(X_out.cpu().numpy())
                
        if tta_preds is None:
            tta_preds = np.vstack(preds).flatten()
        else:
            tta_preds += np.vstack(preds).flatten()
    
    return tta_preds / 3

In [100]:
images = list(Path('/kaggle/input/stable-diffusion-image-to-prompts/images').glob('*.png'))
embeddings3 = predict(images, CFG.model_path, CFG.model_name, CFG.input_size, CFG.batch_size)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

# Ensemble

In [101]:
import torch.nn.functional as F

def normalize(embeds):
    embeds = embeds.reshape(-1, 384)
    return (embeds / np.linalg.norm(embeds, ord=2, axis=1, keepdims=True)).reshape(-1)

In [102]:
test_prompt_embeddings_preds = (CLIP_KNN_g_ratio * test_prompt_embeddings_g_14) + (CLIP_KNN_ratio * test_prompt_embeddings) + (ratio_ViT_384 * embeddings3) + (ratio_ViT_224 * embeddings4)+ (ratio_interrogator *prompt_embeddings)

In [103]:
# normalize
test_prompt_embeddings_preds = normalize(test_prompt_embeddings_preds)

In [104]:
imgIds = [i.stem for i in test_images]
EMBEDDING_LENGTH = 384
imgId_eId = [
    '_'.join(map(str, i)) for i in zip(
        np.repeat(imgIds, EMBEDDING_LENGTH),
        np.tile(range(EMBEDDING_LENGTH), len(imgIds)))]

submission = pd.DataFrame(
    index=imgId_eId,
    data=test_prompt_embeddings_preds,
    columns=['val']
).rename_axis('imgId_eId')
submission.to_csv('submission.csv')

In [105]:
submission

Unnamed: 0_level_0,val
imgId_eId,Unnamed: 1_level_1
f27825b2c_0,-0.056371
f27825b2c_1,0.071624
f27825b2c_2,0.011509
f27825b2c_3,-0.030538
f27825b2c_4,-0.068708
...,...
c98f79f71_379,-0.008605
c98f79f71_380,0.077290
c98f79f71_381,0.034607
c98f79f71_382,-0.028248
