In [2]:
import openai
import os
import tiktoken
import time
from typing import List, Dict, Any
import pandas as pd
import asyncio
import logging
from tabulate import tabulate
import textwrap
import nest_asyncio

nest_asyncio.apply()

In [3]:
nest_asyncio.apply()
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

if not OPENAI_API_KEY:
    print("OPENAI_API_KEY is not set")
    exit(1)

openai.api_key = OPENAI_API_KEY
client = openai.OpenAI(api_key=OPENAI_API_KEY)

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def num_tokens_from_string(string: str, encoding_name: str = "cl100k_base") -> int:
    """Returns the number of tokens in a text string."""
    encoding = tiktoken.get_encoding(encoding_name)
    num_tokens = len(encoding.encode(string))
    return num_tokens

In [16]:
def get_embeddings(
    df: pd.DataFrame,
    num_rows: int = 10, 
    max_tokens: int = 8191, # Max tokens for text-embedding-3-small
    encoding_name: str = "cl100k_base",
    price_per_token: float = 0.02 / 1000000 # Price for text-embedding-3-small per token
) -> List[Dict[str, Any]]:
    """
    Process a DataFrame, adding embeddings to a specified number of rows,
    ensuring the total number of tokens per request does not exceed max_tokens.
    Batches requests to OpenAI API.
    """
    result_list = []
    batch_case_text = []
    batch_ids = []
    current_tokens = 0
    total_tokens = 0
    total_batches = 0
    start_time = time.time()

    print(f"Processing {num_rows} rows...")

    for index, row in df.head(num_rows).iterrows():
        case_text = row['case_text']
        case_id = row['case_id']
        case_title = row['case_title'] # Get case_title

        # Skip row if case_text is not a string or is empty/whitespace
        if not isinstance(case_text, str) or not case_text.strip():
            print(f"Skipping row {index} with invalid case_text: {case_text}")
            continue

        tokens = num_tokens_from_string(case_text, encoding_name)

        # Check if adding this text would exceed the token limit for the current batch
        if current_tokens + tokens > max_tokens:
            # Process the current batch before adding the new text
            if batch_case_text:
                print(f"Processing batch of {len(batch_case_text)} items, tokens: {current_tokens}")
                try:
                    embeddings = client.embeddings.create(
                        model="text-embedding-3-small",
                        input=batch_case_text,
                        dimensions=512
                    ).data[0].embedding
                    for i, text in enumerate(batch_case_text):
                        # Find original title for the corresponding case_id in the batch
                        original_title = df.loc[df['case_id'] == batch_ids[i], 'case_title'].iloc[0]
                        result_list.append({
                            'case_id': batch_ids[i],
                            'case_title': original_title, # Add case_title here
                            'case_text': text,
                            'embeddings': embeddings
                        })
                    total_batches += 1
                except Exception as e:
                    print(f"Error processing batch: {e}")

                # Reset batch
                batch_case_text = []
                batch_ids = []
                current_tokens = 0

        # Add the current text to the batch if it fits (or it's the start of a new batch)
        if tokens <= max_tokens: # Ensure single item is not too large
             batch_case_text.append(case_text)
             batch_ids.append(case_id)
             current_tokens += tokens
             total_tokens += tokens
        else:
            print(f"Skipping row {index} as it exceeds max_tokens: {tokens} tokens")


    # Process the final batch if it's not empty
    if batch_case_text:
        print(f"Processing final batch of {len(batch_case_text)} items, tokens: {current_tokens}")
        try:
            embeddings = client.embeddings.create(
                model="text-embedding-3-small",
                input=batch_case_text,
                dimensions=512
            ).data[0].embedding
            for i, text in enumerate(batch_case_text):
                 # Find original title for the corresponding case_id in the batch
                 original_title = df.loc[df['case_id'] == batch_ids[i], 'case_title'].iloc[0]
                 result_list.append({
                    'case_id': batch_ids[i],
                    'case_title': original_title, # Add case_title here
                    'case_text': text,
                    'embeddings': embeddings
                 })
            total_batches += 1
        except Exception as e:
            print(f"Error processing final batch: {e}")


    end_time = time.time()
    duration = end_time - start_time
    money_burned = total_tokens * price_per_token

    # Print the statistics
    print(f"Completed in {duration:.2f} seconds.")
    print(f"Total number of tokens: {total_tokens:,}")
    print(f"Total number of batches: {total_batches:,}")
    print(f"Money burned: ${money_burned:.6f}")


    return result_list

