# Knowledge Graph Creation with Generative AI

In this notebook, let's explore how to leverage Google Generative AI to build and consume a knowledge graph in Neo4j.

This notebook parses Form-13 data From SEC EDGAR. While partially structured with XML, the formatting of these forms isn't always consistent and contains some non-standard practices.  Instead of spending our time writing a bespoke parser to extract data from these files and load into Neo4j, we will prompt a Large Language Model (LLM) to do this for us automatically.  We will then also use the LLM to generate Cypher statements to load the extracted data into a Neo4j graph.

Let's install and import some libraries 

In [None]:
%%capture
%pip install --user "google-cloud-aiplatform>=1.25.0" --upgrade
%pip install --user "google-cloud-aiplatform[pipelines]>=1.25.0"
%pip install --user graphdatascience 

Now restart the kernel. That will allow the Python evironment to import the new packages.

In [None]:
import json
import numpy as np
import os
import re
from string import Template

# Vertexai and gcloud
import vertexai
from vertexai.language_models import TextGenerationModel
from google.cloud import storage

# Neo4j
from graphdatascience import GraphDataScience

## Prompt Definition

We will extract knowledge adhering to the following schema in the upcoming sections. This is a very Simplified schema to denote investment management entities and companies they own through common stock. Typically, you will have Domain Experts who come up with a richer data model, and you can extend the below to work on more data/forms to fill such a model.

![](images/12-graph-data-model.png)

To achieve our extraction goal as per the schema, We will use a series of prompts; each focused on only one task - to extract a specific entity. In this way, you can go for more granular extraction. While we don’t do so here, consider running QA on the prompt pipelines to ensure the extracted information is correct in a production scenario.

Let's go in this order to gather the data in accordance to our data model:

1. Extract Manager Information
2. Extract Filing Information

In [None]:
mgr_info_tpl = """From the text below, extract the following as json. Do not miss any of these information.
* The tags mentioned below may or may not namespaced. So extract accordingly. Eg: <ns1:tag> is equal to <tag>
* "managerName" - The name from the <name> tag under <filingManager> tag
* "street1" - The manager's street1 address from the <com:street1> tag under <address> tag
* "street2" - The manager's street2 address from the <com:street2> tag under <address> tag
* "city" - The manager's city address from the <com:city> tag under <address> tag
* "stateOrCounty" - The manager's stateOrCounty address from the <com:stateOrCountry> tag under <address> tag
* "zipCode" - The manager's zipCode from the <com:zipCode> tag under <address> tag
* "reportCalendarOrQuarter" - The reportCalendarOrQuarter from the <reportCalendarOrQuarter> tag under <address> tag
* Just return me the JSON enclosed by 3 backticks. No other text in the response

Text:
$ctext
"""

In [None]:
filing_info_tpl = """From the text below, extract the following as json. The text below contains a list of investments. Please extract the below variables into json list enclosed by 3 back ticks. Please use the quoted names below while doing this
* "cusip" - The cusip from the <cusip> tag under <infoTable> tag
* "companyName" - The name under the <nameOfIssuer> tag.
* "value" - The value from the <value> tag under <infoTable> tag. Return as a number. 
* "shares" - The sshPrnamt from the <sshPrnamt> tag under <infoTable> tag. Return as a number. 
* "sshPrnamtType" - The sshPrnamtType from the <sshPrnamtType> tag under <infoTable> tag
* "investmentDiscretion" - The investmentDiscretion from the <investmentDiscretion> tag under <infoTable> tag
* "votingSole" - The votingSole from the <votingSole> tag under <infoTable> tag
* "votingShared" - The votingShared from the <votingShared> tag under <infoTable> tag
* "votingNone" - The votingNone from the <votingNone> tag under <infoTable> tag

Text:
$ctext
"""

## Functions for Using LLMs

Let's create some helper function to talk to the LLM with our prompt and text input. See [Vertex AI documentation](https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models) for foundation models.

