In [None]:
from IPython.display import Markdown, display
from llama_index import SQLDatabase, ServiceContext
import sqlite3
from llama_index.llms import OpenAI
from llama_index.indices.struct_store.sql_query import (
    SQLTableRetrieverQueryEngine,
)
from llama_index.objects import (
    SQLTableNodeMapping,
    ObjectIndex,
    SQLTableSchema,
)
from llama_index import VectorStoreIndex
from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
    select,
)
import os

db_dir = "files/db"
db_files = [os.path.join(db_dir, file) for file in os.listdir(db_dir) if file.endswith('.db')]

for db_file in db_files:
    engine = create_engine('sqlite:///' + db_file)
    conn = sqlite3.connect(db_file)
    cursor = conn.cursor()

    # Get all tables
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")

    tables = cursor.fetchall()
    for table in tables:
        print(table[0])

    table_name = table[0]
    print(f"\nTable: {table_name} in {db_file}")
    metadata_obj = MetaData()
    print(metadata_obj)
    # Get table schema
    print("Schema:")
    cursor.execute(f"PRAGMA table_info({table_name});")
    schema = cursor.fetchall()
    for column in schema:
        print(column)

    # Get the first five rows
    # print("\nFirst 5 rows:")
    cursor.execute(f"SELECT * FROM {table_name} LIMIT 5;")
    rows = cursor.fetchall()
    for row in rows:
        print(row)

    llm = OpenAI(temperature=0.1, model="gpt-3.5-turbo")
    service_context = ServiceContext.from_defaults(llm=llm)
    sql_database = SQLDatabase(engine)
    # set Logging to DEBUG for more detailed outputs
    table_node_mapping = SQLTableNodeMapping(sql_database)
    table_schema_objs = [
        (SQLTableSchema(table_name=str(table_name)))
    ]  # add a SQLTableSchema for each table

    obj_index = ObjectIndex.from_objects(
        table_schema_objs,
        table_node_mapping,
        VectorStoreIndex,
    )
    query_engine = SQLTableRetrieverQueryEngine(
        sql_database, obj_index.as_retriever(similarity_top_k=3)
    )
    response = query_engine.query("How many different proteins with lymphoid tissue?")
    display(Markdown(f"<b>{response}</b>"))
    # Close the cursor and connection
    cursor.close()
    conn.close()


In [1]:
# Create combined schema of tables

import sqlite3
import json
from sqlalchemy import create_engine, MetaData, Table, Column, String
import re
def combine_schemas(db_files):
    combined_schema = {}

    for db_file in db_files:
        engine = create_engine('sqlite:///' + db_file)
        conn = sqlite3.connect(db_file)
        cursor = conn.cursor()

        # Get all tables
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        tables = cursor.fetchall()

        for table in tables:
            table_name = table[0]
            metadata_obj = MetaData()
            
            # Get table schema
            cursor.execute(f"PRAGMA table_info({table_name});")
            schema = cursor.fetchall()

            # Create a Table object to store schema info
            table_obj = Table(table_name, metadata_obj)

            for column in schema:
                col_name, col_type = column[1], column[2]
                # Add column to the table object
                table_obj.append_column(Column(col_name, String))

            # Serialize table schema
            schema_info = [{"column_name": col.name, "data_type": str(col.type)} for col in table_obj.columns]
            combined_schema[f"{table_name} in {db_file}"] = schema_info

        conn.close()

    return combined_schema

def save_schema_to_json(combined_schema, filename="combined_schema.json"):
    with open(filename, "w") as file:
        json.dump(combined_schema, file, indent=4)

# Paths to your database files
db_files = ["files/db/CCLEGisticCNDB.db" , 'files/db/CCLEMutDB.db', 'files/db/CCLEVarDB.db']
all_schemas = combine_schemas(db_files)
save_schema_to_json(all_schemas)

# Now the schema is saved in 'combined_schema.json'


In [2]:
# A check
import re 
import logging

logging.basicConfig(level=logging.INFO)
log = logging.getLogger(__name__)

def extract_sql(llm_response: str) -> str:
    # If the llm_response contains a markdown code block, with or without the sql tag, extract the sql from it
    sql = re.search(r"```sql\n(.*)```", llm_response, re.DOTALL)
    if sql:
        log(f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(1)}")
        return sql.group(1)

    sql = re.search(r"```(.*)```", llm_response, re.DOTALL)
    if sql:
        log(f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(1)}")
        return sql.group(1)

    return llm_response

def is_sql_valid(sql: str) -> bool:
    # This is a check to see the SQL is valid and should be run
    # This simple function just checks if the SQL contains a SELECT statement

    if "SELECT" in sql.upper():
        return True
    else:
        return False


In [12]:
# LLM checks to write a SQL query

