In [10]:
import os, math
from typing import List, Dict, Tuple
from termcolor import colored
from dotenv import load_dotenv

load_dotenv()

True

In [22]:
from openai import AzureOpenAI
api_key = os.getenv("API_KEY")
endpoint = os.getenv("ENDPOINT")

client = AzureOpenAI(
    api_version="2024-12-01-preview",
    azure_endpoint=endpoint,
    api_key=api_key,
)

In [None]:
def get_model_response(prompt: str, top_k: int = 5, temperature: float = 0.0, complete_sentence: bool = True):
    """
    Returns:
      response_text: str (only the completion part if complete_sentence=True)
      token_probs: List[Dict] with per-token for completion only:
          {
            "selected_token": str,
            "selected_prob": float,
            "top_logprobs": List[{"token": str, "probability": float}]
          }
      full_sentence: str (original prompt + completion)
    """
    messages = [{"role": "user", "content": prompt},
                {"role": "system", "content": "You are a helpful assistant. Complete the users sentence given the context."}] 
    
    kwargs = dict(
        messages=messages,
        temperature=temperature,
        logprobs=True,
        top_logprobs=top_k,
    )

    completion = client.chat.completions.create(
        model="gpt-4o-mini", **kwargs
    )

    choice = completion.choices[0]
    response_text = choice.message.content
    
    # Find where the prompt ends by reconstructing token by token
    reconstructed = ""
    prompt_token_count = 0
    
    for i, token_info in enumerate(choice.logprobs.content):
        reconstructed += token_info.token
        # Check if we've reached or passed the end of the prompt
        if prompt in reconstructed:
            prompt_token_count = i + 1
            # Find exactly where the prompt ends
            prompt_end_in_reconstructed = reconstructed.find(prompt) + len(prompt)
            if len(reconstructed) >= prompt_end_in_reconstructed:
                break
    
    # Now only process tokens after the prompt
    token_probs = []
    for token_info in choice.logprobs.content[prompt_token_count:]:
        selected_token = token_info.token
        selected_prob = math.exp(token_info.logprob)

        top_items = []
        for lp in token_info.top_logprobs:
            top_items.append({
                "token": lp.token,
                "probability": math.exp(lp.logprob)
            })

        token_probs.append({
            "selected_token": selected_token,
            "selected_prob": selected_prob,
            "top_logprobs": top_items
        })

    completed_sentence = response_text.replace(prompt, "", 1)
    return completed_sentence, token_probs, response_text

In [35]:
def show_and_write_probs(token_probs: List[Dict], tables_per_line: int = 5) -> None:
    """
    Prints a compact table per generated token:
      - First line shows the token index and the selected token.
      - Following lines show the top alternatives with probabilities (%).
    Colors:
      - cyan  -> the selected token row
      - magenta -> alternatives
      - green -> header
    """
    probs_color = "light_magenta"
    token_color = "green"
    selected_token_color = "cyan"

    num_tokens = len(token_probs)
    num_top_probs = len(token_probs[0]["top_logprobs"]) if token_probs else 0

    # Build a rectangular grid of small tables for printing
    data: List[List[str]] = []
    token_index = 0

    for t in token_probs:
        block: List[str] = []
        display_token = repr(t["selected_token"])[1:-1]  # strip quotes
        block.append(f"{token_index:>2}: {display_token}")

        for alt in t["top_logprobs"]:
            alt_tok = repr(alt["token"])[1:-1]
            prob_pct = alt["probability"] * 100.0
            if alt_tok == display_token:
                block.append(f"{alt_tok:>12}: {prob_pct:7.2f}")
            else:
                block.append(f"{alt_tok:>12}: {prob_pct:7.2f}")

        data.append(block)
        token_index += 1

    # formatted print in rows of `tables_per_line`
    for i in range(0, num_tokens, tables_per_line):
        print()  # blank line between rows of tables
        for j in range(num_top_probs + 1):  # +1 for the header line
            for k in range(i, min(i + tables_per_line, num_tokens)):
                line = data[k][j] if j < len(data[k]) else ""
                if j == 0:
                    print(colored(f"{line:32}", token_color), end="")
                else:
                    # Highlight the selected token line in cyan
                    sel_line = data[k][0].split(": ", 1)[1]
                    selected = sel_line
                    this_tok = data[k][j].split(":")[0].strip()
                    if this_tok == selected[-min(len(selected), 32):].strip():
                        print(colored(f"{line:32}", selected_token_color), end="")
                    else:
                        print(colored(f"{line:32}", probs_color), end="")
            print()
    print()

In [42]:
prompt = "It's a lovely day at DjangoCon US, let's go to the"
completion, token_probs, full_sentence = get_model_response(prompt, top_k=5, temperature=0.0, complete_sentence=True)

print("=== SENTENCE COMPLETION ===")
print(completion)

print("\n=== TOKEN PROBABILITIES (COMPLETION ONLY) ===")
show_and_write_probs(token_probs, tables_per_line=4)

=== SENTENCE COMPLETION ===
It's a lovely day at DjangoCon US! Let's go to the keynote session to hear some inspiring talks from industry leaders. After that, we can explore the various workshops and breakout sessions to learn more about Django and network with fellow developers. Don't forget to check out the sponsor booths for some cool swag and resources! And of course, we should make time for lunch and maybe a fun evening social event to unwind and connect with others in the community. What are you most excited about at the conference?

=== TOKEN PROBABILITIES (COMPLETION ONLY) ===

[32m 0: It's                        [0m[32m 1:  a                          [0m[32m 2:  lovely                     [0m[32m 3:  day                        [0m
[36m        It's:   30.65           [0m[36m           a:   99.71           [0m[36m      lovely:  100.00           [0m[36m         day:  100.00           [0m
[95m           D:   18.59           [0m[95m       great:    0.25          