### Import Packages

In [1]:
import pandas as pd
import logging 
import os
import time
import requests 
import tiktoken
import ray
import numpy as np
import json
import openai

from math import ceil

# In order for ray to work, make sure you uninstall pydantic and reinstall this: pip install "pydantic<2"
from class_data.data import Data
from utils.system import *

# Set OPENAI KEY
os.environ["OPENAI_API_KEY"] = json.load(open(get_config() / 'api.json'))['openai_api_key']

import warnings
warnings.filterwarnings('ignore')

### Data

In [None]:
wsj_multiple = Data(folder_path=get_format_data() / 'art', file_pattern='wsj_art_*')
wsj_multiple = wsj_multiple.concat_files()

### Parallelized: Get number of tokens (per article)

In [6]:
@ray.remote
def get_token_count(article_text, encoding_param):
    encoding = tiktoken.get_encoding(encoding_param)
    token_count = len(encoding.encode(article_text))
    return token_count

def process_tokens_in_batches(df, column_name, encoding_param, batch_size):
    num_batches = np.ceil(len(df) / batch_size)
    all_token_counts = []
    print(f"Number of batches: {int(num_batches)}")

    for i in range(int(num_batches)):
        print(f"Processing batch: {i + 1}/{int(num_batches)}")
        start_index = i * batch_size
        end_index = start_index + batch_size
        batch = df[column_name][start_index:end_index]

        # Start asynchronous tasks for the batch
        futures = [get_token_count.remote(text, encoding_param) for text in batch]
        token_counts = ray.get(futures)
        all_token_counts.extend(token_counts)

    # Assign the token counts back to the DataFrame
    df['n_tokens'] = all_token_counts
    return df

In [17]:
embedding_encoding = "cl100k_base" 
max_tokens = 8000
min_tokens = 20

In [8]:
batch_size = 5000

# Process articles in batches for WSJ
ray.init(num_cpus=16, ignore_reinit_error=True)
start_time = time.time()
wsj_multiple = process_tokens_in_batches(wsj_multiple, 'body_txt', embedding_encoding, batch_size)
elapsed_time = time.time() - start_time
print(f"Total time to get all tokens: {round(elapsed_time)} seconds")

# Shutdown Ray
ray.shutdown()

2024-04-08 14:48:03,662	INFO worker.py:1673 -- Started a local Ray instance.


Number of batches: 166
Processing batch: 1/166
Processing batch: 2/166
Processing batch: 3/166
Processing batch: 4/166
Processing batch: 5/166
Processing batch: 6/166
Processing batch: 7/166
Processing batch: 8/166
Processing batch: 9/166
Processing batch: 10/166
Processing batch: 11/166
Processing batch: 12/166
Processing batch: 13/166
Processing batch: 14/166
Processing batch: 15/166
Processing batch: 16/166
Processing batch: 17/166
Processing batch: 18/166
Processing batch: 19/166
Processing batch: 20/166
Processing batch: 21/166
Processing batch: 22/166
Processing batch: 23/166
Processing batch: 24/166
Processing batch: 25/166
Processing batch: 26/166
Processing batch: 27/166
Processing batch: 28/166
Processing batch: 29/166
Processing batch: 30/166
Processing batch: 31/166
Processing batch: 32/166
Processing batch: 33/166
Processing batch: 34/166
Processing batch: 35/166
Processing batch: 36/166
Processing batch: 37/166
Processing batch: 38/166
Processing batch: 39/166
Processing 

In [18]:
# Filter WSJ
print(f"Length before: {len(wsj_multiple)}")
wsj_multiple_after = wsj_multiple[wsj_multiple.n_tokens <= max_tokens]
wsj_multiple_after = wsj_multiple_after.loc[wsj_multiple_after.n_tokens >= 20]
print(f"Length after: {len(wsj_multiple_after)}")

Length before: 766190
Length after: 765824


In [19]:
# Export Data
chunks = np.array_split(wsj_multiple_after, 8)
for i, df in enumerate(chunks, 1):
    print(i)
    df.to_parquet(get_format_data() / 'token' / f'wsj_tokens_{i}.parquet.brotli', compression='brotli')

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50


### Parallelized: Get embeddings (per article)

In [2]:
# Read in token dataset
wsj_multiple = Data(folder_path=get_format_data() / 'token', file_pattern='wsj_tokens_*')
wsj_multiple = wsj_multiple.concat_files()

