# Asset Manager Example
In this notebook, let's explore how to leverage Google Generative AI to build and consume a knowledge graph in Neo4j.

This notebook parses Form-13 data From SEC EDGAR. The Form 13 files are semi structured data that are pretty nasty to parse  We'll use generative AI to do it for us.  We will then also use the LLM to generate Cypher statments to load the extracted data into a Neo4j graph.  Then, we'll use a chatbot to query the knowledge graph we've created.

## Setup
First off, check that the Python environment you installed in the readme is running this notebook. Make sure you select the `py38` kernel in the top right of this notebook. You should see a 3.8 version when you run this command.

In [None]:
import sys
sys.version

Next we install and import some libraries 

In [None]:
%pip install --user graphdatascience
%pip install --user langchain  # library for combining functional steps around LLM calls
%pip install --user google-cloud-aiplatform  # library for accessing VertexAI
%pip install --user gradio  # for building the chat interface
%pip install --user google.cloud
%pip install --user numpy

You probably want to restart your kernel here to ensure everything is loaded.

In [None]:
import IPython
app = IPython.Application.instance()
app.kernel.do_shutdown(True)

## Prompt Definition

In the upcoming sections, we will extract knowledge adhering to the following schema. This is a very Simplified schema to denote investment management entities and companies they own through common stock. Normally, you will have Domain Experts who come up with a richer data model, and you can extend the below to work on more data/forms to fill such a model.

To achieve our Extraction goal as per the schema, We will use a series of prompts, each focused on only one task - to extract a specific entity. By this way, you can go for more granular extraction. The prompts I used here can be improved and in production scenario, you should consider running QA on the prompt pipelines to ensure that the extracted information is correct.

Let's go in this order to gather the data in accordance to out data model:

1. Extract Manager Information
2. Extract Filing Information

In [1]:
mgr_info_tpl = """From the text below, extract the following as json. Do not miss any of these information.
* The tags mentioned below may or may not namespaced. So extract accordingly. Eg: <ns1:tag> is equal to <tag>
* "name" - The name from the <name> tag under <filingManager> tag
* "street1" - The manager's street1 address from the <com:street1> tag under <address> tag
* "street2" - The manager's street2 address from the <com:street2> tag under <address> tag
* "city" - The manager's city address from the <com:city> tag under <address> tag
* "stateOrCounty" - The manager's stateOrCounty address from the <com:stateOrCountry> tag under <address> tag
* "zipCode" - The manager's zipCode from the <com:zipCode> tag under <address> tag
* "reportCalendarOrQuarter" - The reportCalendarOrQuarter from the <reportCalendarOrQuarter> tag under <address> tag
* Just return me the JSON enclosed by 3 backticks. No other text in the response

Text:
$ctext
"""

In [2]:
filing_info_tpl = """The text below contains a list of investments. Each instance of <infoTable> tag represents a unique investment. 
For each investment, please extract the below variables into json then combine into a list enclosed by 3 back ticks. Please use the quated names below while doing this
* "cusip" - The cusip from the <cusip> tag under <infoTable> tag
* "companyName" - The name under the <nameOfIssuer> tag.
* "value" - The value from the <value> tag under <infoTable> tag. Return as a number. 
* "shares" - The sshPrnamt from the <sshPrnamt> tag under <infoTable> tag. Return as a number. 
* "sshPrnamtType" - The sshPrnamtType from the <sshPrnamtType> tag under <infoTable> tag
* "investmentDiscretion" - The investmentDiscretion from the <investmentDiscretion> tag under <infoTable> tag
* "votingSole" - The votingSole from the <votingSole> tag under <infoTable> tag
* "votingShared" - The votingShared from the <votingShared> tag under <infoTable> tag
* "votingNone" - The votingNone from the <votingNone> tag under <infoTable> tag

Text:
$ctext
"""

## Functions for Using LLMs
Let's create some helper function to talk to the LLM with our prompt and text input. We will use the text-bison base model. In your use case, you might need to tune it. VertexAI provides an elegant way to finetune it. The weights will be staying within your tenant and the base model is frozen.

First off, you'll need to set your project_id.

In [3]:
project_id = 'neo4jbusinessdev'
location = 'us-central1'

In [4]:
# wrapper for calling language model
def run_text_model(
    project_id: str,
    model_name: str,
    temperature: float,
    max_decode_steps: int,
    top_p: float,
    top_k: int,
    prompt: str,
    location: str = location,
    tuned_model_name: str = None,
    ) :
    """Text Completion Use a Large Language Model."""
    vertexai.init(project=project_id, location=location)
    if tuned_model_name is None:
        model = TextGenerationModel.from_pretrained(model_name)
    else:
        model = model.get_tuned_model(tuned_model_name)
    response = model.predict(
        prompt,
        temperature=temperature,
        max_output_tokens=max_decode_steps,
        top_k=top_k,
        top_p=top_p,)
    return response.text

In [5]:
# wrapper for entity extraction / parsing
def extract_entities_relationships(prompt, tuned_model_name=None):
    try:
        res = run_text_model(project_id, "text-bison@001", 0, 1024, 0.8, 1, prompt, location, tuned_model_name)
        return res
    except Exception as e:
        print(e)

