In [1]:
import os
import json
import gradio as gr
import pandas as pd
from dotenv import load_dotenv
from langchain.document_loaders import TextLoader
from langchain.schema import Document
from langchain.document_loaders import TextLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.chains import ConversationalRetrievalChain
from langchain.prompts import PromptTemplate
from langchain_community.chat_models import ChatOpenAI
from langchain.memory import ConversationBufferMemory
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings
from langchain_community.llms import HuggingFaceHub


In [2]:
load_dotenv(override=True)
os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-if-not-using-env')

In [3]:
def load_all_schema_files(folder_path):
    documents = []

    for filename in os.listdir(folder_path):
        file_path = os.path.join(folder_path, filename)
        ext = os.path.splitext(filename)[1].lower()

        if ext == ".md":
            loader = TextLoader(file_path, encoding="utf-8")
            docs = loader.load()
            documents.extend(docs)

        elif ext in [".xls", ".xlsx"]:
            df = pd.read_excel(file_path)
            for _, row in df.iterrows():
                content = "\n".join([f"{col}: {row[col]}" for col in df.columns])
                documents.append(Document(page_content=content))

        else:
            print(f"Skipped unsupported file: {filename}")

    return documents

In [4]:
SCHEMA_FOLDER_PATH = r"C:your-database-path"  # フォルダ名を指定

if not os.path.exists(SCHEMA_FOLDER_PATH):
    print(f"Error: Folder '{SCHEMA_FOLDER_PATH}' not found.")
    exit()

schema_documents = load_all_schema_files(SCHEMA_FOLDER_PATH)

In [5]:
SYSTEM_PROMPT = """
You are a skilled SQL engineer. Based on the table schema information below and the user's question, generate an appropriate SQL query.

Guidelines:
- Output only the SQL query. Do not include any explanations or additional text.
- Use JOIN clauses where foreign key relationships exist.
- Select the most relevant columns and tables to fulfill the user's request.

Schema Information:
{context}

Question:
{question}
"""
PROMPT_TEMPLATE = PromptTemplate(
    input_variables=["context", "question"],
    template=SYSTEM_PROMPT
)


In [7]:
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
schema_chunks = text_splitter.split_documents(schema_documents)
embedding_function = OpenAIEmbeddings()
db = Chroma.from_documents(schema_chunks, embedding_function)
retriever = db.as_retriever()

llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0.7)

  llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0.7)


In [8]:
# --- Gradio UIのための関数定義 ---
def generate_sql_chat(user_question, chat_history):
    # スキーマ情報から関連するチャンクを検索
    relevant_schemas = retriever.get_relevant_documents(user_question)
    schema_context = "\n".join([doc.page_content for doc in relevant_schemas])
    
    # プロンプトを構築
    prompt_with_context = PROMPT_TEMPLATE.format(context=schema_context, question=user_question)

    # LLMを呼び出してSQLを生成
    response = llm.invoke(prompt_with_context)
    
    # Gradioは文字列で返す
    return response.content.strip()

In [9]:
# Gradioインターフェースの設定
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
    gr.Markdown("# SQL生成AIアシスタント")
    gr.Markdown("データ分析の質問を入力すると、SQLクエリを自動で生成します。")
    
    chatbot = gr.Chatbot()
    msg = gr.Textbox(label="質問を入力")
    
    def respond(message, chat_history):
        # AIからの応答を取得
        sql_query = generate_sql_chat(message, chat_history)
        
        # SQLをコードブロックとしてフォーマット
        formatted_sql = f"```sql\n{sql_query}\n```"
        
        # チャット履歴に追加
        chat_history.append((message, formatted_sql))
        return "", chat_history
    
    msg.submit(fn=respond, inputs=[msg, chatbot], outputs=[msg, chatbot])

  chatbot = gr.Chatbot()


In [10]:
# サーバーを起動
demo.launch()

* Running on local URL:  http://127.0.0.1:7860
* To create a public link, set `share=True` in `launch()`.




  relevant_schemas = retriever.get_relevant_documents(user_question)