We will use the text-bison base model. In some cases, there may be a need to fine-tune LLM models for KG creation. [Vertex AI provides an elegant way to fine-tune](https://cloud.google.com/vertex-ai/docs/generative-ai/models/tune-models) where the updated weights/model stay within your tenant and the base model is frozen.

In [None]:
# wrapper for calling language model
def run_text_model(
    project_id: str,
    model_name: str,
    temperature: float,
    max_decode_steps: int,
    top_p: float,
    top_k: int,
    prompt: str,
    tuned_model_name: str = None,
    ) :
    """Text Completion Use a Large Language Model."""
    if tuned_model_name is None:
        model = TextGenerationModel.from_pretrained(model_name)
    else:
        model = model.get_tuned_model(tuned_model_name)
    response = model.predict(
        prompt,
        temperature=temperature,
        max_output_tokens=max_decode_steps,
        top_k=top_k,
        top_p=top_p,)
    return response.text

In [None]:
# wrapper for entity extraction / parsing
def extract_entities_relationships(prompt, tuned_model_name=None):
    try:
        res = run_text_model(project_id, "text-bison@001", 0, 1024, 0.8, 1, prompt, tuned_model_name)
        return res
    except Exception as e:
        print(e)

In [None]:
# splitting function for chunking up filing information to avoid hitting LLM token limits
def split_filing_info(s, chunk_size=5):
    pattern = '(</(\w+:)?infoTable>)'
    splitter = re.findall(pattern, s)[0][0]
    _parts = s.split(splitter)
    if len(_parts) > chunk_size:
        chunks_of_list = np.array_split(_parts, len(_parts)/chunk_size) # max 5 filings per part
        chunks_of_str = map(lambda x: splitter.join(x), chunks_of_list)
        return list(chunks_of_str)
    else:
        return [s]

## Test Example for Parsing
Let's start with one form13 file to see how we can parse it with Generative AI.

In [None]:
storage_client = storage.Client()
bucket = storage_client.bucket('neo4j-datasets')
blob = bucket.blob('form13/raw/raw_2023-05-15_archives_edgar_data_1027451_0000919574-23-003245.txt')

inp_text = blob.download_as_string().decode()

In [None]:
print(inp_text[:1500])

We can split data into manager and filing info pieces using `<XML>` tags

In [None]:
contents = inp_text.split('<XML>')
manager_info = contents[1].split('</XML>')[0].strip()
filing_info = contents[2].split('</XML>')[0].strip()

### Parsing Manager Information

In [None]:
# Note, you will need to set your project_id
project_id = 'neo4jbusinessdev'
location = 'us-central1'

In [None]:
#initialize / authenticate vertex AI SDK to begin
vertexai.init(project=project_id, location=location)

In [None]:
# Create prompt
prompt = Template(mgr_info_tpl).substitute(ctext=manager_info)
print(prompt)

In [None]:
# Use LLM to parse out manager info
manager_data = json.loads(extract_entities_relationships(prompt).split('```')[1].strip('json'))
manager_data

### Parse Filing Information
We will parse filing info in a similar manner to manager information. Because the filings include a list of many entries however, we will want to split the input into chunks so as not to exceed input or output token limits. 

In [None]:
filing_info_chunks = split_filing_info(filing_info)
len(filing_info_chunks)

In [None]:
prompt = Template(filing_info_tpl).substitute(ctext=filing_info_chunks[0])
response = extract_entities_relationships(prompt)
print(response)

## Data Ingestion

### Test Example

Let's walk through the steps to do this with just the 1 form above first, then we can move on to parsing and ingesting multiple form13s

To start we can run the LLM parsing over all the filing info from the form and then combine the resulting JSON into a list conducive for Neo4j loading.

In [None]:
filings_list = []
for filing_info_chunk in filing_info_chunks:
    prompt = Template(filing_info_tpl).substitute(ctext=filing_info_chunk)
    response = extract_entities_relationships(prompt)
    filings_list.extend(json.loads(response.replace('```', '')))

for item in filings_list:
    item['managerName'] = manager_data['managerName']
    item['reportCalendarOrQuarter'] = manager_data['reportCalendarOrQuarter']
filings_list[:5]

In [None]:
len(filings_list)

#### Establish Neo4j Connection

In [None]:
# username is neo4j by default
NEO4J_USERNAME = 'neo4j'
# You will need to change these to match
NEO4J_URI = '<neo4j+s://xxxxx.databases.neo4j.io>'
NEO4J_PASSWORD = '<password>'

In [None]:
# Establish connection
gds = GraphDataScience(
    NEO4J_URI,
    auth=(NEO4J_USERNAME, NEO4J_PASSWORD),
    aura_ds=True
)
gds.set_database('neo4j')

Before loading, we should create node key constraints for nodes.  This acts as a unique id and an index and is necessary for fast, efficient queries.  In general, if you notice ingestion is super slow (and getting slower) with Neo4j, double-check that you created indexes.  For this small sample, it won't matter, but it will undoubtedly impact as we ingest more data. 

In [None]:
gds.run_cypher('CREATE CONSTRAINT unique_manager IF NOT EXISTS FOR (n:Manager) REQUIRE (n.managerName) IS NODE KEY')
gds.run_cypher('CREATE CONSTRAINT unique_company_id IF NOT EXISTS FOR (n:Company) REQUIRE (n.cusip) IS NODE KEY')

To merge the data, we can use parameterized Cypher queries.  Basically, we will send filings in batches (in this sample case, just one batch) for each node and relationship type and insert them as parameters in the query.

In [None]:
# Merge Company Nodes.
gds.run_cypher('''
UNWIND $records AS record
MERGE (c:Company {cusip: record.cusip})
SET c.companyName = record.companyName
RETURN count(c) AS company_node_merge_count
''', params={'records':filings_list})

In [None]:
# Merge Manager Node. In this case we just have one
gds.run_cypher('''
MERGE (m:Manager {managerName: $name})
RETURN count(m) AS manager_node_merge_count
''', params={'name':manager_data['managerName']})

In [None]:
# Merge OWNS Relationship
gds.run_cypher('''
UNWIND $records AS record
MATCH (m:Manager {managerName: record.managerName})
MATCH (c:Company {cusip: record.cusip})
MERGE(m)-[r:OWNS]->(c)
SET r.reportCalendarOrQuarter = record.reportCalendarOrQuarter,
    r.value = record.value,
    r.shares = record.shares
RETURN count(r) AS owns_relationship_merge_count
''', params={'records':filings_list})

You can now check the graph to see the loaded sample data

![](images/13-sample-graph-snapshot.png)

### Ingest Multiple Form 13 Files
We will make a pipeline using the methods above.  In this case we will take a two-step approach, first parse all the data, then chunk that data and ingest into Neo4j.

For purposes of this module we will just use a few form13 files.  Below is a subset

In [None]:
sample_file_names = [
    'form13/raw/raw_2022-01-06_archives_edgar_data_1495703_0001495703-22-000002.txt',
    'form13/raw/raw_2022-01-03_archives_edgar_data_1875995_0001875995-22-000004.txt',
    'form13/raw/raw_2022-01-03_archives_edgar_data_1844571_0001844571-22-000001.txt'
]

In [None]:
#helper function for getting filing info
def get_manager_and_filing_info(raw_txt):
    contents = raw_txt.split('<XML>')
    manager_info = contents[1].split('</XML>')[0].strip()
    filing_info = contents[2].split('</XML>')[0].strip()
    
    return manager_info, filing_info

In [None]:
%%time

print(f'=== Parsing {len(sample_file_names)} Form 13 Files ===')

filings_list = []
manager_list = []

for file_name in sample_file_names:
    
    print(f'--- parsing {file_name} ---')
    try:
        #Get raw form13 file
        print('getting file text from gcloud....')
        blob = bucket.blob(file_name)
        raw_text = blob.download_as_string().decode()

        #Get raw manager and filing info from file
        print('getting file contents...')
        manager_info, filing_info = get_manager_and_filing_info(raw_text)

        #Parse manager info into dict using LLM
        print('Parsing submission and manager info...')
        mng_prompt = Template(mgr_info_tpl).substitute(ctext=manager_info)
        mng_response = extract_entities_relationships(mng_prompt)
        manager_data = json.loads(mng_response.replace('```', ''))
        manager_list.append({'managerName': manager_data['managerName']})

        #Parse filing info into list of dicts using LLM
        print('Parsing filing info...')
        tmp_filing_list = []
        for filing_info_chunk in split_filing_info(filing_info):
            filing_prompt = Template(filing_info_tpl).substitute(ctext=filing_info_chunk)
            filing_response = extract_entities_relationships(filing_prompt)
            tmp_filing_list.extend(json.loads(filing_response.split('```')[1].strip('json')))
        for item in tmp_filing_list: #Add information from manager_info to enable OWNS relationship loading
            item['managerName'] = manager_data['managerName']
            item['reportCalendarOrQuarter'] = manager_data['reportCalendarOrQuarter']
        filings_list.extend(tmp_filing_list)
    except Exception as e:
        raise e


Now we can merge the mananger nodes

In [None]:
# Merge Manager Nodes.
gds.run_cypher('''
UNWIND $records AS record
MERGE (m:Manager {managerName: record.managerName})
RETURN count(m) AS manager_node_merge_count
''', params={'records':manager_list})

For filings lets check ther length of the list

In [None]:
len(filings_list)

While we should not need chunking for this example, below is an example of how to chunk up a parameterized function for loading in case you need to scale up. 

In [None]:
# at the dataset gets bigger we will want to chunk up the filings we send to Neo4j
def chunks(xs, n=10_000):
    n = max(1, n)
    return [xs[i:i + n] for i in range(0, len(xs), n)]

In [None]:
# Merge Company Nodes
for d in chunks(filings_list):
    res = gds.run_cypher('''
    UNWIND $records AS record
    MERGE (c:Company {cusip: record.cusip})
    SET c.companyName = record.companyName
    RETURN count(c) AS company_node_merge_count
    ''', params={'records':d})
    print(res)

In [None]:
# Merge OWNS Relationships
for d in chunks(filings_list):
    res = gds.run_cypher('''
    UNWIND $records AS record
    MATCH (m:Manager {managerName: record.managerName})
    MATCH (c:Company {cusip: record.cusip})
    MERGE(m)-[r:OWNS]->(c)
    SET r.reportCalendarOrQuarter = record.reportCalendarOrQuarter,
        r.value = record.value,
        r.shares = record.shares
    RETURN count(r) AS owns_relationship_merge_count
    ''', params={'records':d})
    print(res)

And that is it!  You can recheck the graph to see updated nodes and relationships.

This type of workflow can be applied to other unstructured data to parse entities with language models and load them into a Neo4j Knowledge graph. 

Before moving on to the next lab, let’s clean up the data and constraints we created.


## Clean Up

In [None]:
# Remove Data
gds.run_cypher('MATCH (n) DETACH DELETE n')

In [None]:
# Drop Constriants
gds.run_cypher('DROP CONSTRAINT unique_manager IF EXISTS')
gds.run_cypher('DROP CONSTRAINT unique_company_id IF EXISTS')