In [1]:
import tiktoken
import os
import pandas as pd
from time import sleep
from dotenv import load_dotenv
from pathlib import Path
from openai import OpenAI, RateLimitError

In [2]:
Path("output").mkdir(exist_ok=True)

In [3]:
df_train = pd.read_parquet('data/train.parquet')
df_train.sample(5)

Unnamed: 0,text,label
13677,Ed Harris's work in this film is up to his usu...,1
3665,It must be assumed that those who praised this...,0
21177,"One used to say, concerning Nathaniel Hawthorn...",1
15133,Why has this not been released? I kind of thou...,1
6597,If you appreciate the renaissance in Asian hor...,0


In [4]:
df_test = pd.read_parquet('data/test.parquet')
df_test.sample(5)

Unnamed: 0,text,label
19040,Shame is rather unique as a war film (or rathe...,1
12470,This movie was terrible not only was the plot ...,0
7434,Herman has made northern drama his own with Li...,0
10653,Part of the movie's low rating is the emphasis...,0
11662,CHE! is a bad movie and deserves it reputation...,0


In [5]:
text_emb_3_large_cost = 0.13 / 1000000  # 0.13 USD per 1M tokens
ENCODING = tiktoken.get_encoding("cl100k_base")
emb_column = "text"
encoded_text_train = ENCODING.encode(df_train[emb_column].str.cat(sep=" "))
encoded_text_test = ENCODING.encode(df_test[emb_column].str.cat(sep=" "))
print(f"This embedding operation will cost a total of {(len(encoded_text_train) + len(encoded_text_test)) * text_emb_3_large_cost:.2f} USD.")

This embedding operation will cost a total of 1.91 USD.


In [6]:
load_dotenv()
key = os.environ.get("API_KEY")
client = OpenAI(api_key=key)

In [7]:
def get_embeddings_batch(text, model="text-embedding-3-large"):
    batch_embeds = client.embeddings.create(input=text, model=model)
    embeds = [e.embedding for e in batch_embeds.data]
    return embeds

In [8]:
def add_embeddings(df):
    batch_size = 512
    embedding_batches = []
    embeddings = [item for sublist in embedding_batches for item in sublist]
    current_batch = len(embeddings) // batch_size + 1
    total_batches = len(df) // batch_size + 1
    for i in range(current_batch, total_batches + 1):
        start_ind = (i - 1) * batch_size
        end_ind = min(start_ind + batch_size - 1, len(df))
        batch = df.loc[start_ind:end_ind, emb_column].tolist()
        print(f"Processing records for batch {i} from {start_ind} to {end_ind}")
        for j in range(10):
            try:
                embedding_batches.append(get_embeddings_batch(batch))
                break
            except RateLimitError:
                print(f"Rate limit error, waiting 15 seconds and trying again (attempt {j + 1})")
                sleep(15)
    embeddings = [item for sublist in embedding_batches for item in sublist]
    df["emb_large"] = embeddings
    return df

In [9]:
df_train_emb = add_embeddings(df_train)
df_train_emb.to_parquet("output/train_emb.parquet")

Processing records for batch 1 from 0 to 511
Processing records for batch 2 from 512 to 1023
Processing records for batch 3 from 1024 to 1535
Processing records for batch 4 from 1536 to 2047
Processing records for batch 5 from 2048 to 2559
Processing records for batch 6 from 2560 to 3071
Processing records for batch 7 from 3072 to 3583
Processing records for batch 8 from 3584 to 4095
Processing records for batch 9 from 4096 to 4607
Processing records for batch 10 from 4608 to 5119
Processing records for batch 11 from 5120 to 5631
Processing records for batch 12 from 5632 to 6143
Processing records for batch 13 from 6144 to 6655
Processing records for batch 14 from 6656 to 7167
Processing records for batch 15 from 7168 to 7679
Processing records for batch 16 from 7680 to 8191
Processing records for batch 17 from 8192 to 8703
Processing records for batch 18 from 8704 to 9215
Processing records for batch 19 from 9216 to 9727
Processing records for batch 20 from 9728 to 10239
Processing re

In [10]:
df_test_emb = add_embeddings(df_test)
df_test_emb.to_parquet("output/test_emb.parquet")

Processing records for batch 1 from 0 to 511
Processing records for batch 2 from 512 to 1023
Processing records for batch 3 from 1024 to 1535
Processing records for batch 4 from 1536 to 2047
Processing records for batch 5 from 2048 to 2559
Processing records for batch 6 from 2560 to 3071
Processing records for batch 7 from 3072 to 3583
Processing records for batch 8 from 3584 to 4095
Processing records for batch 9 from 4096 to 4607
Processing records for batch 10 from 4608 to 5119
Processing records for batch 11 from 5120 to 5631
Processing records for batch 12 from 5632 to 6143
Processing records for batch 13 from 6144 to 6655
Processing records for batch 14 from 6656 to 7167
Processing records for batch 15 from 7168 to 7679
Processing records for batch 16 from 7680 to 8191
Processing records for batch 17 from 8192 to 8703
Processing records for batch 18 from 8704 to 9215
Processing records for batch 19 from 9216 to 9727
Processing records for batch 20 from 9728 to 10239
Processing re