## How embeddings extracted from text in Kaggle evaluation kernel

In [1]:
import sys
import numpy as np
import pandas as pd
from pathlib import Path

sys.path.append('/home/toomuch/kaggle-diffusion/sentence-transformers')
from sentence_transformers import SentenceTransformer, models

comp_path = Path('/data/kaggle/image2prompt')

In [2]:
prompts = pd.read_csv(comp_path / 'prompts.csv', index_col='imgId')
prompts.head(7)

Unnamed: 0_level_0,prompt
imgId,Unnamed: 1_level_1
20057f34d,hyper realistic photo of very friendly and dys...
227ef0887,"ramen carved out of fractal rose ebony, in the..."
92e911621,ultrasaurus holding a black bean taco in the w...
a4e1c55a9,a thundering retro robot crane inks on parchme...
c98f79f71,"portrait painting of a shimmering greek hero, ..."
d8edf2e40,an astronaut standing on a engaging white rose...
f27825b2c,Kaggle employee Phil at a donut shop ordering ...


In [3]:
sample_submission = pd.read_csv(comp_path / 'sample_submission.csv', index_col='imgId_eId')
sample_submission.head()

Unnamed: 0_level_0,val
imgId_eId,Unnamed: 1_level_1
20057f34d_0,0.018848
20057f34d_1,0.03019
20057f34d_2,0.072792
20057f34d_3,-0.000673
20057f34d_4,0.016774


In [4]:
st_model = SentenceTransformer('/home/toomuch/kaggle-diffusion/all-MiniLM-L6-v2')
prompt_embeddings = st_model.encode(prompts['prompt']).flatten()

In [5]:
assert np.all(np.isclose(sample_submission['val'].values, prompt_embeddings, atol=1e-07))

## Extract biased texts from diffusion-db
Pipeline:
1. Extract texts & save them somewhere
1. Extract CLIP and MiniLM embeddings from them
1. Try to approximate MiniLM embeddings using CLIP and MLP

In [6]:
import pandas as pd

df = pd.read_parquet('./metadata-large.parquet')
len(list(df['prompt'].unique()))

1819808

In [7]:
df.head()

Unnamed: 0,image_name,prompt,part_id,seed,step,cfg,sampler,width,height,user_name,timestamp,image_nsfw,prompt_nsfw
0,3ccdc650-871a-4ad9-9bf2-dc475b83ed32.webp,beautiful porcelain ivory fair face woman biom...,1,2625978990,50,21.0,8,512,704,01f4e782b48faedf416083b2fbabaca2a45621b15ead23...,2022-08-20 10:03:00+00:00,0.038466,0.003089
1,1f1fcb70-63a4-40b1-ada9-2c15fb2ca10a.webp,complex 3 d render hyper detailed ultra sharp ...,1,738462306,50,10.0,8,512,704,01f4e782b48faedf416083b2fbabaca2a45621b15ead23...,2022-08-20 10:55:00+00:00,0.187317,0.001722
2,b0809c6b-cf43-4a82-99f7-6f2947d433fc.webp,complex 3 d render hyper detailed ultra sharp ...,1,1584972414,50,10.0,8,512,704,01f4e782b48faedf416083b2fbabaca2a45621b15ead23...,2022-08-20 10:55:00+00:00,0.065495,0.001722
3,b8cff57e-eb9d-467a-95a1-f6e3b8a38575.webp,complex 3 d render hyper detailed ultra sharp ...,1,2816373313,50,10.0,8,512,704,01f4e782b48faedf416083b2fbabaca2a45621b15ead23...,2022-08-20 10:55:00+00:00,0.083114,0.001722
4,298086cb-1c05-424e-b83b-a6148e8816e2.webp,complex 3 d render hyper detailed ultra sharp ...,1,3079866895,50,10.0,8,512,704,01f4e782b48faedf416083b2fbabaca2a45621b15ead23...,2022-08-20 10:55:00+00:00,0.148977,0.001722


In [8]:
import hashlib
from tqdm import tqdm

In [9]:
df['id'] = [hashlib.md5(el.encode('utf-8')).hexdigest()[:8] for el in tqdm(df['prompt'])]

