In [1]:
# !wget https://gist.githubusercontent.com/huseinzol05/98974ae8c6c7a65d4bc0af9f5003786a/raw/2e06e71ef7349a57bc58cc9913ae6bae1f9f8447/mp.py

In [2]:
from scipy.spatial import KDTree
from datasketch import MinHash, MinHashLSH
from glob import glob
from tqdm import tqdm
import numpy as np
import mp
import pandas as pd

In [3]:
files = glob('embedding/*.npy')
len(files)

636921

In [6]:
embeddings = []
for f in tqdm(files):
    embeddings.append(np.load(f))
embeddings = np.array(embeddings)

100%|██████████████████████████████████████████████████████████████████████████████| 636921/636921 [00:56<00:00, 11205.05it/s]


In [7]:
from sklearn.preprocessing import normalize

def deduplicate_embeddings(embeddings, similarity_threshold=0.9):
    """
    Deduplicate embeddings based on cosine similarity threshold.
    
    Args:
        embeddings: numpy array of shape [N, dim] where N is number of embeddings
        similarity_threshold: float between 0 and 1, threshold for considering embeddings as duplicates
    
    Returns:
        unique_indices: indices of unique embeddings
        duplicate_groups: list of lists containing indices of similar embeddings
    """
    # Normalize embeddings for cosine similarity
    embeddings = normalize(embeddings)
    N = embeddings.shape[0]
    
    # Track which embeddings have been marked as duplicates
    is_duplicate = np.zeros(N, dtype=bool)
    duplicate_groups = []
    unique_indices = []
    
    # Process embeddings in batches for memory efficiency
    batch_size = 300000
    
    for i in tqdm(range(N)):
        if is_duplicate[i]:
            continue
            
        # Calculate similarities for current embedding with remaining embeddings
        start_idx = i + 1
        similar_indices = [i]
        
        while start_idx < N:
            end_idx = min(start_idx + batch_size, N)
            batch_similarities = embeddings[i:i+1] @ embeddings[start_idx:end_idx].T
            
            # Find similar embeddings in batch
            batch_similar = np.where(batch_similarities[0] >= similarity_threshold)[0]
            batch_similar_global_idx = batch_similar + start_idx
            
            # Filter out already marked duplicates
            batch_similar_global_idx = batch_similar_global_idx[~is_duplicate[batch_similar_global_idx]]
            
            if len(batch_similar_global_idx) > 0:
                similar_indices.extend(batch_similar_global_idx.tolist())
                is_duplicate[batch_similar_global_idx] = True
                
            start_idx = end_idx
            
        if len(similar_indices) > 1:
            duplicate_groups.append(similar_indices)
        else:
            unique_indices.append(i)
            
    return np.array(unique_indices), duplicate_groups

In [None]:
unique_indices, duplicate_groups = deduplicate_embeddings(embeddings, similarity_threshold=0.95)

  8%|██████▏                                                                         | 49567/636921 [15:09<4:39:53, 34.97it/s]