### Import Packages

In [7]:
import pandas as pd
import os
import time
import requests 
import tiktoken
import ray
import numpy as np
import json

from math import ceil

# In order for ray to work, make sure you uninstall pydantic and reinstall this: pip install "pydantic<2"
from openai import OpenAI
from class_data.data import Data
from utils.system import *

import warnings
warnings.filterwarnings('ignore')

### Data

In [8]:
# wsj_multiple = Data(folder_path=get_format_data() / 'art', file_pattern='wsj_art_*')
# wsj_multiple = wsj_multiple.concat_files()

cc_multiple = Data(folder_path=get_format_data() / 'art', file_pattern='cc_art_*')
cc_multiple = cc_multiple.concat_files()

Loading Data: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:42<00:00,  1.18it/s]


### Parallelized: Get number of tokens (per article)

In [10]:
@ray.remote
def get_token_count(article_text, encoding_param):
    encoding = tiktoken.get_encoding(encoding_param)
    token_count = len(encoding.encode(article_text))
    return token_count

def process_tokens_in_batches(df, column_name, encoding_param, batch_size):
    num_batches = np.ceil(len(df) / batch_size)
    all_token_counts = []
    print(f"Number of batches: {int(num_batches)}")

    for i in range(int(num_batches)):
        print(f"Processing batch: {i + 1}/{int(num_batches)}")
        start_index = i * batch_size
        end_index = start_index + batch_size
        batch = df[column_name][start_index:end_index]

        # Start asynchronous tasks for the batch
        futures = [get_token_count.remote(text, encoding_param) for text in batch]
        token_counts = ray.get(futures)
        all_token_counts.extend(token_counts)

    # Assign the token counts back to the DataFrame
    df['n_tokens'] = all_token_counts
    return df

In [11]:
embedding_encoding = "cl100k_base" 
max_tokens = 8191

In [13]:
batch_size = 5000

# # Process articles in batches for WSJ
# ray.init(num_cpus=16, ignore_reinit_error=True)
# start_time = time.time()
# wsj_multiple = process_tokens_in_batches(wsj_multiple, 'body_txt', embedding_encoding, batch_size)
# elapsed_time = time.time() - start_time
# print(f"Total time to get all tokens: {round(elapsed_time)} seconds")

# Process articles in batches for CC
ray.init(num_cpus=16, ignore_reinit_error=True)
start_time = time.time()
cc_multiple = process_tokens_in_batches(cc_multiple, 'body', embedding_encoding, batch_size)
elapsed_time = time.time() - start_time
print(f"Total time to get all tokens: {round(elapsed_time)} seconds")

# Shutdown Ray
ray.shutdown()

2024-04-05 21:04:20,072	INFO worker.py:1673 -- Started a local Ray instance.


Number of batches: 166
Processing batch: 1/166
Processing batch: 2/166
Processing batch: 3/166
Processing batch: 4/166
Processing batch: 5/166
Processing batch: 6/166
Processing batch: 7/166
Processing batch: 8/166
Processing batch: 9/166
Processing batch: 10/166
Processing batch: 11/166
Processing batch: 12/166
Processing batch: 13/166
Processing batch: 14/166
Processing batch: 15/166
Processing batch: 16/166
Processing batch: 17/166
Processing batch: 18/166
Processing batch: 19/166
Processing batch: 20/166
Processing batch: 21/166
Processing batch: 22/166
Processing batch: 23/166
Processing batch: 24/166
Processing batch: 25/166
Processing batch: 26/166
Processing batch: 27/166
Processing batch: 28/166
Processing batch: 29/166
Processing batch: 30/166
Processing batch: 31/166
Processing batch: 32/166
Processing batch: 33/166
Processing batch: 34/166
Processing batch: 35/166
Processing batch: 36/166
Processing batch: 37/166
Processing batch: 38/166
Processing batch: 39/166
Processing 

In [16]:
# # Filter WSJ
# print(f"Length before: {len(wsj_multiple)}")
# wsj_multiple = wsj_multiple[wsj_multiple.n_tokens <= max_tokens]
# print(f"Length after: {len(wsj_multiple)}")

# Filter CC
print(f"Length before: {len(cc_multiple)}")
cc_multiple_after = cc_multiple[cc_multiple.n_tokens <= max_tokens]
print(f"Length after: {len(cc_multiple_after)}")

Length before: 827337
Length after: 771897


In [18]:
print(f"Min Tokens: {cc_multiple.n_tokens.min()}")
print(f"Max Tokens: {cc_multiple.n_tokens.max()}")
print(f"Mean Tokens: {cc_multiple.n_tokens.mean()}")
print(f"STD Tokens: {cc_multiple.n_tokens.std()}")

Min Tokens: 0
Max Tokens: 219187
Mean Tokens: 4321.441914238091
STD Tokens: 3026.044563263028