print(len(df['id'].unique()) / len(df['prompt'].unique()))

100%|██████████| 14000000/14000000 [00:17<00:00, 783401.21it/s]


0.9998093205437057


In [10]:
from transformers import AutoModel, AutoTokenizer
# import torch 
openclip_model = AutoModel.from_pretrained('laion/CLIP-ViT-H-14-laion2B-s32B-b79K').to('cuda:2')
openclip_tokenizer = AutoTokenizer.from_pretrained('laion/CLIP-ViT-H-14-laion2B-s32B-b79K')

# class Model(torch.nn.Module):
#     def __init__(self)


`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.


In [11]:
openclip_tokenizer(
    "whereas it goes",
    add_special_tokens=True,
    max_length=77,
    padding="max_length",
    return_token_type_ids=True,
    truncation=True,
)

{'input_ids': [49406, 42234, 585, 2635, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

In [15]:
# df['prompt'].drop_duplicates().hist()
df['prompt'].apply(lambda x: len(x.split(' '))).quantile(q=0.985)

66.0

In [16]:
import torch
import clip


class ClipDataset(torch.utils.data.Dataset):
    def __init__(self, df):
        assert len(df.columns) == 2
        self.pairs = df.values

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        prompt_id, prompt = self.pairs[idx]
        while True:
            try:
                prompt_tensor_clip = clip.tokenize([prompt])
                break
            except RuntimeError:
                prompt = " ".join(prompt.split(" ")[:-1])

        return prompt_id, prompt, prompt_tensor_clip


class OpenClipDataset(torch.utils.data.Dataset):
    def __init__(self, df, tokenizer):
        assert len(df.columns) == 2
        self.pairs = df.values
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        prompt_id, prompt = self.pairs[idx]
        prompt_tensor_clip = self.tokenizer(
            prompt,
            add_special_tokens=True,
            max_length=64,
            padding="max_length",
            return_token_type_ids=True,
            truncation=True,
        )['input_ids']

        prompt_tensor_clip = torch.tensor(prompt_tensor_clip, dtype=torch.long)

        return prompt_id, prompt, prompt_tensor_clip


dataloader = torch.utils.data.DataLoader(
    # dataset=ClipDataset(df[["id", "prompt"]].drop_duplicates(subset=["id"])),
    dataset=OpenClipDataset(df[["id", "prompt"]].drop_duplicates(subset=["id"]), openclip_tokenizer),
    # batch_size=6144,
    batch_size=4096,
    num_workers=0,
    shuffle=False,
)

In [17]:
openclip_model.text_projection = openclip_model.text_projection.to('cuda:2')
openclip_model.text_model = openclip_model.text_model.to('cuda:2')
st_model = st_model.to('cuda:2')

In [18]:
# st_model.device

In [19]:
# import os 
# os.environ['OMP_NUM_THREADS'] = "1"

In [20]:
# openclip_model.encode_text

In [21]:
dump = dict()
import time


torch.cuda.empty_cache()
for batch in tqdm(dataloader):
    prompt_ids, prompts, prompt_tensors_clip = batch
    with torch.no_grad():
        # print(type(st_model.encode(prompts)))
        # raise KeyboardInterrupt
        # print('st model')
        st_embeddings = st_model.encode(prompts).tolist()
        # time.sleep(3)
        # clip_embeddings = (
        #     clip_model.encode_text(prompt_tensors_clip.squeeze().to(device))
        #     .detach()
        #     .cpu()
        #     .numpy()
        #     .tolist()
        # )
        # print(prompt_tensors_clip)
        # print('to cuda')
        prompt_tensors_clip = prompt_tensors_clip.to('cuda:2')
        # time.sleep(3)
        # print('openclip text model')
        clip_embeddings = openclip_model.text_model(prompt_tensors_clip)['pooler_output']
        # time.sleep(3)
        # print('openclip text proj')
        clip_embeddings = openclip_model.text_projection(clip_embeddings)
        # time.sleep(3)
        # print(clip_embeddings.device)
        clip_embeddings = clip_embeddings.detach().cpu().numpy().tolist()
    # print(st_embeddings)
    # print(clip_embeddings)
    for _prompt_id, _st_emb, _clip_emb, _prompt in zip(prompt_ids, st_embeddings, clip_embeddings, prompts):
        dump[_prompt_id] = {'prompt': _prompt, 'MiniLM-emb': _st_emb, 'CLIP-emb': _clip_emb}
    # raise KeyboardInterrupt

  1%|▏         | 6/445 [01:24<1:43:03, 14.09s/it]


KeyboardInterrupt: 

In [29]:
pd.DataFrame.from_dict(data=dump, orient='index').reset_index(drop=False).to_parquet('./a.parquet', index=False)

In [None]:
raise KeyboardInterrupt

In [2]:
import pandas as pd

a = pd.read_parquet('/home/toomuch/kaggle-diffusion/vectors/embeddings-__-vit-h-14-laion2B-s32B-b79K-__-MiniLM.parquet')

In [14]:
# type(a['MiniLM-emb'][0])
# list(a['MiniLM-emb'])[0]
a[:1000].to_parquet('./a.parquet')

In [3]:
import torch

class HeadDataLoader(torch.utils.data.DataLoader):
    def __init__(self, df):
        self.minilm_embeddings = list(df['MiniLM-emb'])
        self.clip_embeddings = list(df['CLIP-emb'])

    def __len__(self):
        return len(self.clip_embeddings)
    
    def __getitem__(self, idx):
        clip_emb = torch.tensor(self.clip_embeddings[idx], dtype=torch.float32)
        minilm_emb = torch.tensor(self.minilm_embeddings[idx], dtype=torch.float32)
        return clip_emb, minilm_emb
        
    

Unnamed: 0,index,prompt,MiniLM-emb,CLIP-emb
0,8bd1ade6,beautiful porcelain ivory fair face woman biom...,"[-0.008047424256801605, -0.03451211750507355, ...","[0.1196979284286499, 0.47622451186180115, 0.55..."
1,4a80a483,complex 3 d render hyper detailed ultra sharp ...,"[-0.010402043350040913, -0.06673333048820496, ...","[0.0020346567034721375, -0.027843043208122253,..."
2,1cabbbe2,complex 3 d render hyper detailed ultra sharp ...,"[-0.015670793130993843, -0.07329162955284119, ...","[-0.0025339871644973755, -0.12575972080230713,..."
3,2323fa81,complex 3 d render hyper detailed ultra sharp ...,"[-0.011442017741501331, -0.06704302877187729, ...","[-0.03198816627264023, -0.06073711812496185, 0..."
4,c3e14111,complex 3 d render hyper detailed ultra sharp ...,"[-0.014331274665892124, -0.0750780776143074, 0...","[-0.11527558416128159, -0.1218007355928421, 0...."
...,...,...,...,...
1819456,2fe9d49b,dreaming electric bicycle and electric car by ...,"[-0.010595626197755337, 0.12237173318862915, 0...","[-0.7898232936859131, -0.3672802448272705, -0...."
1819457,f428cb53,"riding neon bycicles in the woods, painted by ...","[-0.0046276310458779335, 0.0736890584230423, 0...","[0.4300571084022522, 0.36568161845207214, -0.3..."
1819458,f2a9f73f,"Ibai Llanos dressed as Willy Wonka, highly det...","[-0.08139385282993317, 0.07119607925415039, 0....","[0.16320644319057465, -0.1726306974887848, 0.1..."
1819459,46c5b930,"Ibai Berto Romero as Willy Wonka, highly detai...","[-0.0948033556342125, 0.04334269464015961, -0....","[0.10932405292987823, -0.15810289978981018, 0...."


In [99]:
# df[['id', 'prompt']].loc(lambda parachute backpackx: len(x.split(' ')))

In [100]:
# df[['id', 'prompt']].drop_duplicates(subset=['id'])

In [115]:
# from transformers import AutoModel, AutoTokenizer

# openclip_model = AutoModel.from_pretrained('laion/CLIP-ViT-H-14-laion2B-s32B-b79K')
# openclip_tokenizer = AutoTokenizer.from_pretrained('laion/CLIP-ViT-H-14-laion2B-s32B-b79K')

`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.


