# Logprobs Tree Visualizer

### The code in this notebook is broken into 3 sections
1. Initial exploration of the logprobs object in OpenAI
2. Defining a TreeNode class and helper functions which are used to build a tree of sequential logprobs in a series of tokens.
3. Setting up code to build a one off visualization of this tree
4. Wrapping the code from step (3) in a function which can produce visualizations of logprobs trees for an arbitrary input token

### Note on section dependency
Section (3) is dependent on section (2). However you can skip right to section (4) "Build log probs tree for a given starting token or prompt" to start generating visualizations.

### Setup

In [None]:
!pip install openai
!pip install igraph
import os
import numpy as np
import igraph as ig
import plotly.graph_objects as go

In [2]:
from openai import OpenAI
os.environ["OPENAI_API_KEY"] = "YOUR_API_KEY"
client = OpenAI()

# Logprobs exploration

In [23]:
system_message = "You are a next word prediction machine. You will get a sequence of words from the user. You should return the next work in the sequence."

completion = client.chat.completions.create(
  model="gpt-4-1106-preview",
  max_tokens=4,
  temperature=0,
  logprobs=True,
  top_logprobs=3,
  messages=[
    {"role": "system", "content": system_message},
    {"role": "user", "content": "The"}
  ]
)

response = completion.choices[0].message.content
print(response)


quick


In [24]:
print(completion.choices[0].logprobs.content[0].top_logprobs)

[TopLogprob(token='quick', bytes=[113, 117, 105, 99, 107], logprob=-0.49071938), TopLogprob(token='cat', bytes=[99, 97, 116], logprob=-1.3969693), TopLogprob(token='sun', bytes=[115, 117, 110], logprob=-2.6782193)]


In [25]:
top_logprobs = completion.choices[0].logprobs.content[0].top_logprobs
for token in top_logprobs:
  print(f"Token: {token.token}, Logprob: {token.logprob}, Linear Logprob: {np.round(np.exp(token.logprob)*100,2)}")

Token: quick, Logprob: -0.49071938, Linear Logprob: 61.22
Token: cat, Logprob: -1.3969693, Linear Logprob: 24.73
Token: sun, Logprob: -2.6782193, Linear Logprob: 6.87


In [26]:
completion = client.chat.completions.create(
  model="gpt-4-1106-preview",
  max_tokens=4,
  temperature=0,
  logprobs=True,
  top_logprobs=3,
  messages=[
    {"role": "system", "content": system_message},
    {"role": "user", "content": "The"}
  ]
)

response = completion.choices[0].message.content
print(response)

quick


In [27]:
top_logprobs = completion.choices[0].logprobs.content[0].top_logprobs
for token in top_logprobs:
  print(f"Token: {token.token}, Logprob: {token.logprob}, Linear Logprob: {np.round(np.exp(token.logprob)*100,2)}")

Token: quick, Logprob: -0.53744483, Linear Logprob: 58.42
Token: cat, Logprob: -1.3499448, Linear Logprob: 25.93
Token: sun, Logprob: -2.7718198, Linear Logprob: 6.25


# Define TreeNode class

In [28]:
class TreeNode:
  def __init__(self, token, linear_prob=100, depth=0):
    self.token = token
    self.linear_prob = linear_prob
    self.depth = depth
    self.children = []

  def add_child(self, token, log_prob):
    linear_prob = np.round(np.exp(log_prob)*100,2)
    child_node = TreeNode(token, linear_prob, self.depth+1)
    self.children.append(child_node)
    return child_node

def fetch_top_logprobs(completion):
  top_logprobs = completion.choices[0].logprobs.content[0].top_logprobs
  return top_logprobs

def get_completion(prompt):
  system_message = "You are a next word prediction machine. You will get a sequence of words from the user. You should return the next work in the sequence."

  completion = client.chat.completions.create(
    model="gpt-4-1106-preview",
    max_tokens=4,
    temperature=0,
    logprobs=True,
    top_logprobs=3,
    messages=[
      {"role": "system", "content": system_message},
      {"role": "user", "content": prompt}
    ]
  )
  return completion

