### `initialize`

##### prequesite files

In [141]:
import os
import shutil

# mkdir /tmp/snippets if not exists
if not os.path.exists('/tmp/snippets'):
    os.makedirs('/tmp/snippets')

In [None]:
print('''
cp ~/Downloads/snippets/actions_with_embeddings.csv /tmp/snippets/
cp ~/Downloads/snippets/actions.csv /tmp/snippets/
cp ~/Downloads/snippets/connection_schema.csv /tmp/snippets/
cp ~/Downloads/snippets/fuzz_run_search.py /tmp/snippets/
''')

In [143]:
import numpy as np
from openai import OpenAI
import pandas as pd
import duckdb

conn = duckdb.connect(database=':memory:')
conn.sql(f"create table metadata as select * from read_csv_auto('/tmp/snippets/connection_schema.csv')")

client = OpenAI(max_retries=5)
EMBEDDING_MODEL="text-embedding-ada-002"

def cosine_similarity(a, b):
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

def get_embedding(text, model=EMBEDDING_MODEL):
   text = text.replace("\n", " ")
   return client.embeddings.create(input = [text], model=model).data[0].embedding

def search_docs(df, user_query, top_n=3, to_print=True):
    embedding = get_embedding(
        user_query,
        model=EMBEDDING_MODEL 
)

    df["similarities"] = df.saved_embedding.apply(lambda x: cosine_similarity(x, embedding))
    print(df.head(10))
    res = (
        df.sort_values("similarities", ascending=False)
        .head(top_n)
    )
    if to_print:
        display(res)
    return res

def get_next_task(df, user_query, top_n=3, to_print=True):
    s = search_docs(df, user_query, top_n=top_n, to_print=to_print)
    return s

import ast
docs_df = pd.read_csv("/tmp/snippets/actions_with_embeddings.csv")
docs_df.dropna()
docs_df['saved_embedding'] = docs_df['embedding'].apply(ast.literal_eval)

##### imports

In [144]:
# Imports
import os
import autogen

# Gemini imports 
import chromadb
from vertexai.generative_models import HarmBlockThreshold, HarmCategory
from autogen import ConversableAgent, AssistantAgent, UserProxyAgent
from typing_extensions import Annotated
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from autogen import Agent, AssistantAgent, ConversableAgent, UserProxyAgent
from autogen.agentchat.contrib.img_utils import _to_pil, get_image_data
#from autogen.agentchat.contrib.multimodal_conversable_agent import MultimodalConversableAgent
from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent
from autogen.code_utils import DEFAULT_MODEL, UNKNOWN, content_str, execute_code, extract_code, infer_lang
safety_settings = {
    HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
    HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
    HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
    HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
}

# Env
from dotenv import load_dotenv
load_dotenv()

project_id="sample-1474250537486"

import google.auth

scopes = ["https://www.googleapis.com/auth/cloud-platform"]
creds, project = google.auth.default(scopes)
auth_req = google.auth.transport.requests.Request()
creds.refresh(auth_req)
prompt_price_per_1k = (
    0.000125  
)

completion_token_price_per_1k = (
    0.000375  # For more up-to-date prices see https://cloud.google.com/vertex-ai/generative-ai/pricing
)


##### models

In [145]:
# Models
cerebras_config_list = [
{
        "model": "llama3.3-70b",
        "api_key": os.environ.get("CEREBRAS_API_KEY"),
        "api_type": "cerebras",
        "stream" : False,
        "base_url": "https://api.cerebras.ai/v1",
        "temperature": 0.0
}
]

gemini_config_list = [
    {
        "model": "gemini-2.0-flash-exp",
        "api_type": "google",
        "project_id": project_id,
        "location": "us-central1",
        "google_application_credentials": "/Users/brian/key.json",
        "api_rate_limit" : 1,
        "safety_settings": safety_settings,
        "temperature": 0.0,
        "max_tokens": 7000
    },
]

oai_config_list = [
    {
        "model": "gpt-4o", 
        "api_key": os.environ["OPENAI_API_KEY"],
        "temperature": 0.0, 
        "stream": False
    }
]

groq_config_list = [
    {
        "model": "llama3-70b-8192",
        "api_key": os.environ["GROQ_API_KEY"],
        "api_type": "groq",
        "frequency_penalty": 0.5,
        "max_tokens": 2048,
        "presence_penalty": 0.2,
        "seed": 42,
        "temperature": 0.0,
        "top_p": 0.2
    }
]

anthropic_config_list = [
    {
        "model": "claude-3-5-sonnet-20240620",
        "api_key": os.environ["ANTHROPIC_API_KEY"],
        "api_type": "anthropic",
    },
]

cerebras_small_config_list = [
{
        "model": "llama3.3-8b",
        "api_key": os.environ.get("CEREBRAS_API_KEY"),
        "api_type": "cerebras",
        "stream" : False,
        "base_url": "https://api.cerebras.ai/v1",
        "temperature": 0.0
}
]

groq_small_config_list = [
    {
        "model": "llama3-8b-8192",
        "api_key": os.environ["GROQ_API_KEY"],
        "api_type": "groq",
        "frequency_penalty": 0.5,
        "max_tokens": 2048,
        "presence_penalty": 0.2,
        "seed": 42,
        "temperature": 0.0,
        "top_p": 0.2
    }
]

ollama_config_list = [
    {
        'model': 'llama3.2:1b', #'smollm2:1.7b',
        'api_key' : 'ollama',
        'base_url': 'http://localhost:11434/v1' ,
    },
]

CEREBRAS_SMALL_CONFIG_LIST = {"config_list": cerebras_config_list}
GROQ_SMALL_CONFIG_LIST = {"config_list": groq_small_config_list}
CONFIG_LIST = {"config_list": oai_config_list}
GEMINI_CONFIG_LIST = {"config_list": gemini_config_list}
OAI_CONFIG_LIST = {"config_list": oai_config_list}
CEREBRAS_CONFIG_LIST = {"config_list": cerebras_config_list}
GROQ_CONFIG_LIST = {"config_list": groq_config_list}
ANTHROPIC_CONFIG_LIST = {"config_list": anthropic_config_list}
OLLAMA_CONFIG_LIST = {"config_list": ollama_config_list}

