Skip to content

Commit

Permalink
Fix lint errors.
Browse files Browse the repository at this point in the history
  • Loading branch information
emfomy committed Apr 23, 2020
1 parent 60e14e4 commit 752301b
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 92 deletions.
1 change: 1 addition & 0 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ disable =
too-many-ancestors,
too-many-branches,
too-many-instance-attributes,
too-many-locals,
too-many-nested-blocks,

[FORMAT]
Expand Down
109 changes: 17 additions & 92 deletions ckipnlp/driver/coref.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from treelib import (
Tree as _Tree,
Node as _Node,
)

from ckipnlp.container import (
Expand All @@ -25,10 +24,6 @@

from ckipnlp.data.parsed import (
APPOSITION_ROLES as _APPOSITION_ROLES,
HUMAN_ROLES as _HUMAN_ROLES,
NEUTRAL_ROLES as _NEUTRAL_ROLES,
OBJECT_ROLES as _OBJECT_ROLES,
SUBJECT_ROLES as _SUBJECT_ROLES,
)

from ckipnlp.data.coref import (
Expand All @@ -45,37 +40,6 @@

################################################################################################################################

from colored import * # pylint: disable=wrong-import-order

def print_spam(*args, **kwargs):
print(stylize(' '.join(map(str, args)), fg('magenta') + attr('dim'))) # pylint: disable=no-value-for-parameter

def print_debug(*args):
print(stylize(' '.join(map(str, args)), fg('blue'))) # pylint: disable=no-value-for-parameter

def print_verbose(*args):
print(stylize(' '.join(map(str, args)), fg('magenta'))) # pylint: disable=no-value-for-parameter

def print_info(*args):
print(stylize(' '.join(map(str, args)), fg('cyan'))) # pylint: disable=no-value-for-parameter

def print_notice(*args):
print(stylize(' '.join(map(str, args)), fg('cyan') + attr('bold'))) # pylint: disable=no-value-for-parameter

def print_warning(*args):
print(stylize(' '.join(map(str, args)), fg('yellow') + attr('bold'))) # pylint: disable=no-value-for-parameter

def print_success(*args):
print(stylize(' '.join(map(str, args)), fg('green') + attr('bold'))) # pylint: disable=no-value-for-parameter

def print_error(*args):
print(stylize(' '.join(map(str, args)), fg('red') + attr('bold'))) # pylint: disable=no-value-for-parameter

def print_fatal(*args):
print(stylize(' '.join(map(str, args)), bg('red') + attr('bold'))) # pylint: disable=no-value-for-parameter

################################################################################################################################

class CkipCorefChunker(_BaseDriver): # pylint: disable=too-few-public-methods
"""The CKIP co-reference driver."""

Expand All @@ -94,9 +58,6 @@ def _call(self, *, parsed):
# Get results
coref = self._get_result(tree_list, coref_tree=coref_tree)

for line in coref:
print_success(line.to_text())

return coref

def _init(self):
Expand All @@ -105,11 +66,8 @@ def _init(self):
@classmethod
def _get_coref(cls, tree_list):

node2coref_list = [{} for _ in tree_list] # tree_id => {node_id => ref_id}
coref2node = {} # ref_id => (tree_id, node_id)

coref_tree = _Tree()
coref_tree.create_node(tag='@', identifier=0)
coref_tree.create_node(identifier=0)

name2node = {} # name => (tree_id, node_id)

Expand All @@ -120,34 +78,21 @@ def _get_coref(cls, tree_list):

# Find coref
for tree_id, tree in enumerate(tree_list):
print_error()
print_error('='*8, tree_id, '='*8)

print_info('curr_source', curr_source)
print_info('curr_subject', curr_subject)
print_info('last_source', last_source)
print_info('last_subject', last_subject)

tree.show()

# Get relations
appositions = []
for rel in tree.get_relations():
print_notice(rel)
if rel.relation.data.role in _APPOSITION_ROLES:
appositions.append((rel.head.identifier, rel.tail.identifier,))

# Get sources/targets
node_ids = {}
for nid in cls._get_sources(tree): # Source
node_ids[nid] = 'Src'
print_success('Src', tree[nid] if nid >= 0 else nid)
for nid in cls._get_subjects(tree): # Human
node_ids[nid] = 'Sub'
print_warning('Sub', tree[nid] if nid >= 0 else nid)
for nid in cls._get_targets(tree): # Target
node_ids[nid] = 'Tgt'
print_notice('Tgt', tree[nid] if nid >= 0 else nid)

source_ids = {nid: ntype for nid, ntype in node_ids.items() if ntype != 'Tgt'}
target_ids = {nid: ntype for nid, ntype in node_ids.items() if ntype == 'Tgt'}
Expand All @@ -162,25 +107,21 @@ def _get_coref(cls, tree_list):

parent_id = name2node.get(source.data.word, None)
if parent_id:
coref_tree.create_node(tag=source.tag, identifier=(tree_id, sid,), parent=parent_id, data=True)
coref_tree.create_node(identifier=(tree_id, sid,), parent=parent_id, data=True)
else:
name2node[source.data.word] = curr_source
coref_tree.create_node(tag=source.tag, identifier=(tree_id, sid,), parent=coref_tree.root, data=True)
coref_tree.create_node(identifier=(tree_id, sid,), parent=coref_tree.root, data=True)

# Link targets to previous sources
for tid, ttype in target_ids.items():
for tid in target_ids:
if tid < 0 and last_subject:
coref_tree.create_node(tag=str(tid), identifier=(tree_id, tid,), parent=last_subject, data=False)
coref_tree.create_node(identifier=(tree_id, tid,), parent=last_subject, data=False)

if tid >= 0:
target = tree[tid]
if curr_source and tree[tid].data.word in _SELF_WORDS:
coref_tree.create_node(tag=target.tag, identifier=(tree_id, tid,), parent=curr_source, data=False)
coref_tree.create_node(identifier=(tree_id, tid,), parent=curr_source, data=False)
elif last_source:
coref_tree.create_node(tag=target.tag, identifier=(tree_id, tid,), parent=last_source, data=False)

coref_tree.show(key=lambda node: node.identifier, idhidden=False)
print()
coref_tree.create_node(identifier=(tree_id, tid,), parent=last_source, data=False)

for head_id, tail_id in appositions:
head_id = (tree_id, head_id,)
Expand All @@ -202,9 +143,6 @@ def _get_coref(cls, tree_list):
last_source = curr_source
last_subject = curr_subject

coref_tree.show(key=lambda node: node.identifier, idhidden=False)
print()

return coref_tree


Expand All @@ -216,35 +154,22 @@ def _get_result(cls, tree_list, *, coref_tree):
coref2node = {} # ref_id => node

for ref_id, coref_source in enumerate(coref_tree.children(coref_tree.root)):
print_success(ref_id, coref_source)
tree_id, node_id = coref_source.identifier
coref2node[ref_id] = tree_list[tree_id][node_id]
for tree_id, node_id in coref_tree.expand_tree(coref_source.identifier):
node2coref[tree_id, node_id] = ref_id

for (k1, k2,), r in node2coref.items():
if k2 >= 0:
print_notice(tree_list[k1][k2], r)
else:
print_notice((k1, k2,), r)
print()

# Generate result
tokens_list = _CorefParagraph()

for tree_id, tree in enumerate(tree_list):
tokens = _CorefSentence()
tokens_list.append(tokens)

nodes = tree.leaves()

print_error()
print_error('='*8, tree_id, '='*8)
print_notice(node2coref)
print_verbose(nodes)

if (tree_id, -1) in node2coref:
ref_id = node2coref[tree_id, -1]
tokens.append(_CorefToken(
tokens.append(_CorefToken( # pylint: disable=no-value-for-parameter
word=None,
idx=None,
coref=(ref_id, 'zero'),
Expand All @@ -253,30 +178,30 @@ def _get_result(cls, tree_list, *, coref_tree):
elif (tree_id, -2) in node2coref:
# The pos of the first leaf node starts with 'Cb'. e.g. 而且、但是、然而
node = nodes.pop(0)
tokens.append(_CorefToken(
tokens.append(_CorefToken( # pylint: disable=no-value-for-parameter
word=node.data.word,
idx=node.identifier,
coref=None,
))

ref_id = node2coref[tree_id, -2]
tokens.append(_CorefToken(
tokens.append(_CorefToken( # pylint: disable=no-value-for-parameter
word=None,
idx=None,
coref=(ref_id, 'zero'),
))

for node in nodes:
ref_id = node2coref.get((tree_id, node.identifier,), -1)
if ref_id >=0:
if ref_id >= 0:
ref_node = coref2node[ref_id]
tokens.append(_CorefToken(
tokens.append(_CorefToken( # pylint: disable=no-value-for-parameter
word=node.data.word,
idx=node.identifier,
coref=(ref_id, 'source' if node.identifier == ref_node.identifier else 'target',),
))
else:
tokens.append(_CorefToken(
tokens.append(_CorefToken( # pylint: disable=no-value-for-parameter
word=node.data.word,
idx=node.identifier,
coref=None,
Expand Down Expand Up @@ -314,9 +239,9 @@ def transform_pos(*, ws, pos, ner):

for line_ws, line_pos, line_ner in zip(ws, pos, ner):
idxmap = {idx: i for i, idx in enumerate(_np.cumsum(list(map(len, line_ws))))}
for ner in line_ner:
if ner.ner == 'PERSON':
line_pos[idxmap[ner.idx[1]]] = 'Nb'
for token in line_ner:
if token.ner == 'PERSON':
line_pos[idxmap[token.idx[1]]] = 'Nb'

########################################################################################################################

Expand Down

0 comments on commit 752301b

Please sign in to comment.