def build_prompt(node):
    tokens = []
    while node:
        tokens.append(node.token)
        node = node.parent if hasattr(node, 'parent') else None
    return " ".join(reversed(tokens)).strip()

def expand_tree(node, current_depth, max_depth):
    if current_depth < max_depth:
        prompt = build_prompt(node)
        completion = get_completion(prompt)
        top_logprobs = fetch_top_logprobs(completion)
        for token_info in top_logprobs:
            child_node = node.add_child(token_info.token, token_info.logprob)
            # Assign parent to child (optional, if you need to rebuild prompts later)
            child_node.parent = node
            expand_tree(child_node, current_depth + 1, max_depth)

In [29]:
root_node = TreeNode("The")
max_depth = 2  # Adding 2 layers to the tree
expand_tree(root_node, 0, max_depth)

### Print logprobs tree

In [30]:
from collections import deque

def print_tree_by_level(root):
    if not root:
        return

    queue = deque([root])  # Start with the root node in the queue

    while queue:
        level_size = len(queue)  # Number of elements (nodes) at the current level
        current_level_nodes = []

        for _ in range(level_size):
            node = queue.popleft()  # Remove and return the leftmost node
            current_level_nodes.append(f"{node.token} (Prob: {node.linear_prob}%)")

            # Add all children of the current node to the queue
            queue.extend(node.children)

        # Print all nodes in the current level
        print(" ".join(current_level_nodes))


In [13]:
# Assuming your tree has been built and is rooted at `root_node`
print_tree_by_level(root_node)

The (Prob: 100%)
quick (Prob: 57.3%) cat (Prob: 31.64%) sun (Prob: 5.17%)
brown (Prob: 99.93%) The (Prob: 0.03%) 

 (Prob: 0.02%) sat (Prob: 99.52%) ch (Prob: 0.17%) in (Prob: 0.08%) r (Prob: 47.97%) sh (Prob: 43.68%) sets (Prob: 7.59%)


# One off tree visualization

In [31]:
def collect_edges_nodes(node, edges=[], nodes=[], parent_index=None, node_index=[0]):
    """
    Recursively collects nodes and edges from the custom tree structure.

    Args:
        node: The current TreeNode being processed.
        edges: The list of edges collected so far.
        nodes: The list of nodes collected so far.
        parent_index: The index of the current node's parent in the nodes list.
        node_index: A single-item list used as a mutable counter for indexing nodes.

    Returns:
        A tuple containing two lists: nodes (as a list of labels) and edges (as a list of (parent, child) index pairs).
    """
    current_index = node_index[0]
    nodes.append(f"{node.token} ({node.linear_prob}%)")
    if parent_index is not None:
        edges.append((parent_index, current_index))

    node_index[0] += 1
    for child in node.children:
        collect_edges_nodes(child, edges, nodes, current_index, node_index)

    return nodes, edges


In [32]:
def visualize_tree_with_igraph(nodes, edges):
    G = ig.Graph(edges=edges, directed=True)
    lay = G.layout('rt')  # Reingold-Tilford layout for tree structures

    position = {k: lay[k] for k in range(len(nodes))}
    Y = [lay[k][1] for k in range(len(nodes))]
    M = max(Y)

    Xn = [position[k][0] for k in range(len(nodes))]
    Yn = [2*M - position[k][1] for k in range(len(nodes))]
    Xe = []
    Ye = []
    for edge in edges:
        Xe += [position[edge[0]][0], position[edge[1]][0], None]
        Ye += [2*M - position[edge[0]][1], 2*M - position[edge[1]][1], None]

    lines = go.Scatter(x=Xe, y=Ye, mode='lines',
                        line=dict(color='rgb(210,210,210)', width=1),
                        hoverinfo='none')
    dots = go.Scatter(x=Xn, y=Yn, mode='markers', name='',
                      marker=dict(symbol='circle-dot', size=18, color='lightseagreen',
                                  line=dict(color='rgb(50,50,50)', width=1)),
                      text=nodes, hoverinfo='text', opacity=0.8)

    layout = go.Layout(title='Logprobs Tree Visualization',
                       showlegend=False, xaxis={'showgrid': False, 'zeroline': False, 'showticklabels': False},
                       yaxis={'showgrid': False, 'zeroline': False, 'showticklabels': False},
                       plot_bgcolor='white', margin={'l': 0, 'r': 0, 'b': 0, 't': 30})

    fig = go.Figure(data=[lines, dots], layout=layout)
    annotations = []
    for i, label in enumerate(nodes):
        annotations.append(dict(x=Xn[i], y=Yn[i], xref="x", yref="y",
                                text=label, showarrow=False, font=dict(size=12),
                                bgcolor="rgba(255, 255, 255, 0.5)"))  # Semi-transparent background for readability

    fig.update_layout(annotations=annotations)

    fig.show()