##### config

In [146]:
# Config

llm_config = GEMINI_CONFIG_LIST
#llm_config = OAI_CONFIG_LIST

### `agents + tools`

##### user_proxy, executor, job_assistant, sql_assistant, reviewer_assistant

In [147]:
# Agents

user_proxy = autogen.UserProxyAgent(
    name="Admin_User",
    system_message="A human admin. Interact with the assistants to complete the tasks. A common workflow is identifying a connection name, then identifying a schema, then identifying a table, then running a dq job, then checking the results.",
    description="""A team member that wants to efficiently complete tasks.""",
    code_execution_config=False,
    human_input_mode="NEVER",
)

executor = autogen.UserProxyAgent(
    name="Executor_User",
    system_message="Executor. Execute the instructions and tools suggested by sql_assistant or job_assistant and report the results.",
    description="""A computer terminal that performs no other action than running tool calls from sql_assistant or job_assistant.""",
    human_input_mode="NEVER",
    llm_config=llm_config,
    code_execution_config=False,  
)

job_assistant = autogen.AssistantAgent(
    name="Job_Assistant",
    llm_config=llm_config,
    system_message="""Job Assistant. You are able to run dq jobs and check dq job results/status. 
You use the run_dq_job function to run dq jobs. You use the get_job_status function to check the results, one time each time you're asked.
Because these dq job functions trigger an API call, the user will want current (up to date) results, you should execute the functions rather than rely on the historical context. 
Even if job was recently run, you can run it one more time if the user requests this. 
Example: run a dq job for tables w/ 'xyz' in the name (run dq jobs, you need a connection_name, dataset, query)
Example: whats the status of the dq jobs (check dq job results/status, no arguments needed)
Include the word 'TERMINATE' and summarize the answer, if you can answer the question from the context, rather than asking for more tasks.""",
)

sql_assistant = autogen.AssistantAgent(
    name="SQL_Assistant",
    llm_config=llm_config,
    system_message="""SQL Assistant. You provide instructions to the executor to run a query on the 'metadata' table, in order to answer the most user questions.
REMEMBER: 
- Only query the 'metadata' table.
- Only use the columns 'connection_name', 'schema_name', 'table_name'
- 'connections' typically refers to the 'connection_name' column
- 'schema' typically refers to the 'schema_name' column
- 'table' typically refers to the 'table_name' column
- single quote for escape characters
- focus on the immediate task, don't deviate from the task
If the question is unclear, try a distinct or a limit 10 query to narrow down the results. 
If you're unsure, you can try 1 attempt to interpret the question (best guess). 
If you're still not sure, or there are no results after trying a query, ask for clarificaiton.
EXAMPLE_PROMPT: use sql, count total number of tables in this schema
EXAMPLE QUERY: select * from metadata where connection_name = '<CONNECTION_NAME>' and schema_name = '<SCHEMA>' and table_name like '%<SEARCH_STRING>%' limit 30

IMPORTANT: If you can clearly answer the question, include the word 'TERMINATE' and summarize the answer, don't respond with the same query and don't ask to help with more tasks.""",
    max_consecutive_auto_reply=4,
)

reviewer_assistant = autogen.AssistantAgent(
    name="Reviewer_Assistant",
    system_message="""Review the prompt and results to concisely answer the question. Summarize the initial question and answer so it's easy to understand. 
Use bullet points or lists if the sql or job assistant returns multiple results.
As the reviewer assisant, you do not include actions, arguments, or tool calls. Only the planner_assistant includes that information.
Include the word 'TERMINATE' with the summarized response, rather than asking for more tasks.""",
    llm_config=llm_config,
)

##### next steps assistant

In [111]:
next_steps_agent = autogen.ConversableAgent(
    name="next_steps_agent",
    system_message="""You review conversation snippets to identify 2 key pieces of information and suggest actions for a user.

Because the workflow follows a hierarchy of connection > schema > table > job / dataset, you should provide what most recent context is inferred, likely next steps, and missing context.
Respond concisely, like the examples below to help guide the user. 

Example Responses:
    # Example 0, No elements in context.
    Based on this recent conversation history.
    The most recently mentioned context includes:
     - connection_name = none
     - schema_name = none 
     - next steps: list connections, get job results
     Note: These actions can be run at any point in the workflow. As for getting started, the most likely first step is to list distinct available connections.

    # Example 1, Only connection_name recently mentioned.
    Based on this recent conversation history.
    The most recently mentioned context includes:
     - connection_name = 'xyz' 
     - schema_name = none 
     - next steps: search for a schema, count tables in each schema
     Note: Once you've identified a connection_name, the most common next step is to find a schema or analyze schemas in a connection.

    # Example 2, connection_name and schema_name recently mentioned.
    Based on this recent conversation history.
    The most recently mentioned context includes:
     - connection_name = 'xyz' 
     - schema_name = 'xyz' 
     - next steps: search for table(s)
     Note: Once you've identified a connection_name and schema_name combination, the most common next step is to search for tables in a schema.

    # Example 3, table_name(s) recently mentioned.
    Based on this recent conversation history.
    The most recently mentioned context includes:
     - connection_name = none
     - schema_name = none
     - table_name(s) = claims, nyse 
     - next steps: run a dq job for one or several tables
     Note: Once you've identified a table_name or several table_names, the most common next step is to run a dq job for one or several tables.

    # Example 4, recent jobs have been run.
    Based on this recent conversation history.
    The most recently mentioned context includes:
     - connection_name = xyz
     - schema_name = xyz
     - table_name(s) = xyz
     - dataset(s) = xyz
     - next steps: check job results
     Note: Once you've run a job, the most common next step is to check the results of a dq job. Or you can always run the job again or start a new search.

    While this is a simple workflow, you can always start a new search or run a job again. The user may not work linearly.
      """,
    llm_config=OAI_CONFIG_LIST,
)

##### planner and context assistants (not needed)

