# SETUP FOR TAG

This notebook will be used for TAG operations with the PostgreSQL database's `strava_api` schema. The general flow is as follows:

1. The user asks a question.
2. The intent of the question is identified and the question is routed accordingly.
3. In the case of a TAG intent being identified (e.g., a user asking about how many runs they've done in the last month), the LLM will be invoked to generate a SQL query.
4. The SQL query is executed and, assuming no errors, the results are fed back into an additional LLM call to generate a response to be sent back to the user.
5. The resulting answer is returned to the user.

## DATABASE CONNECTION

In [72]:
# Establish the service
from os import path as ospath
from sys import path as syspath
syspath.append(ospath.abspath("..")) # Establishing root to root python directory

from services.database import DatabaseService
db_service = DatabaseService()

In [73]:
# Test query
from models.athlete import Athlete
session = db_service.get_session()
athlete_name = "Patrick Lister"
try:
    result = session.query(Athlete).filter_by(athlete_name=athlete_name).first()
    print(f"Fetched athlete with ID {result.athlete_id} and name {result.athlete_name}: {result}")
    session.close()
except Exception as e:
    print(f"Error fetching data: {e}")

Fetched athlete with ID 41580846 and name Patrick Lister: <Athlete(athlete_id=41580846, athlete_name=Patrick Lister, email=PListerJr@gmail.com)>


## LLM SETUP

In [74]:
# Set up client
from os import getenv
from openai import OpenAI
from dotenv import load_dotenv
load_dotenv()

client = OpenAI(api_key=getenv("OPENAI_API_KEY"))

# Test completion
use_streaming = False
response = client.chat.completions.create(
    model="gpt-4o-mini",
    messages=[
        {"role": "developer", "content": "You are a helpful assistant."},
        {"role": "user", "content": "Hi!"}
    ],
    stream=use_streaming
)

if use_streaming:
    print("Assistant's streamed response: ")
    for chunk in response:
        print(chunk.choices[0].delta)
else:
    print(f"Assistant's response: {response.choices[0].message}")

Assistant's response: ChatCompletionMessage(content='Hello! How can I assist you today?', refusal=None, role='assistant', audio=None, function_call=None, tool_calls=None)


# TAG (Table Augmented Generation)

In [75]:
# Establish TAG prompt and acquire user input
tag_prompt = """
### Instructions:
You are an expert SQL assistant. Your task is to generate an optimized SQL query based on the provided database schema and the user's request. 
The query should be well-structured, efficient, and free of errors. If any assumptions are necessary, clarify them before proceeding.

### Database Schema(s):
{schema_description}

### User Request:
"{user_question}"

### Query Constraints:
- Use appropriate SQL joins if multiple tables are involved.
- Apply filtering conditions (`WHERE`, `HAVING`) based on the user's request.
- Use `LIMIT` when the user requests a subset of results.
- Ensure the query is optimized and avoids unnecessary computations.
- If aggregation is required, use `GROUP BY` appropriately.
- Use aliases for readability where necessary.
- Format the query with proper indentation for clarity.

### Output:
Provide only the final SQL query without explanations or additional text.
"""
user_input = input("Enter your question: ")

In [76]:
def clean_query(query: str):
    """
    Cleans the LLM-generated SQL query by removing Markdown-style formatting.
    """
    return (
        query.strip()  # Remove leading/trailing whitespace
        .replace("```sql", "")  # Remove opening Markdown SQL block
        .replace("```", "")  # Remove closing Markdown block
        .strip()  # Trim any remaining spaces
    )

In [77]:
# Establish the schema (manually for now)
from models.activity import Activity
from models.athlete import Athlete
activity_desc = Activity().convert_to_schema_description()
athlete_desc = Athlete().convert_to_schema_description()
schema_desc = f"{activity_desc}\n{athlete_desc}"

In [78]:
from sqlalchemy import text, Sequence, Row, Any
from tenacity import retry, stop_after_attempt, wait_random_exponential
error_msg = None
query_to_execute = None
debug_mode = True

@retry(
    stop=stop_after_attempt(5),
    wait=wait_random_exponential(min=1, max=10),
)
def execute_query(schema_desc: str) -> tuple[Sequence[Row[Any]], int, int]:
    """
    Executes the generated SQL query and returns the results.
    """
    global error_msg, query_to_execute

    # Generate a query based on the user's question
    messages: list[dict[str, str]] = []
    if error_msg:
        messages.append({"role": "developer", "content": error_msg})
    messages.append(
        {"role": "user", "content": f"{tag_prompt.format(schema_description=schema_desc, user_question=user_input)}"}
    )
    if debug_mode:
        print(f"Messages being fed in to the LLM:\n{messages}")
    query_result = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=messages + [  # Properly concatenate list instead of using `append`
            {"role": "user", "content": f"{tag_prompt.format(schema_description=schema_desc, user_question=user_input)}"}
        ],
        store=True
    )

    # Completion ID
    completion_id = query_result.id
    if debug_mode:
        print(f"Chat completed: {completion_id}")

    # Clean up the query so it's executable
    query = query_result.choices[0].message.content
    query_to_execute = clean_query(query)

    # Execute the query
    session = db_service.get_session()
    try: 
        print(f"\nExecuting this generated query: {query_to_execute}")
        result = session.execute(text(query_to_execute)).fetchall()
    except Exception as e:
        error_msg = f"An error occurred while executing this query: {query_to_execute}.\nHere is the error: {e}\nPlease generate a query to resolve this issue.\n"
        print(error_msg)
        db_service.close_session() # Close session before retry
        raise e # Hit the retry mechanism
    db_service.close_session()
    return result, len(result), completion_id

