# Import packages and initialize GCP connection

In [6]:
import pandas as pd
import vertexai.preview.generative_models as generative_models

from google.cloud import aiplatform
from google.oauth2 import service_account

from typing import List, Dict

from IPython.display import display, Markdown

from vertexai.generative_models import GenerativeModel, Part, Tool, FunctionDeclaration, GenerationResponse

from langchain_community.graphs import Neo4jGraph

In [7]:
credentials = service_account.Credentials.from_service_account_file("gcp-service-account.json")

aiplatform.init(project = 'gcp-project-id',
                credentials = credentials)

# Functions

In [8]:
def create_node_file(source_path, node_name, destiny_path="graph_files/"):
    """Creates the CSV file of the selected node.
    Inputs:
        - source_path: The path of the excel file with the original data
        - node_name: Specifies the type of node that will be created ['Adverse Event', 'SOC', 'Drug']
        - destiny_path: The path where the node file will be stored"""

    # Get the data depending on the type of node
    if node_name == "Drug":
        data = pd.read_excel(source_path,
                             sheet_name = "Data")
        data = pd.DataFrame(data = {"drug": data.columns[5:]})
    
    elif node_name == "SOC":
        data = pd.read_excel(source_path,
                             sheet_name = "Data",
                             usecols = ['SOC ID', 'System Organ Class'])[:-2]
    
    else:
        data = pd.read_excel(source_path,
                             sheet_name = "Data",
                             usecols = ['Adverse Event ID', 'Adverse Event'])[:-2]

    # Remove duplicates
    data.drop_duplicates(inplace = True,
                         ignore_index = True)

    # Create the CSV file
    node_path = destiny_path + node_name + '_node.csv'
    data.to_csv(node_path, index = False)

In [9]:
def create_relationship(base_path, rel_name="CAUSES", destiny_path="graph_files/", date=""):
    """Creates the CSV file for the selected relationship.
    Inputs:
        - base_path: Path of the excel file with the base data
        - rel_name: Name of the relationship (ideally all name in UPPER CASE)
        - destiny_path: Path where the CSV file will be created (must be the same as the nodes)
        - date: The date the data was extracted from OFFX (only necessary for CAUSES relationship)"""

    # Read the main file
    base_data = pd.read_excel(base_path,
                              sheet_name = "Data")

    # Build the data depending on the type of relationship
    if rel_name == "CAUSES":
        # Get the Drug node data
        drug_node_data = pd.read_csv(destiny_path + "Drug_node.csv")

        # Create the dataset with the number of occurrances
        occurrences = []
        ae_ids = []
        drug_names = []

        for drug in range(len(drug_node_data)):
            for ae in range(len(base_data)):
                ae_ids.append(base_data["Adverse Event ID"][ae])
                drug_names.append(drug_node_data.iloc[drug, 0])
                # The drug data starts at the 5th column of the base file
                occurrences.append(base_data.iloc[ae, 5+drug])

        dates = [pd.to_datetime(date, format="%m/%d/%Y")] * len(occurrences)

        rel_data = pd.DataFrame(data = {"ae_id": ae_ids,
                                        "drug": drug_names,
                                        "reports": occurrences,
                                        "date": dates})
    else:
        rel_data = base_data[["Adverse Event ID", "SOC ID"]]

    # Remove null values
    rel_data = rel_data.dropna(how = 'any',
                               axis = 0)
    
    # Create the CSV file
    rel_path = destiny_path + rel_name + '_relationship.csv'
    rel_data.to_csv(rel_path, index = False)

In [10]:
def get_neo4j_arch():
    """Gets the schema of the Neo4j dataset"""
    
    graph = Neo4jGraph(url = "bolt://localhost:7687",
                       username = "neo4j",
                       password = "12345678")
    
    query = "CALL apoc.meta.schema() YIELD value RETURN value"
    
    return graph.query(query)[0]['value']

In [11]:
def extract_function_calls(response: GenerationResponse) -> List[Dict]:
    """Get the function calls detected by the LLM
    Inputs:
        - response: The answer you get from the LLM"""
    function_calls = []
    
    if response.candidates[0].function_calls:
        for function_call in response.candidates[0].function_calls:
            function_call_dict = {function_call.name: {}}
            for key, value in function_call.args.items():
                function_call_dict[function_call.name][key] = value
            function_calls.append(function_call_dict)
    return function_calls

In [12]:
def run_neo4j_query(query, db_url="bolt://localhost:7687", user="neo4j", password="12345678"):
    """Connects to a Neo4j dataset and executes a query
    Inputs:
        - query: The query to be executed on Neo4j (must be Cypher code)
        - db_url: The url of the running Neo4j dataset
        - user: The user connected to the Neo4j dataset
        - password: The password for connecting to the Neo4j dataset"""

    # Connect to the database
    faers_graph = Neo4jGraph(url = db_url,
                             username = user,
                             password = password)
    
    return faers_graph.query(query)

