In [1]:
# @dlei5

# import libs
import numpy as np
from scipy.spatial import distance
import math

import util
import set_expan
import time
import pickle

from tree_node import TreeNode
from collections import defaultdict

In [2]:
# set up global vars

level2max_children = {-1:10, 0:15, 1:5, 2:1e9, 3:1e9, 4:1e9, 5:1e9}
level2reference_edges = defaultdict(list)
negativeSeedEids = set()
load_pickle_flag = True
FLAGS_USE_TYPE = True



data = "dblpv2"
folder = '../../data/'+data+'/intermediate/'

"""
userInput = [
  ["ROOT", -1, ["machine learning", "data mining", "information retrieval"]],

  ["machine learning", 0, ["supervised machine learning", "unsupervised learning", "reinforcement learning"]],
  ["data mining", 0, ["pattern mining", "graph mining", "web mining", "text mining"]],
  ["information retrieval", 0, ["document retrieval", "query processing", "relevance feedback"]],
    
  ["supervised machine learning", 1, ["support vector machines"]]
]
"""

userInput = [
  ["ROOT", -1, ["machine learning", "data mining", "information retrieval"]],

  ["machine learning", 0, ["supervised machine learning", "unsupervised learning", "reinforcement learning"]],
  ["data mining", 0, ["pattern mining", "graph mining", "web mining", "text mining"]],
  ["information retrieval", 0, ["document retrieval", "query processing", "relevance feedback"]],

  ["supervised machine learning", 1, ["support vector machines", "decision trees", "random forests"]],
  ["unsupervised learning", 1, ["agglomerative clustering", "principle component analysis", "latent dirichlet allocation"]],
  ["reinforcement learning", 1, ["markov decision processes", "policy gradient"]]
  # ["supervised machine learning", 0, ["named entity recognition", "information extraction", "machine translation"]],
  # ["neural networks", 1, ["convolutional neural network", "recurrent neural networks"]],
]

In [3]:
# load pickle data 

print("=== Start loading all files ...... ===")
print("load_from_pickle")
pickle_file = '../../data/' + data + '/intermediate/all_data.pickle'
with open(pickle_file, "rb") as fin:
    [eid2ename, ename2eid, eid2patterns, pattern2eids, eidAndPattern2strength, eid2types, type2eids,
     eidAndType2strength, eid2embed, embed_matrix, eid2rank, rank2eid, embed_matrix_array, skipgramsByEidPairMap,
     eidPairsBySkipgramMap, weightByEidPairAndSkipgramMap] = pickle.load(fin)
print("=== Finish loading all files. HaHaHa ===")


# load PPMI
print('loading eid pair document-level PPMI score')
eidpair2PPMI = util.loadEidDocPairPPMI(folder + 'eidDocPairPPMI.txt')



=== Start loading all files ...... ===
load_from_pickle
=== Finish loading all files. HaHaHa ===
loading eid pair document-level PPMI score


In [4]:
# define functions 
def sim_par_new(p_eid, c_eid, reference_edges, embed_alpha, eid2embed, eidpair2PPMI, embed_dim=100, DEBUG=False):
  '''
  reference_edges: a list of (parent, child) eids
  embed_alpha: relative weight of embedding, set 1.0 to use only embedding, set 0.0 to use only document-co-occurrence
  '''

  ## embedding offest similarity
  target_offset = np.zeros([1, embed_dim])
  for edge in reference_edges:
    target_offset += (eid2embed[edge[0]] - eid2embed[edge[1]])
  target_offset /= (len(reference_edges))

  current_relation_offset = eid2embed[p_eid] - eid2embed[c_eid]
  embedding_sim = float(1.0 - distance.cdist(current_relation_offset, target_offset, 'cosine'))
  if embedding_sim < 0:
    embedding_sim = 0.0

  ## docPPMI similarity
  if frozenset([c_eid, p_eid]) in eidpair2PPMI:
    cooccur_sim = eidpair2PPMI[frozenset([c_eid, p_eid])]
  else:
    cooccur_sim = 0.0

  # overall_sim = ( embedding_sim ** embed_alpha ) * (cooccur_sim ** (1-embed_alpha))
  overall_sim = embedding_sim * math.sqrt(1 + cooccur_sim)
  if DEBUG:
    print("Embedding_sim:", embedding_sim)
    print("Co-occurrence_sim:", cooccur_sim)
    print("Scaled_embedding_sim:", (embedding_sim ** (1 - embed_alpha)))
    print("Overall_sim:", overall_sim)
    print("%.6f" % embedding_sim, "%.6f" % cooccur_sim, "%.6f" % overall_sim)
  return overall_sim


