In [1]:
import duckdb
import pandas as pd
import json
import yaml
import openai
from tqdm import tqdm
from vanna.remote import VannaDefault
import os

In [2]:
with open("config.yaml", "r") as stream:
    try:
        PARAM = yaml.safe_load(stream)
    except yaml.YAMLError as exc:
        print(exc)

openai.api_key  = PARAM['openai_api']
client = openai.OpenAI(api_key = PARAM['openai_api'])

vanna_ai_api = PARAM['vanna_ai']

def get_embedding(text, model="text-embedding-3-small"):
   text = text.replace("\n", " ")
   return client.embeddings.create(input = [text], model=model).data[0].embedding

In [3]:
con = duckdb.connect("drug.db")

#con = duckdb.connect("")


con.sql("INSTALL duckpgq FROM community;")
con.sql("INSTALL fts;")
con.sql("INSTALL vss;")

con.load_extension("duckpgq")
con.load_extension("fts")
con.load_extension("vss")


In [4]:
vn = VannaDefault(model='sixing', api_key=vanna_ai_api)

In [5]:
df = vn.get_training_data()
for index, row in tqdm(df.iterrows()):
  #print (row['id'])
  vn.remove_training_data(id=row['id'])

8it [00:02,  3.14it/s]


In [6]:
vn.connect_to_duckdb(url='drug.db')



True


In [7]:
vn.run_sql("""
CREATE PROPERTY GRAPH drug_graph
  VERTEX TABLES (
    Drug, Disorder, MOA
  )
EDGE TABLES (
  DrugDisorder 	SOURCE KEY (drug_cui) REFERENCES Drug (drug_cui)
                DESTINATION KEY (disorder_cui) REFERENCES Disorder (disorder_cui)
  LABEL MAY_TREAT,
  DrugMOA SOURCE KEY (drug_cui) REFERENCES Drug (drug_cui)
          DESTINATION KEY (moa_id) REFERENCES MOA (moa_id)
  LABEL HAS_MOA
);
          """)

Unnamed: 0,Success


In [8]:
vn.train(
    question="What is the MOA of medrysone?", 
    sql="""SELECT * FROM GRAPH_TABLE (drug_graph
        MATCH
        (d:Drug WHERE d.name = 'medrysone')-[h:HAS_MOA]->(m:MOA)
        COLUMNS (m.name AS moa_name)
      )
      LIMIT 5;""",
)

'408c039a3a0273ca92af988480d800b3-sql'

In [9]:
vn.train(ddl="""CREATE TABLE Trials (
            PostingID     INTEGER NOT NULL PRIMARY KEY,
            Sponsor    VARCHAR,
            StudyTitle     VARCHAR,
            Drug   VARCHAR,
            Disorder   VARCHAR,
            Phase   VARCHAR,
            LinkToSponsorStudyRegistry   VARCHAR,
            LinkToClinicalTrials   VARCHAR,
            cui  VARCHAR[],
            PreferredUMLSName   VARCHAR[],
            )""")

vn.train(ddl="""CREATE TABLE Drug (
            drug_cui  VARCHAR NOT NULL PRIMARY KEY,
            name  VARCHAR
            )""")

vn.train(ddl="""CREATE TABLE Disorder (
            disorder_cui  VARCHAR NOT NULL PRIMARY KEY,
            name  VARCHAR,
            definition  VARCHAR,
            definitionEmbedding FLOAT[1536]
            )""")

vn.train(ddl="""CREATE TABLE MOA (
            moa_id  VARCHAR NOT NULL PRIMARY KEY,
            name  VARCHAR
            )""")

vn.train(ddl="""CREATE TABLE DrugDisorder (
            drug_cui  VARCHAR NOT NULL REFERENCES Drug(drug_cui),
            disorder_cui  VARCHAR NOT NULL REFERENCES Disorder(disorder_cui)
            )""")

vn.train(ddl="""CREATE TABLE DrugMOA (
            drug_cui  VARCHAR NOT NULL REFERENCES Drug(drug_cui),
            moa_id  VARCHAR NOT NULL REFERENCES MOA(moa_id)
            )""")

vn.train(ddl="""
CREATE PROPERTY GRAPH drug_graph
  VERTEX TABLES (
    Drug, Disorder, MOA
  )
EDGE TABLES (
  DrugDisorder 	SOURCE KEY (drug_cui) REFERENCES Drug (drug_cui)
                DESTINATION KEY (disorder_cui) REFERENCES Disorder (disorder_cui)
  LABEL MAY_TREAT,
  DrugMOA SOURCE KEY (drug_cui) REFERENCES Drug (drug_cui)
          DESTINATION KEY (moa_id) REFERENCES MOA (moa_id)
  LABEL HAS_MOA
);
          """)

Adding ddl: CREATE TABLE Trials (
            PostingID     INTEGER NOT NULL PRIMARY KEY,
            Sponsor    VARCHAR,
            StudyTitle     VARCHAR,
            Drug   VARCHAR,
            Disorder   VARCHAR,
            Phase   VARCHAR,
            LinkToSponsorStudyRegistry   VARCHAR,
            LinkToClinicalTrials   VARCHAR,
            cui  VARCHAR[],
            PreferredUMLSName   VARCHAR[],
            )
Adding ddl: CREATE TABLE Drug (
            drug_cui  VARCHAR NOT NULL PRIMARY KEY,
            name  VARCHAR
            )
Adding ddl: CREATE TABLE Disorder (
            disorder_cui  VARCHAR NOT NULL PRIMARY KEY,
            name  VARCHAR,
            definition  VARCHAR,
            definitionEmbedding FLOAT[1536]
            )
Adding ddl: CREATE TABLE MOA (
            moa_id  VARCHAR NOT NULL PRIMARY KEY,
            name  VARCHAR
            )
Adding ddl: CREATE TABLE DrugDisorder (
            drug_cui  VARCHAR NOT NULL REFERENCES Drug(drug_cui),
            d

'336391-ddl'

In [10]:
# df_information_schema = vn.run_sql("SELECT * FROM INFORMATION_SCHEMA.COLUMNS")
# plan = vn.get_training_plan_generic(df_information_schema)
# vn.train(plan=plan, documentation="In our business, we define 'joint-related disorders' as 'Rheumatoid Arthritis' and 'Arthralgia'")

In [11]:
from vanna.flask import VannaFlaskApp
VannaFlaskApp(vn, allow_llm_to_see_data=True, chart=False).run()

Your app is running at:
http://localhost:8084
 * Serving Flask app 'vanna.flask'
 * Debug mode: on
None
None
None
None
None
{'function_name': 'get_treatable_diseases', 'description': 'What kind of diseases can {drug_name} treat?', 'arguments': [{'name': 'drug_name', 'description': 'The name of the drug', 'general_type': 'STRING', 'is_user_editable': True}], 'sql_template': 'SELECT DISTINCT d.name FROM Drug AS dr JOIN DrugDisorder AS dd ON dr.drug_cui = dd.drug_cui JOIN Disorder AS d ON dd.disorder_cui = d.disorder_cui WHERE dr.name = {drug_name}', 'post_processing_code_template': ''}
[{'function_name': 'get_treatable_diseases', 'description': 'What kind of diseases can {drug_name} treat?', 'post_processing_code_template': '', 'arguments': [{'name': 'drug_name', 'description': 'The name of the drug', 'general_type': 'STRING', 'is_user_editable': True, 'available_values': None}], 'sql_template': 'SELECT DISTINCT d.name FROM Drug AS dr JOIN DrugDisorder AS dd ON dr.drug_cui = dd.drug_cui 

In [None]:
con.close()