In [97]:
#!/usr/bin/env python3
#coding: utf-8

import sys
from collections import defaultdict
import numpy as np

# parameters
#source_filename, target_filename, alignment_filename = sys.argv[1:4]

# number of sentences -- in PUD it is always 1000
SENTENCES = 1000

# field indexes
ID = 0
FORM = 1
LEMMA = 2
UPOS = 3
XPOS = 4
FEATS = 5
HEAD = 6
DEPREL = 7

# returns dict[source_id] = [target_id_1, target_id_2, target_id_3...]
# and a reverse one as well
# TODO depending on what type of alignment you use, you may not need to have a list of aligned tokens -- maybe there is at most one, or even exactly one?
def read_alignment(fh):
    line = fh.readline()
    src2tgt = defaultdict(list)
    tgt2src = defaultdict(list)
    for st in line.split():
        (src, tgt) = st.split('-')
        src = int(src)
        tgt = int(tgt)
        src2tgt[src].append(tgt)
        tgt2src[tgt].append(src)
    return (src2tgt, tgt2src)

# returns a list of tokens, where each token is a list of fields;
# ID and HEAD are covnerted to integers and switched from 1-based to 0-based
# if delete_tree=True, then syntactic anotation (HEAD and DEPREL) is stripped
def read_sentence(fh, delete_tree=False):
    sentence = list()
    for line in fh:
        if line == '\n':
            # end of sentence
            break
        elif line.startswith('#'):
            # ignore comments
            continue
        else:
            fields = line.strip().split('\t')
            if fields[ID].isdigit():
                # make IDs 0-based to match alignment IDs
                fields[ID] = int(fields[ID])-1
                fields[HEAD] = int(fields[HEAD])-1
                if delete_tree:
                    # reasonable defaults:
                    fields[HEAD] = -1       # head = root
                    fields[DEPREL] = 'dep'  # generic deprel
                sentence.append(fields)
            # else special token -- continue
    return sentence

# takes list of lists as input, ie as returned by read_sentence()
# switches ID and HEAD back to 1-based and converts them to strings
# joins fields by tabs and tokens by endlines and returns the CONLL string
def write_sentence(sentence):
    result = list()
    for fields in sentence:
        # switch back to 1-based IDs
        fields[ID] = str(fields[ID]+1)
        fields[HEAD] = str(fields[HEAD]+1)
        result.append('\t'.join(fields))
    result.append('')
    return '\n'.join(result)

## depending on the pos of the target token, get the head with the most possible pos
def get_head_pos(target_token_pos):
    if target_token_pos == "DET" or target_token_pos == "ADJ":
        return ["NOUN"]
    elif target_token_pos == "AUX":
        return ["VERB"]
    elif target_token_pos == "ADV":
        return ["VERB", "ADJ", "ADV"]
    elif target_token_pos == "NOUN":
        return ["VERB"]
    else:
        return ["VERB"]

## get the root child (the child has no children) of a token to 
## make sure its children doesn't go back to its head (a circle)
def detect_circle(visited, node, children_trace):
    if node in visited:
        return True
    else:
        visited.add(node)
        if node in children_trace:
            for child in children_trace[node]:
                detect_circle(visited, child, children_trace)
        return False


source_filename = "zh_pud-ud-test.conllu"
target_filename = "ja_pud-ud-test.conllu"
alignment_filename = "zh-ja.fwd"

