# LLMs as classifiers (Part 2): log probabilities in practice

This notebook explores the practical aspects of using LLM log probabilities for classification. We will conduct two main experiments:
1.  **Stability Analysis**: How sensitive are log-prob scores to changes in models, prompts, and label phrasing?
2.  **Thresholded Classification**: How can we use log-prob scores to build a tunable classifier?

In [32]:
import os
import openai
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import json
import requests
from tqdm.notebook import tqdm
import warnings

# Suppress verbose warnings
warnings.filterwarnings('ignore')

# Set up plotting style
sns.set_theme(style="whitegrid")

# Wide enough to see the generated texts
pd.set_option('display.max_colwidth', 240)

# Initialize OpenAI client (assumes OPENAI_API_KEY is set as an environment variable)
try:
    client = openai.OpenAI()
    print("OpenAI client initialized.")
except openai.OpenAIError as e:
    print(f"Error initializing OpenAI client: {e}")
    client = None

GPT5_MINI_ALIAS = "gpt-5-mini"
# GPT35_TURBO = "gpt-3.5-turbo-0125"
# MODELS_TO_TEST = [GPT5_MINI_ALIAS, GPT35_TURBO]

OpenAI client initialized.


## 1. Synthetic Data Generation

First, we generate a simple, synthetic dataset for text classification. This keeps the focus on model behavior rather than data complexity. The dataset is cached to a CSV file to avoid re-generating it on every run.

In [24]:
topics = [
    'Technology', 'Sports', 'Politics', 'Art', 'Science', 
    'Health', 'Education', 'Travel', 'Food', 'History'
]
NUM_SAMPLES = 1

def generate_text_for_topics(topics_list, verbose=False, model=GPT5_MINI_ALIAS):
    prompt = f"Generate a short sentence about {' and '.join(topics_list)}. Do not mention the topic names."
    try:
        response = client.chat.completions.create(
            model=model,
            messages=[{'role': 'user', 'content': prompt}],
            max_completion_tokens=1000
        )
        if verbose:
            print(f"{model} -> {response}")
        ret = response.choices[0].message.content.strip()
        return ret
    except Exception as e:
        print(f'Error generating text with OpenAI: {e}')
        return 'Error generating text.'

dataset = []
verbose = True
force_gen = False

# Force regeneration of the dataset using the new model
if os.path.exists('synthetic_dataset.csv') and not force_gen:
    print("Existing synthetic dataset loaded")
    df = pd.read_csv('synthetic_dataset.csv')
    # os.remove('synthetic_dataset.csv')
    # print('Removed existing synthetic_dataset.csv to regenerate.')
else:
    for _ in tqdm(range(NUM_SAMPLES), desc='Generating dataset'):
        num_topics = np.random.randint(1, 4)
        assigned_topics = np.random.choice(topics, num_topics, replace=False).tolist()
        # print(assigned_topics)
        text = generate_text_for_topics(assigned_topics, verbose=verbose)
        dataset.append({'text': text, 'topics': assigned_topics})
    df = pd.DataFrame(dataset)
    df.to_csv('synthetic_dataset.csv', index=False)

df.head()

Existing synthetic dataset loaded


Unnamed: 0,text,topics
0,Regular training teaches teamwork and discipline that boost academic success.,"['Sports', 'Education']"


## 2. Core Functions: Getting Log-Probs for Labels

We define below functions to take a piece of text and a set of possible labels, then use an LLM to calculate the log probability for each label by aggregating scores across multiple tokens.

**A note on API Support:** OpenAI, Ollama, Anthropic Claude (3.5+), and Google Gemini (Vertex AI) all support logprobs.

