# Knowledge Graph Creation with Generative AI

In this notebook, let's explore how to leverage Google Generative AI to build and consume a knowledge graph in Neo4j.

This notebook parses Form-13 data From SEC EAGDAR. While partially structured with XML, the formatting of these forms isn't always consistent and contains some non-standard practices.  Insead of spending our time writing a bespoke parser to extract data from these files and load into Neo4j, we will prompt a Large Language Model (LLM) to do this for us automatically.  We will then also use the LLM to generate Cypher statments to load the extracted data into a Neo4j graph.

## Setup
First off, check that the Python environment you installed in the readme is running this notebook. Make sure you select the py38 kernel in the top right of this notebook. You should see a 3.8 version when you run this command.

In [1]:
import sys
sys.version

'3.10.10 | packaged by conda-forge | (main, Mar 24 2023, 20:08:06) [GCC 11.3.0]'

Next we install and import some libraries 

In [2]:
%%capture
%pip install --user "google-cloud-aiplatform>=1.25.0" --upgrade
%pip install --user "google-cloud-aiplatform[pipelines]>=1.25.0"
%pip install --user graphdatascience 

Now restart the kernel. That will allow the Python evironment to import the new packages.

In [3]:
import json
import numpy as np
import os
import re
from string import Template

# Vertexai and gcloud
import vertexai
from vertexai.preview.language_models import TextGenerationModel
from google.cloud import storage

# Neo4j
from graphdatascience import GraphDataScience

## Prompt Definition

In the upcoming sections, we will extract knowledge adhering to the following schema. This is a very Simplified schema to denote investment management entities and companies they own through common stock. Normally, you will have Domain Experts who come up with a richer data model, and you can extend the below to work on more data/forms to fill such a model.

![](images/12-graph-data-model.png)

To achieve our Extraction goal as per the schema, We will use a series of prompts, each focused on only one task - to extract a specific entity. By this way, you can go for more granular extraction. The prompts I used here can be improved and in production scenario, you should consider running QA on the prompt pipelines to ensure that the extracted information is correct.

Let's go in this order to gather the data in accordance to out data model:

1. Extract Manager Information
2. Extract Filing Information

In [4]:
mgr_info_tpl = """From the text below, extract the following as json. Do not miss any of these information.
* The tags mentioned below may or may not namespaced. So extract accordingly. Eg: <ns1:tag> is equal to <tag>
* "name" - The name from the <name> tag under <filingManager> tag
* "street1" - The manager's street1 address from the <com:street1> tag under <address> tag
* "street2" - The manager's street2 address from the <com:street2> tag under <address> tag
* "city" - The manager's city address from the <com:city> tag under <address> tag
* "stateOrCounty" - The manager's stateOrCounty address from the <com:stateOrCountry> tag under <address> tag
* "zipCode" - The manager's zipCode from the <com:zipCode> tag under <address> tag
* "reportCalendarOrQuarter" - The reportCalendarOrQuarter from the <reportCalendarOrQuarter> tag under <address> tag
* Just return me the JSON enclosed by 3 backticks. No other text in the response

Text:
$ctext
"""

In [5]:
filing_info_tpl = """The text below contains a list of investments. Each instance of <infoTable> tag represents a unique investment. 
For each investment, please extract the below variables into json then combine into a list enclosed by 3 back ticks. Please use the quated names below while doing this
* "cusip" - The cusip from the <cusip> tag under <infoTable> tag
* "name" - The name under the <nameOfIssuer> tag.
* "value" - The value from the <value> tag under <infoTable> tag. Return as a number. 
* "shares" - The sshPrnamt from the <sshPrnamt> tag under <infoTable> tag. Return as a number. 
* "sshPrnamtType" - The sshPrnamtType from the <sshPrnamtType> tag under <infoTable> tag
* "investmentDiscretion" - The investmentDiscretion from the <investmentDiscretion> tag under <infoTable> tag
* "votingSole" - The votingSole from the <votingSole> tag under <infoTable> tag
* "votingShared" - The votingShared from the <votingShared> tag under <infoTable> tag
* "votingNone" - The votingNone from the <votingNone> tag under <infoTable> tag

Text:
$ctext
"""

## Functions for Using LLMs

Lets creater some  helper function to talk to the LLM with our prompt and text input. We will use the text-bison base model. In your use case, you might need to tune it. VertexAI provides an elegant way to finetune it. The weights will be staying within your tenant and the base model is frozen.

