In [1]:
%pip install -q langchain requests ollama langchain-ollama langchain-community aim streamlit

Note: you may need to restart the kernel to use updated packages.


# 1 Text Chunking

For the sake of model max token limits, we need to process the text chapter by chapter.

Since we cannot fit the whole document into the model, we need to split the text into smaller chunks. We can use the `CharacterTextSplitter` to split the text into smaller chunks.

`chunk_overlap` sets the overlapping tokens between the chunks, which serves as a sliding window that moves over the text, avoiding splitting the text in the middle of a sentence.

In [13]:
# Text chunking

from langchain.text_splitter import CharacterTextSplitter
from langchain.schema.runnable import RunnableLambda

book_file_paths = ['./data/a-xmax-carol-stave'+ str(i+1) +'.txt' for i in range(5)]

texts = []
for book_file_path in book_file_paths:
    with open(book_file_path, 'r') as file:
        texts.append(file.read())


# Initialize the splitter
splitter = CharacterTextSplitter(
    separator = "\n\n",
    chunk_size=2048,
    chunk_overlap = 256,
)

text = texts[0]

# Create RunnableLambda for the splitter
splitter_lambda = lambda text: splitter.create_documents([text])
splitter_runnable = RunnableLambda(splitter_lambda)

chunks = splitter.create_documents([text])

chunks_with_id = [(i, chunk) for i, chunk in enumerate(chunks)]

# Iterative Entity Extraction

Entity extraction is a two-step process:
- Chunk-level Entity Extraction (CLEE)
- Entity Aggregation (EA)

For the sake of model max token limits, we need to process the text chapter by chapter.

The process is shown below, and is wrapped in another python module `extraction.py`.

First, let's get some ideas from the model on the possible entity types.

In [16]:
from langchain.schema.output_parser import StrOutputParser
from langchain_ollama import ChatOllama
from langchain_core.prompts import ChatPromptTemplate

llama3_1_entity_type = ChatOllama(
    model="llama3.1:8b",
    temperature=0.8,
    top_k=60,
    top_p=0.9,
    num_predict=2048,
    base_url="http://127.0.0.1:11434",
)

entity_type_prompt_template = ChatPromptTemplate.from_messages([
        ("system", "You are to assist the user to construct a knowledge graph for a novel. "),
        ("human", """
The user is working on extracting entities from a book to build a knowledge graph. You goal is to provide a list of possible entity types that are commonly present in knowledge graphs for novels.
Please reply in the following format:
["entity_type1", "entity_type2", "entity_type3", ...]
"""),
    ])

entity_type_chain = entity_type_prompt_template | llama3_1_entity_type | StrOutputParser()

entity_type_chain.invoke({})

'Here\'s a list of common entity types found in knowledge graphs for novels:\n\n["Person", "Location", "Organization", "Event", "TimePeriod", "Concept", "Attribute", "Relationship"]\n\nMore specifically, these might include:\n\n* Characters (e.g. protagonist, antagonist, supporting characters)\n\t+ Person\n* Places and settings (e.g. cities, countries, fictional worlds)\n\t+ Location\n* Institutions and groups (e.g. governments, companies, gangs)\n\t+ Organization\n* Significant events (e.g. battles, discoveries, birthdays)\n\t+ Event\n* Historical periods or eras (e.g. ancient civilizations, historical events)\n\t+ TimePeriod\n* Ideas, themes, and emotions (e.g. love, friendship, redemption)\n\t+ Concept\n* Physical and mental characteristics (e.g. appearance, skills, personality traits)\n\t+ Attribute\n* Relationships between characters (e.g. friendship, romance, family ties)\n\t+ Relationship\n\nLet me know if you\'d like to proceed with extracting entities from the novel!'

## Chunk-level Entity Extraction

Next, we build a simple pipeline for Chunk-level Entity Extraction (CLEE) using the Ollama model.

In [17]:
# use ollama llama3.1:8b for entity extraction

import yaml
from langchain.schema.output_parser import StrOutputParser
from langchain_ollama import ChatOllama
from langchain_core.prompts import ChatPromptTemplate

# read prompt files from yaml
with open('./prompts/entity.yaml', 'r') as file:
    entity_yaml = yaml.safe_load(file)

# Initialize the ollama model    
llama3_1_params = entity_yaml['chat_model_params']['llama3.1']

llama3_1 = ChatOllama(
    model="llama3.1:8b",
    temperature=0.2,
    top_k=10,
    top_p=0.6,
    num_predict=2048,
    base_url="http://127.0.0.1:11434",
)