def runSetExpan(seedEidsWithConfidence, numToExpan):
    expandedEidsWithConfidence = set_expan.setExpan(
        seedEidsWithConfidence=seedEidsWithConfidence,
        negativeSeedEids=negativeSeedEids,
        eid2patterns=eid2patterns,
        pattern2eids=pattern2eids,
        eidAndPattern2strength=eidAndPattern2strength,
        eid2types=eid2types,
        type2eids=type2eids,
        eidAndType2strength=eidAndType2strength,
        eid2ename=eid2ename,
        eid2embed=eid2embed,
        source_weights={"sg":1.0, "tp":5.0, "eb":5.0},
        max_expand_eids=numToExpan,
        use_embed=True,
        use_type=True,
        FLAGS_VERBOSE=False,
        FLAGS_DEBUG=False
    )
    print(expandedEidsWithConfidence)
    for ele in expandedEidsWithConfidence:
      print("eid=", ele[0], "ename=", eid2ename[ele[0]], "confidence_score=", ele[1])

    
    return expandedEidsWithConfidence


def obtainReferenceEdges(targetNode):
  reference_edges = []
  ## Add 1) edges in user guidance and 2) first FLAGS_INITIAL_EDGE edge under each sibling nodes as
  # reference edges
  for sibling in targetNode.parent.children:
    print(sibling)
    cnt = 0
    for cousin in sibling.children:
      print(cousin)
      if cousin.isUserProvided and sibling.isUserProvided:
        reference_edges.append((sibling.eid, cousin.eid))
      else:
        reference_edges.append((sibling.eid, cousin.eid))
        cnt += 1
        if cnt >= FLAGS_INITIAL_EDGE:
          break
  return reference_edges


def init():
    # generate reference edges
    rootNode = None
    ename2treeNode = {}
    for i, node in enumerate(userInput):
      if i == 0: ## ROOT
        rootNode = TreeNode(parent=None, level=-1, eid=-1, ename="ROOT", isUserProvided=True, confidence_score=0.0,
                            max_children=level2max_children[-1])
        ename2treeNode["ROOT"] = rootNode
        for children in node[2]:
          newNode = TreeNode(parent=rootNode, level=0, eid=ename2eid[children], ename=children, isUserProvided=True,
                             confidence_score=0.0, max_children=level2max_children[0])
          ename2treeNode[children] = newNode
          rootNode.addChildren([newNode])
      else:
        ename = node[0]
        eid = ename2eid[ename]  # assume user supervision is an entity mention in entity2id.txt
        level = node[1]
        childrens = node[2]
        if ename in ename2treeNode: # existing node
          parent_treeNode = ename2treeNode[ename]
          for children in childrens:
            newNode = TreeNode(parent=parent_treeNode, level=parent_treeNode.level+1, eid=ename2eid[children],
                               ename=children, isUserProvided=True, confidence_score=0.0,
                               max_children=level2max_children[parent_treeNode.level+1])
            ename2treeNode[children] = newNode
            parent_treeNode.addChildren([newNode])
            level2reference_edges[parent_treeNode.level].append((parent_treeNode.eid, newNode.eid))
        else: # not existing node
          print("[ERROR] disconnected tree node: %s" % node)
        
    return rootNode


def getNodesByLevel(node, level):
    if node.level == level:
        return [node]
    
    res = []
    for child in node.children:
        res = res + getNodesByLevel(child, level)
        
    return res
    

In [5]:
rootNode = init()
rootNode.printSubtree(0)
reference_edges = obtainReferenceEdges(rootNode.children[0])


level = -1 # start from root
MAX_LEVEL = 2
targetNode = rootNode
while level < MAX_LEVEL:
    
    level += 1
    print(level)
    
    seedEidsWithConfidence = [(child.eid, child.confidence_score) for child in getNodesByLevel(rootNode, level)]
    newOrderedChildrenEidsWithConfidence = runSetExpan(seedEidsWithConfidence, 3 ** level)
    
    print(newOrderedChildrenEidsWithConfidence)
    for ele in newOrderedChildrenEidsWithConfidence:
        newChildEid = ele[0]
        confidence_score = ele[1]
            
        if level == 0: # first level, just append them to root

            newChild = TreeNode(parent=rootNode, level=level, eid=newChildEid, ename=eid2ename[newChildEid],
                                  isUserProvided=False, confidence_score=confidence_score,
                                  max_children=level2max_children[targetNode.level+1])
            rootNode.addChildren([newChild])


        else:
            # calculate best parent to attach to 
            max_sim = 0
            best_parent = None
            for p in getNodesByLevel(rootNode, level-1): # find all parents for this child 
#                 obtainReferenceEdges(rootNode.children[0])
                sim = sim_par_new(p.eid, ele[0], reference_edges, 0.5, eid2embed, eidpair2PPMI, embed_dim=100, DEBUG=False)