In [33]:
# Assuming your tree is rooted at `root_node` and is already built
nodes, edges = collect_edges_nodes(root_node)
visualize_tree_with_igraph(nodes, edges)

# Build log probs tree for a given starting token or prompt

Note: tree will have nodes scaled based on probability

### Define TreeNode class
In case cell above has not been run

In [34]:
class TreeNode:
  def __init__(self, token, linear_prob=100, depth=0):
    self.token = token
    self.linear_prob = linear_prob
    self.depth = depth
    self.children = []

  def add_child(self, token, log_prob):
    linear_prob = np.round(np.exp(log_prob)*100,2)
    child_node = TreeNode(token, linear_prob, self.depth+1)
    self.children.append(child_node)
    return child_node

def fetch_top_logprobs(completion):
  top_logprobs = completion.choices[0].logprobs.content[0].top_logprobs
  return top_logprobs

def get_completion(prompt):
  system_message = "You are a next word prediction machine. You will get a sequence of words from the user. You should return the next work in the sequence."

  completion = client.chat.completions.create(
    model="gpt-4-1106-preview",
    max_tokens=4,
    temperature=0,
    logprobs=True,
    top_logprobs=3,
    messages=[
      {"role": "system", "content": system_message},
      {"role": "user", "content": prompt}
    ]
  )
  return completion

def build_prompt(node):
    tokens = []
    while node:
        tokens.append(node.token)
        node = node.parent if hasattr(node, 'parent') else None
    return " ".join(reversed(tokens)).strip()

def expand_tree(node, current_depth, max_depth):
    if current_depth < max_depth:
        prompt = build_prompt(node)
        completion = get_completion(prompt)
        top_logprobs = fetch_top_logprobs(completion)
        for token_info in top_logprobs:
            child_node = node.add_child(token_info.token, token_info.logprob)
            # Assign parent to child (optional, if you need to rebuild prompts later)
            child_node.parent = node
            expand_tree(child_node, current_depth + 1, max_depth)

### Set up functions


In [35]:
def collect_edges_nodes(node, edges=None, nodes=None, parent_index=None, node_index=[0], max_probs_per_level=None, current_level=0):
    if edges is None: edges = []
    if nodes is None: nodes = []
    if max_probs_per_level is None: max_probs_per_level = {}

    # Update max probability for the current level
    max_probs_per_level[current_level] = max(max_probs_per_level.get(current_level, 0), node.linear_prob)

    current_index = node_index[0]
    node_label = f"{node.token} ({node.linear_prob}%)"
    nodes.append((current_level, node_label, node.linear_prob))  # Include level and probability for scaling
    if parent_index is not None:
        edges.append((parent_index, current_index))

    node_index[0] += 1
    for child in node.children:
        collect_edges_nodes(child, edges, nodes, current_index, node_index, max_probs_per_level, current_level + 1)

    return nodes, edges, max_probs_per_level