In [48]:
def match_label_in_logprobs(label, logprobs_list):
    """
    Common function to check if a label can be constructed from logprobs.
    
    Args:
        label: The label string to match
        logprobs_list: List of logprob objects, each with a 'top_logprobs' attribute
                      containing tokens with 'token' and 'logprob' attributes
    
    Returns:
        tuple: (total_log_prob, success) where success is True if label fully matched
    """
    total_log_prob = 0
    label_remaining = label
    
    for token_logprob in logprobs_list:
        if not label_remaining:
            break
            
        # Check if any top logprob token matches the start of remaining label
        matched = False
        for top_token in token_logprob.top_logprobs:
            token_str = top_token.token if hasattr(top_token, 'token') else top_token['token']
            logprob_val = top_token.logprob if hasattr(top_token, 'logprob') else top_token['logprob']
            
            if label_remaining.startswith(token_str):
                total_log_prob += logprob_val
                label_remaining = label_remaining[len(token_str):]
                matched = True
                break
        
        if not matched:
            # Label path not found in this response
            return -999
    
    # Check if we matched the complete label
    if not label_remaining:
        return total_log_prob
    else:
        return -999


def get_log_probs_for_labels_openai(text, prompt_template, labels, model, client):
    """Calculates the log probability of each label for a given text using OpenAI."""
    
    prompt = prompt_template.format(text=text)
    
    response = client.chat.completions.create(
        model=model,
        messages=[{"role": "user", "content": prompt}],
        temperature=0,
        max_tokens=20,
        logprobs=True,
        top_logprobs=20,
    )
    
    logprobs = response.choices[0].logprobs.content
    label_scores = {}
    
    for label in labels:
        score = match_label_in_logprobs(label, logprobs)
        label_scores[label] = score
            
    return label_scores


def get_log_probs_for_labels_ollama(text, prompt_template, labels, model='dolphin-mistral', verbose=False):
    """Calculates the log probability of each label for a given text using Ollama."""
    
    prompt = prompt_template.format(text=text, topics=topics)
    
    url = 'http://localhost:11434/api/generate'
    payload = {
        'model': model,
        'prompt': prompt,
        'stream': False,
        'logprobs': True,
        'top_logprobs': 20,
        'temperature': 0,
    }
    
    try:
        response = requests.post(url, json=payload)
        response.raise_for_status()
        resp_json = response.json()

        if verbose:
            print(f"{prompt} -> {model} -> {resp_json}")
        
        label_scores = {}
        
        if 'logprobs' in resp_json and resp_json['logprobs'] is not None:
            # Convert Ollama format to common format
            logprobs_list = []
            for logprob_info in resp_json['logprobs']:
                # Wrap in object-like structure for compatibility
                class LogprobWrapper:
                    def __init__(self, top_logprobs):
                        self.top_logprobs = top_logprobs
                
                logprobs_list.append(LogprobWrapper(logprob_info['top_logprobs']))
            
            for label in labels:
                score, success = match_label_in_logprobs(label, logprobs_list)
                label_scores[label] = score
        else:
            # No logprobs available
            label_scores = {label: -999 for label in labels}
            
        return label_scores
        
    except requests.exceptions.RequestException as e:
        print(f'Error getting probs: {e}')
        return {label: -999 for label in labels}

## 3. Experiment 1: Log Probability Stability

Here, we test how log-prob scores change when we vary the model, prompt phrasing, and label style. We run a small sample of our dataset through different configurations and collect the results.

In [55]:
# Define different prompt and label styles

# lower case, shorter
simple_topics = [
    'tech', 'sports', 'politics', 'art', 'science', 
    'health', 'education', 'travel', 'food', 'history'
]
topics_verbose = [
    'article about technology', 'article about sports', 
    'article about politics', 'article about art', 
    'article about science', 'article about health', 
    'article about education', 'article about travel', 
    'article about food', 'article about history'
]

# Define prompt templates
PROMPT_TEMPLATES = {
    # Minimal/direct
    'simple': "Text: '{text}'\nChoose From: {topics}\nTopic:",
    
    # Instruction-based
    'instruction': "Classify this text into one of these topics: {topics}\n\nText: {text}\n\nTopic:",
    
    # Question format
    'question': "What are the bests topic for the following text? Choose from: {topics}\n\nText: {text}\n\nTopic:",
    
    # Chain-of-thought style
    'cot': "Read the text below and determine which topic it belongs to from this list: {topics}\n\nText: {text}\n\nThe most appropriate topic is:",
    
    # JSON/structured output
    'json': "Classify the text into one of these topics: {topics}\n\nText: '{text}'\n\nReturn JSON:\n{{\"topic\": \"",
    
    # Role-based
    'role': "You are a content classifier. Categorize this text into 1 or more topics, choosing from: {topics}\n\nText: {text}\n\nCategory:",
    
    # Few-shot style (without actual examples, just framing)
    'task': "Task: Assign one mor multiple topic labels to the given text.\nAvailable topics: {topics}\n\nText: {text}\n\nAssigned topic:",
    
    # Explicit constraint
    'constrained': "Select one or more topics from [{topics}] that best describes this text:\n\n{text}\n\nSelected topic:",
    
    # Natural language
    'natural': "Here's a text: \"{text}\"\n\nWhich of these topics does it belong to? {topics}\n\nAnswer:"
}