In [112]:
planner_assistant = autogen.AssistantAgent(
    name="Planner_Assistant",
    llm_config=llm_config,
    system_message="""You're an expert at finding the information you need from the context or guiding the user on additional information. 
First check if the answer is found in the context, for key details.
If you can't find the information in the chat context, then respond with the key details and action, depending on the type of task.

# Answer found in the context (no action needed)
Example: 
- prompt: whats the 'xyz' connection name?
- context: 'the connection name 'xyz' was found in the recent chat context.'  

# Answer found in the context
Example: 
- prompt: whats the 'xyz' connection name?
- task: sql 
- action: look-up distinct connection_names in metadata table
- arguments: none 

# Answer not found in the context, send to assistant to get more information
Example:
- prompt: what schemas are in the bigquery connection?
- task: sql 
- action: look-up connection_names in metadata table
- arguments: none 

# Send to jobassistant to perform this task
Example:
- prompt: whats the latest job results?
- task: job status  
- action: look-up job_status
- arguments: none 

# Answer not found in the context, send to sql assistant to get more information
Example:
- prompt: what schemas are in the bigquery connection?
- task: sql 
- action: look-up schema_names, by connection_name in metadata table
- arguments: connection_name 

# Answer not found in the context, send to sql assistant to get more information
Example:
- prompt: list tables w/ 'xyz' in the name
- task: sql 
- action: look-up table_names, by connection_name and schema_name in metadata table
- arguments: connection_name, schema_name

# Send to job assistant to perform this task
Example:
- prompt: run a dq job for tables w/ 'xyz' in the name, in the 'xyz' schema, use the previous connection
- task: run job  
- action: run a dq job for a table, include connection_name and schema_name
- arguments: connection_name, schema_name, table_name

These are examples of the initial analysis
When a generic technology is mentioned, it typically just refers to a connection_name, try to find a connection_name 
If you can answer the question from the chat history, go ahead and answer the question, and include the word 'TERMINATE'.
If you should send the question to the sql or job assistant, respond accordingly with the task, action, and arguments
Or if you can't find the information you need, ask for clarification.
NEVER: guess the answer, or make up an answer. use facts from the chat history or send to the sql or job assistant."""
)

context_agent = autogen.ConversableAgent(
    name="context_agent",
    system_message="""You never include the word 'TERMINATE' in your response. 
You review conversation snippets to identify 2 key pieces of information for the task at hand.

Because the workflow follows a hierarchy of connection > schema > table > job / dataset, you must identify key elements from the most recent context.
Respond concisely. 

Example Responses:
    # Example 0, No elements in context or not context at all..
    Based on this recent conversation history.
    The most recently mentioned context includes:
     - connection_name = none
     - schema_name = none 
     - next steps: list distinct connections, or maybe get job results
     Note: The most common first step is to list distinct available connections.

    # Example 1, Only connection_name recently mentioned.
    Based on this recent conversation history.
    The most recently mentioned context includes:
     - connection_name = 'xyz' 
     - schema_name = none 
     - next steps: search for a schema, count tables in each schema
     Note: Once you've identified a connection_name, a common next step is to search for a schema (exact or substring of the name).

    # Example 2, connection_name and schema_name recently mentioned.
    Based on this recent conversation history.
    The most recently mentioned context includes:
     - connection_name = 'xyz' 
     - schema_name = 'xyz' 
     - next steps: search for table(s)
     Note: Once you've identified a connection_name and schema_name combination, a next step is to search for tables (exact or substring of the name) in a schema.

    # Example 3, table_name(s) recently mentioned.
    Based on this recent conversation history.
    The most recently mentioned context includes:
     - connection_name = none
     - schema_name = none
     - table_name(s) = claims, nyse 
     - next steps: run a dq job for one or several tables
     Note: Once you've identified a table_name or several table_names, a common next step is to run a dq job for one or several tables.

    # Example 4, recent jobs have been run.
    Based on this recent conversation history.
    The most recently mentioned context includes:
     - connection_name = xyz
     - schema_name = xyz
     - table_name(s) = xyz
     - dataset(s) = xyz
     - next steps: check job results
     Note: Once you've run a job, a common next step is to check the results of a dq job.
      """,
    llm_config=llm_config,
)


#### tools

In [None]:
# Tools
@executor.register_for_execution()
@executor.register_for_llm()
@job_assistant.register_for_llm(description=f"""Submits a DQ Job to run a DQ check.
IMPORTANT: This requires a connection_name, dataset, query.
NOTE: Use the table name as the dataset name
The DQ Job query should use schema.table format and have a limit 10000 to always limit results. 
""")
def run_dq_job(dataset: Annotated[str, "dataset name to use"], query: Annotated[str, "query to use"], connection_name: Annotated[str, "connection name to use"]):
    import requests
    import json 
    from datetime import datetime
    import urllib3
    urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

    row_count = 0 
    try:
        import duckdb 
        conn = duckdb.connect(database=':memory:')
        conn.sql(f"create table metadata as select * from read_csv_auto('/tmp/snippets/connection_schema.csv')")
        table_name = dataset
        schema_name = schema_name = query.split(' ')[3].split('.')[0].strip()
        print(schema_name)
        sql_statement = f"select * from metadata where connection_name = '{connection_name}' and schema_name = '{schema_name}' and table_name = '{table_name}' limit 1"
        rs = conn.sql(sql_statement.replace("\\",""))
        print(rs)
        print(rs.to_df().shape)
        row_count = rs.to_df().shape[0]
        print(f"row_count: {row_count}")
        if row_count == 0:
            return f"Unable to find table {table_name} in schema {schema_name} for connection {connection_name}. Please confirm the schema and connection and try again."
    except Exception as e:
        print(e)
        return f"Unable to validate and table {table_name} in schema {schema_name} for connection {connection_name}. Please confirm the schema and connection and try again."

    def get_api_token():
        prod_auth = requests.post(os.environ.get("DQ_URL") + '/v3/auth/signin',
                            headers={'Accept-Language': 'en-US,en;q=0.9', 'Connection': 'keep-alive', 'Content-Type': 'application/json'},
                            data=json.dumps({'username': os.environ.get("DQ_USERNAME"), 'password': os.environ.get("DQ_CREDENTIAL"), 'iss': os.environ.get("DQ_TENANT")}),
                                verify=False)
        token = prod_auth.json()["token"]
        return token

    token = get_api_token()
    headers = {'Authorization': 'Bearer ' + token}
    run_id = datetime.now().strftime("%Y-%m-%d")
    params = {'dataset': "AI_"+dataset, 
              'runId': run_id, 
              'pushdown': {'connectionName': connection_name, 
                           'sourceQuery': query.replace("\\", "") },
              'agentId': {'id': 0}
              }
    response = requests.request('POST', os.environ.get("DQ_URL") + '/v2/run-job-json', headers=headers, json=params, verify=False)
    if response.status_code == 200:
        print(response.json())
    return f"job triggered successfully, for {table_name} in schema {schema_name} for connection {connection_name}. response: {query} \n\n {response.json()}"

