#### Import Package

##### Follow these steps to download Milvus: https://milvus.io/docs/install_standalone-docker.md
##### --> To start up milvus, cd to directory of docker-compose.yml and run: docker compose up -d --> docker compose ps 
##### --> To shut down milvus, cd to directory of docker-compose.yml and run: docker compose down

In [1]:
import time
import os
import pandas as pd
import numpy as np
import ray
import json
import openai as OpenAI

from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility

from utils.system import *
from class_data.data import Data

#### Export Data

##### --> Make sure to run format_data.ipynb and get_emb_openai.ipynb (chronologically) to get the data for this
##### --> Skip this if the data is already provided 

In [None]:
# Multiple Articles per Day Open AI Embeddings
wsj_multiple_openai = Data(folder_path=get_format_data() / 'openai', file_pattern='wsj_emb_openai_*')
wsj_multiple_openai = wsj_multiple_openai.concat_files()
print(wsj_multiple_openai.shape)
# Multiple Articles per Day Data
wsj_multiple = Data(folder_path=get_format_data() / 'token', file_pattern='wsj_tokens_*')
wsj_multiple = wsj_multiple.concat_files()
print(wsj_multiple.shape)

In [None]:
# Merge Embeddings and Article
wsj_combine = pd.concat([wsj_multiple_openai, wsj_multiple], axis=1)

In [None]:
# Set limit to the exact same value used in embedding_similarity.ipynb to align indexes
limit = 30
count = wsj_combine.groupby(wsj_combine.index)['accession_number'].count()
valid_dates_mask = count >= limit
wsj_combine = wsj_combine[wsj_combine.index.isin(count[valid_dates_mask].index)]
print(wsj_combine.shape)

In [None]:
# Add IDs
wsj_combine = wsj_combine.reset_index()
wsj_combine = wsj_combine.rename(columns={'index':'date'})
wsj_combine.index.names = ['id']
wsj_combine = wsj_combine.reset_index().set_index(['id', 'date'])
# Add article count
wsj_combine['article_count'] = wsj_combine.groupby(level='date')['body_txt'].transform('count')

In [None]:
# Export Data
chunks = np.array_split(wsj_combine, 50)
for i, df in enumerate(chunks, 1):
    print(i)
    df.to_parquet(get_format_data() / 'web' / f'wsj_all_{i}.parquet.brotli', compression='brotli')

#### Load Data

In [2]:
# Multiple Articles per Day Open AI Embeddings
wsj_combine = Data(folder_path=get_format_data() / 'web', file_pattern='wsj_all_*')
wsj_combine = wsj_combine.concat_files()
wsj_combine.shape

(830899, 6)

#### Milvus Add Data

In [3]:
def db_add_group(group):
    # Create lists for bulk insertion
    ids = []
    dates = []
    headlines = []
    documents = []
    n_tokens = []
    n_dates = []
    embeddings = []
    for row in group.iterrows():
        ids.append(row[0][0])
        dates.append(row[0][1].strftime("%Y-%m-%d"))
        headlines.append(row[1]['headline'])
        documents.append(row[1]['body_txt'])
        n_tokens.append(row[1]['n_tokens'])
        n_dates.append(row[1]['article_count'])
        embeddings.append(row[1]['ada_embedding'].tolist())

    # Prepare the data for insertion
    insert_data = [
        ids, dates, headlines, documents, n_tokens, n_dates, embeddings
    ]

    # Bulk add to collection
    collection.insert(insert_data)

def db_add_all(df, group_size):
    # Calculate the total number of groups
    total_groups = int(np.ceil(len(df) / group_size))
    print(f"Total groups: {total_groups}")

    for group_idx in range(total_groups):
        print("-" * 60)
        print(f"Processing group: {group_idx + 1}/{total_groups}")

        # Create group
        group_start = group_idx * group_size
        group_end = min(group_start + group_size, len(df))
        group = df[group_start:group_end]

        # Add group
        db_add_group(group)

In [9]:
# Local Start
COLLECTION_NAME = 'wsj_emb'
DIMENSION = 1536
COUNT = 100 
MILVUS_HOST = 'localhost'
MILVUS_PORT = '19530'
OPENAI_ENGINE = 'text-embedding-ada-002'
API_KEY = json.load(open(get_config() / 'api.json'))['openai_api_key']

In [28]:
# Cloud Start
COLLECTION_NAME = 'wsj_emb'
DIMENSION = 1536
COUNT = 100 
MILVUS_HOST = json.load(open(get_config() / 'milvus.json'))['gcp_milvus_server'] 
MILVUS_PORT = '19530'
OPENAI_ENGINE = 'text-embedding-ada-002'
API_KEY = json.load(open(get_config() / 'api.json'))['openai_api_key']

In [29]:
# Connect to server
connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)

MilvusException: <MilvusException: (code=2, message=Fail connecting to server on 34.139.188.28:19530. Timeout)>

In [73]:
# Remove collection if it already exists
if utility.has_collection(COLLECTION_NAME):
    utility.drop_collection(COLLECTION_NAME)

In [74]:
# Create collection which includes the id, title, and embedding
fields = [
    FieldSchema(name='id', dtype=DataType.INT64, description='article id', is_primary=True, auto_id=False),
    FieldSchema(name='date', dtype=DataType.VARCHAR, description='yyyy-mm-dd date', max_length=65535),
    FieldSchema(name='headline', dtype=DataType.VARCHAR, description='article headline', max_length=65535),
    FieldSchema(name='document', dtype=DataType.VARCHAR, description='article text', max_length=65535),
    FieldSchema(name='n_token', dtype=DataType.INT64, description='number of tokens in article'),
    FieldSchema(name='n_date', dtype=DataType.INT64, description='number of articles in date'),
    FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION)
]
schema = CollectionSchema(fields=fields, description='wall street journal openai embeddings')
collection = Collection(name=COLLECTION_NAME, schema=schema)

# Create an index for the collection
index_params = {
    'index_type': 'IVF_FLAT',
    'metric_type': 'COSINE',
    'params': {'nlist': 1024}
}
collection.create_index(field_name="embedding", index_params=index_params)

Status(code=0, message=)

In [75]:
# Insert data into collection
db_add_all(df=wsj_combine, group_size=5000)

Total groups: 167
------------------------------------------------------------
Processing group: 1/167
------------------------------------------------------------
Processing group: 2/167
------------------------------------------------------------
Processing group: 3/167
------------------------------------------------------------
Processing group: 4/167
------------------------------------------------------------
Processing group: 5/167
------------------------------------------------------------
Processing group: 6/167
------------------------------------------------------------
Processing group: 7/167
------------------------------------------------------------
Processing group: 8/167
------------------------------------------------------------
Processing group: 9/167
------------------------------------------------------------
Processing group: 10/167
------------------------------------------------------------
Processing group: 11/167
---------------------------------------------