# WS24 - Intelligente Informationssysteme

## Block 3: Retrieval Augmented Generation

**Part 9: Advanced Information Extraction**

1. Concept Extraction with Prompting
2. Named Entity Recognition with GliNER

## Concept Extraction with Prompting

Extract Concepts, Entities and Relations out of text and make a knowledge graph

In [1]:
# Load some data: each youtube video transcript is one document and should be handeled and chunked with llama_index
import os

from llama_index.core.schema import BaseNode
from llama_index.core import Document
from llama_index.core.node_parser import SentenceSplitter

# Path to data
base_path = f".{os.sep}data"

In [None]:
from typing import Tuple, List
import json

def _load_text(file_path: str) -> str:
    with open(file_path, "r") as f:
        text = f.read()
    return text

def _load_metadata(file_path: str) -> str:
    with open(file_path, "r") as f:
        metadata = json.loads(f.read())
    return metadata

def chunker(text:str, metadata={}, chunk_size=200) -> Tuple[Document, List[BaseNode]]:  
    document = Document(text=text, metadata=metadata)
    splitter = SentenceSplitter(
        chunk_size=chunk_size,     # number of words
        chunk_overlap=20,
        #paragraph_separator = "\n\n\n" not used 
    )
    nodes = splitter.get_nodes_from_documents([document])
    return (document, nodes)

In [None]:
text = _load_text(file_path=f"{base_path}/alice.txt")
document, nodes = chunker(text=text, metadata={}, chunk_size=200) # chunk size includes metadata size

print(f"There are {len(nodes)} nodes")
print(nodes[0].text, nodes[0].metadata)
print("="*80)
print(nodes[1].text, nodes[1].metadata)

In [None]:
TERMS = ["object", "entity", "location", "person", "concept"]

In [None]:
SYS_PROMPT = (
    "You are a network graph maker who extracts terms and their relations from a given context. "
    "You are provided with a context chunk (delimited by ```). Your task is to extract the ontology "
    "of terms mentioned in the given context. These terms should represent the key concepts as per the context. \n"
    "Thought 1: While traversing through each sentence, Think about the key terms mentioned in it.\n"
        f"\tTerms may include {', '.join(TERMS)}, etc.\n"
        "\tTerms should be as atomistic as possible\n\n"
    "Thought 2: Think about how these terms can have one on one relation with other terms.\n"
        "\tTerms that are mentioned in the same sentence or the same paragraph are typically related to each other.\n"
        "\tTerms can be related to many other terms\n\n"
    "Thought 3: Find out the relation between each such related pair of terms. \n\n"
    "Format your output as a list of json. \n"
    "Each element of the list contains a pair of terms and the relation between them, like the follwing: \n"
    "```json"
    "[\n"
    "   {\n"
    '       "node_1": "A concept from extracted ontology",\n'
    '       "node_2": "A related concept from extracted ontology",\n'
    '       "edge": "relationship between the two concepts, node_1 and node_2 explained in one verb or phrease"\n'
    "   }, {...}\n"
    "]"
    "```"
)

In [None]:
def clean_output(text) -> list:
    result = []
    start = text.find("[")
    stop =  text.find("]")
    if start > 0 and stop > 0 and stop > start:
        text = text[start:stop+1]
    try:
        result = json.loads(text)
    except Exception as e:
        print(e)
    return result

In [None]:
text = '0123[{"id":5667, "name":"Klaus"}]9'
result = clean_output(text)
print(type(result),":", result)

In [None]:
def extract_nodes_edges(results:list) -> list:
    formatted_results = []
    for _dict in results:
        try:
            node_1 = _dict.get("node_1","").strip().lower()
            node_2 = _dict.get("node_2","").strip().lower()
            edge = _dict.get("edge","unknown").strip().lower()
            if len(node_1) > 0 and len(node_2) > 0:
                formatted_results.append((node_1, node_2, edge))
        except Exception as e:
            print(e)
    return formatted_results
        

In [None]:
import ollama

for i, node in enumerate(nodes):
    print(i, node.node_id, end="")
    if 'graph_structure' in node.metadata:
        graph_structure = node.metadata['graph_structure']
    else:
        messages = [{'role': 'system', 'content': SYS_PROMPT},
                    {'role': 'user', 'content': f"context: ```{node.text}``` \n\n output: "}]
        response = ollama.chat(model='llama3.2:latest', messages=messages)
        results = clean_output(response.message.content)
        graph_structure = extract_nodes_edges(results)
        if len(graph_structure) == 0: #do a retry
            print(" retry", end="")
            response = ollama.chat(model='llama3.2:latest', messages=messages)
            results = clean_output(response.message.content)
            graph_structure = extract_nodes_edges(results)
        node.metadata = {"graph_structure": graph_structure}
    #print("\n",graph_structure)
    #print("="*80)