import replicate
import pandas as pd
import json
import os
import config
import importlib
importlib.reload(config)
from config import config, reset_config
from dotenv import load_dotenv
load_dotenv()
# import logging

# # Set the logging level for httpx to WARNING
# logging.getLogger("httpx").setLevel(logging.WARNING)

folder_path = 'files'
if not os.path.exists(folder_path):
    os.makedirs(folder_path)

reset_config()
config.set_mode('dbs')

INSTRUCTION = config.INSTRUCTION
F_NAME = config.F_NAME

def load_file(file_path):
    with open(file_path, 'r') as file:
        return json.load(file)

perturbations = load_file(config.perturbations)
knowledgebase = load_file(config.knowledgebase)
db_desc = load_file(config.db_layout)

df = pd.read_excel(config.questions)
df.to_excel(config.q_original, index=False)

df['Question'] = df['Question'].str.strip()  # Removes leading/trailing whitespace

# Check for duplicate questions
duplicates = df.duplicated(subset=['Question'], keep=False)
if duplicates.any():
    print("Duplicates found. Removing duplicates.")
    df = df.drop_duplicates(subset=['Question'], keep='first')
    df.to_excel(config.q_db, index=False)
else:
    print("No duplicates found.")

# DataFrame to store the results
results_df = pd.DataFrame(columns=['Model', 'Question', 'Response'])

models = {
    # "qwen-14b": "nomagick/qwen-14b-chat:f9e1ed25e2073f72ff9a3f46545d909b1078e674da543e791dec79218072ae70",
    # "falcon-40b": "joehoover/falcon-40b-instruct:7d58d6bddc53c23fa451c403b2b5373b1e0fa094e4e0d1b98c3d02931aa07173",
    # "yi-34b": "01-ai/yi-34b-chat:914692bbe8a8e2b91a4e44203e70d170c9c5ccc1359b283c84b0ec8d47819a46",
    "mistral-7b": "mistralai/mistral-7b-instruct-v0.2:f5701ad84de5715051cb99d550539719f8a7fbcf65e0e62a3d1eb3f94720764e",
    # "llama2-70b": "meta/llama-2-70b-chat",
    # "openhermes2": "antoinelyset/openhermes-2.5-mistral-7b:d7ccd25700fb11c1787c25b580ac8d715d2b677202fe54b77f9b4a1eb7d73e2b",
    # "mixtral-instruct": "mistralai/mixtral-8x7b-instruct-v0.1:2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e",
    # "deepseek_33bq": "kcaverly/deepseek-coder-33b-instruct-gguf:ea964345066a8868e43aca432f314822660b72e29cab6b4b904b779014fe58fd",
    }

prompt_for_qwen="""<|im_start|>system\n {INSTRUCTION}. Please write the appropriate SQL query using these three tables. The tables can be understood as {config.db_layout}. Try to answer the following question. The SQL should be returned within ''' SQL query '''. <|im_end|>\n<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n"""
prompt_for_hermes = """[
{{
  "role": "system",
  "content": "{INSTRUCTION}. Please write the appropriate SQL query using these three tables. The tables can be understood as {config.db_layout}. Try to answer the following question. The SQL should be returned within ''' SQL query ''' " 
}},
{{
  "role": "user",
  "content": {question}
}}
]"""

# Iterate through each model
for model_key, model_value in models.items():
    responses = []

    for index, row in df.iterrows():
        qn = row['Question']
        question = json.dumps(qn)

        if model_key == "yi-34b":  # Yi model
            prompt = prompt_for_qwen.format(INSTRUCTION=INSTRUCTION, question=question)
        if model_key == "qwen-14b":  # Qwen model
            prompt = prompt_for_qwen.format(INSTRUCTION=INSTRUCTION, question=question)
        elif model_key == "openhermes2":  # Hermes model
            prompt = prompt_for_hermes.format(INSTRUCTION=INSTRUCTION, question=question)
        else:
            prompt = f"{INSTRUCTION}. Please write the appropriate SQL query using these three tables. The tables can be understood as {config.db_layout}. Try to answer the following question. The SQL should be returned within ''' SQL query '''. {question}"

        try:
            print(prompt)
            output = replicate.run(
                model_value,
                input={
                  "debug": False,
                #   "top_k": 50,
                  "top_p": 0.9,
                  "prompt": prompt,
                  "temperature": 0.7,
                  "max_new_tokens": 500,
                  "min_new_tokens": -1
                }
            )
            response = ""
            response_parts = []  # Initialize an empty list to collect string representations
            print(f"Output is: {output}")
            for item in output:
                item_str = str(item).strip()  # Convert item to string
                response += item_str if len(item_str) == 1 else f" {item_str}"

            response = response.strip()
            print(f"Output is: {response}")
            extracted_sql = extract_sql(response) # Get only the SQL query
            print(f"Output is: {extracted_sql}")
            valid = is_sql_valid(response) # Check if the SQL query is valid
            response = {
                "response": extracted_sql,
                "is_valid": valid
            }

        except Exception as e:
            response = f"Error: {e}"

        new_row = pd.DataFrame({'Model': [model_key], 'Question': [qn], 'Response': [extracted_sql]})
        results_df = pd.concat([results_df, new_row], ignore_index=True)

        if index % 20 == 0:  # Save every 10 questions, adjust as needed
            results_df.to_excel(config.results_file_path, index=False, sheet_name='Sheet1')
            
