<a href="https://colab.research.google.com/github/mshumer/openai-logit-bias-classification-walkthrough/blob/main/OpenAI_powered_Classification_with_Logit_Bias.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Logit Bias for OpenAI-powered Classification
By Matt Shumer (https://twitter.com/mattshumer_)

Github repo: https://github.com/mshumer/openai-logit-bias-classification-walkthrough

Reusable notebook for creating classifiers with OpenAI models and logit bias, benchmarking each model's performance for your use-case, and exporting the classifiers you create for use in your code.

In the first cell, add in your OpenAI key to get started.

In [None]:
!pip install openai
!pip install prettytable
!pip install tiktoken

import openai
import tiktoken
import prettytable

openai.api_key = "ADD YOUR OPENAI KEY HERE"

**First, a bit about logit bias.**

> Logit bias is a technique used to influence the output probabilities of a machine learning model, specifically in the context of natural language processing models like GPT-3. It works by adding a constant value to the logits (pre-activation scores) of specific tokens (words or phrases) before they are passed through the activation function to calculate probabilities.

In simpler terms, logit bias helps you guide the model's output towards certain tokens, making them more or less likely to be chosen in the generated text. It's a way to control the model's behavior by nudging it towards the desired output.

Now, let's get started. In this notebook, we're using logit bias to create powerful classifiers with very little effort.

------------

In this cell, we're creating a binary true/false classifier. We're using it to classify statements as true or false, but feel free to modify it to support your use-case, by a) updating the test cases, and b) updating the messages in the OpenAI call.


This cell is used to benchmark multiple OpenAI models against your use-case, to show you which model is best. You can also see the average latency of each model.


As with all the cells in this notebook, you can adjust the models if you don't yet have access to GPT-4.

In [None]:
from prettytable import PrettyTable
import time

# here we're doing true or false statements... feel free to adjust for your use-case
test_cases = [
    {
        'input': 'The United States includes New York City.',
        'expected_output': 'true'
    },
    {
        'input': 'The Earth is flat.',
        'expected_output': 'false'
    },
    {
        'input': 'The capital of France is Paris.',
        'expected_output': 'true'
    },
    {
        'input': 'The Moon is made of green cheese.',
        'expected_output': 'false'
    },
    {
        'input': 'Water boils at 100 degrees Celsius at sea level.',
        'expected_output': 'true'
    }
]

models = ['gpt-4', 'gpt-4-0613', 'gpt-3.5-turbo-0613', 'gpt-3.5-turbo'] # choose what models you'd like to benchmark
model_results = {model: {'correct': 0, 'total': 0} for model in models}

# Initialize the table
table = PrettyTable()
table.field_names = ["Input", "Expected"] + models

# Wrap the text in the "Input" column
table.max_width["Input"] = 100

# Initialize timers
model_timers = {model: 0 for model in models}

for test_case in test_cases:
    row = [test_case['input'], test_case['expected_output']]
    for model in models:
        start_time = time.time()

        x = openai.ChatCompletion.create(
            model=model,
            messages=[
                {"role": "system", "content": "Determine if the input statement is true or false. Return 'true' if it is true, return 'false' if it is false."}, # adjust for your use-case
                {"role": "user", "content": f"Here is the statement: `{test_case['input']}`"} # adjust for your use-case
            ],
            logit_bias={
                '1904': 100,  # true to 100
                '3934': 100,  # false to 100
            },
            max_tokens=1,
            temperature=0,
        ).choices[0].message.content

        end_time = time.time()
        model_timers[model] += (end_time - start_time)

        status = "✅" if x == test_case['expected_output'] else "❌"
        row.append(status)

        # Update model results
        if x == test_case['expected_output']:
            model_results[model]['correct'] += 1
        model_results[model]['total'] += 1

    table.add_row(row)

print(table)

# Calculate and print the percentage of correct answers and average time for each model
for model in models:
    correct = model_results[model]['correct']
    total = model_results[model]['total']
    percentage = (correct / total) * 100
    avg_time = model_timers[model] / total
    print(f"{model} got {percentage:.2f}% correct. Average time: {avg_time:.2f} seconds.")

In this cell, we're creating a more flexible classifier, with inputs we can define. In the example here, we're using it to sort statements into happy, sad, neutral buckets, but feel free to modify it to support your use-case, by a) updating the test cases, and b) updating the messages in the OpenAI call.


This cell is used to benchmark multiple OpenAI models against your use-case, to show you which model is best. You can also see the average latency of each model.


As with all the cells in this notebook, you can adjust the models if you don't yet have access to GPT-4.

In [None]:
from prettytable import PrettyTable
import time
import openai

# Function to get token IDs for a given model family
def get_token_ids(model_family, text):
    if 'gpt-4' in model_family:
        model_enc = tiktoken.encoding_for_model("gpt-4")
    else:
        model_enc = tiktoken.encoding_for_model("gpt-3")
    return model_enc.encode(text)