In [6]:
# Note, you will need to set your project_id
project_id = 'neo4jbusinessdev'
location = 'us-central1'

In [7]:
# wrapper for calling language model
def run_text_model(
    project_id: str,
    model_name: str,
    temperature: float,
    max_decode_steps: int,
    top_p: float,
    top_k: int,
    prompt: str,
    location: str = location,
    tuned_model_name: str = None,
    ) :
    """Text Completion Use a Large Language Model."""
    vertexai.init(project=project_id, location=location)
    if tuned_model_name is None:
        model = TextGenerationModel.from_pretrained(model_name)
    else:
        model = model.get_tuned_model(tuned_model_name)
    response = model.predict(
        prompt,
        temperature=temperature,
        max_output_tokens=max_decode_steps,
        top_k=top_k,
        top_p=top_p,)
    return response.text

In [8]:
# wrapper for entity extraction / parsing
def extract_entities_relationships(prompt, tuned_model_name=None):
    try:
        res = run_text_model(project_id, "text-bison@001", 0, 1024, 0.8, 1, prompt, location, tuned_model_name)
        return res
    except Exception as e:
        print(e)

In [9]:
# splitting function for chunking up filing information to avoid hitting LLM token limits
def split_filing_info(s, chunk_size=5):
    pattern = '(</(\w+:)?infoTable>)'
    splitter = re.findall(pattern, s)[0][0]
    _parts = s.split(splitter)
    if len(_parts) > chunk_size:
        chunks_of_list = np.array_split(_parts, len(_parts)/chunk_size) # max 5 filings per part
        chunks_of_str = map(lambda x: splitter.join(x), chunks_of_list)
        return list(chunks_of_str)
    else:
        return [s]

## Test Example for Parsing
Lets start with one form13 file to see how we can parse it with Generative AI.

In [10]:
storage_client = storage.Client()
bucket = storage_client.bucket('neo4j-datasets')
blob = bucket.blob('form13/raw/raw_2023-05-15_archives_edgar_data_1027451_0000919574-23-003245.txt')

inp_text = blob.download_as_string().decode()

In [11]:
print(inp_text[:1500])

<SEC-DOCUMENT>0000919574-23-003245.txt : 20230515
<SEC-HEADER>0000919574-23-003245.hdr.sgml : 20230515
<ACCEPTANCE-DATETIME>20230515103943
ACCESSION NUMBER:		0000919574-23-003245
CONFORMED SUBMISSION TYPE:	13F-HR
PUBLIC DOCUMENT COUNT:		2
CONFORMED PERIOD OF REPORT:	20230331
FILED AS OF DATE:		20230515
DATE AS OF CHANGE:		20230515
EFFECTIVENESS DATE:		20230515

FILER:

	COMPANY DATA:	
		COMPANY CONFORMED NAME:			TIGER MANAGEMENT L.L.C.
		CENTRAL INDEX KEY:			0001027451
		IRS NUMBER:				000000000
		STATE OF INCORPORATION:			DE

	FILING VALUES:
		FORM TYPE:		13F-HR
		SEC ACT:		1934 Act
		SEC FILE NUMBER:	028-05892
		FILM NUMBER:		23919492

	BUSINESS ADDRESS:	
		STREET 1:		101 PARK AVENUE
		CITY:			NEW YORK
		STATE:			NY
		ZIP:			10178
		BUSINESS PHONE:		212-984-2500

	MAIL ADDRESS:	
		STREET 1:		101 PARK AVENUE
		CITY:			NEW YORK
		STATE:			NY
		ZIP:			10178

	FORMER COMPANY:	
		FORMER CONFORMED NAME:	TIGER MANAGEMENT LLC/NY
		DATE OF NAME CHANGE:	20010606
</SEC-HEADER>
<DOCUMENT>
<TYPE>

We can split data into the manager and filing info pices using the `<XML>` tags

In [12]:
contents = inp_text.split('<XML>')
manager_info = contents[1].split('</XML>')[0].strip()
filing_info = contents[2].split('</XML>')[0].strip()

### Parsing Manager Information

In [13]:
#initialize / authenticate vertex AI SDK to begin
vertexai.init(project=project_id, location=location)

In [14]:
# Create prompt
prompt = Template(mgr_info_tpl).substitute(ctext=manager_info)
print(prompt)