results_df.to_excel(config.results_file_path, index=False, sheet_name='Sheet1')

No duplicates found.
You are an exceptional computational biologist and genomics expert and know everything about drug discovery.. Please write the appropriate SQL query using these three tables. The tables can be understood as utils/database_description.json. Try to answer the following question. The SQL should be returned within ''' SQL query '''. "Which cell lines have high dependency for the target of interest for the gene POLR3E?"
Output is: <generator object Prediction.output_iterator at 0x177adec50>
Output is: ``` SQL   -- Ass uming the following table names and column names based on the given context  -- Table: cell_ lines  -- Columns: id, name, gene_ depend ency_ sc ores  -- Table: targets  -- Columns: id, name, gene_ symbol  -- Table: gene_ target_ map  -- Columns: cell_ line_ id, target_ id, score   -- SQL query  SELECTc. name  FROM cell_ linesc  JO IN gene_ target_ mapg tm ONc. id=g tm. cell_ line_ id  JO IN targetst ONg tm. target_ id=t. id W HEREt.g ene_ symbol=' POLR3E' 

In [None]:
# GPT-4 writes a SQL query

import pandas as pd
import json
import openai
import requests
from openai import OpenAI
import time
from dotenv import load_dotenv
load_dotenv()
import os
folder_path = 'files'
if not os.path.exists(folder_path):
    os.makedirs(folder_path)

with open('config.json', 'r') as config_file:
    config = json.load(config_file)

INSTRUCTION = config['instructions']
F_NAME = config["name"]
GPT_MODEL = config["GPT_MODEL"]
INPUT_CSV_PATH = 'files/questions_db.xlsx'
OUTPUT_CSV_PATH = f'files/{F_NAME}_results_gpt4_db.xlsx'

client = OpenAI()
def show_json(obj):
    print(json.loads(obj.model_dump_json()))

assistant = client.beta.assistants.create(
    name=f"{F_NAME} AI Evaluator via reading DB",
    instructions=INSTRUCTION,
    model=GPT_MODEL,
)
show_json(assistant)

# Utility functions
def read_csv(file_path):
    return pd.read_excel(file_path)

def process_data_for_gpt(data):
    prompts = []
    for _, row in data.iterrows():
        question = row['Question']
        prompt = f"Please write the appropriate SQL query using these three table schemas {all_schemas} to answer the following question. The SQL should be returned within ''' SQL query '''.:\n\n{question}"
        prompts.append(prompt)
    return prompts

def submit_message_and_create_run(assistant_id, prompt):
    thread = client.beta.threads.create() # If you replace this globally it appends all answers to the one before.
    client.beta.threads.messages.create(thread_id=thread.id, role="user", content=prompt)
    return client.beta.threads.runs.create(thread_id=thread.id, assistant_id=assistant_id), thread

def wait_on_run_and_get_response(run, thread):
    while run.status == "queued" or run.status == "in_progress":
        run = client.beta.threads.runs.retrieve(thread_id=thread.id, run_id=run.id)
        time.sleep(0.5)
    messages = client.beta.threads.messages.list(thread_id=thread.id, order="asc")
    return [m.content[0].text.value for m in messages if m.role == 'assistant']

def create_output_csv(data, responses, model_name, interim_csv_path):
    new_rows = []
    for question, response in zip(data['Question'], responses):
        new_rows.append({'Model': model_name, 'Question': question, 'Response': response})
    new_data = pd.DataFrame(new_rows)
    new_data.to_excel(interim_csv_path, index=False)

data = read_csv(INPUT_CSV_PATH)
prompts = process_data_for_gpt(data)
ASSISTANT_ID = assistant.id

responses = []
for prompt in prompts:
    run, thread = submit_message_and_create_run(ASSISTANT_ID, prompt)
    response = wait_on_run_and_get_response(run, thread)
    if isinstance(response, list):
        response = ' '.join(map(str, response))
    response = response.replace("\\\\n", "\\n")
    response = response.strip()
    extracted_sql = extract_sql(response) # Get only the SQL query
    valid = is_sql_valid(response) # Check if the SQL query is valid
    response = {
        "response": extracted_sql,
        "is_valid": valid
    }
    print(response)
    responses.append(response)

create_output_csv(data, responses, GPT_MODEL, OUTPUT_CSV_PATH)