In [1]:
from langchain_community.document_loaders import TextLoader
from langchain.text_splitter import TokenTextSplitter

text_splitter = TokenTextSplitter(chunk_size=1000, chunk_overlap=100)

In [2]:
from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import pipeline

tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")

nlp = pipeline("ner", model=model, tokenizer=tokenizer, grouped_entities=True)

Some weights of the model checkpoint at dslim/bert-base-NER were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
def extract_entities(text, entity_group="PER"):
    entities = {}
    # for i, line in enumerate(lines):
    #     if (i % 1000 == 0):
    #         print(f"Processed: {i} lines")
        
    results = nlp(text)
    
    for result in results:
        word = text[result["start"]:result["end"]]
        entity_group = result["entity_group"]
        score = result["score"]
        
        if entity_group == "PER":
            dummy = {}
            dummy.setdefault("score", 0)
            dummy.setdefault("count", 0)
            
            entities.setdefault(word, dummy)
            entry = entities[word]
            
            if score > entry["score"]:
                entry["score"] = float(round(score, 4))
                entry["type"] = entity_group
                entry["count"] += 1
    
                entities[word] = entry
    
    return entities

In [4]:
def entities_to_graph(entities):
    nodes = [(k, v) for (k,v) in entities.items()]
    edges = []

    node_labels = {f"{k}": f"{k}" for k in entities}
    edge_labels = {}

    return (
        nodes,
        edges,
        node_labels,
        edge_labels
    )

In [5]:
import networkx as nx
import matplotlib.pyplot as plt
import io
from PIL import Image

def visualize_graph(nodes, edges, node_labels=None, edge_labels=None, title="Characters and their relations"):
    G = nx.Graph()
    
    # Add nodes and edges
    G.add_nodes_from(nodes)
    G.add_edges_from(edges)
    
    # Create a figure with a reasonable size
    plt.figure(figsize=(20, 8))
    
    # Set the layout for the graph
    pos = nx.spring_layout(G, k=1, iterations=50)
    
    # Draw the graph
    nx.draw_networkx_nodes(G, pos, alpha=0.7)
    nx.draw_networkx_edges(G, pos, edge_color='gray', 
                          width=1, alpha=0.5)
    
    # Add node labels if provided
    if node_labels is None:
        node_labels = {node: str(node) for node in nodes}
    nx.draw_networkx_labels(G, pos, node_labels)
    
    # Add edge labels if provided
    if edge_labels is not None:
        nx.draw_networkx_edge_labels(G, pos, edge_labels)
    
    # Add title and remove axes
    plt.title(title)
    plt.axis('off')
    
    # Show the plot
    # plt.show()

    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    buf.seek(0)
    plt.close()
    img = Image.open(buf)
    
    return img

In [6]:
import gradio as gr
import time
import json

def read_book(file):
    loader = TextLoader(file)

    doc = loader.load()
    chunks = text_splitter.split_documents(doc)
    entities = {}
    
    for i, chunk in enumerate(chunks):
        new_entities = extract_entities(chunk.page_content)
        
        entities.update(new_entities)
        
        graph_data = entities_to_graph(entities)
        graph = visualize_graph(*graph_data)

        yield [
            f"Reading page {i+1}/{len(chunks)}. Found {len(new_entities.keys())} new characters for a total of {len(entities.keys())}",
            chunk.page_content,
            graph            
        ]
        
        # time.sleep(3)

# demo = gr.Interface(
#     fn=read_book,
#     inputs=["file"],
#     outputs=["text", "text", "text"],
# )

with gr.Blocks() as app:
    gr.Markdown("The coolest book reading club!")

    with gr.Row(equal_height=True):
        with gr.Column():
            upload = gr.File(label="Upload a book")
            btn = gr.Button(value="Read book")
    
            response = gr.Textbox(label="Action")
            
            chapter = gr.Textbox(label="Current chapter")
            
        with gr.Column():
                
            graph = gr.Image(type="pil")
    
        btn.click(
            fn=read_book,
            inputs=[upload],
            outputs=[response, chapter, graph]
        )

In [7]:
# next(read_book("data/dracula.txt"))

In [8]:
app.launch()

* Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.