#                 sim = sim_par_new(p.eid, ele[0], obtainReferenceEdges(), 0.5, eid2embed, eidpair2PPMI, embed_dim=100,
#                                     DEBUG=False)

                print(p.ename, ":", eid2ename[newChildEid], sim)

                if sim >  max_sim:
                    max_sim = sim
                    best_parent = p

            print("=== attaching", eid2ename[ele[0]], " to ", best_parent.ename)
            if best_parent is not None:
                newChild = TreeNode(parent=best_parent, level=level, eid=newChildEid, ename=eid2ename[newChildEid],
                                  isUserProvided=False, confidence_score=confidence_score,
                                  max_children=level2max_children[p.level+1]) 
                best_parent.addChildren([newChild])
                
                
rootNode.printSubtree(0)

ROOT  (eid=-1, log_prob=0.000000)
	machine learning  (eid=8723, log_prob=0.000000)
		supervised machine learning  (eid=15042, log_prob=0.000000)
			support vector machines  (eid=15066, log_prob=0.000000)
		unsupervised learning  (eid=16106, log_prob=0.000000)
		reinforcement learning  (eid=12815, log_prob=0.000000)
	data mining  (eid=3362, log_prob=0.000000)
		pattern mining  (eid=11186, log_prob=0.000000)
		graph mining  (eid=6257, log_prob=0.000000)
		web mining  (eid=16714, log_prob=0.000000)
		text mining  (eid=15448, log_prob=0.000000)
	information retrieval  (eid=7348, log_prob=0.000000)
		document retrieval  (eid=4144, log_prob=0.000000)
		query processing  (eid=12403, log_prob=0.000000)
		relevance feedback  (eid=12886, log_prob=0.000000)
machine learning (eid=8723,log_prob=0.000000,parent=ROOT)
supervised machine learning (eid=15042,log_prob=0.000000,parent=machine learning)
unsupervised learning (eid=16106,log_prob=0.000000,parent=machine learning)
reinforcement learning (eid

[[8369, 0.0], [10022, 0.0], [10231, 0.0], [7816, -0.066837228701791096], [7819, -0.088544323378328599], [12494, 0.0], [8562, -0.062026353782166366], [8335, 0.0], [1064, -0.01475871170787668], [3490, -0.068937091643396425], [10023, -0.12595076182689699], [7814, -0.10041115016353214], [2171, -0.15038368124242413], [702, -0.26432780113263504], [13108, -0.1812655451779529], [9790, -0.16460656466978082], [9007, -0.15474051735125166], [8676, -0.17957688838816005], [3493, -0.24195959687456589], [12781, -0.21852516087747248], [10114, -0.22436840076812634], [12798, -0.24547904343039118], [11942, -0.21655860038836772], [7850, -0.22863959118177107], [10230, -0.24763978041197454], [8192, -0.22938079654935489], [1074, -0.40802661859348277], [15068, -0.28221397436417128], [1766, -0.29621842871454074], [1063, -0.35231387887423737], [2164, -0.30560512283502272], [16304, -0.30981703538719874], [12892, -0.28894929687073306], [2175, -0.28032224001020917], [4618, -0.29765639960347728], [929, -0.2782447051

unsupervised learning : linear discriminant analysis 0.13943266847293384
reinforcement learning : linear discriminant analysis 0.13965342937604552
document classification : linear discriminant analysis 0.180537972499909
brain-computer interface : linear discriminant analysis 0.32029939574608385
spam detection : linear discriminant analysis 0.01575947570565006
pattern mining : linear discriminant analysis 0.044305735995656126
graph mining : linear discriminant analysis 0.10243718105965094
web mining : linear discriminant analysis 0.35674725996986356
text mining : linear discriminant analysis 0.3449157715377116
mobile robotics : linear discriminant analysis 0.1509466218952915
web usage mining : linear discriminant analysis 0.20104803490280854
rule discovery : linear discriminant analysis 0.05082820848931091
microarray experiments : linear discriminant analysis 0.08691022649935311
document retrieval : linear discriminant analysis 0.13683321302413687
query processing : linear discriminant 

web usage mining : ensemble classifier 0.45646625923498685
rule discovery : ensemble classifier 0.2792750901681078
microarray experiments : ensemble classifier 0.19161563100141554
document retrieval : ensemble classifier 0.33701379305475787
query processing : ensemble classifier 0.0629893894363317
relevance feedback : ensemble classifier 0.12253494474573012
query expansion : ensemble classifier 0.27906343934344746
text clustering : ensemble classifier 0.3815567432574609
document analysis : ensemble classifier 0.38328103249786993
video retrieval : ensemble classifier 0.2934860205610469
robot control : ensemble classifier 0.25263487830417763
surveillance systems : ensemble classifier 0.25583137992812377
gait recognition : ensemble classifier 0.10503988413120369
target tracking : ensemble classifier 0.1074129140366149
=== attaching ensemble classifier  to  text mining
supervised machine learning : back-propagation neural network 0.0968552476926805
unsupervised learning : back-propagation 