@executor.register_for_execution()
@executor.register_for_llm()
@job_assistant.register_for_llm(description=f"""check status""")
def get_job_status():
    import requests
    import json
    import urllib3
    import pandas as pd
    urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

    def get_api_token():
        prod_auth = requests.post(os.environ.get("DQ_URL") + '/v3/auth/signin',
                            headers={'Accept-Language': 'en-US,en;q=0.9', 'Connection': 'keep-alive', 'Content-Type': 'application/json'},
                            data=json.dumps({'username': os.environ.get("DQ_USERNAME"), 'password': os.environ.get("DQ_CREDENTIAL"), 'iss': os.environ.get("DQ_TENANT")}),
                            verify=False)
        token = prod_auth.json()["token"]
        return token

    token = get_api_token()
    headers = {'Authorization': 'Bearer ' + token}
    params = {
        'jobStatus': '',
        'limit': '5',
    }

    response = requests.get(os.environ.get("DQ_URL") + '/v2/getowlcheckq', params=params, headers=headers, verify=False)

    #df = pd.DataFrame(response.json()['data'], columns=['dataset', 'runId', 'status', 'activity', 'updtTs'])
    df = pd.DataFrame(response.json()['data'], columns=['dataset', 'status', 'activity'])
    return f"job status: ~~~{df.to_markdown(index=False, tablefmt='presto', floatfmt='.0%').replace('000','')}~~~"

@executor.register_for_execution()
@executor.register_for_llm()
@sql_assistant.register_for_llm(description=f"""Runs A SELECT statement on 'metadata' table. the columns 'connection_name', 'schema_name', 'table_name'""")
def run_sql_statement(sql_statement: Annotated[str, "SQL statement to execute"]):
    import duckdb 
    conn = duckdb.connect(database=':memory:')
    conn.sql(f"create table metadata as select * from read_csv_auto('/tmp/snippets/connection_schema.csv')")
    rs = conn.sql(sql_statement.replace("\\",""))
    rs.show()
    return f"results: ~~~{sql_statement} \n {rs.to_df().drop_duplicates().head(15).to_markdown(index=False)}~~~"


#### transform

In [114]:
# Transform

from autogen.agentchat.contrib.capabilities import transform_messages, transforms
import json
import re
pattern = r"~~~.*?~~~"
import pprint
import copy
import re
from typing import Dict, List, Tuple

class MessageRedact:
    def __init__(self):
        self._content_wrapper_pattern = r"~~~.*?~~~" 
        #self._openai_key_pattern = r"sk-([a-zA-Z0-9]{48})"
        self._replacement_string = "TRUNCATED_MESSAGE"

    def apply_transform(self, messages: List[Dict]) -> List[Dict]:
        temp_messages = copy.deepcopy(messages)

        total_tool_calls = 0
        total_tool_responses = 0
        counter = 0
        for m in temp_messages:
            #print(m['content'])
            if "tool_calls" in m:
                total_tool_calls += len(m["tool_calls"])
            if "tool_responses" in m:
                total_tool_responses += len(m["tool_responses"])
                print(f"total_tool_calls: {total_tool_calls}, total_tool_responses: {total_tool_responses}")
            
            content_string = "Context from previous tool calls: "
            if "tool_calls" in m:
                for i in m["tool_calls"]:
                    # if 'id' in i:
                    #     i.pop('id')

                    # json dump tool call to string
                    content_string += "\n function name: " + i['function']['name']
                    content_string += "\n function arguments: " + json.dumps(i['function']['arguments'])
                    # content_string += "\n function arguments: "
                    # tool_args = json.loads(i['function']['arguments'])
                    # for k, v in tool_args.items():
                    #     content_string += f"\n  - {k}: {v}"
                
                m.pop("tool_calls")
                m['role'] = "user"
                m["content"] = content_string

            if "tool_responses" in m:
                for j in m["tool_responses"]: 
                    # if 'tool_call_id' in j:
                    #     j.pop('tool_call_id')
                    
                    # json dump tool response to string

                    content_string += "Context from previous tool response:" + json.dumps(j['content'])
                    
                m["content"] = content_string
                m['role'] = "user"
                m.pop("tool_responses")  
                
            message = m
            
            if isinstance(message["content"], str):
                counter += 1
                if counter < 5:
                    message["content"] = message["content"][:600]
                else:
                    message["content"] = message["content"][:1000]
                    #message["content"] = re.sub(self._content_wrapper_pattern, self._replacement_string, message["content"],  flags=re.DOTALL )  
            
                if "tool_responses" in message:
                    if isinstance(message["tool_responses"], list):
                        for item in message["tool_responses"]:
                            if isinstance(item["content"], str):
                                item["content"] = re.sub(self._content_wrapper_pattern, self._replacement_string, item["content"], flags=re.DOTALL)
                                
            elif isinstance(message["content"], list):
                for item in message["content"]:
                    if item["type"] == "text":
                        item["text"] = re.sub(self._content_wrapper_pattern, self._replacement_string, item["text"], flags=re.DOTALL)

        temp_messages.append({'content': 'Thank you. I will review this information and get back to you.', 'role': 'user', 'name': 'Admin_User'})
        return temp_messages

    def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]:
        keys_redacted = self._count_redacted(post_transform_messages) - self._count_redacted(pre_transform_messages)
        if keys_redacted > 0:
            return f"Redacted {keys_redacted} Matching Patterns.", True
        return "", False

    def _count_redacted(self, messages: List[Dict]) -> int:
        # counts occurrences in message content
        count = 0
        for message in messages:
            if isinstance(message["content"], str):
                if self._replacement_string in message["content"]:
                    count += 1
            elif isinstance(message["content"], list):
                for item in message["content"]:
                    if isinstance(item, dict) and "text" in item:
                        if self._replacement_string in item["text"]:
                            count += 1
        return count