In [13]:
def initialize_gemini(tools, temp=0, max_tkn=8192, p=0.95):
    """Initialize the LLM to get resopnses from it.
    Inputs:
        - tools: The list of tools that the LLM will have available to execute
        - temp: The LLM temperature
        - max_tkn: Maximum number of output tokens of the LLM
        - p: Answer selection probability for the LLM"""
    
    gemini = "gemini-1.5-flash-001"

    safety_settings = {generative_models.HarmCategory.HARM_CATEGORY_UNSPECIFIED: generative_models.HarmBlockThreshold.BLOCK_NONE,
                       generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_NONE,
                       generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_NONE,
                       generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_NONE,
                       generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: generative_models.HarmBlockThreshold.BLOCK_NONE}
    
    model = GenerativeModel(gemini,
                            generation_config = {"temperature": temp,
                                                 "max_output_tokens": max_tkn,
                                                 "top_p": p},
                            safety_settings = safety_settings,
                            tools = tools)
    return model

In [14]:
def call_gemini(model, question):
    """Makes an API call to use the LLM and get a response back
    Inputs:
        - model: The LLM that wants to be used
        - question: The user question"""

    chat = model.start_chat(response_validation = False)

    # Get the tool calls from the LLM
    res = chat.send_message(question)
    function_calls = extract_function_calls(res)
    print(function_calls)

    # Gather each tool call individually (assuming the LLM only has the Neo4j tool)
    tool_responses = []

    for function_call in function_calls:
        tool_responses.append(run_neo4j_query(function_call["query_neo4j"]["query"]))

    # Give back to the LLM the tool responses
    final_response = chat.send_message(Part.from_function_response(name = "query_neo4j",
                                                                   response = {"content": tool_responses}))
    return final_response.text

# Testing the code

## Creating the dataset

In [110]:
# Create the node files
source_path = 'data/source_file.xlsx'
destiny_path = 'graph_files/'

print("Creating node files")
[create_node_file(source_path, node, destiny_path) for node in ["Adverse Event", "SOC", "Drug"]]
print("\t> Node files created\n")

# Create the relationship files
print("Creating relationship files")
[create_relationship(source_path, rel, destiny_path, "06/13/2024") for rel in ["CAUSES", "MANIFESTS_IN"]]
print("\t> Relationship files created\n")

Creating node files
	> Node files created

Creating relationship files
	> Relationship files created



## Connecting to the dataset

In [15]:
# Create the tool definition for the LLM
query_neo4j = FunctionDeclaration(name = "query_neo4j",
                                  description = "Get data from a Neo4j dataset",
                                  parameters = {"type": "object",
                                                "properties": {"query": {"type": "string",
                                                                         "description": f'The query to be executed on Neo4j (must be Cypher code) according to the following schema: {get_neo4j_arch()}'}}})
neo4j_tool = Tool(function_declarations = [query_neo4j])

# Get the model
model = initialize_gemini([neo4j_tool])

In [16]:
# Answer your question
question = "How many different soc are?"

res = call_gemini(model, question)
print(res)

[{'query_neo4j': {'query': 'MATCH (n:SOC) RETURN count(DISTINCT n)'}}]
There are 24 different SOCs. 



In [54]:
question = "Give me the name and number of reports of the drugs that provoque Anaemia"

res = call_gemini(model, question)
print(res)

The following drugs provoke Anaemia:

* DTRM-12\nPhase II  with 1 report
* MH048\nPhase I  with 1 report
* poseltinib\nPhase II  with 1 report
* HWH486\n  with 1 report
* orelabrutinib\nLaunched 2021 with 1 report
* fenebrutinib\nPhase III  with 1 report
* zanubrutinib\nLaunched 2019 with 1 report
* LP-168\nPhase II  with 1 report
* luxeptinib\nPhase I/II  with 1 report
* spebrutinib\nPhase I  with 1 report
* evobrutinib\nPhase III  with 1 report
* nemtabrutinib\nPhase I/II  with 1 report
* edralbrutinib\nPhase II  with 1 report
* elsubrutinib\nPhase II  with 1 report
* acalabrutinib\nLaunched 2017 with 1 report
* TL-895\nPhase II  with 1 report
* vecabrutinib\nPhase I/II  with 1 report
* tolebrutinib\nPhase III  with 1 report
* BMS-986142\nPhase II  with 1 report
* M7583\nPhase I/II  with 1 report
* DZD-8586\n  with 1 report
* tirabrutinib\nLaunched 2020 with 1 report
* ibrutinib\nLaunched 2013 with 1 report
* pirtobrutinib\nLaunched 2023 with 1 report
* abivertinib\nPhase II  with 1 