In [4]:
!pip install pyrebase4 gdown faiss-gpu
!gdown 1pF1u8mekNs_z_KvFIJRTdqhlFHm1lp5n

Downloading...
From: https://drive.google.com/uc?id=1pF1u8mekNs_z_KvFIJRTdqhlFHm1lp5n
To: /kaggle/working/firebase_auth.json
100%|██████████████████████████████████████| 2.40k/2.40k [00:00<00:00, 12.6MB/s]


In [5]:
import pandas as pd
import re
import pyrebase
import os
import sys
import faiss
import numpy as np
import polars as pl
from tqdm.notebook import tqdm
import torch
import torch.nn.functional as F
from sklearn.metrics.pairwise import cosine_similarity

In [7]:
def load_db():
    config_path = 'firebase_auth.json'
    assert os.path.exists(config_path)

    config = {
      "apiKey": "AIzaSyBnWywH3ZswQNyLblBohBAp__f_F2myt5M",
      "authDomain": "datasetcollect-81ac0.firebaseapp.com",
      "databaseURL": "https://datasetcollect-81ac0-default-rtdb.firebaseio.com",
      "storageBucket": "datasetcollect-81ac0.appspot.com",
      "ServiceAccount": config_path,
    }

    firebase = pyrebase.initialize_app(config)
    db = firebase.database()
    return db

db = load_db()

In [8]:
def load_db_prompts():
    db_prompts = []
    prompts_table = db.child("prompts").get().val()
    for prompt_record in prompts_table.items():
        db_prompts.append(prompt_record[1]["prompt"])
    return db_prompts

def load_new_prompts(filename):
    new_prompts = []
    with open(filename, mode="r", encoding="UTF-8") as file:
        for line in file.readlines():
            new_prompts.append(line.strip())
    return new_prompts
        
db_prompts = load_db_prompts()
new_prompts = load_new_prompts("/kaggle/input/collected-prompts/prompts-batch-1.txt")

In [9]:
sys.path.append("/kaggle/input/sentence-transformers-222/sentence-transformers")
from sentence_transformers import SentenceTransformer

In [10]:
MAX_PROMPT_LEN = 150

def check_string(string):
        return bool(re.search(r'[^A-Za-z0-9,.\\-\\s]', string))

def check_correct_prompt(prompt):    
    return check_string(prompt) and len(prompt.split()) >= 5 and len(prompt) <= MAX_PROMPT_LEN

mixed_prompts = new_prompts.copy()
mixed_prompts.extend(db_prompts)
mixed_prompts = [prompt.replace('A sticker of', '', 1).strip() for prompt in mixed_prompts 
                if check_correct_prompt(prompt)]

In [11]:
df_mixed_prompts = pl.DataFrame({
    "prompt": mixed_prompts
})

df_mixed_prompts.head()

prompt
str
"""a classic movi…"
"""a classic blac…"
"""a tranquil zen…"
"""a legendary kr…"
"""a futuristic u…"


In [12]:
model = SentenceTransformer("/kaggle/input/sentence-transformers-222/all-MiniLM-L6-v2")

vector = model.encode(df_mixed_prompts["prompt"].to_numpy(), 
                      batch_size=512, 
                      show_progress_bar=True, 
                      device="cuda", 
                      convert_to_tensor=True)

  return self.fget.__get__(instance, owner)()


Batches:   0%|          | 0/3 [00:00<?, ?it/s]

In [19]:
threshold = 0.8
n_neighbors = 1000  

batch_size = 1000
similar_vectors = []


resources = faiss.StandardGpuResources()
index = faiss.IndexIVFFlat(faiss.IndexFlatIP(vector.shape[1]), vector.shape[1], 5, faiss.METRIC_INNER_PRODUCT)
gpu_index = faiss.index_cpu_to_gpu(resources, 0, index)

gpu_index.train(F.normalize(vector).cpu().numpy())
gpu_index.add(F.normalize(vector).cpu().numpy())

sim_batches = []
indices_batches = []

for i in tqdm(range(0, len(vector), batch_size)):
    batch_data = vector.cpu().numpy()[i:i + batch_size]
    similarities, indices = gpu_index.search(batch_data, n_neighbors)
    sim_batches.append(similarities)
    indices_batches.append(indices)
    
    for j in range(similarities.shape[0]):
        close_vectors = indices[j, similarities[j] >= threshold]
        
        index_base = i
                
        close_vectors = close_vectors[close_vectors != index_base + j]  
        
        similar_vectors.append((index_base + j, close_vectors))



df = df_mixed_prompts.with_columns(pl.Series(values=list(range(len(df_mixed_prompts))), name="index"))
df = df.filter(~pl.col("index").is_in(np.unique(np.concatenate([x for _, x in similar_vectors])).tolist()))
df = df.to_pandas()
df['prompt'] = df['prompt'].apply(lambda p: f'A sticker of {p}')
df['prompt'] = [prompt for prompt in df['prompt'] if prompt not in db_prompts]
df.head()

  0%|          | 0/2 [00:00<?, ?it/s]

Unnamed: 0,prompt,index
0,A sticker of a classic black and white tv.,1
1,A sticker of a tranquil venice canal.,10
2,A sticker of a soulful american blues guitarist.,15
3,A sticker of a tranquil pond with water lilies.,17
4,A sticker of a cute baby kangaroo.,18


In [None]:
# TODO display and check visually dropped prompts

In [20]:
df.shape

(335, 2)

In [23]:
def dump_new_prompts(df):
    for new_prompt in df["prompt"]:
        new_prompt_record = {
            "prompt": new_prompt,
            "is_generated": False
        }
        db.child("prompts").push(new_prompt_record)
        
dump_new_prompts(df)

In [25]:
len(db.child("prompts").get().val())

335