# redact_handling = transform_messages.TransformMessages(transforms=[MessageRedact()])
# redact_handling.add_to_agent(user_proxy)

# Limit the message history to the 3 most recent messages
# context_handling = transform_messages.TransformMessages(
#     transforms=[
#         transforms.MessageHistoryLimiter(max_messages=10),        
#     ]
# )
# context_handling.add_to_agent(user_proxy)


### `prompts`

In [115]:
#prompt = """list the distinct connections""" 
#prompt = "list distinct connections" 
#prompt = "whats the bigquery connection name?"
#prompt = "is there a samples schema in the bigquery connection?"
#prompt = "list the distinct schemas in the bigquery connection"
#prompt = "count the distinct tables in the bigquery connection"

#prompt = """
# in that bigquery connection
# in the same schema
# list tables w/ 'nyse' in the name
# """

#prompt="""run a job for those 3 tables"""
#prompt="can you run a job for those tables (the first 4 listed)?"
#prompt = "list the tables listed w/ claims in the name, use bigquery connection, samples schema"
#prompt = "run a dq job for the claims dents and claims detail table" 
#prompt = "run a dq job for the all the claims tables listed in that connection and schema" 

#prompt="""what schema am i using?"""
#prompt="""get the latest job results"""
#prompt="""what connections do i have?"""
#prompt="""what schemas are in that bigquery connection?"""

#prompt = "in the sql server connection, what tables are in the example schema with geo_uip_stat in the name"
#prompt = "are there tables with 'customer' in the name, in the tf_esg schema in the snowflake connection?"
#prompt = "using the connecton named bigquery and the samples schema, what tables have census in the name?"
#prompt = "do all upper case, in the snowflake connection not sap, what tables are in the 'PUBLIC' schema with CUSTOMER in the name"
#prompt = "run a job for those CUSTOMER tables listed"
#prompt = "check the sql server connection, not sap, what tables are in the 'dbo' schema with accounts as part of the name?"
#prompt = "check job status"
#prompt = "check the samples schema, in the bigquery connection, tables with 'claims' in he table name. "
#prompt = "run a  job, use the samples schema, in the bigquery connection, run a job the last 3 listed w/ with 'claims_dent' in the table name. "
#prompt = "run a job for the claims_dent table listd, same bigquery connection" 
#prompt = "list the available connection names"
#prompt = "can you get a dq job results? "
#prompt = "run a dq job for the loan_customer tables, bigquery connection, samples schema"
#prompt = "run a dq job for hte samples schema, bigquery connection, claims_dent in the name"
#prompt = "list the distinct schemas in the bigquery connection"
#prompt = "list distinct connections" 
#prompt = "check job status"

prompt = """ run a job for these items
- in the samples schema, in the bigquery connection, tables with 'claims_dents_4' and claim details in the name. 
"""
#prompt = "whats the bigquery connection name?"
#prompt = "whats the status of that customer job"
#prompt = "list claims tables, samples schema, use the bigquery connection"
#prompt = "what were the names of the tables were just mentioned?"
#prompt = "what was the bigquery connection and schema mentioned?"
#prompt = "list the tables w/ patent in the name, use bigquery connection, samples schema"
#prompt = "run a job for those tables just listed."

prompt = "list the tables listed w/ claims in the name, use bigquery connection, samples schema"
#prompt = "list the tables listed w/ claims in the name, use previous connection and schema" 
#prompt = "list the tables listed w/ customer in the name, same connection and schema" 
#prompt = "list the tables listed w/ patent in the name, same connection and schema" 
#prompt = "list the tables listed w/ claims in the name, use bigquery connection, samples schema"

prompt = "run a dq job for the claims dents and claims detail table" 
#prompt = "what tables did you last mention? "
#prompt = "run a dq job for those tables"
#prompt = "run a dq job for those tables w/ customer, use bigquery connection, samples schema"
#prompt = "is there a  bigquery connection?"
#prompt = "whats the name of the bigquery connection?"
#prompt = "what connections are there?"
#prompt = "what schemas are in that bigquery connection"
#prompt = "use the samples schema, any tables that start with 'm'"

prompt = "is there a samples schema in the bigquery connection?"
#prompt = "are there tables taht start with 'g' in that schema? "
#prompt = "run a job for global_air_quality table "
#prompt = "what are the latest results?"
# prompt = "are there tables taht start with 'k' in that schema? "
# prompt = """run a job for those tables listed
# | APPROVED_BIGQUERY_PUSHDOWN | samples       | kb_participant_loc    |
# | APPROVED_BIGQUERY_PUSHDOWN | samples       | kishore               |
# | APPROVED_BIGQUERY_PUSHDOWN | samples       | kb_participant_shirt  |"""

#prompt = "what tables did you last mention? "
#prompt = "check job status"

# connection
prompt="list the distinct connections available"
prompt="""what connections do i have?"""
prompt = "whats the bigquery connection name?"
prompt="""whats the name of the bigquery connection? it's not named bigquery, check for more info"""

# schema
prompt="""what schemas are in that bigquery connection?"""
prompt="""what schema was i using? """
prompt="is there a schema w/ 'samples' in the name in that connection"
prompt="""what connection was i just using? """

# table
prompt="""
in that bigquery connection
in the same schema
list tables w/ 'nyse' in the name
"""
prompt="list first 5 tables that start with the letter 'd' "  

# job
prompt="""run a job for those tables just mentioned"""
prompt="""
in the connection APPROVED_BIGQUERY_PUSHDOWN
in the 'samples' schema
run a dq job for those 5 tables
"""