From the text below, extract the following as json. Do not miss any of these information.
* The tags mentioned below may or may not namespaced. So extract accordingly. Eg: <ns1:tag> is equal to <tag>
* "name" - The name from the <name> tag under <filingManager> tag
* "street1" - The manager's street1 address from the <com:street1> tag under <address> tag
* "street2" - The manager's street2 address from the <com:street2> tag under <address> tag
* "city" - The manager's city address from the <com:city> tag under <address> tag
* "stateOrCounty" - The manager's stateOrCounty address from the <com:stateOrCountry> tag under <address> tag
* "zipCode" - The manager's zipCode from the <com:zipCode> tag under <address> tag
* "reportCalendarOrQuarter" - The reportCalendarOrQuarter from the <reportCalendarOrQuarter> tag under <address> tag
* Just return me the JSON enclosed by 3 backticks. No other text in the response

Text:
<?xml version="1.0" encoding="UTF-8"?>
<edgarSubmission xsi:schemaLocati

In [15]:
# Use LLM to parse out manager info
manager_data = json.loads(extract_entities_relationships(prompt).split('```')[1].strip('json'))
manager_data

{'name': 'TIGER MANAGEMENT L.L.C.',
 'street1': '101 PARK AVENUE',
 'street2': '',
 'city': 'NEW YORK',
 'stateOrCounty': 'NY',
 'zipCode': '10178',
 'reportCalendarOrQuarter': '03-31-2023'}

### Parse Filing Information
We will parse filing info in a similar manner to manager information. Because the filings include a list of many entries however, we will want to split the input into chunks so as not to exceed input or output token limits. 

In [16]:
filing_info_chunks = split_filing_info(filing_info)
len(filing_info_chunks)

6

In [17]:
prompt = Template(filing_info_tpl).substitute(ctext=filing_info_chunks[0])
response = extract_entities_relationships(prompt)
print(response)

  ```
[
  {
    "cusip": "02079K305",
    "name": "ALPHABET INC",
    "value": 1488692,
    "shares": 14352,
    "sshPrnamtType": "SH",
    "investmentDiscretion": "SOLE",
    "votingSole": 14352,
    "votingShared": 0,
    "votingNone": 0
  },
  {
    "cusip": "023135106",
    "name": "AMAZON COM INC",
    "value": 582556,
    "shares": 5640,
    "sshPrnamtType": "SH",
    "investmentDiscretion": "SOLE",
    "votingSole": 5640,
    "votingShared": 0,
    "votingNone": 0
  },
  {
    "cusip": "049468101",
    "name": "ATLASSIAN CORPORATION",
    "value": 205404,
    "shares": 1200,
    "sshPrnamtType": "SH",
    "investmentDiscretion": "SOLE",
    "votingSole": 1200,
    "votingShared": 0,
    "votingNone": 0
  },
  {
    "cusip": "053332102",
    "name": "AUTOZONE INC",
    "value": 9218063,
    "shares": 3750,
    "sshPrnamtType": "SH",
    "investmentDiscretion": "SOLE",
    "votingSole": 3750,
    "votingShared": 0,
    "votingNone": 0
  },
  {
    "cusip": "19260Q107",
    "name":

## Data Ingestion

### Test Example

Lets walk through the steps to do this with just the 1 form above first, then we can move on to parsing and ingesting all the form13s

To start we can run the LLM parsing over all the filing info from the form then combine the resulting json into a list condusive for Neo4j loading.

In [18]:
filings_list = []
for filing_info_chunk in filing_info_chunks:
    prompt = Template(filing_info_tpl).substitute(ctext=filing_info_chunk)
    response = extract_entities_relationships(prompt)
    filings_list.extend(json.loads(response.replace('```', '')))

for item in filings_list:
    item['managerName'] = manager_data['name']
    item['reportCalendarOrQuarter'] = manager_data['reportCalendarOrQuarter']
filings_list[:5]

