In [1]:
SRLEARN_PATH = "../../srlearn"
PROJECT_PATH = ".."
import sys
sys.path.append(SRLEARN_PATH)
sys.path.append(PROJECT_PATH)

import re

from copy import deepcopy
from srlearn.database import Database

from utils.experiment import loadDatabase, getLogger
from utils.utils import cleanPreds

In [2]:
DATA_PATH = "./preprocessed"
logger = getLogger("Data Converter")

In [69]:
def convertToArity2(literals, model: dict = None):
    newLiterals = set()

    literalsAreModes = False

    if any([mode in literals[0] for mode in ["+", "-", "#", "`"]]):
        literalsAreModes = True

    if not literalsAreModes:
        assert model is not None, "Model is required to convert literals. You can extract it from a Database object by calling the method `extractSchemaPreds`."

    unaryLiterals = []
    nAryLiterals = []

    for literal in literals:
        predicate, terms = re.findall(r"(.*)\((.*)\)\.", literal)[0]
        terms = terms.split(",")
        arity = len(terms)
        if arity == 2:
            newLiterals.add(literal)
        elif arity == 1:
            unaryLiterals.append(literal)
        else:
            nAryLiterals.append(literal)

    unaryLiterals = cleanPreds(unaryLiterals)
    nAryLiterals = cleanPreds(nAryLiterals)

    for literal in unaryLiterals:
        predicate, term = re.findall(r"(.*)\((.*)\)\.", literal)[0]
        if literalsAreModes:
            for firstArgMode, secondArgMode in [("+", "+"), ("+", "-"), ("-", "+")]:
                newLiteral = f"{term}haslabel({firstArgMode}{term},{secondArgMode}{term}label)."
                newLiterals.add(newLiteral)
        else:
            termType = model[predicate][0]
            newLiteral = f"{termType}haslabel({term},{predicate})."
            newLiterals.add(newLiteral)

    nAryPredicateUniqueId = {}

    for literal in nAryLiterals:
        predicate, terms = re.findall(r"(.*)\((.*)\)\.", literal)[0]
        terms = terms.split(",")
        arity = len(terms)
        if literalsAreModes:
            for term in terms:
                for firstArgMode, secondArgMode in [("+", "+"), ("+", "-"), ("-", "+")]:
                    newLiteral = f"{predicate}{term}({firstArgMode}{predicate},{secondArgMode}{term})."
                    newLiterals.add(newLiteral)
        else:
            termTypes = model[predicate]
            if predicate not in nAryPredicateUniqueId:
                nAryPredicateUniqueId[predicate] = 1
            predicateUniqueId = nAryPredicateUniqueId[predicate]
            for term, termType in zip(terms, termTypes):
                newLiteral = f"{predicate}{termType}({predicate}{predicateUniqueId},{term})."
                newLiterals.add(newLiteral)
            nAryPredicateUniqueId[predicate] += 1
            
    return list(newLiterals)

In [70]:
def convertDatabaseToArity2(database: Database):
    database = deepcopy(database)
    model = database.extractSchemaPreds()
    database.facts = convertToArity2(database.facts, model = model)
    database.pos = convertToArity2(database.pos, model = model)
    database.neg = convertToArity2(database.neg, model = model)
    database.modes = convertToArity2(database.modes, model = None)
    return database

In [108]:
def convertDatabaseToNeo4j(database: Database):
    model = database.extractSchemaPreds()
    for predicate, terms in model.items():
        assert len(terms) == 2, "The current implementation is not able to convert predicates with arity different than 2. Try convert its predicates to arity 2."
    
    nodeQuery = set()
    relationQuery = set()
    
    for literal in (database.facts + database.pos):
        predicate, terms = re.findall(r"(.*)\((.*)\)\.", literal)[0]
        terms = terms.split(",")
        if predicate not in model:
            continue
        termTypes = [termType.capitalize() for termType in model[predicate]]
        predicate = predicate.upper()
        nodeQuery.add(f'MERGE ({terms[0]}:{termTypes[0]} {{name:"{terms[0]}"}})')
        nodeQuery.add(f'MERGE ({terms[1]}:{termTypes[1]} {{name:"{terms[1]}"}})')
        relationQuery.add(f'MERGE ({terms[0]})-[:{predicate}]-({terms[1]})')

    query = list(nodeQuery) + list(relationQuery)
    query = "\n".join(query)

    return query

