# bootstrap_words_02_add_frequency_embedding

- read in CSV -> spark df
- scrape word frequencies from API -> joblib dict
- scrape word embeddings from BERT -> joblib dict
- load frequency and embeddings, add to spark df
- NO date added column, just use version
- save as Delta table bronze.words


In [None]:
%run "./00_setup.ipynb"

In [None]:
import os
import joblib
from pyspark.sql import functions as F
from pyspark.sql.types import *
from src.constants import WORDLIST_PATH, WORDLIST_TEMP_CSV_FILENAME, WORD_DATA_FILENAME
from src.ngramsutils import get_word_frequencies_threaded
from src.embeddingutils import get_word_embeddings
###
import numpy as np
import os
from src.wordutils import get_letter_set, filter_wordlist
from src.fileutils import word_file_to_set, get_local_path
from src.constants import (WORDS_PKL_FILENAME,
                           WORDS_PARQUET_FILENAME,
                           NGRAMS_API_BASE,
                           NGRAMS_BATCH_SIZE)
from src.ngramsutils import get_word_frequencies

In [None]:
def collect_word_data_with_checkpoint(words_list: list, 
                                     output_path: str,
                                     batch_size: int = 100,
                                     resume_job: bool = False):
    """
    Collect frequency and embedding data for a list of words with checkpointing.
    Saves final results as a pickle file that can be reused for multiple purposes.
    
    Args:
        words_list: List of words to process
        output_path: Path to save final pickle file with all data
        batch_size: Batch size for API calls (default 100 for API limits)
        resume_job: Whether to resume from existing checkpoint
    
    Returns:
        dict: Contains 'freq_dict' and 'embeddings_dict'
    """
    checkpoint_path = output_path.replace('.joblib', '_checkpoint.joblib')
    
    # Check for existing final output
    if os.path.exists(output_path) and not resume_job:
        print(f"Final output already exists at {output_path}")
        with open(output_path, 'rb') as f:
            return joblib.load(f)
    
    # Initialize or load checkpoint
    freq_dict = {}
    embeddings_dict = {}
    start_batch = 0
    
    if resume_job and os.path.exists(checkpoint_path):
        with open(checkpoint_path, 'rb') as f:
            checkpoint_data = joblib.load(f)
            freq_dict = checkpoint_data.get('freq_dict', {})
            embeddings_dict = checkpoint_data.get('embeddings_dict', {})
            start_batch = checkpoint_data.get('last_batch', 0) + 1
        print(f"Resuming from batch {start_batch} with {len(freq_dict)} words already processed")
    
    total_batches = (len(words_list) + batch_size - 1) // batch_size
    
    try:
        # Process remaining batches
        for batch_idx in range(start_batch, total_batches):
            start_idx = batch_idx * batch_size
            end_idx = min(start_idx + batch_size, len(words_list))
            batch_words = words_list[start_idx:end_idx]
            
            print(f"Processing batch {batch_idx + 1}/{total_batches}: words {start_idx}-{end_idx-1}")
            
            # Get frequencies for this batch
            batch_freq_dict = get_word_frequencies_threaded(batch_words, max_workers=10)
            freq_dict.update(batch_freq_dict)
            
            # Get embeddings for this batch
            batch_embeddings_dict = get_word_embeddings(batch_words)
            embeddings_dict.update(batch_embeddings_dict)
            
            # Save checkpoint after each batch
            checkpoint_data = {
                'freq_dict': freq_dict,
                'embeddings_dict': embeddings_dict,
                'last_batch': batch_idx,
                'total_batches': total_batches
            }
            with open(checkpoint_path, 'wb') as f:
                joblib.dump(checkpoint_data, f)
            
            print(f"✅ Completed batch {batch_idx + 1}/{total_batches}")
    
    except Exception as err:
        print(f"Exception occurred! Resume job from batch {batch_idx}")
        print(f"Current progress: {len(freq_dict)} words processed")
        raise err
    
    # Save final output
    final_data = {
        'freq_dict': freq_dict,
        'embeddings_dict': embeddings_dict,
        'processed_words': list(freq_dict.keys()),
        'total_words': len(freq_dict),
        'metadata': {
            'batch_size': batch_size,
            'total_batches': total_batches
        }
    }
    
    with open(output_path, 'wb') as f:
        joblib.dump(final_data, f)
    
    print(f"✅ Processing complete! Saved {len(freq_dict)} words to {output_path}")
    
    # Clean up checkpoint
    if os.path.exists(checkpoint_path):
        os.remove(checkpoint_path)
        print("✅ Cleaned up checkpoint file")
    
    return final_data

In [None]:
# Read CSV directly into Spark and collect word data
initial_schema = StructType([
    StructField("word", StringType(), True),
    StructField("letter_set", StringType(), True),
    StructField("version", IntegerType(), True)
])

temp_path = get_local_path(f"{WORDLIST_PATH}/{WORDLIST_TEMP_CSV_FILENAME}")
spark_df = spark.read.csv(temp_path, header=True, schema=initial_schema)
words_list = [row.word for row in spark_df.select("word").collect()]
word_data_path = get_local_path(f"{WORDLIST_PATH}/{WORD_DATA_FILENAME}")

print(f"Read in {len(words_list)} words from csv.")

In [None]:
# TODO: Testing purposes only -- remove later
test_words_list = words_list[245:111111:350]
print(f"testing on {len(test_words_list)} words...")
word_data = collect_word_data_with_checkpoint(
    test_words_list, 
    word_data_path,
    batch_size=100,
    resume_job=False
)



### END TODO

# First run - set resume_job=False
# word_data = collect_word_data_with_checkpoint(
#     words_list, 
#     word_data_path,
#     batch_size=100,
#     resume_job=False
# )

# To resume if it crashes - set resume_job=True
# word_data = collect_word_data_with_checkpoint(
#     words_list, 
#     word_data_path,
#     batch_size=100,
#     resume_job=True
# )


In [None]:
print(f"✅ Saved embedding and frequency data to {word_data_path}")