In [36]:
def collect_edges_nodes(node, edges=None, nodes=None, parent_index=None, node_index=None, max_probs_per_level=None, current_level=0):
    if edges is None: edges = []
    if nodes is None: nodes = []
    if node_index is None: node_index = [0]  # Re-initialize node_index for each visualization
    if max_probs_per_level is None: max_probs_per_level = {}

    # Update max probability for the current level
    max_probs_per_level[current_level] = max(max_probs_per_level.get(current_level, 0), node.linear_prob)

    current_index = node_index[0]
    node_label = f"{node.token} ({node.linear_prob}%)"
    nodes.append((current_level, node_label, node.linear_prob))  # Include level and probability for scaling
    if parent_index is not None:
        edges.append((parent_index, current_index))

    node_index[0] += 1
    for child in node.children:
        collect_edges_nodes(child, edges, nodes, current_index, node_index, max_probs_per_level, current_level + 1)

    return nodes, edges, max_probs_per_level


In [37]:
def visualize_tree_with_igraph(nodes, edges, max_probs_per_level):
    G = ig.Graph(edges=edges, directed=True)
    lay = G.layout('rt')

    position = {k: lay[k] for k in range(len(nodes))}
    Y = [lay[k][1] for k in range(len(nodes))]
    M = max(Y)

    # Preparing data for Plotly
    Xn, Yn, sizes, texts = [], [], [], []
    for k, (level, label, prob) in enumerate(nodes):
        Xn.append(position[k][0])
        Yn.append(2*M - position[k][1])
        # Scale node size by its probability relative to the max at its level
        relative_size = 3*prob / max_probs_per_level[level] * 40  # Base size factor
        sizes.append(relative_size)
        texts.append(label)

    lines, dots = create_plotly_traces(Xn, Yn, sizes, texts, edges, position, M)
    plot_with_plotly(lines, dots)

def create_plotly_traces(Xn, Yn, sizes, texts, edges, position, M):
    Xe, Ye = [], []
    for edge in edges:
        Xe += [position[edge[0]][0], position[edge[1]][0], None]
        Ye += [2*M - position[edge[0]][1], 2*M - position[edge[1]][1], None]

    lines = go.Scatter(x=Xe, y=Ye, mode='lines', line=dict(color='rgb(210,210,210)', width=1), hoverinfo='none')
    dots = go.Scatter(x=Xn, y=Yn, mode='markers', name='', marker=dict(symbol='circle-dot', size=sizes, color='lightseagreen', line=dict(color='rgb(50,50,50)', width=1)), text=texts, hoverinfo='text', opacity=0.8)

    return lines, dots

def plot_with_plotly(lines, dots):
    layout = go.Layout(
        title='Tree Visualization with Scaled Nodes',
        showlegend=False,
        xaxis={'showgrid': False, 'zeroline': False, 'showticklabels': False},
        yaxis={'showgrid': False, 'zeroline': False, 'showticklabels': False},
        plot_bgcolor='white',
        margin={'l': 0, 'r': 0, 'b': 0, 't': 30},
        # Add hovermode to layout to support text visibility
        hovermode='closest'
    )

    # Adjust dots Scatter for text visibility
    # Note: 'text' and 'textposition' are already part of dots, so we enhance it here for visibility
    dots.marker.size = [size if size > 10 else 10 for size in dots.marker.size]  # Ensure minimum size for visibility
    dots.textposition = 'bottom center'  # Adjust as needed to position text

    fig = go.Figure(data=[lines, dots], layout=layout)

    # Optionally, add annotations directly if finer control is needed, or adjust textposition above
    for idx, text in enumerate(dots.text):
      fig.add_annotation(x=dots.x[idx], y=dots.y[idx], text=text, showarrow=False, yshift=10)

    fig.show()



### Wrap in function

In [38]:
def visualize_tree(root_node):
  nodes, edges, max_probs_per_level = collect_edges_nodes(root_node)
  visualize_tree_with_igraph(nodes, edges, max_probs_per_level)

def build_tree(first_token):
  root_node = TreeNode(first_token)
  max_depth = 2  # Adding 2 layers to the tree
  expand_tree(root_node, 0, max_depth)
  return root_node

def get_tree_visualization(first_token):
  root_node = build_tree(first_token)
  visualize_tree(root_node)

### Call function

In [40]:
get_tree_visualization("The")

In [39]:
get_tree_visualization("I")

In [41]:
get_tree_visualization("What is ")

What is love, baby don't hurt me 🎶