In [None]:
###### save the data to disk #####
#_json = {"document": document.to_dict(),
#         "node": [node.to_dict() for node in nodes]}
#with open(f"{base_path}/alice.json", "w") as f:
#    f.write(json.dumps(_json, indent=3))

###### just persist the nodes #######
from llama_index.core.storage.docstore import SimpleDocumentStore
docstore = SimpleDocumentStore()
docstore.add_documents(nodes)
docstore.persist(persist_path = f"{base_path}/alice_nodes.bin")

###### load the data from disk #####
new_docstore = SimpleDocumentStore.from_persist_path(persist_path = f"{base_path}/alice_nodes.bin")
ref_doc_infos = new_docstore.get_all_ref_doc_info() # dictionary of RefDocInfo objects containing node_ids
ref_doc_info = list(ref_doc_infos.values())[0]
nodes = new_docstore.get_nodes(ref_doc_info.node_ids)

In [None]:
# Bild a graph based on networkx
import networkx as nx
M = nx.MultiGraph() # Lets start with an undirected graph
G = nx.Graph() # contextual proximity graph
THRESHOLD = 2

all_nodes = set()
doc_nodes_dict = {}
for node in nodes:
    doc_nodes_dict[node.node_id] = []
    all_edges_per_node = {}
    for (source, target, edge) in node.metadata.get("graph_structure", []):
        all_nodes.add(source), all_nodes.add(target)
        if (source, target, edge) not in all_edges_per_node:
            all_edges_per_node[(source, target, edge)] = 0
        all_edges_per_node[(source, target, edge)] += 1
        doc_nodes_dict[node.node_id].append(source)
        doc_nodes_dict[node.node_id].append(target)
    
    # add edges:
    for (source, target, edge) in all_edges_per_node:
        weight = all_edges_per_node[(source, target, edge)]
        M.add_edge(source, target, relation=edge, title=edge, node_id=node.node_id, weight=weight)
                
    # add contextual proximity: nodes (source or target) in same text chunk
    contextual_proximity_adjacency = {}
    for source in doc_nodes_dict[node.node_id]:
        for target in doc_nodes_dict[node.node_id]:
            key = (source, target)
            if key not in contextual_proximity_adjacency:
                contextual_proximity_adjacency[key] = 0
            contextual_proximity_adjacency[key] += 1
    for (source, target) in contextual_proximity_adjacency:
        weight = contextual_proximity_adjacency[(source, target)]
        if source != target and weight > THRESHOLD:
            G.add_edge(source, target, relation="contextual proximity", title="contextual proximity",
                       node_id=node.node_id, weight=weight)
            M.add_edge(source, target, relation="contextual proximity", title="contextual proximity",
                       node_id=node.node_id, weight=weight)
print(M)
print(G)
print(f"There are {len(all_nodes)} unique node candidates")

In [None]:
# Use Girvan-Newman Community Detection Algorithm
communities_generator = nx.community.girvan_newman(M)
top_level_communities = next(communities_generator)
next_level_communities = next(communities_generator)
communities = sorted(map(sorted, next_level_communities))
print("Number of Communities = ", len(communities))
print(communities)

In [None]:
clusters = {}
for i,nodelist in enumerate(communities):
    clusters[i] = nodelist

graph_nodes = M.nodes()

for i in clusters:
    for node_id in clusters[i]:
        M.nodes[node_id]['color'] = i
        M.nodes[node_id]['group'] = i
        M.nodes[node_id]['size'] = M.degree(node_id)


In [None]:
#!pip install pyvis

In [None]:
from pyvis.network import Network

net = Network(
    notebook=False,
    # bgcolor="#1a1a1a",
    cdn_resources="remote",
    height="900px",
    width="100%",
    select_menu=True,
    # font_color="#cccccc",
    filter_menu=False,
)

net.from_nx(M)
# net.repulsion(node_distance=150, spring_length=400)
net.force_atlas_2based(central_gravity=0.015, gravity=-31)
# net.barnes_hut(gravity=-18100, central_gravity=5.05, spring_length=380)
net.show_buttons(filter_=["physics"])

net.show("./index.html", notebook=False)

In [None]:
# next step: save the graph into neoj4

## Named Entity Recognition with GliNER

In [None]:
# !pip install gliner

In [None]:
from gliner import GLiNER

In [None]:
model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1")
text = _load_text(file_path=f"{base_path}/alice.txt")

(document, nodes) = chunker(text=text, metadata={}, chunk_size=200)

labels = ["object", "entity", "location", "person", "concept", "animal"]

entities_in_nodes = {}
for i, node in enumerate(nodes):
    
    entities = model.predict_entities(node.text, labels)
    
    node.metadata['entities'] = []
    for entity in entities:
        name = str(entity["text"]).lower()
        entity["text"] = name
        label = entity["label"]
        node.metadata['entities'].append((name, label))
        if name not in entities_in_nodes:
            entities_in_nodes[name] = set()
        entities_in_nodes[name].add(node.node_id)
    print(node.metadata['entities'])
    if i> 10: break
