In [1]:
import warnings
warnings.filterwarnings("ignore", category=FutureWarning, module="transformers.utils")
warnings.filterwarnings("ignore", category=FutureWarning, module="thinc.shims")
warnings.filterwarnings("ignore", category=FutureWarning, module="spacy_transformers.layers")
warnings.filterwarnings("ignore", category=UserWarning, module="spacy_transformers.layers")
#warnings.filterwarnings("ignore", category=UserWarning, module="spacy.util")

import torch
import spacy
from pathlib import Path
import pandas as pd
import numpy as np
import datasets
from spacy import displacy
from functools import reduce
import utils
import rebel_spacy

from neo4j import GraphDatabase

from config import load_config
from graphs import (
    explode_columns,
    run_query,
    get_all_relationships,
    get_all_nodes,
    cleanup_database,
    create_relationship,
    create_node,
)

#from spacy_custom import get_relations
from spacy.tokens import Doc
from wasabi import msg

from coref import resolve_references

In [2]:
if torch.cuda.is_available():
    torch_device = "gpu" 
    rebel_device = 0
else:
    torch_device ="cpu"
    rebel_device = -1


In [3]:
config, secrets = load_config()

In [4]:
# Load a dataset and retrieve just a sample in `ds`
roc18_ds = datasets.load_dataset("igormorgado/ROCStories2018")
ds = roc18_ds['train'].select(range(10))

Repo card metadata block was not found. Setting CardData to empty.


In [84]:
nlp = spacy.load("en_core_web_lg")
#nlp.analyze_pipes(pretty=True);

In [4]:
#nlp_coref = spacy.load("en_coreference_web_trf")
# use replace_listeners for the coref components
#nlp_coref.replace_listeners("transformer", "coref", ["model.tok2vec"])
#nlp_coref.replace_listeners("transformer", "span_resolver", ["model.tok2vec"])
#nlp.add_pipe("merge_entities")
#nlp.add_pipe("coref", source=nlp_coref)
#nlp.add_pipe("span_resolver", source=nlp_coref)
#nlp_coref.analyze_pipes(pretty=True);

In [6]:
#doc_nlp = nlp(text)
#doc_coref = nlp_coref(text)

In [18]:
def token_to_json(token, sentence_id):
    token_dict = {
        "id": f"{sentence_id}_{token.i}",
        "sentence_id": sentence_id,
        "text": token.text,
        "lemma": token.lemma_,
        "pos": token.pos_,
        "tag": token.tag_,
        "dep": token.dep_,
        "is_alpha": token.is_alpha,
        "is_stop": token.is_stop,
        "is_sent_start": token.is_sent_start,
        "head": f"{sentence_id}_{token.head.i}",
    }
    return token_dict
    

In [10]:
# Not using rebel for now to extract Relations.
# spacy_lm = 'en_core_web_sm'
# relext_pipeline = spacy.load(spacy_lm, disable=['ner', 'lemmatizer', 'attribute_rules', 'tagger'])

# rebel_config_params = {
#     'device': rebel_device,
#     'model_name': 'Babelscape/rebel-large'
# }
# rebel_comp = relext_pipeline.add_pipe("rebel", config=rebel_config_params)

In [11]:
# doc[doc.ents[0].start:doc.ents[0].end]

In [97]:
def build_token_query() -> str:
    query = """
        CREATE (t:Token {
            id: $id,
            sentence_id: $sentence_id,
            text: $text,
            pos: $pos,
            tag: $tag,
            lemma: $lemma,
            is_alpha: $is_alpha,
            is_stop: $is_stop,
            is_sent_start: $is_sent_start
        })
        RETURN t
    """
    return query

def build_dependency_query() -> str:
    query = """
        MATCH (t1:Token {id: $id}), (t2:Token {id: $head})
        CREATE (t1)-[:DEPENDS_ON {type: $dep}]->(t2)
    """
    return query

def build_sentence_query() -> str:
    query = """
        MATCH (r1:Token {id: $id}), (r2:Token {id: $next})
        CREATE (r1)-[:NEXT_SENTENCE]->(r2)
    """
    return query

In [19]:
for token in doc[:6]:
    print(token, token.head.i)

David 1
noticed 1
he 4
had 4
put 1
on 4


In [100]:
# Connect to Neo4j
driver = GraphDatabase.driver(config.neo4j.uri, auth=(config.neo4j.username, secrets.neo4j_password))
with driver.session() as session:
    # Clear existing graph
    session.execute_write(cleanup_database)

    roots = []
    sentence_id = 0

    for sentence_id, sent in enumerate(doc.sents):
        root = None
     
        # Create nodes for each token in the sentence
        for token in sent:
            query = build_token_query()
            data = token_to_json(token, sentence_id)
            node = session.execute_write(run_query, query, data)
            
            if token.dep_ == "ROOT":
                roots.append(node)
                   
        # Create relationships based on dependencies within the sentence
        for token in sent:
            if token.head.i != token.i:  # Exclude root
                query = build_dependency_query()
                data = token_to_json(token, sentence_id)
                session.execute_write(run_query, query, data)
    
    # Create relationships between sentence roots
    roots_df = pd.concat(roots, ignore_index=True)
    root_ids = list(roots_df['t'].apply(lambda x: x.get('id')))
    for id_, next_ in zip(root_ids[:-1], root_ids[1:]):
        query = build_sentence_query()
        data = {"id": id_, "next": next_}
        session.execute_write(run_query, query, data)

    driver.close()
    print(f"Dependency graph created in Neo4j with {len(root_ids)} sentences linked.")

Dependency graph created in Neo4j with 5 sentences linked.


In [91]:
root_ids = list(roots_df['t'].apply(lambda x: x.get('id')))

In [94]:
list(zip(root_ids[:-1], root_ids[1:]))

[('0_1', '1_12'), ('1_12', '2_19'), ('2_19', '3_33'), ('3_33', '4_41')]

In [57]:
sentence_roots[0]['t'][0]['id']

'0_1'

In [58]:
sentence_roots[1]['t'][0]['id']

'1_14'

In [39]:
pd.concat(sentence_roots, ignore_index=True)['t']

0    (is_sent_start, sentence_id, pos, is_alpha, le...
1    (is_sent_start, sentence_id, pos, is_alpha, le...
2    (is_sent_start, sentence_id, pos, is_alpha, le...
3    (is_sent_start, sentence_id, pos, is_alpha, le...
Name: t, dtype: object

In [9]:
text = utils.rocstory(ds[0])
doc = nlp(text)

#create_dependency_graph(doc)

In [75]:
driver = GraphDatabase.driver(config.neo4j.uri, auth=(config.neo4j.username, secrets.neo4j_password))

In [77]:
sess = driver.session()