In [3]:
!pip install pyrebase4 gdown faiss-gpu >/dev/null
!gdown 1pF1u8mekNs_z_KvFIJRTdqhlFHm1lp5n >/dev/null

Downloading...
From: https://drive.google.com/uc?id=1pF1u8mekNs_z_KvFIJRTdqhlFHm1lp5n
To: /kaggle/working/firebase_auth.json
100%|██████████████████████████████████████| 2.40k/2.40k [00:00<00:00, 10.5MB/s]


In [4]:
import pandas as pd
import re
import pyrebase
import os
import sys
import faiss
import numpy as np
import polars as pl
from tqdm.notebook import tqdm
import torch
import torch.nn.functional as F
from sklearn.metrics.pairwise import cosine_similarity



In [6]:
import pyrebase
import os

def load_db():
    config_path = 'firebase_auth.json'
    assert os.path.exists(config_path)

    config = {
      "apiKey": "AIzaSyBnWywH3ZswQNyLblBohBAp__f_F2myt5M",
      "authDomain": "datasetcollect-81ac0.firebaseapp.com/",
      "databaseURL": "https://datasetcollect-81ac0-default-rtdb.firebaseio.com/",
      "storageBucket": "datasetcollect-81ac0.appspot.com",
      "ServiceAccount": config_path,
    }


    firebase = pyrebase.initialize_app(config)
    db = firebase.database()
    return db

db = load_db()

print("Number of raw prompts: ", len(db.child("prompts_raw").get().val()))
print("Number of generated prompts: ", len(db.child("prompts_generated").get().val()))

cnt_picked = 0

for prompt_id, prompt_record in db.child("prompts_raw").get().val().items():
    if prompt_record["is_picked"]:
        cnt_picked += 1

print("Number of picked raw prompts: ", cnt_picked)


Number of raw prompts:  15211
Number of generated prompts:  15211
Number of picked raw prompts:  15211


## Загрузка уже имеющихся в базе данных промптов

In [18]:
def load_db_prompts():
    db_prompts = []
    prompts_table = db.child("prompts_raw").get().val()
    for prompt_record in prompts_table.items():
        db_prompts.append(prompt_record[1]["prompt"])
    return db_prompts

def load_new_prompts(filename):
    new_prompts = []
    with open(filename, mode="r", encoding="UTF-8") as file:
        for line in file.readlines():
            new_prompts.append(line.strip())
    return new_prompts
        
db_prompts = load_db_prompts()
new_prompts = load_new_prompts("/kaggle/input/collected-prompts/prompts-batch-19.txt")
print(new_prompts[:5])

['A sticker of a juicy watermelon slice with seeds and a big smile.', 'A sticker of a fluffy pancake stack drizzled in syrup with melting butter.', 'A sticker of a colorful sushi platter with delicate sashimi and sushi rolls.', 'A sticker of a steaming hot bowl of ramen with swirling noodles and savory broth.', 'A sticker of a decadent chocolate lava cake oozing with molten chocolate.']


In [19]:
print("Number of prompts in database: ", len(db_prompts))
print("Number of new prompts: ", len(new_prompts))

Number of prompts in database:  13989
Number of new prompts:  57581


In [20]:
sys.path.append("/kaggle/input/sentence-transformers-222/sentence-transformers")
from sentence_transformers import SentenceTransformer

## Первичная фильтрация некорректных промптов

In [21]:
MAX_PROMPT_LEN = 150

def check_string(string):
        return bool(re.search(r'[^A-Za-z0-9,.\\-\\s]', string))

def check_correct_prompt(prompt):    
    return check_string(prompt) and len(prompt.split()) >= 5 and len(prompt) <= MAX_PROMPT_LEN

mixed_prompts = new_prompts.copy()
mixed_prompts.extend(db_prompts)
mixed_prompts = [prompt.replace('A sticker of', '', 1).strip() for prompt in mixed_prompts 
                if check_correct_prompt(prompt)]

In [22]:
print("Number of mixed prompts: ", len(mixed_prompts))

Number of mixed prompts:  68240


In [23]:
# Преобразование имеющихся промптов в вид pd.DataFrame

df_mixed_prompts = pl.DataFrame({
    "prompt": mixed_prompts
})

df_mixed_prompts.head()

prompt
str
"""a juicy waterm…"
"""a fluffy panca…"
"""a colorful sus…"
"""a steaming hot…"
"""a decadent cho…"


## Фильтрация и удаление повторяющихся промптов

In [24]:
model = SentenceTransformer("/kaggle/input/sentence-transformers-222/all-MiniLM-L6-v2")

vector = model.encode(df_mixed_prompts["prompt"].to_numpy(), 
                      batch_size=512, 
                      show_progress_bar=True, 
                      device="cuda", 
                      convert_to_tensor=True)

Batches:   0%|          | 0/134 [00:00<?, ?it/s]

In [25]:
threshold = 0.8
n_neighbors = 1000  

batch_size = 1000
similar_vectors = []


resources = faiss.StandardGpuResources()
index = faiss.IndexIVFFlat(faiss.IndexFlatIP(vector.shape[1]), vector.shape[1], 5, faiss.METRIC_INNER_PRODUCT)
gpu_index = faiss.index_cpu_to_gpu(resources, 0, index)

gpu_index.train(F.normalize(vector).cpu().numpy())
gpu_index.add(F.normalize(vector).cpu().numpy())

sim_batches = []
indices_batches = []

for i in tqdm(range(0, len(vector), batch_size)):
    batch_data = vector.cpu().numpy()[i:i + batch_size]
    similarities, indices = gpu_index.search(batch_data, n_neighbors)
    sim_batches.append(similarities)
    indices_batches.append(indices)
    
    for j in range(similarities.shape[0]):
        close_vectors = indices[j, similarities[j] >= threshold]
        
        index_base = i
                
        close_vectors = close_vectors[close_vectors != index_base + j]  
        
        similar_vectors.append((index_base + j, close_vectors))



df = df_mixed_prompts.with_columns(pl.Series(values=list(range(len(df_mixed_prompts))), name="index"))
df = df.filter(~pl.col("index").is_in(np.unique(np.concatenate([x for _, x in similar_vectors])).tolist()))
df = df.to_pandas()
df['prompt'] = df['prompt'].apply(lambda p: f'A sticker of {p}')
df = df[~df["prompt"].isin(db_prompts)]
df.head()

  0%|          | 0/69 [00:00<?, ?it/s]

Unnamed: 0,prompt,index
0,A sticker of a farm-fresh vegetable medley wit...,22
1,A sticker of a spicy shrimp etouffee with plum...,76
2,A sticker of a refreshing sweet tea served in ...,316
3,A sticker of a colorful couscous salad with ro...,329
4,"A sticker of a sizzling barbecue grill, with s...",403


In [26]:
print("Number of left new prompts: ", len(df))

Number of left new prompts:  1222


## Добавление в базу данных новых промптов

In [28]:
def dump_new_prompts(df):
    for new_prompt in df["prompt"]:
        new_prompt_record = {
            "prompt": new_prompt,
            "is_generated": False,
            "is_picked": False
        }
        db.child("prompts_raw").push(new_prompt_record)
        
dump_new_prompts(df)

In [29]:
len(db.child("prompts_raw").get().val())

15211