In [6]:
# splitting function for chunking up filing information to avoid hitting LLM token limits
import re
import numpy as np

def split_filing_info(s, chunk_size=5):
    pattern = '(</(\w+:)?infoTable>)'
    splitter = re.findall(pattern, s)[0][0]
    _parts = s.split(splitter)
    if len(_parts) > chunk_size:
        chunks_of_list = np.array_split(_parts, len(_parts)/chunk_size) # max 5 filings per part
        chunks_of_str = map(lambda x: splitter.join(x), chunks_of_list)
        return list(chunks_of_str)
    else:
        return [s]

## Test Example for Parsing
Let's start with one Form 13 file to see how we can parse it with generative AI.

In [7]:
from google.cloud import storage

storage_client = storage.Client()
bucket = storage_client.bucket('neo4j-datasets')
blob = bucket.blob('form13/raw/raw_2023-05-15_archives_edgar_data_1027451_0000919574-23-003245.txt')

inp_text = blob.download_as_string().decode()

In [8]:
print(inp_text[:1500])

<SEC-DOCUMENT>0000919574-23-003245.txt : 20230515
<SEC-HEADER>0000919574-23-003245.hdr.sgml : 20230515
<ACCEPTANCE-DATETIME>20230515103943
ACCESSION NUMBER:		0000919574-23-003245
CONFORMED SUBMISSION TYPE:	13F-HR
PUBLIC DOCUMENT COUNT:		2
CONFORMED PERIOD OF REPORT:	20230331
FILED AS OF DATE:		20230515
DATE AS OF CHANGE:		20230515
EFFECTIVENESS DATE:		20230515

FILER:

	COMPANY DATA:	
		COMPANY CONFORMED NAME:			TIGER MANAGEMENT L.L.C.
		CENTRAL INDEX KEY:			0001027451
		IRS NUMBER:				000000000
		STATE OF INCORPORATION:			DE

	FILING VALUES:
		FORM TYPE:		13F-HR
		SEC ACT:		1934 Act
		SEC FILE NUMBER:	028-05892
		FILM NUMBER:		23919492

	BUSINESS ADDRESS:	
		STREET 1:		101 PARK AVENUE
		CITY:			NEW YORK
		STATE:			NY
		ZIP:			10178
		BUSINESS PHONE:		212-984-2500

	MAIL ADDRESS:	
		STREET 1:		101 PARK AVENUE
		CITY:			NEW YORK
		STATE:			NY
		ZIP:			10178

	FORMER COMPANY:	
		FORMER CONFORMED NAME:	TIGER MANAGEMENT LLC/NY
		DATE OF NAME CHANGE:	20010606
</SEC-HEADER>
<DOCUMENT>
<TYPE>

In [9]:
contents = inp_text.split('<XML>')
manager_info = contents[1].split('</XML>')[0].strip()
filing_info = contents[2].split('</XML>')[0].strip()

## Parsing Manager Information

In [10]:
import vertexai
vertexai.init(project=project_id, location=location)

In [11]:
from string import Template

prompt = Template(mgr_info_tpl).substitute(ctext=manager_info)
print(prompt)

From the text below, extract the following as json. Do not miss any of these information.
* The tags mentioned below may or may not namespaced. So extract accordingly. Eg: <ns1:tag> is equal to <tag>
* "name" - The name from the <name> tag under <filingManager> tag
* "street1" - The manager's street1 address from the <com:street1> tag under <address> tag
* "street2" - The manager's street2 address from the <com:street2> tag under <address> tag
* "city" - The manager's city address from the <com:city> tag under <address> tag
* "stateOrCounty" - The manager's stateOrCounty address from the <com:stateOrCountry> tag under <address> tag
* "zipCode" - The manager's zipCode from the <com:zipCode> tag under <address> tag
* "reportCalendarOrQuarter" - The reportCalendarOrQuarter from the <reportCalendarOrQuarter> tag under <address> tag
* Just return me the JSON enclosed by 3 backticks. No other text in the response

Text:
<?xml version="1.0" encoding="UTF-8"?>
<edgarSubmission xsi:schemaLocati

Now, let's use the LLM to parse out the data we want.

In [12]:
import json
from vertexai.language_models import TextGenerationModel

manager_data = json.loads(extract_entities_relationships(prompt).split('```')[1].strip('json'))
manager_data

{'name': 'TIGER MANAGEMENT L.L.C.',
 'street1': '101 PARK AVENUE',
 'street2': '',
 'city': 'NEW YORK',
 'stateOrCounty': 'NY',
 'zipCode': '10178',
 'reportCalendarOrQuarter': '03-31-2023'}

## Parse Filing Information
We will parse filing info in a similar manner to manager information. Because the filings include a list of many entries however, we will want to split the input into chunks so as not to exceed input or output token limits. 

In [13]:
filing_info_chunks = split_filing_info(filing_info)
len(filing_info_chunks)

6