# results
prompt="""get the latest job results"""

### `chat`

##### groupchat

In [116]:
# Groupchat 
allowed_transitions = {
    user_proxy: [sql_assistant, job_assistant, executor, reviewer_assistant],
    sql_assistant: [executor, user_proxy],
    job_assistant: [executor, user_proxy],
    executor: [reviewer_assistant, user_proxy],
    reviewer_assistant: [user_proxy],
}

groupchat = autogen.GroupChat(
    agents=[user_proxy, sql_assistant, job_assistant, executor, reviewer_assistant], 
    messages=[], max_round=5,
    speaker_transitions_type="allowed", 
    allowed_or_disallowed_speaker_transitions=allowed_transitions,
    send_introductions=True
)

manager = autogen.GroupChatManager(groupchat=groupchat, name="manager", llm_config=llm_config, is_termination_msg=lambda x: "TERMINATE" in x.get("content", ""))

In [117]:
# functions for chat and recent history

def chat(manager, prompt, groupchat):
    clean_transform = MessageRedact()
    tmp_messages = copy.deepcopy(manager.groupchat.messages)
    end_window = len(tmp_messages)
    start_window = 0
    retain_messages = 5

    intro_message = {'content': groupchat.introductions_msg(), 'role': 'user', 'name': 'manager'}
    intro_string = 'We have assembled a great team today'

    for m in tmp_messages:
        if intro_string in m['content']:
            tmp_messages.remove(m)

    if len(tmp_messages) < retain_messages:
        start_window = 0
    else:
        start_window = len(tmp_messages) - retain_messages

    processed_messages = clean_transform.apply_transform(tmp_messages[start_window:end_window])
    processed_messages.insert(0, intro_message)

    last_agent, last_message = manager.resume(messages=processed_messages)
    chat_result = user_proxy.initiate_chat(recipient=manager, message=prompt, clear_history=False, max_rounds=7)
    return chat_result, manager 

def get_recent_context(manager):
    recent_conversation = ""
    for m in manager.groupchat.messages:
        if m['content'] != '' and "Hello everyone." not in m['content']:
            recent_conversation += m['content']
    return recent_conversation

#sql_assistant.llm_config['config_list']
#sql_assistant.llm_config["tools"]

##### initiate chat

In [None]:
# initiate chat
prompt="""list the distinct connections""" 

result = user_proxy.initiate_chat(
manager,
message=prompt,
clear_history=False,
)

##### resume chat

In [None]:
prompt = """
in the connection APPROVED_BIGQUERY_PUSHDOWN
check the 'samples' schema
list first 3 tables w/ 'census' in the name
"""

cr, manager = chat(manager, prompt, groupchat)

##### post chat, next steps

In [None]:
instruction_msg = """Given this conversation snippet below, review and analyze and respond with most recent context. Include the names of the specific elements mentioned.
If multiple items are mentioned, respond with the most recent context i.e. towards the bottom. 

Conversation:
"""
prompt = """
in the connection APPROVED_BIGQUERY_PUSHDOWN
in the 'samples' schema
list tables with 'claims' in the name
"""

cr, manager = chat(manager, prompt, groupchat)

recent_conversation = get_recent_context(manager)
instruction = f"""{instruction_msg}
{recent_conversation}"""

reply = next_steps_agent.generate_reply(messages=[{'content': instruction, 'role': 'user'}])
print(reply)


### `chat with pre-routine actions look-up (RAG) + context `

##### config

In [148]:
#CONFIG_LIST = OAI_CONFIG_LIST
#CONFIG_LIST = GROQ_SMALL_CONFIG_LIST
#CONFIG_LIST = CEREBRAS_SMALL_CONFIG_LIST
#CONFIG_LIST = GROQ_CONFIG_LIST
CONFIG_LIST = GEMINI_CONFIG_LIST

##### agents

In [149]:
connection_parsing_prompt = """You're an expert at identifying the connection name from the text.

INSTRUCTIONS:
You must respond with one of the following valid connection names: 'APPROVED_SAPHANA_PD', 'APPROVED_SNOWFLAKE_PUSHDOWN', 'APPROVED_BIGQUERY_PUSHDOWN', 'APPROVED_SQL_SERVER_PD', or 'NONE':
 - if text contains 'saphana', respond with the word 'APPROVED_SAPHANA_PD'
 - if text contains 'snowflake', respond with the word 'APPROVED_SNOWFLAKE_PUSHDOWN'
 - if text contains 'bigquery', respond with the word 'APPROVED_BIGQUERY_PUSHDOWN'
 - if text contains 'sql server', respond with 'APPROVED_SQL_SERVER_PD'
 - else if you're not sure, respond with 'NONE'

IMPORTANT: Respond with 'NONE' if you're not sure, don't guess, do not make up names.
ALWAYS: Respond with just one word
ALWAYS: Respond with a valid connection name, only use your best guess from the text.
"""

connection_parsing_assistant = autogen.AssistantAgent(
    name="Parsing_agent",
    human_input_mode="NEVER",
    max_consecutive_auto_reply=3,
    llm_config=CONFIG_LIST,
    system_message=connection_parsing_prompt,
    code_execution_config=False
)

schema_parsing_prompt = """You help the human identify the inferred SCHEMA name. 
You must parse out the name that represents the schema they want to use.

You only care about finding the SCHEMA name
IMPORTANT: Ignore connection entity references
IMPORTANT: Ignore table entity references

Always Respond with one word (the name you believe to be correct). 

Example: In the snowflake connection, using the public schema, what tables have xyz in the name.
Response: public

REMEMBER: Just one word answer, no explanations. 
ALWAYS: Respond with a valid schema name, only use your best guess from the message, do not make up names.
Response with 'NA' if you're not sure, don't guess. 
"""

schema_parsing_assistant = autogen.AssistantAgent(
    name="Schema_Parsing_agent",
    llm_config=CONFIG_LIST,
    system_message=schema_parsing_prompt,
    max_consecutive_auto_reply=4,
    code_execution_config=False
)