with open(source_filename) as source, open(target_filename) as target, open(alignment_filename) as alignment:
    outputs = []
    for sentence_id in range(SENTENCES):
        (src2tgt, tgt2src) = read_alignment(alignment)
        source_sentence = read_sentence(source)
        target_sentence = read_sentence(target, delete_tree=True)
        
        # TODO do the projection
        # iterate over source tokens
        # TODO maybe you want to iterate over target tokens?
        for target_token in target_sentence:
            target_token_id = target_token[ID]
            # for each target token aligned to source_token (if any)
            if target_token_id in tgt2src:
                ##print(target_token_id, tgt2src[target_token_id])
                for source_token_id in tgt2src[target_token_id]:
                    source_token = source_sentence[source_token_id]

                    # TODO copy source deprel to target deprel?
                    if source_token[UPOS] != target_token[UPOS]:
                        target_sentence[target_token_id][DEPREL] = source_token[DEPREL]
                    
                    # TODO set target head to something
                    source_token_head = source_token[HEAD]
                    # ...
                    # TODO these are IDs of all tokens aligned to the source_token_head
                    # (depending on the alignment type, the list may be empty, have 1 member, or multiple members)
                    potential_heads = src2tgt[source_token_head]
                    #print(potential_heads)
                    # ...
                    # TODO you should also make sure not to produce cycles
                    #print("potential_heads", potential_heads, "target_id", target_token_id)
                    if target_token_id in range(len(target_sentence))[-5:] and target_sentence[target_token_id][UPOS] == "VERB":
                        print(len(target_sentence), target_token_id)
                        target_sentence[target_token_id][HEAD] = 0
                        print("root:", target_token_id)
                    # else:
                    #     if potential_heads:
                    #         ideal_head_pos = get_head_pos(target_token[UPOS])
                    #         chosen_heads = []
                    #         for head in potential_heads:
                    #             if np.absolute(head-target_token_id) <= 4:
                    #                 chosen_heads.append(head)
                    #         #print("chosen heads:", chosen_heads)
                    #         if chosen_heads:
                    #             for head in chosen_heads:
                    #                 if target_sentence[head][UPOS] in ideal_head_pos:
                    #                     target_sentence[target_token_id][HEAD] = head
                    #                 else:
                    #                     target_sentence[target_token_id][HEAD] = target_token_id + 1
                    #         else:
                    #             target_sentence[target_token_id][HEAD] = target_token_id + 1

                    chosen_heads = []
                    if len(potential_heads) == 1:
                        chosen_heads.append(potential_heads[0])
                    elif potential_heads:
                        same_pos = [head for head in potential_heads if source_sentence[source_token_head][UPOS]==target_sentence[head][UPOS]]
                        if same_pos and np.absolute(head-target_token_id)<= len(target_sentence)//2:
                            chosen_heads.append(same_pos[0])
                        else:
                            ideal_head_pos = get_head_pos(target_token[UPOS])
                            for head in potential_heads:
                                if target_sentence[head][UPOS] in ideal_head_pos and np.absolute(head-target_token_id)<= len(target_sentence)//2:
                                    chosen_heads.append(head)
                    if potential_heads and not chosen_heads:
                        # just choose the first one is not PUNCT
                        for head in potential_heads:
                            if target_sentence[head][UPOS] != "PUNCT" and np.absolute(head-target_token_id)<= len(target_sentence)//2:
                                chosen_heads.append(head)
                                break
                        # all are PUNCT
                        if not chosen_heads:
                            chosen_heads.append(target_token_id + 1)

                    elif not potential_heads:
                        # choose 0
                        chosen_heads.append(target_token_id + 1)

                    ################ checking and preventing a circle ############
                    trace_children = {}
                    trace_heads = {}
                    visited = set()
                    the_head = chosen_heads[0]
                    ## 
                    if target_token_id == 0:
                        trace_children[the_head] = [target_token_id]
                        trace_heads[target_token_id] = the_head
                        target_sentence[target_token_id][HEAD] = the_head
                        #print(the_head, target_token_id)
                        #print(trace_children)
                    else:
                        if the_head not in trace_children:
                            trace_children[the_head] = []
                        trace_children[the_head].append(target_token_id)
                        if target_token_id not in trace_heads:
                            trace_heads[target_token_id] = head
                        #print(the_head, target_token_id)
                        #print(trace_children)
                        has_circle = detect_circle(visited, the_head, trace_children)
                        #print("has circle:", has_circle)
                        ## make sure no existing dependency goes back to the head and the target token itself
                        if not has_circle:
                            target_sentence[target_token_id][HEAD] = the_head
                        else:
                            trace_children[the_head] -= target_token_id
                   
                          
        outputs.append(write_sentence(target_sentence))



46 42
root: 42
21 19
root: 19
40 37
root: 37
34 31
root: 31
26 21
root: 21
33 30
root: 30
21 17
root: 17
28 25
root: 25
32 29
root: 29
25 20
root: 20
18 15
root: 15
19 17
root: 17
37 34
root: 34
23 20
root: 20
20 15
root: 15
37 34
root: 34
15 11
root: 11
23 19
root: 19
23 19
root: 19
42 38
root: 38
31 26
root: 26
26 23
root: 23
23 19
root: 19
23 19
root: 19
36 32
root: 32
29 25
root: 25
29 26
root: 26
32 28
root: 28
13 9
root: 9
19 15
root: 15
25 21
root: 21
21 18
root: 18
42 37
root: 37
42 40
root: 40
27 24
root: 24
32 28
root: 28
42 39
root: 39
22 19
root: 19
18 13
root: 13
19 17
root: 17
53 49
root: 49
19 15
root: 15
31 29
root: 29
26 23
root: 23
23 19
root: 19
13 10
root: 10
21 16
root: 16
25 22
root: 22
40 37
root: 37
35 32
root: 32
17 14
root: 14
33 29
root: 29
27 23
root: 23
36 33
root: 33
30 25
root: 25
33 30
root: 30
28 25
root: 25
19 17
root: 17
24 20
root: 20
26 24
root: 24
27 23
root: 23
17 14
root: 14
54 49
root: 49
25 21
root: 21
52 48
root: 48
42 40
root: 40
32 29
root: 

In [98]:
with open("ja_projected.conllu", "w", encoding="utf-8") as f:
    for output in outputs:
        print(output, file=f)