In [31]:
!pip install -r requirements.txt



In [32]:
import os
import openai
import snowflake.connector
from langchain_openai import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from dotenv import load_dotenv
from snowflake.sqlalchemy import URL
from sqlalchemy import create_engine
import pandas as pd

# Load environment variables from a .env file
load_dotenv()
# Load OpenAI key securely:
openai_api_key = os.getenv("OPENAI_API_KEY")

In [33]:
engine = create_engine(URL(
    account=os.getenv("SNOWFLAKE_ACCOUNT"),
    user=os.getenv("SNOWFLAKE_USER"),
    password=os.getenv("SNOWFLAKE_PASSWORD"),
    warehouse=os.getenv("SNOWFLAKE_WAREHOUSE"),
    database=os.getenv("SNOWFLAKE_DATABASE"),
    schema=os.getenv("SNOWFLAKE_SCHEMA"),
    role=os.getenv("SNOWFLAKE_ROLE")
), pool_size=10, max_overflow=20, connect_args={'insecure_mode': True})

connection = engine.connect()

In [27]:
# 1. Define your schema description and a sample question
schema_description = """
Table: SALES
Columns: REGION, TOTAL_SALES, SALE_DATE
"""

question = "Show me total sales by region for last month"

# 2. LLM
llm = ChatOpenAI(
    temperature=0,
    model_name="gpt-4o-mini",
    openai_api_key=openai_api_key
)

# 3. Prompt template
sql_prompt = PromptTemplate(
    input_variables=["schema_description", "question"],
    template="""
You are an expert Snowflake SQL generator.
Given the following Snowflake schema:

{schema_description}

Convert the business question below into a valid Snowflake SQL query.

Business Question: {question}

Only output the SQL query, nothing else.
"""
)

# 4. Chain (new style)
chain = sql_prompt | llm

# 5. Call with correct keys
result = chain.invoke({
    "schema_description": schema_description,
    "question": question
})

sql_query = result.content
print("Generated SQL:")
print(sql_query)

Generated SQL:
```sql
SELECT REGION, SUM(TOTAL_SALES) AS TOTAL_SALES
FROM SALES
WHERE SALE_DATE >= DATE_TRUNC('MONTH', DATEADD('MONTH', -1, CURRENT_DATE()))
  AND SALE_DATE < DATE_TRUNC('MONTH', CURRENT_DATE())
GROUP BY REGION;
```


In [None]:
def generate_sql(schema_description: str, question: str) -> str:
    sql_query = chain.run(schema_description=schema_description, question=question)
    return sql_query.strip()

def run_query(sql_query: str):
    df = pd.read_sql(sql_query, connection)  # use your SQLAlchemy connection
    return df