TOPIC_SETS = {
    'simple' : simple_topics,
    'normal': topics,
    'verbose': topics_verbose
    # 'short': ["Technology", "Finance", "Health"],
    # 'descriptive': ["Article about Technology", "Article about Finance", "Article about Health"]
}

def run_stability_experiments():

    # Run the experiment
    experiment_results = []
    sample_df = df.sample(n=10, random_state=42) # Use a small, consistent sample for this demo

    # MODELS_TO_TEST = # WIP

    for model in tqdm(MODELS_TO_TEST, desc="Models"):
        for prompt_style, template in tqdm(PROMPT_TEMPLATES.items(), desc="Prompts", leave=False):
            for label_style, labels in tqdm(LABEL_SETS.items(), desc="Labels", leave=False):
                for _, row in sample_df.iterrows():
                    scores = get_log_probs_for_labels(row['text'], template, labels, model)
                    result = {
                        'model': model,
                        'prompt_style': prompt_style,
                        'label_style': label_style,
                        'true_label': row['label'],
                    }
                    # Normalize label keys for consistent columns
                    normalized_scores = {TOPICS[i]: scores.get(labels[i], -999) for i in range(len(TOPICS))}
                    result.update(normalized_scores)
                    experiment_results.append(result)

    results_df = pd.DataFrame(experiment_results)

    print("Experiment results:")
    display(results_df.head())


model = 'dolphin-mistral:latest'
# model='mistral:8b'
get_log_probs_for_labels_ollama(text="Technology is pretty great, although historically also sometimes a bit dangerous.", 
                                prompt_template=PROMPT_TEMPLATES['simple'], labels=TOPIC_SETS['simple'], verbose=True, model=model)