table_parsing_prompt = """You help the human identify the inferred TABLE name. 
You must parse out the name that represents the table they want to use. It could be a substring of the table name.

You only care about finding the TABLE name reference
IMPORTANT: Ignore connection and schema entity references

Always Respond with one word (the name you believe to be correct). 

Example: In the snowflake connection, using the public schema, what tables have 'abc' in the name.
Response: abc

REMEMBER: Just one word answer, no explanations. 
ALWAYS: Respond with a valid table name, only use your best guess from the message, do not make up names.
IMPORTANT: Response with 'NA' if you're not sure, don't guess. 
"""

table_parsing_assistant = autogen.AssistantAgent(
    name="Table_Parsing_agent",
    llm_config=CONFIG_LIST,
    system_message=table_parsing_prompt,
    max_consecutive_auto_reply=4,
    code_execution_config=False
)

##### pre-screen functions

In [150]:
def test_sql(c, s, t, conn):
    validated_c = ""
    validated_s = ""
    validated_t = ""
    element_string = ""
    return_string = ""
    try:
        if c != "":
            connection_statement = f"""select distinct connection_name from metadata where lower(connection_name) = '{c.lower()}'"""
            rs = conn.sql(connection_statement.replace("\\",""))
            c_rc = rs.to_df().count().iloc[0]
            
            if c_rc > 0:
                return_string = "Previous SQL results history: " + connection_statement + "\n"
                #return_string += "Row Count: " + str(c_rc) + "\n"
                return_string += rs.to_df().drop_duplicates().head(20).to_markdown(index=False)
                validated_c = c
            else:
                print("No results found")

        if s != "" and c_rc > 0:
            schema_statement = f"""select distinct schema_name from metadata where lower(connection_name) = '{c.lower()}' and lower(schema_name) = '{s.lower()}' """ 
            
            rs = conn.sql(schema_statement.replace("\\",""))
            s_rc = rs.to_df().count().iloc[0]
            if s_rc > 0:
                return_string = "Previous SQL results history: " + schema_statement + "\n"
                #return_string += "Row Count: " + str(s_rc) + "\n"
                return_string += rs.to_df().drop_duplicates().head(20).to_markdown(index=False)
                validated_s = s
            else:
                print("No results found")

        if t != "" and c_rc > 0 and s_rc > 0:
            
            table_statement = f"""select * from metadata where lower(connection_name) = '{c.lower()}' and lower(schema_name) = '{s.lower()}' and lower(table_name) like '%{t.lower()}%' """ 
            rs = conn.sql(table_statement.replace("\\",""))
            t_rc = rs.to_df().count().iloc[0]
            if t_rc > 0:
                return_string = "Previous SQL results history: " + table_statement + "\n"
                return_string += rs.to_df().drop_duplicates().head(20).to_markdown(index=False)
                validated_t = t
            else:
                print("No results found")
    except: 
        pass
    
    return {
        "element_context" : element_string, 
        "sql_context" : return_string,
        "validated_connection" : validated_c,
        "validated_schema" : validated_s,
        "validated_table" : validated_t
        }

In [151]:
def filter_step(prompt, suggested_task):
    c=""
    s=""
    t="" 

    filter_context = ""
    sql_context = {}
    validated_c = "None"
    validated_s = "None"
    validated_t = "None"
    if 'sql' in suggested_task.strip():
    
        c_ai = connection_parsing_assistant.generate_reply(messages=[{"role" : "user", "content": prompt}])
        s_ai = schema_parsing_assistant.generate_reply(messages=[{"role" : "user", "content": prompt}])
        t_ai = table_parsing_assistant.generate_reply(messages=[{"role" : "user", "content": prompt}])
        
        print(c_ai)
        print(s_ai)
        print(t_ai)
        for i in range(3):
            
            if None == c_ai:
                c_ai = connection_parsing_assistant.generate_reply(messages=[{"role" : "user", "content": prompt}])
            elif None != c_ai:
                try:
                    c=c_ai['content'].replace("\n", "").strip()
                except:
                    c=c_ai
            
            if None == s_ai:
                s_ai = schema_parsing_assistant.generate_reply(messages=[{"role" : "user", "content": prompt}])
            elif None != s_ai:
                try:
                    s=s_ai['content'].replace("\n", "").strip()
                except:
                    s=s_ai

            if None == t_ai:
                t_ai = table_parsing_assistant.generate_reply(messages=[{"role" : "user", "content": prompt}])
            elif None != t_ai:
                try:
                    t=t_ai['content'].replace("\n", "").strip()
                except:
                    t=t_ai
        
        identified_connection, identified_schema, identified_table = c, s, t
        identified_elements_string = f"{c} > {s} > {t}"
        print("IDENTIFIED ELEMENTS: ", identified_elements_string)
        # filter_context += "\nThese are potential identified elements:\n " + identified_elements_string
        # filter_context += next_task[['lookup_values', 'similarities']].to_markdown(index=False ).replace("lookup_values", "tasks")
        #filter_context += "\nSimilar users have used the sql_assistant tool to help identify the connection, schema, and table"
        sql_context = test_sql(c, s, t, conn)
        validated_c = sql_context.get("validated_connection", "")
        validated_s = sql_context.get("validated_schema", "")
        validated_t = sql_context.get("validated_table", "")
    else: 
        filter_context = "You will likely want to either 'list distinct connections' to or request 'check job status" 
        for i in range(1):
            c_ai = connection_parsing_assistant.generate_reply(messages=[{"role" : "user", "content": prompt}])
            print(c_ai)
            
            if None != c_ai:
                try:
                    c=c_ai['content'].replace("\n", "").strip()
                except:
                    c=c_ai
            
        sql_context = test_sql(c, s, t, conn)
        validated_c = sql_context.get("validated_connection", "")
        # assisant_response = connection_parsing_assistant.generate_reply(messages=[{"role" : "user", "content": prompt}])
        # print(assisant_response)
        # if assisant_response is not None:
        #     try:
        #         filter_context += "\nAssistant response: " + assisant_response['content'] + "\n"
        #     except:
        #         filter_context += "\nAssistant response: " + assisant_response + "\n"
        # else:
        #     filter_context += "\nAssistant response: None"
        
        # filter_context += "Other similar tasks are:\n"
        # print(next_task[['lookup_values']].to_markdown(index=False ))
        # filter_context += next_task[['lookup_values']].to_markdown(index=False ).replace("lookup_values", "tasks")
        # print(next_task[['lookup_values', 'action', 'similarities']].to_markdown(index=False ))
        
    return {
        "filter_context" : filter_context,
        "sql_context" : sql_context,
        "connection_name" : c,
        "schema_name" : s,
        "table_name" : t,
        "validated_connection" : validated_c,
        "validated_schema" : validated_s,
        "validated_table" : validated_t
        }