In [None]:
display_results = True

try:
    result, num_rows, completion_id = execute_query(schema_desc=schema_desc)

    # Display the results
    if display_results:
        print(f"\nReceived {num_rows} {"rows" if num_rows != 1 else "row"}: ")
        for row in result:
            print(row)
    
    # Return an answer to the user
    # TODO - Fix this. Need to feed in full context of conversation. Might be easier to just save it globally
    # or within a class. For now, trying to use the OpenAI API...
    past_messages = client.chat.completions.messages.list(completion_id=completion_id)
    messages: list[dict[str, str]] = []
    for message in past_messages.data:
        messages.append(
            {"role": message["role"], "message": message["content"]}
        )
    messages.append(
        {
            "role": "developer", 
            "content": f"""
                Based on this answer retrieved from a database schema containing running data: {result}
                Return a consise but accurate answer to the user for their question: {user_input}
                
                Importantly, I am NOT asking you to retrieve from your training data.
            """
        }
    )
    ai_response = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=messages
    ).choices[0].message.content
    print(f"\nAI response: {ai_response}")

except Exception as e:
    print(f"The last retry attempt falied: {e}")

Messages being fed in to the LLM:
[{'role': 'developer', 'content': 'An error occurred while executing this query: SELECT COUNT(a.activity_id) AS run_count\nFROM strava_api.activities a\nJOIN strava_api.athletes ath ON a.athlete_id = ath.athlete_id\nWHERE ath.athlete_name = \'Patrick Lister\'\n  AND a.year = 2024\n  AND a.wkt_type = \'run\';  -- Assuming \'run\' is the designation for running activities.\nHere is the error: (psycopg2.errors.InvalidTextRepresentation) invalid input syntax for type integer: "run"\nLINE 6:   AND a.wkt_type = \'run\';  -- Assuming \'run\' is the designat...\n                           ^\n\n[SQL: SELECT COUNT(a.activity_id) AS run_count\nFROM strava_api.activities a\nJOIN strava_api.athletes ath ON a.athlete_id = ath.athlete_id\nWHERE ath.athlete_name = \'Patrick Lister\'\n  AND a.year = 2024\n  AND a.wkt_type = \'run\';  -- Assuming \'run\' is the designation for running activities]\n(Background on this error at: https://sqlalche.me/e/20/9h9h)\nPlease gene