#### Import Package

In [None]:
import time
import os
import pandas as pd
import numpy as np
import ray

from qdrant_client import QdrantClient
from qdrant_client.http import models
from qdrant_client.http.models import CollectionStatus
from datetime import datetime

from utils.system import *
from class_data.data import Data

#### Export Data

##### --> Make sure to run format_data.ipynb and get_emb_openai.ipynb (chronologically) to get the data for this
##### --> Skip this if the data is already provided 

In [None]:
# Multiple Articles per Day Open AI Embeddings
wsj_multiple_openai = Data(folder_path=get_format_data() / 'openai', file_pattern='wsj_emb_openai_*')
wsj_multiple_openai = wsj_multiple_openai.concat_files()
print(wsj_multiple_openai.shape)
# Multiple Articles per Day Data
wsj_multiple = Data(folder_path=get_format_data() / 'token', file_pattern='wsj_tokens_*')
wsj_multiple = wsj_multiple.concat_files()
print(wsj_multiple.shape)

In [None]:
# Merge Embeddings and Article
wsj_combine = pd.concat([wsj_multiple_openai, wsj_multiple], axis=1)

In [None]:
# Set limit to the exact same value used in embedding_similarity.ipynb to align indexes
limit = 30
count = wsj_combine.groupby(wsj_combine.index)['accession_number'].count()
valid_dates_mask = count >= limit
wsj_combine = wsj_combine[wsj_combine.index.isin(count[valid_dates_mask].index)]
print(wsj_combine.shape)

In [None]:
# Add IDs
wsj_combine = wsj_combine.reset_index()
wsj_combine = wsj_combine.rename(columns={'index':'date'})
wsj_combine.index.names = ['id']
wsj_combine = wsj_combine.reset_index().set_index(['id', 'date'])
# Add article count
wsj_combine['article_count'] = wsj_combine.groupby(level='date')['body_txt'].transform('count')

In [None]:
# Export Data
chunks = np.array_split(wsj_combine, 50)
for i, df in enumerate(chunks, 1):
    print(i)
    df.to_parquet(get_format_data() / 'web' / f'wsj_all_{i}.parquet.brotli', compression='brotli')

#### Load Data

In [None]:
# Multiple Articles per Day Open AI Embeddings
wsj_combine = Data(folder_path=get_format_data() / 'web', file_pattern='wsj_all_*')
wsj_combine = wsj_combine.concat_files()
wsj_combine.shape

#### Qdrant Add Data

##### NOTE: Persistent does not work for windows
##### --> Download Docker for windows here: https://docs.docker.com/desktop/install/windows-install/
##### --> Check if installed correctly: docker --version
##### --> To start a local server (non-persistent) run this in powershell: docker run -p 6333:6333 qdrant/qdrant:latest
##### --> For deployment (persistent), create a directory called quant_storage, cd to the parent directory, and run these commands in powershell: docker pull qdrant/qdrant --> docker run -p 6333:6333 -v ${PWD}/qdrant_storage:/qdrant/storage qdrant/qdrant



In [None]:
def db_add_group(group):
    # Create lists for bulk insertion
    ids = []
    embeddings = []
    metadatas = []
    for row in group.iterrows():
        ids.append(row[0][0])
        embeddings.append(row[1]['ada_embedding'].tolist())
        metadatas.append({
            "source": "wsj openai embedding", 
            "date": int(time.mktime(row[0][1].timetuple())),
            "headline": row[1]['headline'],
            "document": row[1]['body_txt'],
            "n_token": row[1]['n_tokens'],
            "n_date": row[1]['article_count'],
        })

    # Bulk add to collection
    client.upsert(
        collection_name='wsj_emb',
        points=models.Batch(
            ids=ids,
            vectors=embeddings,
            payloads=metadatas
        )
    )

def db_add_all(df, group_size):
    # Calculate the total number of groups
    total_groups = int(np.ceil(len(df) / group_size))
    print(f"Total groups: {total_groups}")

    for group_idx in range(total_groups):
        print("-" * 60)
        print(f"Processing group: {group_idx + 1}/{total_groups}")

        # Create group
        group_start = group_idx * group_size
        group_end = min(group_start + group_size, len(df))
        group = df[group_start:group_end]

        # Add group
        db_add_group(group)

#### Create Qdrant (Server)

In [None]:
# Create Database in server
client = QdrantClient("http://localhost:6333")

In [None]:
# Create collection (This deletes the current collection)
client.recreate_collection(
    collection_name="wsj_emb",
    vectors_config=models.VectorParams(size=len(wsj_combine['ada_embedding'][0].values[0]), distance=models.Distance.COSINE),
    shard_number=4,
)

In [None]:
start_time = time.time()
db_add_all(df=wsj_combine, group_size=850)
elasped_time = time.time() - start_time
print(f"Total Time: {elasped_time} seconds")