# chunk-level entity extraction (CLEE)
clee_template = ChatPromptTemplate.from_messages([
        ("system", entity_yaml['system']),
        ("human", entity_yaml['clee']),
    ])

clee_chain = clee_template | llama3_1 | StrOutputParser()

# test CLEE
result = clee_chain.invoke({
    "text": chunks[2].page_content,
    "entity_type": entity_yaml["entity_types"],
})

print(result)

[("Scrooge", "Person"), ("Christmas Eve", "Event"), ("City", "Location"), ("Marley", is not mentioned, but Jacob Marley's ghost will be important later in the novel), ("Jacob Marley", "Person"), ("nephew", "Person")]


Next, we define an iterative function that runs the CLEE on each chunk and caches the results.

In [5]:
# Define Iterative RunnableLambda of Chunk_Level Entity Extraction

from langchain.schema.runnable import RunnableLambda
from langchain_community.tools.file_management.write import WriteFileTool
from tqdm import tqdm

def iterate_over_chunks_and_cache(chunks): 
    
    chunk_level_entities = []
    
    with tqdm(total=len(chunks), desc="Processing the chunks...") as pbar:
        for i, chunk in enumerate(chunks):      
            pbar.set_postfix({'Current chunk': chunk.page_content[:10]+'...', 'Iteration': i + 1})
            chunk_level_entities.append(clee_chain.invoke({
                "text": chunk.page_content,
                "entity_type": entity_yaml["entity_types"],
            }))
            pbar.update(1)
    
    
    # cache the chunk-level entities
    WriteFileTool().invoke({
        "file_path": "./output/chunk-level-entities.txt",
        "text": str(chunk_level_entities),
    })
    
    return {
        "entities": chunk_level_entities,
        "entity_type": entity_yaml["entity_types"],
        }

clee_iter = RunnableLambda(iterate_over_chunks_and_cache)

## Stave-level Entity Aggregation

Next, we define a template for Entity Aggregation (EA) and run it on the cached chunk-level entities for each stave.

The following is a test run on the first stave.

In [23]:
# Entity-aggregation

# params: entity-type, entities
ea_template = ChatPromptTemplate.from_messages([
        ("system", entity_yaml['system']),
        ("human", entity_yaml['ea']),
    ])

ea_chain = splitter_runnable | clee_iter | ea_template | llama3_1 | StrOutputParser()

global_entities = ea_chain.invoke(text)

WriteFileTool().invoke({
    "file_path": "./output/example-stave-aggregation.txt",
    "text": str(global_entities),
})

Processing the chunks...:   0%|          | 0/96 [00:00<?, ?it/s, Current chunk=CONTENTS

Processing the chunks...:   1%|          | 1/96 [00:03<05:11,  3.28s/it, Current chunk=CONTENTS

Processing the chunks...:  27%|██▋       | 26/96 [01:40<03:51,  3.31s/it, Current chunk="I am!"a s..., Iteration=26]

Processing the chunks...:  28%|██▊       | 27/96 [01:43<03:50,  3.34s/it, Current chunk="I am!"

Processing the chunks...: 100%|██████████| 96/96 [06:33<00:00,  4.10s/it, Current chunk=*     *   ..., Iteration=96]


'File written successfully to ./output/global-entities.txt.'

## The Result for All Staves

The above process only for the first stave. The logic can be applied to the rest of the staves and is wrapper in another python module `extraction_util.py`.
We can now run the process for all the staves.

In [6]:
# read prompt files from yaml
with open('./prompts/entity.yaml', 'r') as file:
    entity_yaml = yaml.safe_load(file)
    
from utils.extraction_util import extract_entities

stave_entities = []
for i, text in enumerate(texts):
    stave_entities.append(extract_entities(text=text, entity_yaml=entity_yaml, stave_num=i+1))

Processing the chunks of stave1...:   0%|          | 0/22 [00:00<?, ?it/s, Current chunk=STAVE ONE
Processing the chunks of stave1...:   5%|▍         | 1/22 [00:04<01:43,  4.95s/it, Current chunk=STAVE ONE
Processing the chunks of stave1...: 100%|██████████| 22/22 [01:31<00:00,  4.15s/it, Current chunk=Whether th..., Iteration=22]
Processing the chunks of stave2...:   0%|          | 0/21 [00:00<?, ?it/s, Current chunk=STAVE TWO
Processing the chunks of stave2...:   5%|▍         | 1/21 [00:05<01:44,  5.21s/it, Current chunk=STAVE TWO
Processing the chunks of stave2...:  14%|█▍        | 3/21 [00:11<01:07,  3.77s/it, Current chunk="I am!"a s..., Iteration=3]

Processing the chunks of stave2...:  19%|█▉        | 4/21 [00:15<01:00,  3.58s/it, Current chunk="I am!"

Processing the chunks of stave2...: 100%|██████████| 21/21 [01:24<00:00,  4.02s/it, Current chunk="Belle," s..., Iteration=21]
Processing the chunks of stave3...: 100%|██████████| 29/29 [01:59<00:00,  4.12s/it, Current chunk="Spi

## Aggregate over the Stave-level Entities

Next, we define a template for Stave-level Entity Aggregation and run it on the cached chunk-level entities for each stave.

We also run by chunks iteratively, provide the model with the previously extracted entities to maintain the context.

In [7]:
str(stave_entities).__len__()

6411

In [8]:
llama3_1_sle = ChatOllama(
    model="llama3.1:8b",
    temperature=0.2,
    top_k=10,
    top_p=0.6,
    num_predict=4096,
)

The entity list is still too long for LLM inputs. So, let's split the list into smaller chunks, and feed them to the model iteratively, together with the model output context from the previous iteration.

In [9]:
from tqdm import tqdm

with open('./prompts/entity.yaml', 'r') as file:
    entity_yaml = yaml.safe_load(file)
    
stave_aggregation_template = ChatPromptTemplate.from_messages([
        ("system", entity_yaml['system']),
        ("human", entity_yaml['stave-aggregation']),
    ])

stave_aggregation_chain = stave_aggregation_template | llama3_1_sle | StrOutputParser()

previous_entities = ""

for stave in tqdm(stave_entities):

    global_entities = stave_aggregation_chain.invoke({
        "entity_type": entity_yaml["entity_types"],
        "entities": stave,
        "prev_entities": previous_entities,
    })
    
    previous_entities = global_entities
    
print(global_entities)

# cache the global entities
WriteFileTool().invoke({
    "file_path": "./output/global-entities.txt",
    "text": global_entities,
})

# or using pickle
import pickle

with open('output/entities/global-entities.pickle', 'wb') as file:
    pickle.dump(global_entities, file)

100%|██████████| 5/5 [01:25<00:00, 17.16s/it]

[("Jacob Marley", "Person"), ("Scrooge", "Person"), ("Bob Cratchit", "Person"), ("Tiny Tim", "Person"), ("Fred", "Person"), ("Ghost of Jacob Marley", "Object"), ("Ghost of Christmas Present", "Object"), ("wandering Spirits", "Object"), ("saucepan", "Object"), ("door", "Object"), ("corner", "Location"), ("window", "Location"), ("Christmas Time", "Concept"), ("Past", "Concept"), ("Present", "Concept"), ("Future", "Concept"), ("Spirits of all Three", "Concept"), ("Christmas Day", "Event"), ("boy", "Person"), ("Turkey", "Object"), ("Marley", "Person"), ("Girl", "Person"), ("nephew", "Person"), ("girl", "Person"), ("Topper", "Person"), ("the plump sister", "Person"), ("Bob", "Person"), ("Cratchit", "Person"), ("City", "Location")]





# 3 Relation Extraction

First, we ask the llm to provide us with a list of possible relation types.

In [67]:
llama3_1_relation_type = ChatOllama(
    model="llama3.1:8b",
    temperature=0.8,
    top_k=60,
    top_p=0.9,
    num_predict=2048,
    base_url="http://127.0.0.1:11434",
)

relation_type_prompt_template = ChatPromptTemplate.from_messages([
        ("system", "You are to assist the user to construct a knowledge graph for a novel. "),
        ("human", """
The user is working on extracting relations from a book to build a knowledge graph. You goal is to provide a list of possible relation types that are commonly present in knowledge graphs for novels.
Please reply in the following format:
["relation_type1", "relation_type2", "relation_type3", ...]
"""),
    ])

relation_type_chain = relation_type_prompt_template | llama3_1_relation_type | StrOutputParser()

relation_type_chain.invoke({})

'Here\'s a list of common relation types found in knowledge graphs for novels:\n\n["Character-Affiliation", "Character-Relationship", "Location-Association", "Event-Occurrence", "Item-Possession", "Personality-Trait", "Age-Information", "Occupation-Profession", "Education-History", "Family-Background", "Nationality-Origin", "Conflict-Involvement", "Plot-Development", "Setting-Timeframe", "Organization-Membership", "Award-Reception"]\n\nThese relation types can be used to connect entities (e.g., characters, locations, events) and convey meaningful information about the novel. Do you want me to help you categorize or prioritize these relation types?'

## 3.1 Chunk-level Relation Extraction

Let's run Relation Extraction (RE) on each chunk with the help of extracted global entities.

Here's the a test run on the first chunk.

In [23]:
with open('./output/entities/global-entities.txt', 'r') as file:
    global_entities = eval(file.read())

with open("./prompts/relation.yaml", 'r') as file:
    relation_yaml = yaml.safe_load(file)

# chunk-level relation extraction
cl_relation_extraction_template = ChatPromptTemplate.from_messages([
        ("system", relation_yaml['system']),
        ("human", relation_yaml['chunk-relation-extraction-basic']),
    ])

cl_relation_extraction_chain = cl_relation_extraction_template | llama3_1 | StrOutputParser()

cl_relation_extraction_result = cl_relation_extraction_chain.invoke({
    "global_entities": global_entities,
    "current_chunk": chunks[3].page_content,
    "relation_types": relation_yaml["relation-types"],
})

print(cl_relation_extraction_result)

[("Scrooge", "nephew", "Character-Relationship", "Character-Affiliation"), 
 ("Scrooge", "nephew", "Conflict-Involvement", "Plot-Development"), 
 ("Scrooge", "nephew", "Age-Information", "Personality-Trait"), 
 ("Scrooge", "Marley", "Occupation-Profession", "Character-Affiliation"), 
 ("Scrooge", "Bob Cratchit", "Occupation-Profession", "Character-Affiliation")]


## Iterative Chunk-Level Relation Extraction

Now, again, run the relation extraction iteratively for all the chunks.

In [25]:
whole_book = "./data/a-xmas-carol-body.txt"

with open(whole_book, 'r') as file:
    whole_book_text = file.read()
    
# Initialize the splitter
splitter = CharacterTextSplitter(
    separator = "\n\n",
    chunk_size=2048,
    chunk_overlap = 256,
)

chunks_whole_book = splitter.create_documents([whole_book_text])


# chunk-level relation extraction
cl_relation_extraction_template = ChatPromptTemplate.from_messages([
        ("system", relation_yaml['system']),
        ("human", relation_yaml['chunk-relation-extraction-basic']),
    ])

cl_relation_extraction_chain = cl_relation_extraction_template | llama3_1 | StrOutputParser()

def iter_extract_relations(chunks, chain=cl_relation_extraction_chain):

    chunk_level_relations = []
    
    # Iterate over the chunks and extract relations
    with tqdm(total=len(chunks), desc="Processing the chunks ...") as pbar:
        
        for i, chunk in enumerate(chunks):
            
            pbar.set_postfix({'Current chunk': chunk.page_content[:10] + '...', 'Iteration': i + 1})
            
            chunk_level_relations.append(chain.invoke({
                "global_entities": global_entities,
                "current_chunk": chunk.page_content,
                "relation_types": relation_yaml["relation-types"],
            }))
            pbar.update(1)
            
    # Cache the chunk-level relations for the stave
    WriteFileTool().invoke({
        "file_path": "./output/relations/chunk-level-relations.txt",
        "text": str(chunk_level_relations),
    })
    
    return chunk_level_relations

chunk_level_relations = iter_extract_relations(chunks_whole_book)


Processing the chunks ...:   0%|          | 0/96 [00:00<?, ?it/s, Current chunk=CONTENTS

Processing the chunks ...:   1%|          | 1/96 [00:13<21:38, 13.67s/it, Current chunk=CONTENTS

Processing the chunks ...:  27%|██▋       | 26/96 [04:00<09:29,  8.14s/it, Current chunk="I am!"a s..., Iteration=26]

Processing the chunks ...:  28%|██▊       | 27/96 [04:12<10:25,  9.06s/it, Current chunk="I am!"

Processing the chunks ...: 100%|██████████| 96/96 [14:06<00:00,  8.82s/it, Current chunk=*     *   ..., Iteration=96]


In [125]:
with open('./output/relations/chunk-level-relations.txt', 'r') as file:
    chunk_level_relations = eval(file.read())

chunk_level_relations

['[("Scrooge", "Marley", "Character-Relationship", "Character-Affiliation"), \n ("Scrooge", "Marley\'s Ghost", "Character-Relationship", "Character-Affiliation"), \n ("Scrooge", "Bob Cratchit", "Character-Relationship", "Occupation-Profession"), \n ("Scrooge", "Tiny Tim", "Character-Relationship", "Item-Possession"), \n ("Jacob Marley", "Marley\'s Ghost", "Character-Relationship", "Character-Affiliation"), \n ("Christmas Time", "Past", "Location-Association", "Setting-Timeframe"), \n ("Christmas Time", "Present", "Location-Association", "Setting-Timeframe"), \n ("Christmas Time", "Future", "Location-Association", "Setting-Timeframe")]',
 '[("Scrooge", "Marley", "business partner", "Character-Relationship"),\n ("Scrooge", "Marley", "executor", "Character-Relationship"),\n ("Scrooge", "Marley", "administrator", "Character-Relationship"),\n ("Scrooge", "Marley", "assign", "Character-Relationship"),\n ("Scrooge", "Marley", "residuary legatee", "Character-Relationship"),\n ("Scrooge", "Marl

In [126]:
# output processing and cache the results
chunk_level_relations_eval = []

for relations in chunk_level_relations:
    try:
        chunk_level_relations_eval += eval(relations)
    except Exception as e:
        print(e)
    
import pickle

with open('output/relations/chunk-level-relations.pickle', 'wb') as file:
    pickle.dump(chunk_level_relations_eval, file)

unterminated string literal (detected at line 84) (<string>, line 84)


## 3.2 Entity Type Resolution

Finally, we need to resolve the entity types for the entities in the extracted relations.

The entities in the relations extracted from the last step does not really match the entities we extracted separately before.

The result will be a hashmap (python dictionary) with the entity as the key and the entity type as the value.

In [127]:
import pickle

with open('./output/relations/chunk-level-relations.pickle', 'rb') as file:
    cl_relations = pickle.load(file)
    
print(f'Number of tokens: {str(cl_relations).__len__()}')
print(f'Number of relations: {cl_relations.__len__()}')

Number of tokens: 41962
Number of relations: 586


In [128]:
entities_in_extracted_relations = []
for cl_relation in cl_relations:
    entities_in_extracted_relations.append(cl_relation[0])
    entities_in_extracted_relations.append(cl_relation[1])
    
entities_in_extracted_relations = list(set(entities_in_extracted_relations))

print(entities_in_extracted_relations.__len__())

190


Next, build another pipeline for entity type resolution.

In [129]:
# resolve entity types
from tqdm import tqdm
from langchain_ollama import ChatOllama
from langchain_core.prompts import ChatPromptTemplate

relation_yaml = yaml.safe_load(open('./prompts/relation.yaml', 'r'))

assign_entity_types_template = ChatPromptTemplate.from_messages([
        ("system", relation_yaml['system']),
        ("human", relation_yaml['assign-entity-types']),
    ])

llama3_1_assign_entity_type = ChatOllama(
    model="llama3.1:8b",
    temperature=0.3,
    top_k=20,
    top_p=0.6,
    num_predict=4096,
    base_url="http://127.0.0.1:11434",
)

assign_entity_types_chain = assign_entity_types_template | llama3_1_entity_type | StrOutputParser()

entity_types_for_relations = []

for entity in tqdm(entities_in_extracted_relations):
    entity_types_for_relations.append(assign_entity_types_chain.invoke({
        "entity": entity,
        "entity_types": entity_yaml["entity_types"],
    }))

100%|██████████| 190/190 [01:05<00:00,  2.92it/s]


In [130]:
# error handling: if the entity type does not belong to the list, then assign the entity type as 'Miscellaneous'

final_entity_types_list = []
valid_answer = 0

for item in entity_types_for_relations:
    try:
        if item in entity_yaml["entity_types"]:
            final_entity_types_list.append(item)
            valid_answer += 1
        else:
            final_entity_types_list.append("Miscellaneous")
            
    except Exception as e:
        print(e)
        final_entity_types_list.append("Miscellaneous")
        
print(f'{valid_answer} / {entity_types_for_relations.__len__()} valid answers')

entity_types_mapping = dict(zip(entities_in_extracted_relations, final_entity_types_list))

# dump to pickle

with open('./output/relations/relations-entity-type.pickle', 'wb') as file:
    pickle.dump(entity_types_mapping, file)

185 / 190 valid answers
