# Init

In [None]:
%load_ext autoreload

In [None]:
from app import app
context = app.app_context()
context.push()

## Imports

In [None]:
import yaml
from utils import formatters # Provide jupyter output formating

## Inputs

In [None]:
query="What is the relationship between Zn2+ and glycolate?"

## Initial check

Current response:

In [None]:
import openai
response = openai.ChatCompletion.create(
    messages=[
        { "role": "user", "content": query }
    ],
    model="gpt-3.5-turbo",
    temperature=0
)
response

# Extract core terms

## Init model

In [None]:
from langchain.chat_models import ChatOpenAI
llm = ChatOpenAI(temperature=0)
llm

## Pepare prompts

In [None]:
from langchain.prompts import PromptTemplate
from langchain.prompts.chat import (
    ChatPromptTemplate,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
)
QUERY_KEY = 'query'

core_terms_prompt_template = ChatPromptTemplate.from_messages([
    SystemMessagePromptTemplate(
        prompt=PromptTemplate(
            template="Always response with comma separated list of chemical or biological terms identified in prompt.",
            input_variables=[]
        )
    ),
    HumanMessagePromptTemplate(
        prompt=PromptTemplate(
            template=f"{{{QUERY_KEY}}}",
            input_variables=[QUERY_KEY],
        )
    )
])
core_terms_prompt_template

## Create output parser

In [None]:
from langchain.output_parsers import CommaSeparatedListOutputParser
output_parser = CommaSeparatedListOutputParser()
output_parser

## Create chain

In [None]:
from langchain import LLMChain
CORE_TERMS_KEY = 'core_terms'
core_terms_chain = LLMChain(
    llm=llm,
    prompt=core_terms_prompt_template,
    output_parser=output_parser,
    output_key=CORE_TERMS_KEY,
    verbose=True
)
core_terms_chain

## Run

In [None]:
core_terms = core_terms_chain.run(query)
core_terms

# Find core terms eids

## Create chain construct to query db

### Prepare graph results output parser
Unwraps query output

In [None]:
from llmlib.utils.output_parsers.graph_results import GraphResultsOutputParser

### Prepare chain construct with output parser

In [None]:
from llmlib.utils.chains.graph_query import GraphQueryChain

## Get graph reference

In [None]:
from llmlib.database import Neo4j

In [None]:
graph = Neo4j().graph()

## Create chain

In [None]:
query_output_key = 'output'
EID_MAPPING_KEY = 'eid_mapping'
find_eids_chain = GraphQueryChain(
    query=f"""
UNWIND ${CORE_TERMS_KEY} AS term
MATCH (s:Synonym {{lowercase_name: toLower(term)}})<-[:HAS_SYNONYM]-(n)
WHERE n.eid IS NOT NULL
WITH DISTINCT s.name as term, collect(n.eid) as eids
RETURN apoc.text.format(
    "{{term: %s, eid: [%s]}}",
    [term, apoc.text.join(eids, ',')]
) as {query_output_key}
""",
    input_keys=[CORE_TERMS_KEY],
    output_key=EID_MAPPING_KEY,
    graph=graph,
    output_parser=GraphResultsOutputParser(key=query_output_key),
    verbose=True
)
find_eids_chain

### Run

In [None]:
eid_mapping = find_eids_chain.run(core_terms)
eid_mapping

## Combine chains test

### Combine chains

In [None]:
from langchain.chains import SimpleSequentialChain
chain = SimpleSequentialChain(
    chains=[
        core_terms_chain,
        find_eids_chain
    ],
    verbose=True
)
chain

### Test

In [None]:
output = chain.run(query)
output

# Respond to inquiry given database context

## Get graph reference

In [None]:
from llmlib.database import Neo4j

In [None]:
graph = Neo4j().graph()

### Show schema preview

In [None]:
graph.schema

## Limit schema

Whole schema is too big to fit into request context - limit it to required parts.

### Set utils