# here we're doing happy, sad, neutral sentiment analysis... feel free to adjust for your use-case
test_cases = [
    {
        'input': 'My name is Matt',
        'expected_output': 'neutral'
    },
    {
        'input': 'I am so happy today!',
        'expected_output': 'happy'
    },
    {
        'input': 'I had a bad day.',
        'expected_output': 'sad'
    },
    {
        'input': 'The temperature is 50 degrees.',
        'expected_output': 'neutral'
    },
    {
        'input': 'I just won the lottery!',
        'expected_output': 'happy'
    }
]

models = ['gpt-4', 'gpt-4-0613', 'gpt-3.5-turbo-0613', 'gpt-3.5-turbo'] # choose what models you'd like to benchmark
model_results = {model: {'correct': 0, 'total': 0} for model in models}

# Initialize the table
table = PrettyTable()
table.field_names = ["Input", "Expected"] + models

# Wrap the text in the "Input" column
table.max_width["Input"] = 100

# Initialize timers
model_timers = {model: 0 for model in models}

for test_case in test_cases:

    row = [test_case['input'], test_case['expected_output']]

    for model in models:

        logit_bias_values = {
            'happy': get_token_ids(models[0], 'happy')[0],
            'sad': get_token_ids(models[0], 'sad')[0],
            'neutral': get_token_ids(models[0], 'neutral')[0]
        }

        start_time = time.time()

        x = openai.ChatCompletion.create(
            model=model,
            messages=[
                {"role": "user", "content": "Determine the sentiment of the input statement. Return 'happy', 'sad', or 'neutral'."}, # adjust for your use-case
                {"role": "user", "content": f"Here is the statement: `{test_case['input']}`"} # adjust for your use-case
            ],
            logit_bias={
                str(logit_bias_values['happy']): 100,
                str(logit_bias_values['sad']): 100,
                str(logit_bias_values['neutral']): 100
            },
            max_tokens=1,
            temperature=0,
        ).choices[0].message.content

        end_time = time.time()
        model_timers[model] += (end_time - start_time)

        status = "✅" if x == test_case['expected_output'] else "❌"
        row.append(status)

        # Update model results
        if x == test_case['expected_output']:
            model_results[model]['correct'] += 1
        model_results[model]['total'] += 1

    table.add_row(row)

print(table)

# Calculate and print the percentage of correct answers and average time for each model
for model in models:
    correct = model_results[model]['correct']
    total = model_results[model]['total']
    percentage = (correct / total) * 100
    avg_time = model_timers[model] / total
    print(f"{model} got {percentage:.2f}% correct. Average time: {avg_time:.2f} seconds.")

# Now that we've benchmarked, we can pull out the classifiers to use in external code.



First, the 'true'/'false' classifier. Feel free to modify it for your use-case.

In [None]:
openai.ChatCompletion.create(
    model='gpt-3.5-turbo', # adjust to the ideal model for your use-case
    messages=[
        {"role": "user", "content": "Determine if the input statement is true or false. Return 'true' if it is true, return 'false' if it is false."}, # adjust for your use-case
        {"role": "user", "content": f"Here is the statement: `{test_case['input']}`"} # adjust for your use-case
    ],
    logit_bias={
        '1904': 100,  # true to 100
        '3934': 100,  # false to 100
    },
    max_tokens=1,
    temperature=0,
).choices[0].message.content

Now, the dynamic classifier. Feel free to modify it for your use-case.

Note -- to speed this up a bit, grab the outputs from `get_token_ids` and hardcode them into your OpenAI call.

In [None]:
# Function to get token IDs for a given model family
def get_token_ids(model_family, text):
    if 'gpt-4' in model_family:
        model_enc = tiktoken.encoding_for_model("gpt-4")
    else:
        model_enc = tiktoken.encoding_for_model("gpt-3")
    return model_enc.encode(text)

# Adjust these for your use-case
logit_bias_values = {
    'happy': get_token_ids('gpt-3.5-turbo', 'happy')[0],
    'sad': get_token_ids('gpt-3.5-turbo', 'sad')[0],
    'neutral': get_token_ids('gpt-3.5-turbo', 'neutral')[0]
}


openai.ChatCompletion.create(
  model='gpt-3.5-turbo',
  messages=[
      {"role": "user", "content": "Determine the sentiment of the input statement. Return 'happy', 'sad', or 'neutral'."}, # adjust for your use-case
      {"role": "user", "content": f"Here is the statement: `{test_case['input']}`"} # adjust for your use-case
  ],
  logit_bias={
      str(logit_bias_values['happy']): 100,
      str(logit_bias_values['sad']): 100,
      str(logit_bias_values['neutral']): 100
  },
  max_tokens=1,
  temperature=0,
).choices[0].message.content