In [14]:
prompt = Template(filing_info_tpl).substitute(ctext=filing_info_chunks[0])
response = extract_entities_relationships(prompt)
print(response)

  ```
[
  {
    "cusip": "02079K305",
    "companyName": "ALPHABET INC",
    "value": 1488692,
    "shares": 14352,
    "sshPrnamtType": "SH",
    "investmentDiscretion": "SOLE",
    "votingSole": 14352,
    "votingShared": 0,
    "votingNone": 0
  },
  {
    "cusip": "023135106",
    "companyName": "AMAZON COM INC",
    "value": 582556,
    "shares": 5640,
    "sshPrnamtType": "SH",
    "investmentDiscretion": "SOLE",
    "votingSole": 5640,
    "votingShared": 0,
    "votingNone": 0
  },
  {
    "cusip": "049468101",
    "companyName": "ATLASSIAN CORPORATION",
    "value": 205404,
    "shares": 1200,
    "sshPrnamtType": "SH",
    "investmentDiscretion": "SOLE",
    "votingSole": 1200,
    "votingShared": 0,
    "votingNone": 0
  },
  {
    "cusip": "053332102",
    "companyName": "AUTOZONE INC",
    "value": 9218063,
    "shares": 3750,
    "sshPrnamtType": "SH",
    "investmentDiscretion": "SOLE",
    "votingSole": 3750,
    "votingShared": 0,
    "votingNone": 0
  },
  {
    "cusi

## Data Ingestion - Test Example

Let's walk through the steps to do this with just the 1 form above first, then we can move on to parsing and ingesting all the Form 13s.

To start, we can run the LLM parsing over all the filing info from the form then combine the resulting json into a list condusive for Neo4j loading.

In [15]:
filings_list = []
for filing_info_chunk in filing_info_chunks:
    prompt = Template(filing_info_tpl).substitute(ctext=filing_info_chunk)
    response = extract_entities_relationships(prompt)
    filings_list.extend(json.loads(response.replace('```', '')))

for item in filings_list:
    item['managerName'] = manager_data['name']
    item['reportCalendarOrQuarter'] = manager_data['reportCalendarOrQuarter']
filings_list[:5]

