### Import Packages

In [26]:
import pandas as pd
import os
import time
import requests 
import tiktoken
import ray
import numpy as np
import json
import torch
import numpy as np

from typing import Dict
from transformers import AutoModel, AutoTokenizer
from math import ceil
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim

from utils.system import *
from class_data.data import Data

import warnings
warnings.filterwarnings('ignore')

### Data

In [None]:
collect = []
for i in range(1, 6):
    collect.append(pd.read_parquet(get_format_data() / 'art' / f'wsj_art_{i}.parquet.brotli'))
wsj_multiple = pd.concat(collect, axis=0)

### Parallelized: Get number of tokens (per article)

In [None]:
@ray.remote
def get_token_count(article_text, encoding_param):
    encoding = tiktoken.get_encoding(encoding_param)
    token_count = len(encoding.encode(article_text))
    return token_count

def process_tokens_in_batches(df, column_name, encoding_param, batch_size):
    num_batches = np.ceil(len(df) / batch_size)
    all_token_counts = []
    print(f"Number of batches: {int(num_batches)}")

    for i in range(int(num_batches)):
        print(f"Processing batch: {i + 1}/{int(num_batches)}")
        start_index = i * batch_size
        end_index = start_index + batch_size
        batch = df[column_name][start_index:end_index]

        # Start asynchronous tasks for the batch
        futures = [get_token_count.remote(text, encoding_param) for text in batch]
        token_counts = ray.get(futures)
        all_token_counts.extend(token_counts)

    # Assign the token counts back to the DataFrame
    df['n_tokens'] = all_token_counts
    return df

In [None]:
embedding_encoding = "cl100k_base" 
max_tokens = 8000

In [None]:
batch_size = 5000

# Process articles in batches
ray.init(num_cpus=16, ignore_reinit_error=True)
start_time = time.time()
wsj_multiple = process_tokens_in_batches(wsj_multiple, 'body_txt', embedding_encoding, batch_size)
elapsed_time = time.time() - start_time
print(f"Total time to get all tokens: {round(elapsed_time)} seconds")

# Shutdown Ray
ray.shutdown()

In [None]:
# Filter
print(f"Length before: {len(wsj_multiple)}")
wsj_multiple = wsj_multiple[wsj_multiple.n_tokens <= max_tokens]
print(f"Length after: {len(wsj_multiple)}")

In [None]:
# Export Data
chunks = np.array_split(wsj_multiple, 8)
for i, df in enumerate(chunks, 1):
    print(i)
    df.to_parquet(get_format_data() / 'token' / f'wsj_tokens_{i}.parquet.brotli', compression='brotli')

### Parallelized: Get embeddings (per array of article)

In [37]:
# Load articles
wsj_multiple = Data(folder_path=get_format_data() / 'token', file_pattern='wsj_tokens_*').concat_files()

Loading Data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:04<00:00,  1.82it/s]


In [38]:
wsj_multiple_token = wsj_multiple.copy(deep=True)

In [41]:
@ray.remote
def get_embedding_article(articles, model):
    embeddings = model.encode(articles)
    return embeddings

def process_articles_in_batches(df, column_name, model, batch_size, article_size, delay_per_batch):    
    num_batches = np.ceil(len(df) / batch_size)
    times = []
    
    for i in range(int(num_batches)):
        start_time = time.time()
        
        # Check if the batch has already been processed
        save_path = get_format_data() / 'mxbai' / f'wsj_emb_mxbaiembedlargev1_{i}.parquet.brotli'
        if save_path.exists():
            print(f"Skipping batch {i + 1}/{int(num_batches)} (already processed)")
            continue

        # Get batch
        start_index = i * batch_size
        end_index = min(start_index + batch_size, len(df))
        batch_texts = df[column_name][start_index:end_index]
        
        # Group texts into sub-batches of size article_size
        sub_batches = [batch_texts[j:j+article_size].tolist() for j in range(0, len(batch_texts), article_size)]
        
        # Start asynchronous tasks for each sub-batch
        futures = [get_embedding_article.remote(sub_batch, model) for sub_batch in sub_batches]
        embeddings_lists = ray.get(futures)
        
        # Convert embeddings to the desired format (list of embeddings)
        embeddings_formatted = [embedding.tolist() for sublist in embeddings_lists for embedding in sublist]

        # Save Batch
        print(f"Saving progress to {save_path}")
        all_indices = df.index[start_index:end_index].tolist()
        # Create a DataFrame with a single column for embeddings
        embeddings_df = pd.DataFrame({'embedding': embeddings_formatted}, index=all_indices)
        embeddings_df.to_parquet(save_path, compression='brotli')
        print("Progress saved")
        
        # Time taken for the batch
        end_time = time.time()
        batch_time = end_time - start_time
        times.append(batch_time)
        
        print(f"Batch {i + 1}/{int(num_batches)} processed in {batch_time:.2f} seconds")
        
        # Calculate and print estimated time to finish
        avg_time_per_batch = np.mean(times)
        batches_left = int(num_batches) - (i + 1)
        estimated_time_left = avg_time_per_batch * batches_left
        hours, rem = divmod(estimated_time_left, 3600)
        minutes, seconds = divmod(rem, 60)
        print(f"Estimated time to finish: {int(hours)}h {int(minutes)}m {seconds:.2f}s")
        
        if delay_per_batch > 0:
            time.sleep(delay_per_batch)

    return None

In [42]:
# Parameters
model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
batch_size = 5
delay_per_batch = 0
article_size = 10

# Process articles in batches
ray.init(num_cpus=16, ignore_reinit_error=True)

start_time = time.time()
process_articles_in_batches(wsj_multiple_token, 'body_txt', model, batch_size, article_size, delay_per_batch)
elapsed_time = time.time() - start_time
print(f"Total time to get all embeddings: {round(elapsed_time)} seconds")

# Shutdown Ray
ray.shutdown()

2024-04-02 22:45:03,712	INFO worker.py:1507 -- Calling ray.init() again after it has already been called.


Skipping batch 1/166216 (already processed)
Skipping batch 2/166216 (already processed)
Saving progress to C:\Jonathan\QuantResearch\AlgoTradingModels\fenui\data\format\mxbai\wsj_emb_mxbaiembedlargev1_2.parquet.brotli
Progress saved
Batch 3/166216 processed in 27.63 seconds
Estimated time to finish: 1275h 48m 52.31s



KeyboardInterrupt



In [43]:
# Load articles
test = Data(folder_path=get_format_data() / 'mxbai', file_pattern='wsj_emb_mxbaiembedlargev1_*').concat_files()

Loading Data: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 244.87it/s]
