In [None]:
import numpy as np
import pandas as pd
import networkx as nx
import nltk
import json
import math
from multiprocessing import cpu_count
import tqdm
import spacy

In [None]:
relation_groups = [
    'atlocation/locatednear',
    'capableof',
    'causes/causesdesire/*motivatedbygoal',
    'createdby',
    'desires',
    'antonym/distinctfrom',
    'hascontext',
    'hasproperty',
    'hassubevent/hasfirstsubevent/haslastsubevent/hasprerequisite/entails/mannerof',
    'isa/instanceof/definedas',
    'madeof',
    'notcapableof',
    'notdesires',
    'partof/*hasa',
    'relatedto/similarto/synonym',
    'usedfor',
    'receivesaction',
]

merged_relations = [
    'antonym',
    'atlocation',
    'capableof',
    'causes',
    'createdby',
    'isa',
    'desires',
    'hassubevent',
    'partof',
    'hascontext',
    'hasproperty',
    'madeof',
    'notcapableof',
    'notdesires',
    'receivesaction',
    'relatedto',
    'usedfor',
]

relation_text = [
    'is the antonym of',
    'is at location of',
    'is capable of',
    'causes',
    'is created by',
    'is a kind of',
    'desires',
    'has subevent',
    'is part of',
    'has context',
    'has property',
    'is made of',
    'is not capable of',
    'does not desires',
    'is',
    'is related to',
    'is used for',
]

In [None]:
def load_merge_relation():
    '''
    This function create a mapping, that map each multirelation to the toppest relation:
    e.g.: atlocation/locatednear will map both atlocation, locatednear to atlocation relation
    '''
    rel_mapping = {}
    for rel in relation_groups:
#         print(rel)
#         print(rel.strip().split("/"))
        ls = rel.strip().split("/")
        true_rel = ls[0]
        for l in ls:
            if l.startswith("*"):
                rel_mapping[l[1:]] = "*"+true_rel
            else:
                rel_mapping[l] = true_rel
    return rel_mapping
def del_pos(s):
    """
    Deletes part-of-speech encoding from an entity string, if present.
    :param s: Entity string.
    :return: Entity string with part-of-speech encoding removed.
    """
    if s.endswith("/n") or s.endswith("/a") or s.endswith("/v") or s.endswith("/r"):
        s = s[:-2]
    return s

In [None]:
def retrieve_eng(cpnet_path, output_path, output_vocab_path):
    '''
    This function retrieve english triples (head and tail are all english),
    with following format:
        <relation><head><tail><weight>
    '''
    rel_mapping = load_merge_relation()
    num_lines = sum(1 for line in open(cpnet_path, 'r', encoding='utf-8'))
    concept_net_vocab = []
    concept_seens = set()
    with open(cpnet_path,'r', encoding='utf-8') as fin, open(output_path,"w", encoding='utf-8') as fout:
        for line in tqdm(fin, total = num_lines):
            tokens = line.strip().split("\t")
            if tokens[2].startswith('c/en/') and tokens[3].startswith('c/en/'):
                rel = tokens[1].split("/")[-1].lower()
                head = del_pos(tokens[2]).split("/")[-1].lower()
                tail = del_pos(tokens[3]).split("/")[-1].lower()
                if not head.replace("_", "").replace("-", "").isalpha():
                    continue
                if not tail.replace("_", "").replace("-", "").isalpha():
                    continue
                if rel not in rel_mapping:
                    continue
                # maps relation to pre-defined relation
                rel = rel_mapping[rel]
                if rel.startswith("*"):
                    # means reverse part
                    head, tail, rel = tail, head, rel[1:]
                # load to dic format
                data = json.loads(tokens[4])
                # write into new csv file
                fout.write('\t'.join([rel,head,tail, str(data["weight"])]) + "\n")
                for w in [head, tail]:
                    if w not in concept_seens:
                        concept_seens.add(w)
                        concept_net_vocab.append(w)
    with open(output_vocab_path,"w") as fout:
        for word in concept_net_vocab:
            fout.write(word + "\n")

In [None]:
def construct_graph(cpnet_csv_path, cpnet_vocab_path, output_path, prune = True):
    '''
    This function create the graph structure, just like compgcn and gat,
    create the basic graph structure
    '''
    # get the stopwords
    nltk.download('stopwords', quiet=True)
    nltk_stopwords = nltk.corpus.stopwords.words('english')
    nltk_stopwords += ["like", "gone", "did", "going", "would", "could",
                       "get", "in", "up", "may", "wanter"]  
    blacklist = set(["uk", "us", "take", "make", "object", "person", "people"])
    # create mapping
    
    with open(cpnet_vocab_path,"r", encoding= 'utf-8') as fin:
        id2concept = [w.strip() for w in fin]
    concept2id = {w: i for i, w in enumerate(id2concept)}

    id2relation = merged_relations
    relation2id = {r: i for i, r in enumerate(id2relation)}
    # create multidigraph with nx (just like compgcn graph)
    graph = nx.MultiDiGraph()
    nrow = sum(1 for _ in open(cpnet_csv_path, 'r', encoding='utf-8'))
    with open(cpnet_csv_path, "r", encoding="utf8") as fin:

        def not_save(cpt):
            if cpt in blacklist:
                return True
            '''originally phrases like "branch out" would not be kept in the graph'''
            return False
        seen_set = set()
        for line in tqdm(fin, total = nrow):
            ls = line.strip().split('\t')
            rel = relation2id[ls[0]]
            subj = concept2id[ls[1]]
            obj = concept2id[ls[2]]
            weight = float(ls[3])
            if prune and (not_save(ls[1]) or not_save(ls[2]) or id2relation[rel] == "hascontext"):
                continue
            # remove loops in this case
            if subj == obj:
                continue
            if (subj, rel, obj) not in seen_set:
                # add direction
                graph.add_edge(subj,obj, rel= rel, weight = weight)
                seen_set.add((subj,rel,obj))
                graph.add_edge(obj, subj, rel=rel + len(relation2id), weight=weight)
                seen_set.add((obj, subj, rel + len(relation2id)))
    nx.write_gpickle(graph, output_path)
    