In [None]:
result_parser_factory = lambda key: lambda result: list(map(lambda row: row[key], result))
output_result_parser = result_parser_factory('output')

### Get preselected node properties

In [None]:
preselected_node_properties = [ "displayName", "eid" ]
preselected_node_labels = ["Compound", "Reaction", "EnzReaction", "Regulation"]
node_properties_query = """
CALL apoc.meta.data()
YIELD label, other, elementType, type, property
WHERE NOT type = "RELATIONSHIP" AND elementType = "node" AND label in $preselected_node_labels AND property in $preselected_node_properties
WITH label AS nodeLabels, collect({property:property, type:type}) AS properties
RETURN {labels: nodeLabels, properties: properties} AS output
"""
node_properties = output_result_parser(
    graph.query(
        node_properties_query,
        dict(
            preselected_node_properties=preselected_node_properties,
            preselected_node_labels=preselected_node_labels
        )
    )
)
node_properties

### Get preselected relationship properties

In [None]:
preselected_relationship_types = ["CONSUMED_BY", "PRODUCES", "CATALYZES", "REGULATES"]
rel_properties_query = """
CALL apoc.meta.data()
YIELD label, other, elementType, type, property
WHERE NOT type = "RELATIONSHIP" AND elementType = "relationship" AND type in $preselected_relationship_types
WITH label AS nodeLabels, collect({property:property, type:type}) AS properties
RETURN {type: nodeLabels, properties: properties} AS output
"""
relationships_properties = output_result_parser(
    graph.query(
        rel_properties_query,
        dict(
            preselected_relationship_types=preselected_relationship_types,
        )
    )
)
relationships_properties

### Get preselected relationships

In [None]:
rel_query = """
CALL apoc.meta.data()
YIELD label, other, elementType, type, property
WHERE type = "RELATIONSHIP" AND elementType = "node" AND property in $preselected_relationship_types AND label in $preselected_node_labels
UNWIND other AS other_node
WITH label, property, toString(other_node) as other_node
WHERE other_node in $preselected_node_labels
RETURN {start: label, type: property, end: toString(other_node)} AS output
"""
relationships = output_result_parser(
    graph.query(
        rel_query,
        dict(
            preselected_relationship_types=preselected_relationship_types,
            preselected_node_labels=preselected_node_labels
        )
    )
)
relationships

### Compose schema

In [None]:
limited_schema = f"""
Node properties are the following:
{node_properties}
Relationship properties are the following:
{relationships_properties}
The relationships are the following:
{relationships}
"""
limited_schema

### Set schema on the graph

In [None]:
graph.schema = limited_schema
graph.structured_schema = {
    "node_props": {el["labels"]: el["properties"] for el in node_properties},
    "rel_props": {el["type"]: el["properties"] for el in relationships_properties},
    "relationships": relationships,
}
graph.structured_schema

## Prepare prompts

In [None]:
from langchain import OpenAI, PromptTemplate

SCHEMA_KEY = 'schema'
EXAMPLES_KEY = 'examples'

cypher_generation_template = f"""
Task:Generate Cypher statements to query a graph database.
Instructions:
Use only the provided relationship types and properties in the schema.
Do not use any other relationship types or properties that are not provided.
If quering for relationship filter out following eids: ["PROTON"]
Schema:
{{{SCHEMA_KEY}}}
Examples:
{{{EXAMPLES_KEY}}}
Note: Do not include any explanations or apologies in your responses.
Do not respond to any questions that might ask anything else than for you to construct a Cypher statements.
Do not include any text except the generated Cypher statements.
Do not wrap statements in brackets.
Use example queries.

The question is:
{{{QUERY_KEY}}}
"""

cypher_prompt = PromptTemplate(
    input_variables=[SCHEMA_KEY, EXAMPLES_KEY, QUERY_KEY], template=cypher_generation_template
)
cypher_prompt

