#### Import Package

##### Follow these steps to download Milvus: https://milvus.io/docs/install_standalone-docker.md
##### --> To start up milvus, cd to directory of docker-compose.yml and run: docker compose up -d --> docker compose ps 
##### --> To shut down milvus, cd to directory of docker-compose.yml and run: docker compose down

In [1]:
import time
import os
import pandas as pd
import numpy as np
import ray
import json
import openai as OpenAI

from datetime import datetime
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility

from utils.system import *
from class_data.data import Data

#### Export Data

##### --> Make sure to run format_data.ipynb and get_emb_openai.ipynb (chronologically) to get the data for this
##### --> Skip this if the data is already provided 

In [2]:
# Multiple Articles per Day Open AI Embeddings
wsj_multiple_openai = Data(folder_path=get_format_data() / 'openai', file_pattern='wsj_emb_textemb3small_*')
wsj_multiple_openai = wsj_multiple_openai.concat_files()
print(wsj_multiple_openai.shape)
# Multiple Articles per Day Data
wsj_multiple = Data(folder_path=get_format_data() / 'token', file_pattern='wsj_tokens_*')
wsj_multiple = wsj_multiple.concat_files()
print(wsj_multiple.shape)

Loading Data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 832/832 [00:16<00:00, 49.49it/s]


(831077, 1)


Loading Data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:06<00:00,  1.17it/s]

(831077, 4)





In [3]:
# Merge Embeddings and Article
wsj_combine = pd.concat([wsj_multiple_openai, wsj_multiple], axis=1)

In [4]:
# Set limit to the exact same value used in embedding_similarity.ipynb to align indexes
limit = 30
count = wsj_combine.groupby(wsj_combine.index)['accession_number'].count()
valid_dates_mask = count >= limit
wsj_combine = wsj_combine[wsj_combine.index.isin(count[valid_dates_mask].index)]
print(wsj_combine.shape)

(830899, 5)


In [5]:
# Add IDs
wsj_combine = wsj_combine.reset_index()
wsj_combine = wsj_combine.rename(columns={'index':'date'})
wsj_combine.index.names = ['id']
wsj_combine = wsj_combine.reset_index().set_index(['id', 'date'])
# Add article count
wsj_combine['article_count'] = wsj_combine.groupby(level='date')['body_txt'].transform('count')

In [6]:
# Export Data
chunks = np.array_split(wsj_combine, 50)
for i, df in enumerate(chunks, 1):
    print(i)
    df.to_parquet(get_format_data() / 'web' / f'wsj_all_{i}.parquet.brotli', compression='brotli')

  return bound(*args, **kwds)


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50


#### Load Data

In [2]:
# Multiple Articles per Day Open AI Embeddings
wsj_combine = Data(folder_path=get_format_data() / 'web', file_pattern='wsj_all_*')
wsj_combine = wsj_combine.concat_files(1)
wsj_combine.shape

Loading Data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.01s/it]


(16618, 6)

In [7]:
x = wsj_combine.reset_index().set_index('date')

In [9]:
x.loc[(x.index < '2014-01-01') & (x.index > '2013-01-01')]

Unnamed: 0_level_0,id,ada_embedding,accession_number,headline,body_txt,n_tokens,article_count
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1


#### Milvus Add Embedding and Metadata

In [13]:
def convert_date_to_int(date):
    if not isinstance(date, str):
        date = date.strftime("%Y-%m-%d")
    date_obj = datetime.strptime(date, "%Y-%m-%d")
    return int(date_obj.strftime("%Y%m%d"))
    
def db_add_group(group):
    # Create lists for bulk insertion
    ids = []
    dates = []
    headlines = []
    documents = []
    embeddings = []
    for row in group.iterrows():
        ids.append(row[0][0])
        dates.append(convert_date_to_int(row[0][1]))
        headlines.append(row[1]['headline'])
        documents.append(row[1]['body_txt'])
        embeddings.append(row[1]['ada_embedding'].tolist())

    # Prepare the data for insertion
    insert_data = [
        ids, dates, headlines, documents, embeddings
    ]

    # Bulk add to collection
    collection.insert(insert_data)