[{'cusip': '02079K305',
  'companyName': 'ALPHABET INC',
  'value': 1488692,
  'shares': 14352,
  'sshPrnamtType': 'SH',
  'investmentDiscretion': 'SOLE',
  'votingSole': 14352,
  'votingShared': 0,
  'votingNone': 0,
  'managerName': 'TIGER MANAGEMENT L.L.C.',
  'reportCalendarOrQuarter': '03-31-2023'},
 {'cusip': '023135106',
  'companyName': 'AMAZON COM INC',
  'value': 582556,
  'shares': 5640,
  'sshPrnamtType': 'SH',
  'investmentDiscretion': 'SOLE',
  'votingSole': 5640,
  'votingShared': 0,
  'votingNone': 0,
  'managerName': 'TIGER MANAGEMENT L.L.C.',
  'reportCalendarOrQuarter': '03-31-2023'},
 {'cusip': '049468101',
  'companyName': 'ATLASSIAN CORPORATION',
  'value': 205404,
  'shares': 1200,
  'sshPrnamtType': 'SH',
  'investmentDiscretion': 'SOLE',
  'votingSole': 1200,
  'votingShared': 0,
  'votingNone': 0,
  'managerName': 'TIGER MANAGEMENT L.L.C.',
  'reportCalendarOrQuarter': '03-31-2023'},
 {'cusip': '053332102',
  'companyName': 'AUTOZONE INC',
  'value': 9218063,


In [16]:
len(filings_list)

33

## Establish Neo4j Connection
Now, we're going to load data into Neo4j.  To do so, you will, of course, need a Neo4j instance.  The easiest way to get started with Neo4j on Google Cloud is with Aura, the Neo4j managed service.  That comes in a few flavors, various combinations of Professional and Enterprise, conflated with DB and DS. Those stand for database and data science.  The data science version includes the database as well.

For our purposes, we want a Neo4j AuraDS Professional.  You can deploy that from the Marketplace [here](https://console.cloud.google.com/marketplace/product/endpoints/prod.n4gcp.neo4j.io).

When you deploy select "1,000,000" nodes and "2,000,000" relationships.  Select "Node Embedding" for algorithms.  That should give you a good instance for this work.

In [19]:
# You will need to change these to match your credentials
NEO4J_URI = 'neo4j+s://xxxxx.databases.neo4j.io'
NEO4J_PASSWORD = 'your password'

# You can leave this as is.
NEO4J_USERNAME = 'neo4j'

Now, let's create a connection to the database using the Graph Data Science API.  Note this will not work with an AuraDB instance since it does not include Graph Data Science.

In [20]:
from graphdatascience import GraphDataScience

gds = GraphDataScience(
    NEO4J_URI,
    auth=(NEO4J_USERNAME, NEO4J_PASSWORD),
    aura_ds=True
)
gds.set_database('neo4j')

Before loading we should create uniqueness constraints for nodes.  This acts as a unique id and an index and is necessary for fast, efficient, queries.  In general, if you notice ingestion is super slow (and getting slower) with Neo4j, double check that you created indexes.  For this small sample it won't matter, but it will certainly imact as we ingest more data

In [21]:
gds.run_cypher('''CREATE CONSTRAINT unique_company_id IF NOT EXISTS FOR (p:Company) REQUIRE (p.cusip) IS NODE KEY;''')
gds.run_cypher('''CREATE CONSTRAINT unique_manager IF NOT EXISTS FOR (p:Manager) REQUIRE (p.managerName) IS NODE KEY;''')

To Merge in the data we can use parameterized cypher queries.  Basically we are going to send filings in batches (in this sample case just one batch) for each node and relationship type and insert them as parameters in the query.

In [22]:
# Create Company Nodes

gds.run_cypher('''
UNWIND $records AS record
MERGE (c:Company {cusip: record.cusip})
SET c.name = record.name
RETURN count(c) AS company_node_merge_count
''', params={'records':filings_list})

Unnamed: 0,company_node_merge_count
0,33


In [23]:
# Create Manager Node. In this case we just have one

gds.run_cypher('''
MERGE (m:Manager {managerName: $managerName})
''', params={'managerName':manager_data['name']})

In [24]:
# Create OWNS Relationship

gds.run_cypher('''
UNWIND $records AS record
MATCH (m:Manager {managerName: record.managerName})
MATCH (c:Company {cusip: record.cusip})
MERGE(m)-[r:OWNS]->(c)
SET r.reportCalendarOrQuarter = record.reportCalendarOrQuarter,
    r.value = record.value,
    r.shares = record.shares
RETURN count(r) AS owns_relationship_merge_count
''', params={'records':filings_list})

Unnamed: 0,owns_relationship_merge_count
0,0


You can now load Neo4j Browser through the Aura GUI to check the graph and see the loaded sample data.

## Ingest Multiple Form 13 Files
We will make a pipeline using the methods above.  In this case we will take a two-step approach, first parse all the data, then chunk that data and ingest into Neo4j.

To start, lets take a look at all the Form 13 samples we have in cloud storage. 

In [25]:
blobs = storage_client.list_blobs('neo4j-datasets', prefix='form13/raw/', delimiter='/')
file_names = [blob.name for blob in blobs if '.txt' in blob.name]
print(f'{len(file_names)} total raw form13 files')

44142 total raw form13 files


For our purposes, let's just use the first 5 in this list.

In [26]:
sample_file_names = file_names[:5]
sample_file_names

['form13/raw/raw_2022-01-03_archives_edgar_data_1026200_0001567619-22-000057.txt',
 'form13/raw/raw_2022-01-03_archives_edgar_data_1315339_0001315339-22-000001.txt',
 'form13/raw/raw_2022-01-03_archives_edgar_data_1384943_0001384943-22-000001.txt',
 'form13/raw/raw_2022-01-03_archives_edgar_data_1452208_0001104659-22-000174.txt',
 'form13/raw/raw_2022-01-03_archives_edgar_data_1624809_0001624809-22-000001.txt']

In [27]:
#helper function for getting filing info

def get_manager_and_filing_info(raw_txt):
    contents = raw_txt.split('<XML>')
    manager_info = contents[1].split('</XML>')[0].strip()
    filing_info = contents[2].split('</XML>')[0].strip()
    
    return manager_info, filing_info

In [28]:
print(f'=== Parsing {len(sample_file_names)} Form 13 Files ===')

filings_list = []
manager_list = []

for file_name in sample_file_names:
    
    print(f'--- parsing {file_name} ---')
    try:
        #Get raw form13 file
        print('getting file text from gcloud....')
        blob = bucket.blob(file_name)
        raw_text = blob.download_as_string().decode()

        #Get raw manager and filing info from file
        print('getting file contents...')
        manager_info, filing_info = get_manager_and_filing_info(raw_text)

        #Parse manager info into dict using LLM
        print('Parsing submission and manager info...')
        mng_prompt = Template(mgr_info_tpl).substitute(ctext=manager_info)
        mng_response = extract_entities_relationships(mng_prompt)
        manager_data = json.loads(mng_response.replace('```', ''))
        manager_list.append({'name': manager_data['name']})

        #Parse filing info into list of dicts using LLM
        print('Parsing filing info...')
        for filing_info_chunk in filing_info_chunks:
            filing_prompt = Template(filing_info_tpl).substitute(ctext=filing_info_chunk)
            filing_response = extract_entities_relationships(filing_prompt)
            filings_list.extend(json.loads(filing_response.replace('```', '')))
        for item in filings_list: #Add information from manager_info to enable OWNS relationship loading
            if len(manager_data['name']) == 0:
                item['managerName'] = "Private Manager"
            else:
                item['managerName'] = manager_data['name']
            item['reportCalendarOrQuarter'] = manager_data['reportCalendarOrQuarter']
    except Exception as e:
        raise e


=== Parsing 5 Form 13 Files ===
--- parsing form13/raw/raw_2022-01-03_archives_edgar_data_1026200_0001567619-22-000057.txt ---
getting file text from gcloud....
getting file contents...
Parsing submission and manager info...
Parsing filing info...
--- parsing form13/raw/raw_2022-01-03_archives_edgar_data_1315339_0001315339-22-000001.txt ---
getting file text from gcloud....
getting file contents...
Parsing submission and manager info...
Parsing filing info...
--- parsing form13/raw/raw_2022-01-03_archives_edgar_data_1384943_0001384943-22-000001.txt ---
getting file text from gcloud....
getting file contents...
Parsing submission and manager info...
Parsing filing info...
--- parsing form13/raw/raw_2022-01-03_archives_edgar_data_1452208_0001104659-22-000174.txt ---
getting file text from gcloud....
getting file contents...
Parsing submission and manager info...
Parsing filing info...
--- parsing form13/raw/raw_2022-01-03_archives_edgar_data_1624809_0001624809-22-000001.txt ---
getting f

In [29]:
# Merge Manager Nodes.

gds.run_cypher('''
UNWIND $records AS record
MERGE (m:Manager {managerName: coalesce(record.managerName,"Private Manager")})
RETURN count(m) AS manager_node_merge_count
''', params={'records':manager_list})

Unnamed: 0,manager_node_merge_count
0,5


In [30]:
len(filings_list)

165

In [31]:
# as the dataset gets bigger we will want to chunk up the filings we send to Neo4j

def chunks(xs, n=10_000):
    n = max(1, n)
    return [xs[i:i + n] for i in range(0, len(xs), n)]

In [32]:
# Creat Company Nodes

for d in chunks(filings_list):
    res = gds.run_cypher('''
    UNWIND $records AS record
    MERGE (c:Company {cusip: record.cusip})
    SET c.name = record.name
    RETURN count(c) AS company_node_merge_count
    ''', params={'records':d})
    print(res)

   company_node_merge_count
0                       165


In [33]:
# Create the OWNS Relationships

for d in chunks(filings_list):
    res = gds.run_cypher('''
    UNWIND $records AS record
    MATCH (m:Manager {name: record.managerName})
    MATCH (c:Company {cusip: record.cusip})
    MERGE(m)-[r:OWNS]->(c)
    SET r.reportCalendarOrQuarter = record.reportCalendarOrQuarter,
        r.value = record.value,
        r.shares = record.shares
    RETURN count(r) AS owns_relationship_merge_count
    ''', params={'records':d})
    print(res)

   owns_relationship_merge_count
0                              0


That's it!  You've used an LLM to parse and load data into Neo4j, creating a knowledge graph.  At this point, you may want to open Neo4j Browser or Bloom through the Aura console and explore the data you've loaded.

## Chatbot
We're now going to show how to layer a chat bot on top of the knowledge graph we created.

## Load more data
We showed how to load data with the LLM.  Now we're going to load some more rows from a CSV so our chatbot has more to work with.  This is going to take a while to run, ~15 minutes or so.

In [34]:
gds.run_cypher('''LOAD CSV WITH HEADERS FROM 'https://storage.googleapis.com/neo4j-datasets/form13/form13-v2.csv' AS row
MERGE (c:Company {cusip:row.cusip})
ON CREATE SET c.companyName=row.companyName;''')

In [35]:
gds.run_cypher('''LOAD CSV WITH HEADERS FROM 'https://storage.googleapis.com/neo4j-datasets/form13/form13-v2.csv' AS row
MERGE (m:Manager {managerName:row.managerName});''')

In [36]:
gds.run_cypher('''LOAD CSV WITH HEADERS FROM 'https://storage.googleapis.com/neo4j-datasets/form13/form13-v2.csv' AS row
MATCH (m:Manager {managerName:row.managerName})
MATCH (c:Company {cusip:row.cusip})
MERGE (m)-[r:OWNS {reportCalendarOrQuarter:date(row.reportCalendarOrQuarter)}]->(c)
SET r.value = toFloat(row.value), r.shares = toInteger(row.shares);''')

## Cypher Generation
We have to use a prompt template that clearly states what schema to use, the principles that the chatbot should follow in generating responses, and some few-shot examples to help the chatbot be more accurate in its query generation.

In [39]:
#prompt/template 
CYPHER_GENERATION_TEMPLATE = """You are an expert Neo4j Cypher translator who understands the question in english and convert to Cypher strictly based on the Neo4j Schema provided and following the instructions below:
1. Generate Cypher query compatible ONLY for Neo4j Version 5
2. Do not use EXISTS, SIZE keywords in the cypher. Use alias when using the WITH keyword
3. Please do not use same variable names for different nodes and relationships in the query.
4. Use only Nodes and relationships mentioned in the schema
5. Always enclose the Cypher output inside 3 backticks
6. Always do a case-insensitive and fuzzy search for any properties related search. Eg: to search for a Company name use `toLower(c.name) contains 'neo4j'`
7. Candidate node is synonymous to Manager
8. Always use aliases to refer the node in the query
9. 'Answer' is NOT a Cypher keyword. Answer should never be used in a query.
10. Please generate only one Cypher query per question. 
11. Cypher is NOT SQL. So, do not mix and match the syntaxes.
12. Every Cypher query always starts with a MATCH keyword.

Schema:
{schema}
Samples:
Question: Which fund manager owns most shares? What is the total portfolio value?
Answer: MATCH (m:Manager) -[o:OWNS]-> (c:Company) RETURN m.managerName as manager, sum(distinct o.shares) as ownedShares, sum(o.value) as portfolioValue ORDER BY ownedShares DESC LIMIT 10

Question: Which fund manager owns most companies? How many shares?
Answer: MATCH (m:Manager) -[o:OWNS]-> (c:Company) RETURN m.managerName as manager, count(distinct c) as ownedCompanies, sum(distinct o.shares) as ownedShares ORDER BY ownedCompanies DESC LIMIT 10

Question: What are the top 10 investments for Vanguard?
Answer: MATCH (m:Manager) -[o:OWNS]-> (c:Company) WHERE toLower(m.managerName) contains "vanguard" RETURN c.companyName as Investment, sum(DISTINCT o.shares) as totalShares, sum(DISTINCT o.value) as investmentValue order by investmentValue desc limit 10

Question: What other fund managers are investing in same companies as Vanguard?
Answer: MATCH (m1:Manager) -[:OWNS]-> (c1:Company) <-[o:OWNS]- (m2:Manager) WHERE toLower(m1.managerName) contains "vanguard" AND elementId(m1) <> elementId(m2) RETURN m2.managerName as manager, sum(DISTINCT o.shares) as investedShares, sum(DISTINCT o.value) as investmentValue ORDER BY investmentValue LIMIT 10

Question: What are the top investors for Apple?
Answer: MATCH (m1:Manager) -[o:OWNS]-> (c1:Company) WHERE toLower(c1.companyName) contains "apple" RETURN distinct m1.managerName as manager, sum(o.value) as totalInvested ORDER BY totalInvested DESC LIMIT 10

Question: What are the other top investments for fund managers investing in Apple?
Answer: MATCH (c1:Company) <-[:OWNS]- (m1:Manager) -[o:OWNS]-> (c2:Company) WHERE toLower(c1.companyName) contains "apple" AND elementId(c1) <> elementId(c2) RETURN DISTINCT c2.companyName as company, sum(o.value) as totalInvested, sum(o.shares) as totalShares ORDER BY totalInvested DESC LIMIT 10

Question: What are the top investors in the last 3 months?
Answer: MATCH (m:Manager) -[o:OWNS]-> (c:Company) WHERE date() > o.reportCalendarOrQuarter > o.reportCalendarOrQuarter - duration({{months:3}}) RETURN distinct m.managerName as manager, sum(o.value) as totalInvested, sum(o.shares) as totalShares ORDER BY totalInvested DESC LIMIT 10

Question: What are top investments in last 6 months for Vanguard?
Answer: MATCH (m:Manager) -[o:OWNS]-> (c:Company) WHERE toLower(m.managerName) contains "vanguard" AND date() > o.reportCalendarOrQuarter > date() - duration({{months:6}}) RETURN distinct c.companyName as company, sum(o.value) as totalInvested, sum(o.shares) as totalShares ORDER BY totalInvested DESC LIMIT 10

Question: Who are Apple's top investors in last 3 months?
Answer: MATCH (m:Manager) -[o:OWNS]-> (c:Company) WHERE toLower(c.companyName) contains "apple" AND date() > o.reportCalendarOrQuarter > date() - duration({{months:3}}) RETURN distinct m.managerName as investor, sum(o.value) as totalInvested, sum(o.shares) as totalShares ORDER BY totalInvested DESC LIMIT 10

Question: Which fund manager under 200 million has similar investment strategy as Vanguard?
Answer: MATCH (m1:Manager) -[o1:OWNS]-> (:Company) <-[o2:OWNS]- (m2:Manager) WHERE toLower(m1.managerName) CONTAINS "vanguard" AND elementId(m1) <> elementId(m2) WITH distinct m2 AS m2, sum(distinct o2.value) AS totalVal WHERE totalVal < 200000000 RETURN m2.managerName AS manager, totalVal*0.000001 AS totalVal ORDER BY totalVal DESC LIMIT 10

Question: Who are common investors in Apple and Amazon?
Answer: MATCH (c1:Company) <-[:OWNS]- (m:Manager) -[:OWNS]-> (c2:Company) WHERE toLower(c1.companyName) contains "apple" AND toLower(c2.companyName) CONTAINS "amazon" RETURN DISTINCT m.managerName LIMIT 50

Question: What are Vanguard's top investments by shares for 2023?
Answer: MATCH (m:Manager) -[o:OWNS]-> (c:Company) WHERE toLower(m.managerName) CONTAINS "vanguard" AND date({{year:2023}}) = date.truncate('year',o.reportCalendarOrQuarter) RETURN c.companyName AS investment, sum(o.value) AS totalValue ORDER BY totalValue DESC LIMIT 10

Question: What are Vanguard's top investments by value for 2023?
Answer: MATCH (m:Manager) -[o:OWNS]-> (c:Company) WHERE toLower(m.managerName) CONTAINS "vanguard" AND date({{year:2023}}) = date.truncate('year',o.reportCalendarOrQuarter) RETURN c.companyName AS investment, sum(o.shares) AS totalShares ORDER BY totalShares DESC LIMIT 10

Question: Which managers own FAANG stocks?
Answer: MATCH (m:Manager)-[o:OWNS]->(c:Company) WHERE toLower(c.companyName) IN [toLower("Facebook"),toLower("Apple"),toLower("Amazon"),toLower("Netflix"),toLower("Google")] RETURN m.managerName as manager, collect(distinct c.companyName) as companies

Question: {question}
Answer: 
"""

Create a LangChain prompt template.  This defines the inputs that will be included as parameters into the prompt sent to the Cypher generation bot.  In our example, the inputs will be `schema` and `question`.  The question comes from the end user.  The schema is automatically inserted by the LangChain `GraphCypherQAChain` via a built in method to `Neo4jGraph`.

In [40]:
from langchain.prompts.prompt import PromptTemplate

CYPHER_GENERATION_PROMPT = PromptTemplate(
    input_variables=["schema","question"], validate_template=True, template=CYPHER_GENERATION_TEMPLATE
)

We need to connect to the graph via LangChain.

In [41]:
from langchain.graphs import Neo4jGraph

graph = Neo4jGraph(
    url = NEO4J_URI, 
    username = NEO4J_USERNAME, 
    password = NEO4J_PASSWORD
)

We are defining our `chain` object, which combines Neo4j Q/A and VertexAI's `code-bison` LLM.  When the user gives a query, it first goes through `GraphCypherQAChain`, which generates a Cypher query according to the rules laid out in our prompt above.  That result set then goes to the `VertexAI` step of the chain, where the LLM is given the Neo4j result set and instructed to roll it into a natural language response.

In [43]:
from langchain.chains import GraphCypherQAChain
from langchain.llms import VertexAI

chain = GraphCypherQAChain.from_llm(
    VertexAI(model_name='code-bison',
            max_output_tokens=2048,
            temperature=0,
            top_p=0.95,
            top_k=0.40), graph=graph, verbose=True,
            cypher_prompt=CYPHER_GENERATION_PROMPT,
    return_intermediate_steps=True
)

Below we have a few examples of how we can get answers from the chatbot.

In [44]:
r2 = chain("""What are the top 10 investments for Blackrock?""")
print(f"Final answer: {r2['result']}")



[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3m
MATCH (m:Manager) -[o:OWNS]-> (c:Company) WHERE toLower(m.managerName) contains "blackrock" RETURN c.companyName as Investment, sum(DISTINCT o.shares) as totalShares, sum(DISTINCT o.value) as investmentValue order by investmentValue desc limit 10
[0m
Full Context:
[32;1m[1;3m[{'Investment': None, 'totalShares': 4886845172, 'investmentValue': 608138653985000.0}, {'Investment': 'Apple Inc', 'totalShares': 3084362729, 'investmentValue': 304538995305000.0}, {'Investment': 'Johnson & Johnson', 'totalShares': 600132459, 'investmentValue': 66382342451000.0}, {'Investment': 'EXXON MOBIL CORP', 'totalShares': 849919665, 'investmentValue': 62624867042000.0}, {'Investment': 'TESLA INC', 'totalShares': 528248761, 'investmentValue': 59003867451000.0}, {'Investment': 'JPMORGAN CHASE & CO', 'totalShares': 586443138, 'investmentValue': 51420867302000.0}, {'Investment': 'PROCTER AND GAMBLE CO', 'totalShares': 483314

In [45]:
r3 = chain("""What are other top investments for fund managers investing in AstraZeneca?""")
print(f"Final answer: {r3['result']}")



[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3m
MATCH (c1:Company) <-[:OWNS]- (m1:Manager) -[o:OWNS]-> (c2:Company) WHERE toLower(c1.companyName) contains "astrazeneca" AND elementId(c1) <> elementId(c2) RETURN DISTINCT c2.companyName as company, sum(o.value) as totalInvested, sum(o.shares) as totalShares ORDER BY totalInvested DESC LIMIT 10
[0m
Full Context:
[32;1m[1;3m[{'company': None, 'totalInvested': 690559809795000.0, 'totalShares': 7869502203}, {'company': 'SPDR S&P 500 ETF TR', 'totalInvested': 327922309895000.0, 'totalShares': 1229482000}, {'company': 'Apple Inc', 'totalInvested': 241033252745000.0, 'totalShares': 2916923878}, {'company': 'ISHARES TR', 'totalInvested': 104442565293000.0, 'totalShares': 1463346482}, {'company': 'TESLA INC CALL', 'totalInvested': 93992366480000.0, 'totalShares': 864643000}, {'company': 'INVESCO QQQ TR PUT', 'totalInvested': 78031957994000.0, 'totalShares': 373626750}, {'company': 'TESLA INC', 'totalInveste

In [46]:
r4 = chain("""Which fund manager under 200 million has similar investment strategy as Blackrock""")
print(f"Final answer: {r4['result']}")



[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3m
MATCH (m1:Manager) -[o1:OWNS]-> (:Company) <-[o2:OWNS]- (m2:Manager) WHERE toLower(m1.managerName) CONTAINS "blackrock" AND elementId(m1) <> elementId(m2) WITH distinct m2 AS m2, sum(distinct o2.value) AS totalVal WHERE totalVal < 200000000 RETURN m2.managerName AS manager, totalVal*0.000001 AS totalVal ORDER BY totalVal DESC LIMIT 10
[0m
Full Context:
[32;1m[1;3m[{'manager': 'LAKE STREET ADVISORS GROUP, LLC', 'totalVal': 197.487}, {'manager': 'INCA Investments LLC', 'totalVal': 197.18599999999998}, {'manager': 'M Holdings Securities, Inc.', 'totalVal': 194.481}, {'manager': 'Avalon Global Asset Management LLC', 'totalVal': 194.107}, {'manager': 'King Wealth', 'totalVal': 192.66299999999998}, {'manager': 'TRUSTEES OF THE UNIVERSITY OF PENNSYLVANIA', 'totalVal': 190.523}, {'manager': 'Crestview Partners IV GP, L.P.', 'totalVal': 190.332}, {'manager': 'Virtu Financial LLC', 'totalVal': 188.94899999999

In [47]:
r5 = chain("""Please get me 10 common investors between Tesla and Microsoft""")
print(f"Final answer: {r5['result']}")



[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3m
MATCH (c1:Company) <-[:OWNS]- (m:Manager) -[:OWNS]-> (c2:Company) WHERE toLower(c1.companyName) contains "tesla" AND toLower(c2.companyName) CONTAINS "microsoft" RETURN DISTINCT m.managerName LIMIT 10
[0m
Full Context:
[32;1m[1;3m[{'m.managerName': 'GROUP ONE TRADING, L.P.'}, {'m.managerName': 'JANE STREET GROUP, LLC'}, {'m.managerName': 'Optiver Holding B.V.'}, {'m.managerName': 'IMC-Chicago, LLC'}, {'m.managerName': 'CTC LLC'}, {'m.managerName': 'ADVISOR GROUP HOLDINGS, INC.'}, {'m.managerName': 'CIBC WORLD MARKETS CORP'}][0m

[1m> Finished chain.[0m
Final answer:  Here are 10 common investors between Tesla and Microsoft:
1. Baillie Gifford & Co.
2. BlackRock Inc.
3. Capital Research & Management Co. (World Investors)
4. Geode Capital Management LLC
5. Harris Associates
6. Invesco Ltd.
7. JPMorgan Chase & Co.
8. Northern Trust Corp.
9. State Street Corp.
10. The Vanguard Group, Inc. 


In [48]:
r6 = chain("""Which managers own FAANG stocks?""")
print(f"Final answer: {r6['result']}")



[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3m
MATCH (m:Manager)-[o:OWNS]->(c:Company) WHERE toLower(c.companyName) IN [toLower("Facebook"),toLower("Apple"),toLower("Amazon"),toLower("Netflix"),toLower("Google")] RETURN m.managerName as manager, collect(distinct c.companyName) as companies
[0m
Full Context:
[32;1m[1;3m[{'manager': 'Beacon Wealthcare LLC', 'companies': ['Apple']}, {'manager': 'Pinnacle Holdings, LLC', 'companies': ['Google']}][0m

[1m> Finished chain.[0m
Final answer:  Beacon Wealthcare LLC and Pinnacle Holdings, LLC 


# Chatbot
Now we are going to use Gradio to deploy a chat interface that will have our chain behind it.

When we run the code below, a Gradio application will be deployed and can be accessed at a local URL.  We also get a public URL that can be shared for 3 days.

In [None]:
import gradio as gr
import typing_extensions
from langchain.memory import ConversationBufferMemory

memory = ConversationBufferMemory(memory_key = "chat_history", return_messages = True)
agent_chain = chain

def chat_response(input_text,history):
    try:
        return agent_chain.run(input_text)
    except:
        return "I'm sorry, there was an error retrieving the information you requested."

interface = gr.ChatInterface(fn = chat_response,
                             title = "Investment Chatbot",
                             description = "powered by Neo4j",
                             theme = "soft",
                             chatbot = gr.Chatbot(height=500),
                             undo_btn = None,
                             clear_btn = "\U0001F5D1 Clear chat",
                             examples = ["Who are Tesla's top investors in last 3 months?",
                                         "What are the top 10 investments for Blackrock?",
                                         "Which manager owns FAANG stocks?",
                                         "What are other top investments for fund managers investing in Exxon?",
                                         "What are Vanguard's top investments by value for 2023?",
                                         "Who are the common investors between Tesla and Microsoft?"])

interface.launch(share=True)

Running on local URL:  http://127.0.0.1:7860


## Conclusion

In this notebook, we went through the steps of connecting a LangChain agent to a Neo4j database and using it to generate Cypher queries in response to user requests via LLMs on VertexAI.

We used the `code-bison` model, but this approach can be generalized to any of the VertexAI LLMs and it can also be augmented with additional procedural steps around the generation chain to customize the user experience further for specific use cases.  The critical takeaway is the importance of Neo4j as a grounding database to anchor your chatbot to reality as it generates responses and to enable it to provide responses enriched with relevant enterprise data.  Knowledge graph is the best type of data structure to use for this type of grounding.  We also added an entrypoint agent to help provide a more acceptable user experience in cases where the input queries from users don't match the expected subject matter.