In [1]:
# Install tiktoken if it's not already installed
%pip install tiktoken

# Import necessary libraries
import json
import tiktoken

# Function to count tokens using tiktoken
def count_tokens(text, model_name="gpt-3.5-turbo"):
    encoding = tiktoken.encoding_for_model(model_name)
    return len(encoding.encode(text))

# Function to log warnings
def log_warning(warning_message):
    with open('data_warnings.log', 'a') as log_file:
        log_file.write(warning_message + '\n')

# Load the dataset
with open('../data/mvp_dataset.json', 'r') as f:
    data = json.load(f)

total_tokens = 0
warnings = []

# Analyze the dataset for potential issues and count tokens
for entry in data:
    for message in entry['messages']:
        # Check if content is missing
        if 'content' not in message:
            warning = f"Missing content in message: {message}"
            warnings.append(warning)
            log_warning(warning)
            continue
        
        content = message['content']
        
        # Check if content is empty
        if not content:
            warning = f"Empty content in message: {message}"
            warnings.append(warning)
            log_warning(warning)
            continue
        
        # Check if role is not one of the expected values
        if message['role'] not in ['user', 'assistant', 'system']:
            warning = f"Unexpected role '{message['role']}' in message: {message}"
            warnings.append(warning)
            log_warning(warning)
            continue
        
        # Count tokens
        if isinstance(content, dict):
            # Convert dict to string if needed
            content = json.dumps(content)
        if isinstance(content, str):
            total_tokens += count_tokens(content)
        else:
            warning = f"Skipping non-string content: {content}"
            warnings.append(warning)
            log_warning(warning)

print(f"Total tokens in the dataset: {total_tokens}")

# Display all warnings
print("\nWarnings:")
for warning in warnings:
    print(warning)