def db_add_all(df, group_size):
    # Calculate the total number of groups
    total_groups = int(np.ceil(len(df) / group_size))
    print(f"Total groups: {total_groups}")

    for group_idx in range(total_groups):
        print("-" * 60)
        print(f"Processing group: {group_idx + 1}/{total_groups}")

        # Create group
        group_start = group_idx * group_size
        group_end = min(group_start + group_size, len(df))
        group = df[group_start:group_end]

        # Add group
        db_add_group(group)

In [5]:
# Local Start
COLLECTION_NAME = 'wsj_emb'
DIMENSION = 1536
COUNT = 100 
MILVUS_HOST = 'localhost'
MILVUS_PORT = '19530'
OPENAI_ENGINE = 'text-embedding-ada-002'
API_KEY = json.load(open(get_config() / 'api.json'))['openai_api_key']

In [22]:
# Cloud Start
COLLECTION_NAME = 'wsj_emb'
DIMENSION = 1536
COUNT = 100 
MILVUS_HOST = json.load(open(get_config() / 'milvus.json'))['gcp_milvus_server'] 
MILVUS_PORT = '19530'
OPENAI_ENGINE = 'text-embedding-ada-002'
API_KEY = json.load(open(get_config() / 'api.json'))['openai_api_key']

In [28]:
# Ngrok
COLLECTION_NAME = 'wsj_emb'
DIMENSION = 1536
COUNT = 100
MILVUS_HOST = 'tcp://8.tcp.ngrok.io'
MILVUS_PORT = '12487'  
OPENAI_ENGINE = 'text-embedding-ada-002'
API_KEY = json.load(open(get_config() / 'api.json'))['openai_api_key']

In [6]:
# Connect to server
connections.disconnect('default')
connections.connect('default', host=MILVUS_HOST, port=MILVUS_PORT)

In [7]:
# Remove collection if it already exists
if utility.has_collection(COLLECTION_NAME):
    utility.drop_collection(COLLECTION_NAME)

In [9]:
# Create collection which includes the id, title, and embedding
fields = [
    FieldSchema(name='id', dtype=DataType.INT64, description='article id', is_primary=True, auto_id=False),
    FieldSchema(name='date', dtype=DataType.INT32, description='yyyy-mm-dd date as inter yyyymmdd', max_length=65535),
    FieldSchema(name='headline', dtype=DataType.VARCHAR, description='article headline', max_length=65535),
    FieldSchema(name='document', dtype=DataType.VARCHAR, description='article text', max_length=65535),
    FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION)
]
schema = CollectionSchema(fields=fields, description='wall street journal embeddings')
collection = Collection(name=COLLECTION_NAME, schema=schema)

# Create an index for the collection
index_params = {
    'index_type': 'IVF_FLAT',
    'metric_type': 'COSINE',
    'params': {'nlist': 5}
}
collection.create_index(field_name="embedding", index_params=index_params)

Status(code=0, message=)

In [14]:
# Insert data into collection
db_add_all(df=wsj_combine, group_size=5000)

Total groups: 167
------------------------------------------------------------
Processing group: 1/167
------------------------------------------------------------
Processing group: 2/167
------------------------------------------------------------
Processing group: 3/167
------------------------------------------------------------
Processing group: 4/167
------------------------------------------------------------
Processing group: 5/167
------------------------------------------------------------
Processing group: 6/167
------------------------------------------------------------
Processing group: 7/167
------------------------------------------------------------
Processing group: 8/167
------------------------------------------------------------
Processing group: 9/167
------------------------------------------------------------
Processing group: 10/167
------------------------------------------------------------
Processing group: 11/167
---------------------------------------------

#### Milvus Add Embedding

In [15]:
def convert_date_to_int(date):
    if not isinstance(date, str):
        date = date.strftime("%Y-%m-%d")
    date_obj = datetime.strptime(date, "%Y-%m-%d")
    return int(date_obj.strftime("%Y%m%d"))
    