In [None]:
from langchain.prompts import PromptTemplate
from langchain.chains import GraphCypherQAChain
from langchain.prompts.chat import (
    ChatPromptTemplate,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
    AIMessagePromptTemplate,
)

graph_qa_prompt_template = ChatPromptTemplate.from_messages([
    SystemMessagePromptTemplate(
        prompt=PromptTemplate(
            template="You are an expert in field of biology and chemistry with access to graph database.",
            input_variables=[]
        )
    ),
    HumanMessagePromptTemplate(
        prompt=PromptTemplate(
            template=f"""
                Step 1: For given inquiry identify core terms.
                Step 2: Return mapping from core term to list of database 'eid' for related entities.
                Inqury: {{{QUERY_KEY}}}
            """,
            input_variables=[QUERY_KEY],
        )
    ),
    AIMessagePromptTemplate(
        prompt=PromptTemplate(
            template=f"{{{EID_MAPPING_KEY}}}",
            input_variables=[EID_MAPPING_KEY]
        )
    ),
    HumanMessagePromptTemplate(prompt=cypher_prompt)
])
graph_qa_prompt_template

In [None]:
qa_prompt_template = f"""You are an assistant that helps to form nice and human understandable answers.
The information part contains the provided information that you can use to construct an answer.
Make the answer sound as a response to the question. Do not mention that you based the result on the given information.
Information:
{{context}}

Question:
{{{QUERY_KEY}}}
Helpful Answer:"""
qa_prompt = PromptTemplate(
    input_variables=["context", QUERY_KEY], template=qa_prompt_template
)
qa_prompt

## Create chain

#### Adapt GraphCypherQAChain to accept custom input keys

In [None]:
from llmlib.utils.chains.graph_cypher_qa_chain import GraphCypherQAChain

In [None]:
RESPONSE_KEY = 'response'
graph_qa_chain = GraphCypherQAChain.from_llm(
    ChatOpenAI(temperature=0),
    graph=graph,
    verbose=True,
    cypher_prompt=graph_qa_prompt_template,
    qa_prompt=qa_prompt,
    output_key=RESPONSE_KEY,
    return_intermediate_steps=True
)
graph_qa_chain

### Declare query examples

In [None]:
all_relation_type_str = "|".join(preselected_relationship_types)
max_search_depth = 5
examples=[
    f"""
    // Get shortest path between nodes
    MATCH (a), (b)
    WHERE a.eid in <list_of_term_a_eids> AND b.eid in <list_of_term_b_eids>
    MATCH path=allShortestPaths((a)-[:{all_relation_type_str}*1..{max_search_depth}]-(b))
    RETURN [e in apoc.path.elements(path) | coalesce(e.eid, type(e))] as path LIMIT 1;
    """,
    """
    // Get nodes information
    UNWIND <list_of_eids> as eid
    MATCH (node {{eid: eid}}
    RETURN node;
    """
]
examples

### Declare run parameters

In [None]:
graph_qa_chain_inputs = {
    QUERY_KEY: query,
    EID_MAPPING_KEY: eid_mapping,
    EXAMPLES_KEY: examples
}
graph_qa_chain_inputs

### Run

In [None]:
graph_qa_chain.run(graph_qa_chain_inputs)

### Combine chains

In [None]:
from langchain.chains import SequentialChain
chain = SequentialChain(
    chains=[
        core_terms_chain,
        find_eids_chain,
        graph_qa_chain
    ],
    input_variables=[QUERY_KEY, EXAMPLES_KEY],
    output_variables=[RESPONSE_KEY],
    verbose=True
)
chain

### Test

In [None]:
output = chain.run({
  QUERY_KEY: query,
  EXAMPLES_KEY: examples
})
output

In [None]:
output = chain.run({
  QUERY_KEY: "What is the relationship between Zn2+ and glcD?",
  EXAMPLES_KEY: examples
})
output

In [None]:
output = chain.run({
  QUERY_KEY: "What is the relationship between INHBA and MTMR4?",
  EXAMPLES_KEY: examples
})
output