### Import Packages

In [1]:
import pandas as pd
import os
import time
import requests 
import tiktoken
import ray
import numpy as np
import json
import tqdm

from math import ceil

# In order for ray to work, make sure you uninstall pydantic and reinstall this: pip install "pydantic<2"
from openai import OpenAI
from class_model.model import Model
from utils.system import *

import warnings
warnings.filterwarnings('ignore')




### Data

In [None]:
collect = []
for i in range(1, 6):
    collect.append(pd.read_parquet(get_format_data() / 'art' / f'wsj_art_{i}.parquet.brotli'))
wsj_multiple = pd.concat(collect, axis=0)

### Parallelized: Get number of tokens (per article)

In [None]:
@ray.remote
def get_token_count(article_text, encoding_param):
    encoding = tiktoken.get_encoding(encoding_param)
    token_count = len(encoding.encode(article_text))
    return token_count

def process_tokens_in_batches(df, column_name, encoding_param, batch_size):
    num_batches = np.ceil(len(df) / batch_size)
    all_token_counts = []
    print(f"Number of batches: {int(num_batches)}")

    for i in range(int(num_batches)):
        print(f"Processing batch: {i + 1}/{int(num_batches)}")
        start_index = i * batch_size
        end_index = start_index + batch_size
        batch = df[column_name][start_index:end_index]

        # Start asynchronous tasks for the batch
        futures = [get_token_count.remote(text, encoding_param) for text in batch]
        token_counts = ray.get(futures)
        all_token_counts.extend(token_counts)

    # Assign the token counts back to the DataFrame
    df['n_tokens'] = all_token_counts
    return df

In [None]:
embedding_encoding = "cl100k_base" 
max_tokens = 8000

In [None]:
batch_size = 5000

# Process articles in batches
ray.init(num_cpus=16, ignore_reinit_error=True)
start_time = time.time()
wsj_multiple = process_tokens_in_batches(wsj_multiple, 'body_txt', embedding_encoding, batch_size)
elapsed_time = time.time() - start_time
print(f"Total time to get all tokens: {round(elapsed_time)} seconds")

# Shutdown Ray
ray.shutdown()

In [None]:
# Filter
print(f"Length before: {len(wsj_multiple)}")
wsj_multiple = wsj_multiple[wsj_multiple.n_tokens <= max_tokens]
print(f"Length after: {len(wsj_multiple)}")

In [None]:
# Export Data
chunks = np.array_split(wsj_multiple, 8)
for i, df in enumerate(chunks, 1):
    print(i)
    df.to_parquet(get_format_data() / 'token' / f'wsj_tokens_{i}.parquet.brotli', compression='brotli')

### Parallelized: Get embeddings (per article)

In [2]:
# Read in token dataset
collect = []
for i in range(1, 9):
    collect.append(pd.read_parquet(get_format_data() / 'token' / f'wsj_tokens_{i}.parquet.brotli'))
wsj_multiple = pd.concat(collect, axis=0)

In [3]:
wsj_multiple_token = wsj_multiple.copy(deep=True)

In [4]:
@ray.remote
def get_embedding_article(article_text, model):
    api_key = json.load(open(get_config() / 'api.json'))['openai_api_key']
    client = OpenAI(api_key=api_key)
    embedding = client.embeddings.create(input=[article_text.replace("\n", " ")], model=model).data[0].embedding
    return embedding

def process_articles_in_batches(df, column_name, model, batch_size, delay_per_batch):
    num_batches = np.ceil(len(df) / batch_size)
    
    with tqdm.tqdm(total=int(num_batches), desc="Processing batches") as pbar:
        for i in range(int(num_batches)):
            # Check if the batch has already been processed
            save_path = get_format_data() / 'openai' / f'wsj_emb_textemb3small_{i}.parquet.brotli'
            if save_path.exists():
                tqdm.write(f"Skipping batch {i + 1}/{int(num_batches)} (already processed)")
                pbar.update(1)
                continue

            # Get batch
            start_index = i * batch_size
            end_index = min(start_index + batch_size, len(df))
            batch = df[column_name][start_index:end_index]
            
            # Start asynchronous tasks for the batch
            futures = [get_embedding_article.remote(text, model) for text in batch]
            embeddings = ray.get(futures)

            # Save Batch
            tqdm.write(f"Saving progress to {save_path}")
            all_indices = df.index[start_index:end_index].tolist()
            temp_df = pd.DataFrame({'ada_embedding': embeddings}, index=all_indices)
            temp_df.to_parquet(save_path, compression='brotli')             
            tqdm.write("Progress saved")
            pbar.update(1)
            
            # Delay between batches if specified
            if delay_per_batch > 0:
                time.sleep(delay_per_batch)

    return None

In [5]:
# Parameters
model_name = 'text-embedding-3-small'
batch_size = 1000
delay_per_batch = 60

# Process articles in batches
ray.init(num_cpus=16, ignore_reinit_error=True)

start_time = time.time()
process_articles_in_batches(wsj_multiple_token, 'body_txt', model_name, batch_size, delay_per_batch)
elapsed_time = time.time() - start_time
print(f"Total time to get all embeddings: {round(elapsed_time)} seconds")

# Shutdown Ray
ray.shutdown()

2024-02-26 14:56:12,216	INFO worker.py:1673 -- Started a local Ray instance.


Number of batches: 832


Processing batches:   0%|                                                                                                                                                                                                                        | 0/832 [00:00<?, ?it/s]

Processing batch: 1/832
Saving progress to C:\Jonathan\QuantResearch\AlgoTradingModels\fenui\data\format\openai\wsj_emb_textemb3small_0.parquet.brotli...


Processing batches:   0%|▏                                                                                                                                                                                                           | 1/832 [02:47<38:33:56, 167.07s/it]

Progress saved
Processing batch: 2/832


Processing batches:   0%|▏                                                                                                                                                                                                           | 1/832 [03:04<42:36:54, 184.61s/it]

KeyboardInterrupt

