#### Import Package

In [1]:
import chromadb
import time
import os
import pandas as pd
import numpy as np
import ray

from chromadb.config import Settings
from chromadb.utils import embedding_functions
from datetime import datetime

from utils.system import *
from class_data.data import Data

#### Export Data

##### --> Make sure to run format_data.ipynb and get_emb_openai.ipynb (chronologically) to get the data for this
###### --> Skip this if the data is already provided 

In [None]:
# Multiple Articles per Day Open AI Embeddings
wsj_multiple_openai = Data(folder_path=get_format_data() / 'openai', file_pattern='wsj_emb_openai_*')
wsj_multiple_openai = wsj_multiple_openai.concat_files()
print(wsj_multiple_openai.shape)
# Multiple Articles per Day Data
wsj_multiple = Data(folder_path=get_format_data() / 'token', file_pattern='wsj_tokens_*')
wsj_multiple = wsj_multiple.concat_files()
print(wsj_multiple.shape)

In [5]:
# Merge Embeddings and Article
wsj_combine = pd.concat([wsj_multiple_openai, wsj_multiple], axis=1)

In [None]:
# Set limit to the exact same value used in embedding_similarity.ipynb to align indexes
limit = 30
count = wsj_combine.groupby(wsj_combine.index)['accession_number'].count()
valid_dates_mask = count >= limit
wsj_combine = wsj_combine[wsj_combine.index.isin(count[valid_dates_mask].index)]
print(wsj_combine.shape)

In [10]:
# Add IDs
wsj_combine = wsj_combine.reset_index()
wsj_combine = wsj_combine.rename(columns={'index':'date'})
wsj_combine.index.names = ['id']
wsj_combine = wsj_combine.reset_index().set_index(['id', 'date'])
# Add article count
wsj_combine['article_count'] = wsj_combine.groupby(level='date')['body_txt'].transform('count')

In [None]:
# Export Data
chunks = np.array_split(wsj_combine, 50)
for i, df in enumerate(chunks, 1):
    print(i)
    df.to_parquet(get_format_data() / 'web' / f'wsj_all_{i}.parquet.brotli', compression='brotli')

#### Load Data

In [2]:
# Multiple Articles per Day Open AI Embeddings
wsj_combine = Data(folder_path=get_format_data() / 'web', file_pattern='wsj_all_*')
wsj_combine = wsj_combine.concat_files()
wsj_combine.shape

(830899, 6)

#### Create ChromaDB

In [3]:
# Create Database in get_backend() directory
os.environ['ALLOW_RESET'] = 'TRUE'
client = chromadb.PersistentClient(path=str(get_backend_chromadb()))

In [None]:
# Delete Collection
client.delete_collection(name="wsj_emb")

In [4]:
# List Collection
client.list_collections()

[]

#### ChromDB Add Data (Parallelized)

In [6]:
@ray.remote
def db_add_group(group):
    # Create/Get database in get_backend() directory
    os.environ['ALLOW_RESET'] = 'TRUE'
    client = chromadb.PersistentClient(path=str(get_backend_chromadb()))
    
    # Create OpenAI embedding function
    openai_ef = embedding_functions.OpenAIEmbeddingFunction(
                    model_name="text-embedding-ada-002",
                    api_key='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'
                )
    
    # Create/Get Collection
    collection = client.get_or_create_collection(name="wsj_emb", embedding_function=openai_ef, metadata={"hnsw:space": "cosine"})    
    
    # Create lists for bulk insertion
    embeddings = []
    documents = []
    metadatas = []
    ids = []
    for row in group.iterrows():
        embeddings.append(row[1]['ada_embedding'].tolist())
        documents.append(row[1]['body_txt'])
        source = "wsj openai embedding"
        date_timestamp = int(time.mktime(row[0][1].timetuple()))
        metadatas.append({
            "source": source, 
            "date": date_timestamp,
            "headline": row[1]['headline'],
            "n_token": row[1]['n_tokens'],
            "n_date": row[1]['article_count']
        })
        ids.append(f"id{row[0][0]}")

    # Bulk add to collection
    collection.add(
        embeddings=embeddings,
        documents=documents,
        metadatas=metadatas,
        ids=ids
    )

def db_add_all(df, group_size, batch_size, progress_file):
    # Check if progress file exists and load processed batch numbers
    if os.path.exists(progress_file):
        with open(progress_file, 'r') as file:
            processed_batches = {int(line.strip()) for line in file.readlines()}
    else:
        processed_batches = set()

    # Calculate the total number of groups and batches
    total_groups = np.ceil(len(df) / group_size)
    num_batches = np.ceil(total_groups / batch_size)
    print(f"Total groups: {int(total_groups)}, Number of batches: {int(num_batches)}")

    for batch_idx in range(int(num_batches)):
        # Log batch
        print("-" * 60)
        print(f"Processing batch: {batch_idx + 1}/{int(num_batches)}")

        # Check if this batch has already been processed
        if batch_idx in processed_batches:
            print(f"Already processed (skipping): Batch {batch_idx + 1}")
            continue

        # Collect groups: len(futures) = batch_size
        futures = []
        for group_idx in range(batch_idx * batch_size, min((batch_idx + 1) * batch_size, int(total_groups))):
            # Create groups: len(group) = group_size
            group_start = group_idx * group_size
            group_end = min(group_start + group_size, len(df))
            group = df[group_start:group_end]
            future = db_add_group.remote(group)
            futures.append(future)

        # Parallelize tasks across groups
        ray.get(futures)

        # Update progress file with the completed batch number
        with open(progress_file, 'a') as file:
            file.write(f"{batch_idx}\n")

In [8]:
# Process articles in batches
start_time = time.time()
ray.init(num_cpus=16, ignore_reinit_error=True)
# Create Progress File
progress_file = get_backend_data() / "progress_chroma_db.txt"
db_add_all(df=wsj_combine, group_size=500, batch_size=50, progress_file=progress_file)
elasped_time = time.time() - start_time
print(f"Total Time: {elasped_time} seconds")
# Shutdown Ray
ray.shutdown()

2024-01-13 00:37:24,911	INFO worker.py:1673 -- Started a local Ray instance.


Number of batches: 8309
------------------------------------------------------------
Processing batch: 1/8309
------------------------------------------------------------
Processing batch: 2/8309
------------------------------------------------------------
Processing batch: 3/8309
------------------------------------------------------------
Processing batch: 4/8309
------------------------------------------------------------
Processing batch: 5/8309
------------------------------------------------------------
Processing batch: 6/8309
------------------------------------------------------------
Processing batch: 7/8309
------------------------------------------------------------
Processing batch: 8/8309
------------------------------------------------------------
Processing batch: 9/8309
------------------------------------------------------------
Processing batch: 10/8309
------------------------------------------------------------
Processing batch: 11/8309
----------------------------