def db_add_group(group):
    # Create lists for bulk insertion
    ids = []
    dates = []
    headlines = []
    documents = []
    embeddings = []
    for row in group.iterrows():
        ids.append(row[0][0])
        dates.append(convert_date_to_int(row[0][1]))
        embeddings.append(row[1]['ada_embedding'].tolist())

    # Prepare the data for insertion
    insert_data = [
        ids, dates, embeddings
    ]

    # Bulk add to collection
    collection.insert(insert_data)

def db_add_all(df, group_size):
    # Calculate the total number of groups
    total_groups = int(np.ceil(len(df) / group_size))
    print(f"Total groups: {total_groups}")

    for group_idx in range(total_groups):
        print("-" * 60)
        print(f"Processing group: {group_idx + 1}/{total_groups}")

        # Create group
        group_start = group_idx * group_size
        group_end = min(group_start + group_size, len(df))
        group = df[group_start:group_end]

        # Add group
        db_add_group(group)

In [16]:
# Local Start
COLLECTION_NAME = 'wsj_emb_only'
DIMENSION = 1536
COUNT = 100 
MILVUS_HOST = 'localhost'
MILVUS_PORT = '19530'
OPENAI_ENGINE = 'text-embedding-ada-002'
API_KEY = json.load(open(get_config() / 'api.json'))['openai_api_key']

In [22]:
# Cloud Start
COLLECTION_NAME = 'wsj_emb_only'
DIMENSION = 1536
COUNT = 100 
MILVUS_HOST = json.load(open(get_config() / 'milvus.json'))['gcp_milvus_server'] 
MILVUS_PORT = '19530'
OPENAI_ENGINE = 'text-embedding-ada-002'
API_KEY = json.load(open(get_config() / 'api.json'))['openai_api_key']

In [28]:
# Ngrok
COLLECTION_NAME = 'wsj_emb_only'
DIMENSION = 1536
COUNT = 100
MILVUS_HOST = 'tcp://8.tcp.ngrok.io'
MILVUS_PORT = '12487'  
OPENAI_ENGINE = 'text-embedding-ada-002'
API_KEY = json.load(open(get_config() / 'api.json'))['openai_api_key']

In [17]:
# Connect to server
connections.disconnect('default')
connections.connect('default', host=MILVUS_HOST, port=MILVUS_PORT)

In [18]:
# Remove collection if it already exists
if utility.has_collection(COLLECTION_NAME):
    utility.drop_collection(COLLECTION_NAME)

In [19]:
# Create collection which includes the id, title, and embedding
fields = [
    FieldSchema(name='id', dtype=DataType.INT64, description='article id', is_primary=True, auto_id=False),
    FieldSchema(name='date', dtype=DataType.INT32, description='yyyy-mm-dd date as inter yyyymmdd', max_length=65535),
    FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION)
]
schema = CollectionSchema(fields=fields, description='wall street journal embeddings')
collection = Collection(name=COLLECTION_NAME, schema=schema)

# Create an index for the collection
index_params = {
    'index_type': 'IVF_FLAT',
    'metric_type': 'COSINE',
    'params': {'nlist': 5}
}
collection.create_index(field_name="embedding", index_params=index_params)

Status(code=0, message=)

In [20]:
# Insert data into collection
db_add_all(df=wsj_combine, group_size=5000)

Total groups: 167
------------------------------------------------------------
Processing group: 1/167
------------------------------------------------------------
Processing group: 2/167
------------------------------------------------------------
Processing group: 3/167
------------------------------------------------------------
Processing group: 4/167
------------------------------------------------------------
Processing group: 5/167
------------------------------------------------------------
Processing group: 6/167
------------------------------------------------------------
Processing group: 7/167
------------------------------------------------------------
Processing group: 8/167
------------------------------------------------------------
Processing group: 9/167
------------------------------------------------------------
Processing group: 10/167
------------------------------------------------------------
Processing group: 11/167
---------------------------------------------