In [None]:
!pip install pyrebase4 gdown faiss-gpu
!gdown 1pF1u8mekNs_z_KvFIJRTdqhlFHm1lp5n

In [None]:
import os
import re
import requests
import pyrebase
import pandas as pd
from bs4 import BeautifulSoup as Soup

import sys
import faiss
import torch
import numpy as np
import polars as pl
from pathlib import Path
import torch.nn.functional as F
from tqdm.notebook import tqdm
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

In [None]:
config_path = 'firebase_auth.json'
assert os.path.exists(config_path)

config = {
  "apiKey": "AIzaSyBnWywH3ZswQNyLblBohBAp__f_F2myt5M",
  "authDomain": "datasetcollect-81ac0.firebaseapp.com",
  "databaseURL": "https://datasetcollect-81ac0-default-rtdb.firebaseio.com",
  "storageBucket": "datasetcollect-81ac0.appspot.com",
  "ServiceAccount": config_path,
}

firebase = pyrebase.initialize_app(config)
db = firebase.database()

In [None]:
def start_gen():
    gen_id = db.get('gen_id').val()
    if not gen_id:
        raise Exception('Unexpected behaviour, gen_id has to be defined, contact @round_tensor')
    name, val = gen_id.popitem()
    if val == 1:
        raise Exception('Someone is runing gen process right now')
    elif val != 0:
        raise Exception('Wtf why gen_id != 1 or 0')
    elif val == 0:
        db.child("gen_id").set(1)

start_gen()

In [None]:
MAX_LEN = 150

def preprocess(s):
    s = s.replace(r'\n', '').replace('\n', '').strip().capitalize()
    return s

def parse_chatgpt_url(url):
    r = requests.get(url)
    assert r.ok
    body = Soup(r.content, 'lxml').find('body')
    
    string = body.find_all('script')[1].text
    prompts = []
    parts_idx = [m.start() for m in re.finditer("parts", string)]
    for i in range(len(parts_idx) - 1):
        text = string[parts_idx[i]: parts_idx[i + 1]][7:]
        text = text[:text.find('},')]
        lst = list(re.findall(r"A sticker of .*?\\n", text, re.DOTALL))
        last = text[text.rfind('A sticker of'):].strip().capitalize()
        if last.count(r'\n') > 0:
            lst.append(last[:last.find(r'\n')])
        else:
            lst.append(last.replace('"]', ''))
        prompts.extend([i for i in [preprocess(q) for q in lst] if i.endswith('.') and i.count('.') == 1])

    return set(prompts)

def parse_chatgpt_urls(urls):
    res = set()
    for url in urls:
        for p in parse_chatgpt_url(url):
            res.add(p)
    return list(res)

In [None]:
urls = ['https://chat.openai.com/share/46828f0c-0944-463f-a394-b56c07c3249d', 'https://chat.openai.com/share/9708aa83-e06b-4873-8b89-3855149f13d6']
sint_prompts = parse_chatgpt_urls(urls)

In [None]:
def get_prompts(tabeles_to_parse=['table1']):
    result = []
    for table in tabeles_to_parse:
        resp = db.get(table).val()
        for table_name, table_data in resp.items():
            if table_name not in tabeles_to_parse:
                continue
            for record_id, record_data in table_data.items():
                result.append(record_data['prompt_sticker'])
    return result


db_prompts = get_prompts()
sint_prompts.extend(db_prompts)

In [None]:
sys.path.append("/kaggle/input/sentence-transformers-222/sentence-transformers")

In [None]:
def check_string(string: str) -> bool:
    return bool(re.search(r'[^A-Za-z0-9,.\\-\\s]', string))

sint_prompts = [i.replace('A sticker of', '', 1).strip() for i in sint_prompts if len(i.split()) >= 5 and len(i) <= MAX_LEN and check_string(i)]

df_sint = pl.DataFrame({'prompt': sint_prompts})

## Векторизация с использованием SentenceTransformer

In [None]:
model = SentenceTransformer("/kaggle/input/sentence-transformers-222/all-MiniLM-L6-v2")
vector = model.encode(df_sint["prompt"].to_numpy(), batch_size=512, show_progress_bar=True, device="cuda", convert_to_tensor=True)

## Vector search

In [None]:
threshold = 0.50  # !!!должен быть не оч строгим, т.к red ferrari, green ferrari, black ferrari надо дропнуть, а red ferrari, red ford mustand оставить
#поставил 0.5 чтобы получить мало промптов для примера, 0.8 нормально
n_neighbors = 1000  

batch_size = 1000
similar_vectors = []


resources = faiss.StandardGpuResources()
index = faiss.IndexIVFFlat(faiss.IndexFlatIP(vector.shape[1]), vector.shape[1], 5, faiss.METRIC_INNER_PRODUCT)
gpu_index = faiss.index_cpu_to_gpu(resources, 0, index)

gpu_index.train(F.normalize(vector).cpu().numpy())
gpu_index.add(F.normalize(vector).cpu().numpy())

sim_batches = []
indices_batches = []

for i in tqdm(range(0, len(vector), batch_size)):
    batch_data = vector.cpu().numpy()[i:i + batch_size]
    similarities, indices = gpu_index.search(batch_data, n_neighbors)
    sim_batches.append(similarities)
    indices_batches.append(indices)
    
    
    for j in range(similarities.shape[0]):
        close_vectors = indices[j, similarities[j] >= threshold]
        
        index_base = i
                
        close_vectors = close_vectors[close_vectors != index_base + j]  
        
        similar_vectors.append((index_base + j, close_vectors))



df = df_sint.with_columns(pl.Series(values=list(range(len(df_sint))), name="index"))
df = df.filter(~pl.col("index").is_in(np.unique(np.concatenate([x for _, x in similar_vectors])).tolist()))
df = df.to_pandas()
df['prompt'] = df['prompt'].apply(lambda p: f'A sticker of {p}')
df['prompt'] = [i for i in df['prompt'] if i not in db_prompts]

In [None]:
#display similar prompts which >= threashold

N_display = 3
c = 0

for idx_arr in similar_vectors:
    if c == N_display:
        break
    if idx_arr[1].size > 0:
        sim_i = sim_batches[idx_arr[0] // batch_size][idx_arr[0] % batch_size][1] # 1 is most similar to curr el(idx_arr[0])
        print(idx_arr[0], df_sint[idx_arr[0]]['prompt'].item())
        print(int(idx_arr[1][0]), df_sint[int(idx_arr[1][0])]['prompt'].item())
        print(f'Similarity is: {sim_i}', end='\n\n')
        c += 1

In [None]:
#display most similar prompts which < threashold

N_top = 3
N_display = 3
c = 0

for idx_arr in similar_vectors:
    if c == N_display:
        break
    if idx_arr[1].size == 0:
        print(idx_arr[0], df_sint[idx_arr[0]]['prompt'].item())
        print('__'*10)
        for i in range(1, 1 + N_top):
            sim_i = sim_batches[idx_arr[0] // batch_size][idx_arr[0] % batch_size][i]
            idx_i = indices_batches[idx_arr[0] // batch_size][idx_arr[0] % batch_size][i]
            print(int(idx_i), df_sint[int(idx_i)]['prompt'].item(), sim_i)
        c += 1
        print()

In [None]:
df.to_csv('prompts.csv')