Text: 'Technology is pretty great, although historically also sometimes a bit dangerous.'
Choose From: ['Technology', 'Sports', 'Politics', 'Art', 'Science', 'Health', 'Education', 'Travel', 'Food', 'History']
Topic: -> dolphin-mistral:latest -> {'model': 'dolphin-mistral:latest', 'created_at': '2026-01-08T07:09:44.024202Z', 'response': "Based on the given text, I would categorize this topic as 'Science' and 'Technology'. The statement refers to technology having both great benefits (e.g., improved communication, increased efficiency) and potential risks or dangers (e.g., data privacy concerns, reliance on machines), which are key aspects of the science and technology fields.", 'done': True, 'done_reason': 'stop', 'context': [32001, 1587, 13, 1976, 460, 15052, 721, 262, 28725, 264, 10865, 16107, 13892, 28723, 13, 32000, 28705, 13, 32001, 2188, 13, 1874, 28747, 464, 8946, 1818, 2161, 349, 3468, 1598, 28725, 5432, 4264, 1944, 835, 4662, 264, 2286, 9259, 1815, 13, 1209, 12470, 3672, 28747

TypeError: cannot unpack non-iterable int object

Text: 'Technology is pretty great, although historically also sometimes a bit dangerous.'
Choose From: ['Technology', 'Sports', 'Politics', 'Art', 'Science', 'Health', 'Education', 'Travel', 'Food', 'History']
Topic: -> dolphin-mistral:latest -> {'model': 'dolphin-mistral:latest', 'created_at': '2026-01-08T07:08:28.255089Z', 'response': "The topic of this text is 'Science' and 'Technology'. The text emphasizes the benefits and potential dangers that technology brings to society. It also highlights how science has contributed significantly to technological advancements, which can be seen as both helpful and risky in different contexts.", 'done': True, 'done_reason': 'stop', 'context': [32001, 1587, 13, 1976, 460, 15052, 721, 262, 28725, 264, 10865, 16107, 13892, 28723, 13, 32000, 28705, 13, 32001, 2188, 13, 1874, 28747, 464, 8946, 1818, 2161, 349, 3468, 1598, 28725, 5432, 4264, 1944, 835, 4662, 264, 2286, 9259, 1815, 13, 1209, 12470, 3672, 28747, 5936, 8946, 1818, 2161, 647, 464, 28735,

TypeError: cannot unpack non-iterable int object

In [None]:
def

### Visualizing Stability with Histograms

The histograms below show the distribution of log-prob scores for the *correct* label. Each plot demonstrates how the scores shift due to a single changing factor (model or prompt style), revealing the instability of absolute log-prob values.

In [None]:
# Melt the dataframe to make plotting easier
melted_df = results_df.melt(
    id_vars=['model', 'prompt_style', 'label_style', 'true_label'], 
    value_vars=TOPICS, 
    var_name='assigned_label', 
    value_name='log_prob'
)

# Filter for scores of the correct label only
correct_label_scores = melted_df[melted_df['true_label'] == melted_df['assigned_label']]

# Plot 1: Distribution of scores by Model
fig, ax = plt.subplots(figsize=(10, 6))
sns.histplot(data=correct_label_scores[correct_label_scores['prompt_style'] == 'simple'], 
             x='log_prob', hue='model', multiple='kde', ax=ax, fill=True)
ax.set_title('Distribution of Log-Probs for Correct Labels (by Model)')
ax.set_xlabel('Log Probability')
plt.show()

# Plot 2: Distribution of scores by Prompt Style
fig, ax = plt.subplots(figsize=(10, 6))
sns.histplot(data=correct_label_scores[correct_label_scores['model'] == GPT5_MINI_ALIAS], 
             x='log_prob', hue='prompt_style', multiple='kde', ax=ax, fill=True)
ax.set_title(f'Distribution of Log-Probs for Correct Labels (by Prompt Style, Model: {GPT5_MINI_ALIAS})')
ax.set_xlabel('Log Probability')
plt.show()

## 4. Experiment 2: Threshold-Based Classification

Instead of picking the label with the highest score (argmax), we can set a threshold. If a label's log-prob score is above the threshold, it's assigned. This approach is useful for multi-label classification and allows us to tune the classifier's behavior (e.g., for higher precision or higher recall).

In [None]:
# This time we also use ollama
def get_ollama_topic_probs(text, model='dolphin-mistral', verbose=False):
    prompt = f"""
    Given the following text, classify it into one or more of the following topics: {', '.join(topics)}. 
    Respond with a comma-separated list of the most relevant topics only.\n\nText: {text}\n\nTopics:
    """
    url = 'http://localhost:11434/api/generate'
    payload = {
        'model': model,
        'prompt': prompt,
        'stream': False,
        'logprobs': True,
        'top_logprobs': 20
    }
    try:
        response = requests.post(url, json=payload)
        response.raise_for_status()
        resp_json = response.json()
        if verbose:
            print(text)
            print(resp_json)
        topic_probs = {topic: 0.0 for topic in topics}
        if 'logprobs' in resp_json and resp_json['logprobs'] is not None:
            for logprob_info in resp_json['logprobs']:
                for top_logprob in logprob_info['top_logprobs']:
                    token_str = top_logprob['token'].strip()
                    if token_str in topics:
                        prob = math.exp(top_logprob['logprob'])
                        topic_probs[token_str] = max(topic_probs[token_str], prob)
        return topic_probs
    except requests.exceptions.RequestException as e:
        print(f'Error getting probs: {e}')
        return {topic: 0.0 for topic in topics}

def process_dataset_ollama(df, model, verbose=False):
    results = []
    for _, row in tqdm(df.iterrows(), total=len(df), desc=f'Processing with {model}'):
        probs = get_ollama_topic_probs(row['text'], model=model, verbose=verbose)
        results.append(probs)
    return pd.DataFrame(results)

def analyze_thresholds(y_true, y_prob):
    thresholds = np.linspace(0, 1, 100)
    results = []
    for t in thresholds:
        y_pred = (y_prob > t).astype(int)
        results.append({
            'threshold': t,
            'f1': f1_score(y_true, y_pred, average='samples', zero_division=0),
            'precision': precision_score(y_true, y_pred, average='samples', zero_division=0),
            'recall': recall_score(y_true, y_pred, average='samples', zero_division=0)
        })
    return pd.DataFrame(results)

def plot_analysis(df_analysis, model_name):
    plt.figure(figsize=(10, 6))
    plt.plot(df_analysis['threshold'], df_analysis['f1'], label='F1 Score')
    plt.plot(df_analysis['threshold'], df_analysis['precision'], label='Precision')
    plt.plot(df_analysis['threshold'], df_analysis['recall'], label='Recall')
    best_threshold = df_analysis.loc[df_analysis['f1'].idxmax()]
    plt.axvline(x=best_threshold['threshold'], color='r', linestyle='--', label=f'Best Threshold (F1={best_threshold["f1"]:.2f})')
    plt.title(f'Performance vs. Threshold for {model_name}')
    plt.xlabel('Threshold')
    plt.ylabel('Score')
    plt.legend()
    plt.grid(True)
    plt.show()
    print(f'Best threshold for {model_name}: {best_threshold["threshold"]:.2f} with F1={best_threshold["f1"]:.2f}')

In [None]:
def evaluate_thresholds(test_df, model, prompt_template, labels):
    """Calculates precision, recall, and F1 over a range of thresholds."""
    # Get log-prob scores for the entire test set
    all_scores = []
    for _, row in tqdm(test_df.iterrows(), total=len(test_df), desc=f"Scoring {model}"):
        scores = get_log_probs_for_labels(row['text'], prompt_template, labels, model)
        scores['true_label'] = row['label']
        all_scores.append(scores)
    scores_df = pd.DataFrame(all_scores)

    threshold_metrics = []
    thresholds = np.linspace(scores_df[labels].min().min(), scores_df[labels].max().max(), 50)

    for threshold in thresholds:
        tp, fp, fn = 0, 0, 0
        for label in labels:
            # Predictions: score > threshold
            preds = scores_df[label] > threshold
            # Ground truth: true_label == label
            truths = scores_df['true_label'] == label
            
            tp += (preds & truths).sum()
            fp += (preds & ~truths).sum()
            fn += (~preds & truths).sum()

        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
        
        threshold_metrics.append({'threshold': threshold, 'precision': precision, 'recall': recall, 'f1': f1, 'model': model})
        
    return pd.DataFrame(threshold_metrics)

# Run evaluation for both models
threshold_results = []
test_sample_df = df.sample(n=30, random_state=123)

for model in MODELS_TO_TEST:
    metrics_df = evaluate_thresholds(test_sample_df, model, PROMPT_TEMPLATES['simple'], LABEL_SETS['short'])
    threshold_results.append(metrics_df)

all_metrics_df = pd.concat(threshold_results)

# Plot F1-score vs. Threshold
fig, ax = plt.subplots(figsize=(10, 6))
sns.lineplot(data=all_metrics_df, x='threshold', y='f1', hue='model', ax=ax)
ax.set_title('F1 Score vs. Decision Threshold')
ax.set_xlabel('Log-Prob Threshold')
ax.set_ylabel('F1 Score')
plt.show()

# Plot Precision-Recall Curve
fig, ax = plt.subplots(figsize=(10, 6))
sns.lineplot(data=all_metrics_df, x='recall', y='precision', hue='model', ax=ax, err_style=None)
ax.set_title('Precision-Recall Curve')
ax.set_xlabel('Recall')
ax.set_ylabel('Precision')
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
plt.show()

## 5. TL;DR & Conclusion

These experiments demonstrate a few key points:

- **Log-probs are not stable**: Their absolute values are highly sensitive to the model, prompt, and even the phrasing of the labels. They should not be treated as true probabilities.
- **Distributions differ significantly**: Models have unique "log-prob signatures." What constitutes a high score for one model may be a low score for another.
- **Relative scores are useful**: Despite their instability, log-probs are effective for *ranking* and *thresholding*. By sweeping across thresholds, we can tune a classifier's precision/recall trade-off to fit a specific need.

In short, log probabilities are a powerful but low-level tool. They provide a mechanism for building controllable classifiers, but require empirical validation to be used effectively.