In [None]:
# # Export Data
# chunks = np.array_split(wsj_multiple, 8)
# for i, df in enumerate(chunks, 1):
#     print(i)
#     df.to_parquet(get_format_data() / 'token' / f'wsj_tokens_{i}.parquet.brotli', compression='brotli')

# Export Data
chunks = np.array_split(cc_multiple_after, 50)
for i, df in enumerate(chunks, 1):
    print(i)
    df.to_parquet(get_format_data() / 'token' / f'cc_tokens_{i}.parquet.brotli', compression='brotli')

### Parallelized: Get embeddings (per article)

In [2]:
# Read in token dataset
collect = []
for i in range(1, 9):
    collect.append(pd.read_parquet(get_format_data() / 'token' / f'wsj_tokens_{i}.parquet.brotli'))
wsj_multiple = pd.concat(collect, axis=0)

# Read in token dataset
collect = []
for i in range(1, 9):
    collect.append(pd.read_parquet(get_format_data() / 'token' / f'cc_tokens_{i}.parquet.brotli'))
wsj_multiple = pd.concat(collect, axis=0)

In [3]:
wsj_multiple_token = wsj_multiple.copy(deep=True)

In [4]:
@ray.remote
def get_embedding_article(article_text, model):
    api_key = json.load(open(get_config() / 'api.json'))['openai_api_key']
    client = OpenAI(api_key=api_key)
    embedding = client.embeddings.create(input=[article_text.replace("\n", " ")], model=model).data[0].embedding
    return embedding

def process_articles_in_batches(df, column_name, model, batch_size, delay_per_batch):
    num_batches = np.ceil(len(df) / batch_size)
    times = []
    
    for i in range(int(num_batches)):
        start_time = time.time()
        
        # Check if the batch has already been processed
        save_path = get_format_data() / 'openai' / f'wsj_emb_textemb3small_{i}.parquet.brotli'
        if save_path.exists():
            print(f"Skipping batch {i + 1}/{int(num_batches)} (already processed)")
            continue

        # Get batch
        start_index = i * batch_size
        end_index = min(start_index + batch_size, len(df))
        batch = df[column_name][start_index:end_index]
        
        # Start asynchronous tasks for the batch
        futures = [get_embedding_article.remote(text, model) for text in batch]
        embeddings = ray.get(futures)

        # Save Batch
        print(f"Saving progress to {save_path}")
        all_indices = df.index[start_index:end_index].tolist()
        temp_df = pd.DataFrame({'ada_embedding': embeddings}, index=all_indices)
        temp_df.to_parquet(save_path, compression='brotli')
        print("Progress saved")
        
        # Time taken for the batch
        end_time = time.time()
        batch_time = end_time - start_time
        times.append(batch_time)
        
        # Calculate and print the time taken for the batch
        print(f"Batch {i + 1}/{int(num_batches)} processed in {batch_time:.2f} seconds")
        
        # Calculate and print estimated time to finish
        avg_time_per_batch = np.mean(times)
        batches_left = int(num_batches) - (i + 1)
        estimated_time_left = avg_time_per_batch * batches_left
        hours, rem = divmod(estimated_time_left, 3600)
        minutes, seconds = divmod(rem, 60)
        print(f"Estimated time to finish: {int(hours)}h {int(minutes)}m {seconds:.2f}s")
        
        # Delay between batches if specified
        if delay_per_batch > 0:
            time.sleep(delay_per_batch)

    return None

In [5]:
# Parameters
model_name = 'text-embedding-3-small'
batch_size = 1000
delay_per_batch = 0

# Process articles in batches
ray.init(num_cpus=16, ignore_reinit_error=True)

start_time = time.time()
process_articles_in_batches(wsj_multiple_token, 'body_txt', model_name, batch_size, delay_per_batch)
elapsed_time = time.time() - start_time
print(f"Total time to get all embeddings: {round(elapsed_time)} seconds")

# Shutdown Ray
ray.shutdown()

2024-03-03 02:20:52,436	INFO worker.py:1673 -- Started a local Ray instance.


Skipping batch 1/832 (already processed)
Skipping batch 2/832 (already processed)
Skipping batch 3/832 (already processed)
Skipping batch 4/832 (already processed)
Skipping batch 5/832 (already processed)
Skipping batch 6/832 (already processed)
Skipping batch 7/832 (already processed)
Skipping batch 8/832 (already processed)
Skipping batch 9/832 (already processed)
Skipping batch 10/832 (already processed)
Skipping batch 11/832 (already processed)
Skipping batch 12/832 (already processed)
Skipping batch 13/832 (already processed)
Skipping batch 14/832 (already processed)
Skipping batch 15/832 (already processed)
Skipping batch 16/832 (already processed)
Skipping batch 17/832 (already processed)
Skipping batch 18/832 (already processed)
Skipping batch 19/832 (already processed)
Skipping batch 20/832 (already processed)
Skipping batch 21/832 (already processed)
Skipping batch 22/832 (already processed)
Skipping batch 23/832 (already processed)
Skipping batch 24/832 (already processed)
S