[{'cusip': '02079K305',
  'name': 'ALPHABET INC',
  'value': 1488692,
  'shares': 14352,
  'sshPrnamtType': 'SH',
  'investmentDiscretion': 'SOLE',
  'votingSole': 14352,
  'votingShared': 0,
  'votingNone': 0,
  'managerName': 'TIGER MANAGEMENT L.L.C.',
  'reportCalendarOrQuarter': '03-31-2023'},
 {'cusip': '023135106',
  'name': 'AMAZON COM INC',
  'value': 582556,
  'shares': 5640,
  'sshPrnamtType': 'SH',
  'investmentDiscretion': 'SOLE',
  'votingSole': 5640,
  'votingShared': 0,
  'votingNone': 0,
  'managerName': 'TIGER MANAGEMENT L.L.C.',
  'reportCalendarOrQuarter': '03-31-2023'},
 {'cusip': '049468101',
  'name': 'ATLASSIAN CORPORATION',
  'value': 205404,
  'shares': 1200,
  'sshPrnamtType': 'SH',
  'investmentDiscretion': 'SOLE',
  'votingSole': 1200,
  'votingShared': 0,
  'votingNone': 0,
  'managerName': 'TIGER MANAGEMENT L.L.C.',
  'reportCalendarOrQuarter': '03-31-2023'},
 {'cusip': '053332102',
  'name': 'AUTOZONE INC',
  'value': 9218063,
  'shares': 3750,
  'sshPrna

In [19]:
len(filings_list)

33

#### Establish Neo4j Connection
We will assume you are using AuraDS here

In [22]:
# username is neo4j by default
NEO4J_USERNAME = 'neo4j'
# You will need to change these to match
NEO4J_URI = '<neo4j+s://xxxxx.databases.neo4j.io>'
NEO4J_PASSWORD = '<password>'

In [23]:
gds = GraphDataScience(
    NEO4J_URI,
    auth=(NEO4J_USERNAME, NEO4J_PASSWORD),
    aura_ds=True
)
gds.set_database('neo4j')

Before loading we should create uniquenss constraints for nodes.  This acts as a unique id and an index and is necessary for fast, efficient, queries.  In general, if you notice ingestion is super slow (and getting slower) with Neo4j, double check that you created indexes.  For this small sample it won't matter, but it will certainly imact as we ingest more data

In [24]:
gds.run_cypher('CREATE CONSTRAINT unique_manager IF NOT EXISTS FOR (n:Manager) REQUIRE (n.name) IS UNIQUE')
gds.run_cypher('CREATE CONSTRAINT unique_company_id IF NOT EXISTS FOR (n:Company) REQUIRE (n.cusip) IS UNIQUE')

To Merge in the data we can use paramterized cypher queries.  Basically we are going to send filings in batches (in this sample case just one batch) for each node and relationship type and insert them as parameters in the query.

In [25]:
# Merge Company Nodes.
gds.run_cypher('''
UNWIND $records AS record
MERGE (c:Company {cusip: record.cusip})
SET c.name = record.name
RETURN count(c) AS company_node_merge_count
''', params={'records':filings_list})

Unnamed: 0,company_node_merge_count
0,33


In [26]:
# Merge Manager Node. In this case we just have one
gds.run_cypher('''
MERGE (m:Manager {name: $name})
''', params={'name':manager_data['name']})

In [27]:
# Merge OWNS Relationship
gds.run_cypher('''
UNWIND $records AS record
MATCH (m:Manager {name: record.managerName})
MATCH (c:Company {cusip: record.cusip})
MERGE(m)-[r:OWNS]->(c)
SET r.reportCalendarOrQuarter = record.reportCalendarOrQuarter,
    r.value = record.value,
    r.shares = record.shares
RETURN count(r) AS owns_relationship_merge_count
''', params={'records':filings_list})

Unnamed: 0,owns_relationship_merge_count
0,33


You can now check the graph to see the loaded sample data

![](images/13-sample-graph-snapshot.png)

### Ingest Multiple Form 13 Files
We will make a pipeline using the methods above.  In this case we will take a two-step approach, first parse all the data, then chunk that data and ingest into Neo4j.

To start, lets take a look at all the form13 samples we have in cloud storage. 

In [28]:
blobs = storage_client.list_blobs('neo4j-datasets', prefix='form13/raw/', delimiter='/')
file_names = [blob.name for blob in blobs if '.txt' in blob.name]
print(f'{len(file_names)} total raw form13 files')

44142 total raw form13 files


For purposes of this module we will just use the first 5 in this list

In [29]:
sample_file_names = file_names[:5]
sample_file_names

['form13/raw/raw_2022-01-03_archives_edgar_data_1026200_0001567619-22-000057.txt',
 'form13/raw/raw_2022-01-03_archives_edgar_data_1315339_0001315339-22-000001.txt',
 'form13/raw/raw_2022-01-03_archives_edgar_data_1384943_0001384943-22-000001.txt',
 'form13/raw/raw_2022-01-03_archives_edgar_data_1452208_0001104659-22-000174.txt',
 'form13/raw/raw_2022-01-03_archives_edgar_data_1624809_0001624809-22-000001.txt']

In [30]:
#helper function for getting filing info
def get_manager_and_filing_info(raw_txt):
    contents = raw_txt.split('<XML>')
    manager_info = contents[1].split('</XML>')[0].strip()
    filing_info = contents[2].split('</XML>')[0].strip()
    
    return manager_info, filing_info

In [36]:
%%time 

print(f'=== Parsing {len(sample_file_names)} Form 13 Files ===')

filings_list = []
manager_list = []

for file_name in sample_file_names:
    
    print(f'--- parsing {file_name} ---')
    try:
        #Get raw form13 file
        print('getting file text from gcloud....')
        blob = bucket.blob(file_name)
        raw_text = blob.download_as_string().decode()

        #Get raw manager and filing info from file
        print('getting file contents...')
        manager_info, filing_info = get_manager_and_filing_info(raw_text)

        #Parse manager info into dict using LLM
        print('Parsing submission and manager info...')
        mng_prompt = Template(mgr_info_tpl).substitute(ctext=manager_info)
        mng_response = extract_entities_relationships(mng_prompt)
        manager_data = json.loads(mng_response.replace('```', ''))
        manager_list.append({'name': manager_data['name']})

        #Parse filing info into list of dicts using LLM
        print('Parsing filing info...')
        for filing_info_chunk in filing_info_chunks:
            filing_prompt = Template(filing_info_tpl).substitute(ctext=filing_info_chunk)
            filing_response = extract_entities_relationships(filing_prompt)
            filings_list.extend(json.loads(filing_response.replace('```', '')))
        for item in filings_list: #Add information from manager_info to enable OWNS relationship loading
            item['managerName'] = manager_data['name']
            item['reportCalendarOrQuarter'] = manager_data['reportCalendarOrQuarter']
    except Exception as e:
        raise e


=== Parsing 5 Form 13 Files ===
--- parsing form13/raw/raw_2022-01-03_archives_edgar_data_1026200_0001567619-22-000057.txt ---
getting file text from gcloud....
getting file contents...
Parsing submission and manager info...
Parsing filing info...
--- parsing form13/raw/raw_2022-01-03_archives_edgar_data_1315339_0001315339-22-000001.txt ---
getting file text from gcloud....
getting file contents...
Parsing submission and manager info...
Parsing filing info...
--- parsing form13/raw/raw_2022-01-03_archives_edgar_data_1384943_0001384943-22-000001.txt ---
getting file text from gcloud....
getting file contents...
Parsing submission and manager info...
Parsing filing info...
--- parsing form13/raw/raw_2022-01-03_archives_edgar_data_1452208_0001104659-22-000174.txt ---
getting file text from gcloud....
getting file contents...
Parsing submission and manager info...
Parsing filing info...
--- parsing form13/raw/raw_2022-01-03_archives_edgar_data_1624809_0001624809-22-000001.txt ---
getting f

In [37]:
# Merge Manager Nodes.
gds.run_cypher('''
UNWIND $records AS record
MERGE (m:Manager {name: record.name})
RETURN count(m) AS manager_node_merge_count
''', params={'records':manager_list})

Unnamed: 0,manager_node_merge_count
0,5


In [38]:
len(filings_list)

165

In [39]:
# at the dataset gets bigger we will want to chunk up the filings we send to Neo4j
def chunks(xs, n=10_000):
    n = max(1, n)
    return [xs[i:i + n] for i in range(0, len(xs), n)]

In [40]:
# Merge Company Nodes
for d in chunks(filings_list):
    res = gds.run_cypher('''
    UNWIND $records AS record
    MERGE (c:Company {cusip: record.cusip})
    SET c.name = record.name
    RETURN count(c) AS company_node_merge_count
    ''', params={'records':d})
    print(res)

   company_node_merge_count
0                       165


In [41]:
# Merge OWNS Relationships
for d in chunks(filings_list):
    res = gds.run_cypher('''
    UNWIND $records AS record
    MATCH (m:Manager {name: record.managerName})
    MATCH (c:Company {cusip: record.cusip})
    MERGE(m)-[r:OWNS]->(c)
    SET r.reportCalendarOrQuarter = record.reportCalendarOrQuarter,
        r.value = record.value,
        r.shares = record.shares
    RETURN count(r) AS owns_relationship_merge_count
    ''', params={'records':d})
    print(res)

   owns_relationship_merge_count
0                            165