In [109]:
def convertDatabaseModelToNeo4j(database: Database):
    nodeQuery = set()
    relationQuery = set()
    model = database.extractSchemaPreds()

    for predicate, terms in model.items():
        assert len(terms) == 2, "The current implementation is not able to convert predicates with arity different than 2. Try convert its predicates to arity 2."
        predicate = predicate.upper()
        terms = [term.capitalize() for term in terms]
        nodeQuery.add(f'MERGE ({terms[0].lower()}:{terms[0]} {{name:"{terms[0]}"}})')
        nodeQuery.add(f'MERGE ({terms[1].lower()}:{terms[1]} {{name:"{terms[1]}"}})')
        relationQuery.add(f'MERGE ({terms[0].lower()})-[:{predicate}]-({terms[1].lower()})')
    
    query = list(nodeQuery) + list(relationQuery)
    query = "\n".join(query)

    return query

In [110]:
def loadDatabaseAndConvertToNeo4jQueries(
    databaseName: str,
    folds = None, 
    useRecursion = False, 
    targetPredicate = None,
    resetTargetPredicate = False, 
    negPosRatio = 1,
    maxFailedNegSamplingRetries = 50,
    logger = None
):
    datasetPath = f"{DATA_PATH}/{databaseName}"
    database = loadDatabase(
        path = datasetPath,
        folds = folds,
        useRecursion = useRecursion,
        targetPredicate = targetPredicate,
        resetTargetPredicate = resetTargetPredicate, 
        negPosRatio = negPosRatio,
        maxFailedNegSamplingRetries = maxFailedNegSamplingRetries,
        logger = logger
    )

    binaryDatabase = convertDatabaseToArity2(database)
    modelQueries = convertDatabaseModelToNeo4j(binaryDatabase)
    dataQueries = convertDatabaseToNeo4j(binaryDatabase)

    return modelQueries, dataQueries

In [121]:
modelQueries, dataQueries = loadDatabaseAndConvertToNeo4jQueries(
    databaseName = "imdb",
    folds = None, 
    useRecursion = False, 
    targetPredicate = None,
    resetTargetPredicate = False, 
    negPosRatio = 1,
    maxFailedNegSamplingRetries = 50,
    logger = logger
)
print("\nModel Query:")
print(modelQueries)

print("\nData Query:")
print(dataQueries)

2024-03-01 13:38:17,429 - Data Converter - DEBUG - 1/5 folds loaded with success.
2024-03-01 13:38:17,464 - Data Converter - DEBUG - 2/5 folds loaded with success.
2024-03-01 13:38:17,501 - Data Converter - DEBUG - 3/5 folds loaded with success.
2024-03-01 13:38:17,539 - Data Converter - DEBUG - 4/5 folds loaded with success.
2024-03-01 13:38:17,575 - Data Converter - DEBUG - 5/5 folds loaded with success.



Model Query:
MERGE (personlabel:Personlabel {name:"Personlabel"})
MERGE (movie:Movie {name:"Movie"})
MERGE (genre:Genre {name:"Genre"})
MERGE (person:Person {name:"Person"})
MERGE (movie)-[:MOVIE]-(person)
MERGE (person)-[:PERSONHASLABEL]-(personlabel)
MERGE (person)-[:GENRE]-(genre)
MERGE (person)-[:WORKEDUNDER]-(person)

Data Query:
MERGE (aauroraquattrocchi:Person {name:"aauroraquattrocchi"})
MERGE (aeddasabatini:Person {name:"aeddasabatini"})
MERGE (abruceayoung:Person {name:"abruceayoung"})
MERGE (abenjaminmouton:Person {name:"abenjaminmouton"})
MERGE (aaction:Genre {name:"aaction"})
MERGE (aelliottgould:Person {name:"aelliottgould"})
MERGE (amarenschumacher:Person {name:"amarenschumacher"})
MERGE (aradarassimov:Person {name:"aradarassimov"})
MERGE (amarcocavicchioli:Person {name:"amarcocavicchioli"})
MERGE (ascifi:Genre {name:"ascifi"})
MERGE (awilliamatherton:Person {name:"awilliamatherton"})
MERGE (alyubomirbachvarov:Person {name:"alyubomirbachvarov"})
MERGE (akonstanzebreiteb