In [152]:
def search_docs(df, user_query, top_n=4, to_print=True):
    embedding = get_embedding(
        user_query,
        model=EMBEDDING_MODEL 
)
    df["similarities"] = df.saved_embedding.apply(lambda x: cosine_similarity(x, embedding))

    res = (
        df.sort_values("similarities", ascending=False)
        .head(top_n)
    )
    # if to_print:
    #     display(res[['action', 'lookup_values', 'similarities']])
    return res

In [153]:
def get_next_task(df, user_query, top_n=3, to_print=True):
    s = search_docs(df, user_query, top_n=top_n, to_print=to_print)
    return s

In [154]:
def format_filter_context(suggested_task, filter_context):
    filter_context_string = ""
    
    if "sql" in suggested_task.strip():
        filter_context_string += str(filter_context.get("filter_context", "")) # Context: 
        filter_context_string += str(filter_context.get("sql_context", "").get("element_context", ""))
        filter_context_string += str(filter_context.get("sql_context", "").get("sql_context", ""))
        
        # filter_context_string += "Inferred Connection: ", filter_context.get("connection_name", "")
        # filter_context_string += "Inferred Schema: ", filter_context.get("schema_name", "")
        # filter_context_string += "Inferred Table: ", filter_context.get("table_name", "")
    
    if filter_context.get("validated_connection", "") != "":
        filter_context_string += "\nconnection_name: " + str(filter_context.get("validated_connection", ""))
    if filter_context.get("validated_schema", "") != "":
        filter_context_string += "\nschema_name: " + str(filter_context.get("validated_schema", ""))
    if filter_context.get("validated_table", "") != "":
        filter_context_string += "\nsearch string: " + str(filter_context.get("validated_table", ""))

    else:
        filter_context_string += filter_context.get("filter_context", "")
    return filter_context_string

In [155]:
# Groupchat 
allowed_transitions = {
    user_proxy: [sql_assistant, job_assistant, reviewer_assistant],
    sql_assistant: [executor, user_proxy],
    job_assistant: [executor, user_proxy],
    executor: [reviewer_assistant, user_proxy],
    reviewer_assistant: [user_proxy],
}

groupchat = autogen.GroupChat(
    agents=[user_proxy, sql_assistant, job_assistant, executor, reviewer_assistant], 
    messages=[], max_round=5,
    speaker_transitions_type="allowed", 
    allowed_or_disallowed_speaker_transitions=allowed_transitions,
    send_introductions=True
)

previous_history = []
previous_sql_output = ""

#previous_history[-3:]

##### run 

In [None]:
#prompt = "list distinct connections" 
#prompt = "whats the bigquery connection name?"
#prompt = "is there a samples schema in the bigquery connection?"
#prompt = "list the distinct schemas in the bigquery connection"
#prompt = "count the distinct tables in the bigquery connection"

prompt = "list the tables listed w/ claims in the name, use bigquery connection, samples schema"
#prompt = "run a dq job for the claims dents and claims detail table" 
prompt = "run a dq job for the all the claims tables listed in that connection and schema" 

#prompt = "what tables did you last mention? "
prompt = "check job status"

# fetch next task
next_task = get_next_task(docs_df, prompt, top_n=4, to_print=False)
print("NEXT TASK: ", next_task[['action', 'lookup_values', 'similarities']].to_markdown(index=False))
suggested_task = next_task.action.iloc[0]

try:
    previous_message = str(filter_context.get("sql_context", "").get("sql_context", ""))
    if previous_message != "":
        previous_history.append(previous_message)
        previous_sql_output = previous_message
    print("PREVIOUS SQL OUTPUT: ", previous_sql_output)

except:
    pass

# fetch context
filter_context = filter_step(prompt, suggested_task)
filter_context_string = format_filter_context(suggested_task, filter_context)

print("-"*100)
p = ""
p += "\nPrompt: " + prompt + "\n"
p += "Please assist with the prompt above. \n"

if previous_sql_output.strip() != "" and suggested_task.strip() != "general":
    p += previous_sql_output

if filter_context_string.strip() != "":
    p += "\nBest Guess Inferred Context: " + filter_context_string[0:1000]
    print(p)

cr, manager = chat(manager, p, groupchat)


### ```extra```

##### initial embeddings generated by openai


In [141]:

import numpy as np
from openai import OpenAI
import pandas as pd

# file "/tmp/snippets/actions.csv"
# input text column lookup_values
# output column embedding

client = OpenAI(max_retries=5)
EMBEDDING_MODEL="text-embedding-ada-002"

def cosine_similarity(a, b):
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

def get_embedding(text, model=EMBEDDING_MODEL):
   text = text.replace("\n", " ")
   return client.embeddings.create(input = [text], model=model).data[0].embedding

def search_docs(df, user_query, top_n=4, to_print=True):
    embedding = get_embedding(
        user_query,
        model=EMBEDDING_MODEL 
)
    df["similarities"] = df.embedding.apply(lambda x: cosine_similarity(x, embedding))

    res = (
        df.sort_values("similarities", ascending=False)
        .head(top_n)
    )
    if to_print:
        display(res)
    return res

# df = pd.read_csv("/tmp/snippets/actions.csv", index_col=0)
# df = df.dropna()
# df['embedding'] = df.lookup_values.apply(lambda x: get_embedding(x, model=EMBEDDING_MODEL))

# save embeddings
# df.to_csv("/tmp/snippets/actions_with_embeddings.csv")