In [None]:
import os
import csv
import numpy as np
import pandas as pd

from kgtk.configure_kgtk_notebooks import ConfigureKGTK
from kgtk.functions import kgtk, kypher

In [None]:
import random

property_id = "P39"

# make shuffled verison of claims
if not os.path.isfile(f'/out/data/propertiesSplit_final/claims.{property_id}.shuffled.tsv'):
    all_lines = []
    tsv_file = open(f'/out/data/propertiesSplit_final/claims.{property_id}.tsv')
    read_tsv = csv.reader(tsv_file, delimiter="\t")
    next(read_tsv)
    for line in read_tsv:
        all_lines.append(line)
    random.shuffle(all_lines)
    with open(f'/out/data/propertiesSplit_final/claims.{property_id}.shuffled.tsv', 'w') as file:
        writer = csv.writer(file, delimiter='\t')
        writer.writerows(all_lines)

In [None]:
# make link prediction dataset for P31
# assumes existence of separate file with all P31 relationships for the subjects and objects connected with this property
# this runs a kypher query for each Q node and is not efficient. Could be improved with batching.
# takes about 7 hours per property currently.

import csv
import os

def queryKgtk(node, writer):
    try:
        # get triples with node as subject
        res = kgtk(f"""query
        -i /out/labels.en.tsv.gz --as labels
        -i /out/claims.wikibase-item.tsv.gz --as claims
        --match 'claims: (:{node})-[someProp]->(node2)'
        --opt 'labels: (:{node})-[:label]->(lbSubj)'
        --opt 'labels: (node2)-[:label]->(lbObj)'
        --return 'distinct node2 as `object`, someProp as `prop`, lbSubj as `subjectLabel`, lbObj as `objectLabel`'
        """)
        tempRes = []
        addToMain = False
        # only add if node has P31 relationship
        for ind, elem in enumerate(res.object):
            if res.prop[ind].split("-")[1] == "P31":
                addToMain = True
            else:
                tempRes.append([node, res.subjectLabel[ind], res.prop[ind].split("-")[1], res.object[ind], res.objectLabel[ind], "subject"])
        if addToMain:
            for row in tempRes:
                writer.writerow(row)
            # get triples with node as object (only execute if addToMain == True)
            res = kgtk(f"""query
            -i /out/labels.en.tsv.gz --as labels
            -i /out/claims.wikibase-item.tsv.gz --as claims
            --match 'claims: (node2)-[someProp]->(:{node})'
            --opt 'labels: (:{node})-[:label]->(lbObj)'
            --opt 'labels: (node2)-[:label]->(lbSubj)'
            --return 'distinct node2 as `subject`, someProp as `prop`, lbSubj as `subjectLabel`, lbObj as `objectLabel`'
            """)
            tempRes = []
            addToMain = False
            for ind, elem in enumerate(res.subject):
                tempRes.append([node, res.objectLabel[ind], res.prop[ind].split("-")[1], res.subject[ind], res.subjectLabel[ind], "object"])
            for row in tempRes:
                writer.writerow(row)
                    
    except AttributeError:
        print("Caught Attribute Error; trying again")
        queryKgtk(node, writer)

tsv_file = open(f"/out/data/propertiesSplit_final/claims.{property_id}.shuffled.tsv")
input_tsv = csv.reader(tsv_file, delimiter="\t")

count = 0
datasetSize = 2500

dis_subj_classes = set()
dis_obj_classes = set()

# load pre-computed classes
if os.path.isfile(f'/out/output/instance_prediction_datasets/{property_id}.subjects.tsv'):
    tsv_file = open(f'/out/output/instance_prediction_datasets/{property_id}.subjects.tsv')
    read_tsv = csv.reader(tsv_file, delimiter="\t")
    next(read_tsv)
    for line in read_tsv:
        dis_subj_classes.add(line[0])
if os.path.isfile(f'/out/output/instance_prediction_datasets/{property_id}.objects.tsv'):
    tsv_file = open(f'/out/output/instance_prediction_datasets/{property_id}.objects.tsv')
    read_tsv = csv.reader(tsv_file, delimiter="\t")
    next(read_tsv)
    for line in read_tsv:
        dis_obj_classes.add(line[0])
        
print(dis_subj_classes)
print(dis_obj_classes)
        
if os.path.isfile(f'/out/output/instance_prediction_datasets/{property_id}.subjects.tsv'):
    writemode = 'a'
else:
    writemode = 'x'
with open(f'/out/output/instance_prediction_datasets/{property_id}.subjects.tsv', writemode) as subj_file, open(f'/out/output/instance_prediction_datasets/{property_id}.objects.tsv', writemode) as obj_file:
    subj_writer = csv.writer(subj_file, delimiter='\t')
    obj_writer = csv.writer(obj_file, delimiter='\t')
    if writemode == 'x':
        subj_writer.writerow(['node','nodeLabel','prop','nodeConnection','connectionLabel','nodeOrientation'])
        obj_writer.writerow(['node','nodeLabel','prop','nodeConnection','connectionLabel','nodeOrientation'])
    for line in input_tsv:
        if line[0].startswith("Q") and line[2].startswith("Q"):
            subj = line[0]
            obj = line[2]

            if subj not in dis_subj_classes:
                dis_subj_classes.add(subj)
                if len(dis_subj_classes) <= datasetSize:
                    queryKgtk(subj, subj_writer)
            print("Subject classes completed: ", len(dis_subj_classes))
            
            if obj not in dis_obj_classes:
                dis_obj_classes.add(obj)
                if len(dis_obj_classes) <= datasetSize:
                    queryKgtk(obj, obj_writer)
            print("Object classes completed: ", len(dis_obj_classes))

            if len(dis_subj_classes) >= datasetSize and len(dis_obj_classes) >= datasetSize:
                break