# Load Model, Dataset

In [None]:
from google.colab import drive
import json
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
with open("/content/drive/My Drive/sharded_dataset.json") as f:
  data = json.load(f)

In [None]:
DB = []
for task in data:
  if "database" in task['task']:
    DB.append(task)

# OPENAI CODE

In [None]:
from openai import OpenAI
import math

def predictive_entropy_uncertainty_chat(prompt,
                                         temperature=1.0,
                                         max_tokens = 500,
                                         logprobs= True):
    """
    Send `prompt` to OpenAI Completion API, return average token entropy and generated text.
    Entropy is approximated from top `logprobs` returned per token.
    """

    client = OpenAI(api_key="FILLER")
    resp = client.chat.completions.create(
        model="gpt-4.1",
        messages=prompt,
        max_completion_tokens=max_tokens,
        temperature=temperature,
        logprobs=logprobs,
        top_logprobs = 20,
    )

    generated_tokens = resp.choices[0].message.content
    lps = resp.choices[0].logprobs.content

    entropies = []

    for token_info in lps:
        entropy = 0.0
        for alt in token_info.top_logprobs:
            p = math.exp(alt.logprob)
            entropy += -p * alt.logprob
        entropies.append(entropy)

    avg_entropy = sum(entropies) / len(entropies) if entropies else 0.0

    tokens_used = resp.usage.completion_tokens + resp.usage.prompt_tokens

    return avg_entropy, generated_tokens, tokens_used

In [None]:
import re
import os

def prompt_rewrite(prompt):

  new_prompt = [
    {
        "role": "system",
        "content": "You are a prompt rewriter whose main goal is to rewrite the prompt given by the user in the most optimal way without losing any information in them."
    },
    {
        "role": "user",
        "content": (
            "I have a set of questions and/or statements, please REWRITE all the questions/statements so that they are in the most optimal order that is the easiest to understand. DO NOT ANSWER ANY OF THE QUESTIONS JUST REWRITE. Here are the instructions:\n"
            "Write SQL query to find active users"
            "active users are those who made at least 5 purchases"
            "each purchase is in the purchases table"
            "users are in the users table"
            "return user names only"
        )
    },
    {
        "role": "assistant",
        "content": "Write an SQL query to find the names of users who have made at least 5 purchases, using the users and purchases tables."
    },
]

  user_content = "I have a set of questions and/or statements, please REWRITE all the questions/statements so that they are in the most optimal order that is the easiest to understand. DO NOT ANSWER ANY OF THE QUESTIONS JUST REWRITE. Here are the instructions:\n"
  user_messages = [item["content"] for item in prompt if item.get("role") == "user"]
  for msg in user_messages:
    user_content += msg + "\n"

  new_prompt.append({"role": "user", "content": user_content})

  return new_prompt

In [None]:
import random

def with_context_reset_chat(dataset, file_path, threshold,
                            temperature=1.0, runs=1, numQ=50):

    connectors = ["oh also, ", "I just remembered, ", "sorry i forgot to say, ", "", "oh, and ", "FYI, "]
    tokens_used = 0

    for run in range(runs):
        print(f"Run {run+1}/{runs}")
        out_path = file_path.replace(".json", f"_run{run}.json")
        results = []

        for entry in dataset:
            base_system = {
        "role": "system",
        "content": (
            f"""\nYou are helping a user write SQL queries to a database. If something is not clear, you can ask the user to clarify what they need. The schema for the database being accessed is the following:\n{entry['schema_sql']}"""
        )
    }
            shards = entry["shards"]
            print(f"Question with {len(shards)} shards")
            messages = [base_system]
            prev_entropy = float("inf")
            resets = 0
            entropies = []
            before_reset = None
            choice = random.choice(connectors)

            for shard in shards:
                user_content = shard["shard"]
                if shard['shard_id'] != 1:
                  user_content = choice + user_content
                if shard["shard_id"] == len(shards):
                    user_content += " Please include your complete new Query in your response."
                messages.append({"role": "user", "content": user_content})

                entropy, reply, tok = predictive_entropy_uncertainty_chat(messages, temperature=1.0, logprobs=True)
                print(f"Entropy: {entropy:.4f}")
                tokens_used += tok

                if entropy - prev_entropy > threshold:
                    before_reset = list(messages)
                    messages = prompt_rewrite(messages)
                    entropy, reply, tok = predictive_entropy_uncertainty_chat(messages, temperature=0.2, logprobs=True)
                    tokens_used += tok
                    messages = [base_system, {"role": "user", "content": reply}]
                    entropy, reply, tok = predictive_entropy_uncertainty_chat(messages, temperature=1.0, logprobs=True)
                    print(f"Reset entropy: {entropy:.4f}")
                    tokens_used += tok
                    resets += 1

                prev_entropy = entropy
                entropies.append(entropy)
                messages.append({"role": "assistant", "content": reply})

                print(f"Current Tokens Used: {tokens_used}")

                if shard["shard_id"] == len(entry["shards"]):

                  if before_reset:
                    chat_history = f"{before_reset}\n\nAFTER RESET\n\n{messages}"
                  else:
                    chat_history = messages

                  new_entry = {"final_output": reply, "chat_history": chat_history, "entropies": entropies, "resets":resets}

                  if os.path.exists(out_path):
                    with open(out_path, "r") as f:
                        data = json.load(f)
                  else:
                      data = []

                  data.append(new_entry)

                  with open(out_path, "w") as f:
                      json.dump(data, f, indent=2)