In [113]:
# openclip_model.text_model

CLIPTextTransformer(
  (embeddings): CLIPTextEmbeddings(
    (token_embedding): Embedding(49408, 1024)
    (position_embedding): Embedding(77, 1024)
  )
  (encoder): CLIPEncoder(
    (layers): ModuleList(
      (0-23): 24 x CLIPEncoderLayer(
        (self_attn): CLIPAttention(
          (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
        )
        (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (mlp): CLIPMLP(
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
        )
        (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      )
    )
  )
  (final_layer

In [110]:
#  'text_embed_dim',
#  'text_model',
#  'text_projection',
# dir(model)

['T_destination',
 '__annotations__',
 '__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattr__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_apply',
 '_auto_class',
 '_backward_compatibility_gradient_checkpointing',
 '_backward_hooks',
 '_backward_pre_hooks',
 '_buffers',
 '_call_impl',
 '_convert_head_mask_to_5d',
 '_create_repo',
 '_expand_inputs_for_generation',
 '_extract_past_from_model_output',
 '_forward_hooks',
 '_forward_hooks_with_kwargs',
 '_forward_pre_hooks',
 '_forward_pre_hooks_with_kwargs',
 '_from_config',
 '_get_backward_hooks',
 '_get_backward_pre_hooks',
 '_get_decoder_start_token_id',
 '_get_files_timestamps',
 '_get_logits_processor',
 '_get_logits_warpe

In [101]:
# # next(iter(dataloader))
# import os
# os.environ["CUDA_VISIBLE_DEVICES"]="2"

In [96]:
device = "cuda:2" if torch.cuda.is_available() else "cpu"
# clip_model, preprocess = clip.load("RN50", device="cpu", jit=False)
clip_model, preprocess = clip.load("ViT-L/14@336px", device="cpu", jit=False)
clip_model.eval()
clip_model = clip_model.to(device)

st_model = SentenceTransformer('/home/toomuch/kaggle-diffusion/all-MiniLM-L6-v2')
st_model = st_model.to(device)

In [102]:
dump = dict()


torch.cuda.empty_cache()
for batch in tqdm(dataloader):
    prompt_ids, prompts, prompt_tensors_clip = batch
    with torch.no_grad():
        # print(type(st_model.encode(prompts)))
        # raise KeyboardInterrupt
        st_embeddings = st_model.encode(prompts).tolist()
        clip_embeddings = (
            clip_model.encode_text(prompt_tensors_clip.squeeze().to(device))
            .detach()
            .cpu()
            .numpy()
            .tolist()
        )
    # print(st_embeddings)
    # print(clip_embeddings)
    for _prompt_id, _st_emb, _clip_emb, _prompt in zip(prompt_ids, st_embeddings, clip_embeddings, prompts):
        dump[_prompt_id] = {'prompt': _prompt, 'MiniLM-emb': _st_emb, 'CLIP-emb': _clip_emb}
    # raise KeyboardInterrupt

  5%|▍         | 4492/95517 [02:24<48:49, 31.08it/s]  


KeyboardInterrupt: 

In [56]:
dump

{'982a185e': {'MiniLM-emb': [-0.0634867399930954,
   -0.010778216645121574,
   -0.013012065552175045,
   -0.0565788671374321,
   -0.02133307233452797,
   -0.055216625332832336,
   -0.026514187455177307,
   0.006786567158997059,
   -0.023862233385443687,
   0.0010450349655002356,
   -0.042546506971120834,
   -0.015005775727331638,
   0.04655912518501282,
   -0.05151678994297981,
   -0.0037952049169689417,
   0.02768847718834877,
   0.05010500177741051,
   0.0074802422896027565,
   0.022195616737008095,
   0.07230369001626968,
   0.026151690632104874,
   -0.10168015211820602,
   0.009564574807882309,
   -0.015129982493817806,
   0.001035219174809754,
   0.12775464355945587,
   0.07859566062688828,
   -0.012493176385760307,
   0.03038616292178631,
   -0.09009609371423721,
   -0.02951931394636631,
   0.11626584827899933,
   0.004252538550645113,
   0.023955387994647026,
   0.053742725402116776,
   0.09899432212114334,
   -0.05659092217683792,
   0.05053264647722244,
   -0.02342157438397407