In [6]:
def wrap_text(text, wide=60):
    if isinstance(text, str):
        return "\n".join(textwrap.wrap(text, wide))
    else:
        return text

In [7]:
df = pd.read_csv('../data/legal_text_first_1000.csv')

df_display = df.head(3).copy()

for col in df_display.columns:
    df_display[col] = df_display[col].apply(lambda x: wrap_text(x, 60))
    
print(tabulate(df_display, headers='keys', tablefmt='grid', showindex=False))



+-----------+----------------+----------------------------------------------------------+--------------------------------------------------------------+
| case_id   | case_outcome   | case_title                                               | case_text                                                    |
| Case1     | cited          | Alpine Hardwood (Aust) Pty Ltd v Hardys Pty Ltd (No 2)   | Ordinarily that discretion will be exercised so that costs   |
|           |                | [2002] FCA 224 ; (2002) 190 ALR 121                      | follow the event and are awarded on a party and party basis. |
|           |                |                                                          | A departure from normal practice to award indemnity costs    |
|           |                |                                                          | requires some special or unusual feature in the case: Alpine |
|           |                |                                                    

In [17]:
all_cases_embeddings = get_embeddings(df, num_rows=1000)

Processing 1000 rows...
Skipping row 24 with invalid case_text: nan
Processing batch of 26 items, tokens: 8050


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 21 items, tokens: 8191


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 24 items, tokens: 7852


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 32 items, tokens: 6566


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Skipping row 107 with invalid case_text: nan
Processing batch of 11 items, tokens: 6484


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 20 items, tokens: 8137


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 21 items, tokens: 6698


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 4 items, tokens: 7393


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 4 items, tokens: 8002


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 22 items, tokens: 8033


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 12 items, tokens: 5709


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 13 items, tokens: 7445


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 11 items, tokens: 7903


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 16 items, tokens: 7932


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 7 items, tokens: 4215


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 4 items, tokens: 7362


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 12 items, tokens: 7932


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 19 items, tokens: 8094


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 24 items, tokens: 8125


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 30 items, tokens: 8131


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 28 items, tokens: 8004


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 19 items, tokens: 8101


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 7 items, tokens: 7545


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 11 items, tokens: 7854


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 16 items, tokens: 8007


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 19 items, tokens: 6838


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 15 items, tokens: 7516


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 15 items, tokens: 7997


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 9 items, tokens: 8025


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 16 items, tokens: 7598


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 16 items, tokens: 8167


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 15 items, tokens: 8170


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 21 items, tokens: 7814


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 31 items, tokens: 7992


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 14 items, tokens: 7899


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 15 items, tokens: 8061


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 12 items, tokens: 7695


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 12 items, tokens: 7996


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 23 items, tokens: 7449


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 17 items, tokens: 7545


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 22 items, tokens: 7815


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 23 items, tokens: 5640


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 18 items, tokens: 7389


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 12 items, tokens: 8100


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 15 items, tokens: 8007


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 15 items, tokens: 7925


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 14 items, tokens: 7715


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 19 items, tokens: 6775


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 15 items, tokens: 8072


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 19 items, tokens: 7539


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 8 items, tokens: 7995


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 18 items, tokens: 8009


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 10 items, tokens: 7959


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 21 items, tokens: 8107


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 17 items, tokens: 7993


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 22 items, tokens: 8143


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing batch of 23 items, tokens: 7986


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Skipping row 972 with invalid case_text: nan
Processing batch of 23 items, tokens: 7303


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Processing final batch of 19 items, tokens: 7097


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Completed in 42.79 seconds.
Total number of tokens: 450,096
Total number of batches: 59
Money burned: $0.009002


In [18]:
# Save embeddings to a csv file
df_embeddings = pd.DataFrame(all_cases_embeddings)

for col in df_embeddings.columns:
    df_embeddings[col] = df_embeddings[col].apply(lambda x: wrap_text(x, 60))
    
print(tabulate(df_embeddings.head(3), headers='keys', tablefmt='grid', showindex=False))

df_embeddings.to_csv('../data/all_cases_embeddings.csv', index=False)

+-----------+----------------------------------------------------------+--------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [None]:
# create a training set of 500 cases for BM25 encoder