Loading Data: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:24<00:00,  2.04it/s]


In [3]:
@ray.remote
def get_embedding_article(article_text, model):
    try:
        embedding = openai.embeddings.create(input=[article_text.replace("\n", " ")], model=model).data[0].embedding
        return embedding
    except Exception as e:
        logging.error(f"An error occurred: {str(e)}")
        return None

def process_articles_in_batches(output_folder, file_prefix, df, column_name, model, batch_size, delay_per_batch):
    num_batches = np.ceil(len(df) / batch_size)
    times = []
    
    for i in range(int(num_batches)):
        start_time = time.time()
        
        # Check if the batch has already been processed
        # save_path = get_format_data() / 'openai' / f'cc_emb_textemb3small_{i}.parquet.brotli'
        save_path = output_folder / f'{file_prefix}_{i}.parquet.brotli'
        if save_path.exists():
            print(f"Skipping batch {i + 1}/{int(num_batches)} (already processed)")
            continue

        # Get batch
        start_index = i * batch_size
        end_index = min(start_index + batch_size, len(df))
        batch = df[column_name][start_index:end_index]
        
        # Start asynchronous tasks for the batch
        futures = [get_embedding_article.remote(text, model) for text in batch]
        embeddings = ray.get(futures)

        # Save Batch
        print(f"Saving progress to {save_path}")
        all_indices = df.index[start_index:end_index]
        temp_df = pd.DataFrame({'ada_embedding': embeddings}, index=all_indices)
        temp_df.to_parquet(save_path, compression='brotli')
        print("Progress saved")

        # Delay between batches if specified
        if delay_per_batch > 0:
            time.sleep(delay_per_batch)
        
        # Time taken for the batch
        end_time = time.time()
        batch_time = end_time - start_time
        times.append(batch_time)
        
        # Calculate and print the time taken for the batch
        print(f"Batch {i + 1}/{int(num_batches)} processed in {batch_time:.2f} seconds")
        
        # Calculate and print estimated time to finish
        avg_time_per_batch = np.mean(times)
        batches_left = int(num_batches) - (i + 1)
        estimated_time_left = avg_time_per_batch * batches_left
        hours, rem = divmod(estimated_time_left, 3600)
        minutes, seconds = divmod(rem, 60)
        print(f"Estimated time to finish: {int(hours)}h {int(minutes)}m {seconds:.2f}s")
        
    return None

In [None]:
# Parameters
output_folder = get_format_data() / 'openai'
file_prefix = 'cc_emb_textemb3small'
model_name = 'text-embedding-3-small'
batch_size = 1000
delay_per_batch = 0

In [7]:
# Process articles in batches
ray.init(num_cpus=16, ignore_reinit_error=True)

start_time = time.time()
process_articles_in_batches(output_folder=output_folder, file_prefix=file_prefix, df=wsj_multiple, column_name='body_txt', model=model_name, batch_size=batch_size, delay_per_batch=delay_per_batch)
elapsed_time = time.time() - start_time
print(f"Total time to get all embeddings: {round(elapsed_time)} seconds")

# Shutdown Ray
ray.shutdown()

2024-04-09 10:13:08,875	INFO worker.py:1673 -- Started a local Ray instance.


Skipping batch 1/766 (already processed)
Skipping batch 2/766 (already processed)
Skipping batch 3/766 (already processed)
Skipping batch 4/766 (already processed)
Skipping batch 5/766 (already processed)
Skipping batch 6/766 (already processed)
Skipping batch 7/766 (already processed)
Skipping batch 8/766 (already processed)
Skipping batch 9/766 (already processed)
Skipping batch 10/766 (already processed)
Skipping batch 11/766 (already processed)
Skipping batch 12/766 (already processed)
Skipping batch 13/766 (already processed)
Skipping batch 14/766 (already processed)
Skipping batch 15/766 (already processed)
Skipping batch 16/766 (already processed)
Skipping batch 17/766 (already processed)
Skipping batch 18/766 (already processed)
Skipping batch 19/766 (already processed)
Skipping batch 20/766 (already processed)
Skipping batch 21/766 (already processed)
Skipping batch 22/766 (already processed)
Skipping batch 23/766 (already processed)
Skipping batch 24/766 (already processed)
S