In [1]:
import sqlite3
import json
import re
import pandas as pd
import numpy as np
from collections import Counter
from queue import Queue
from itertools import combinations as combs

In [2]:
def conll2graph(record):
    """Converts sentences described using CoNLL-U format 
    (http://universaldependencies.org/format.html) to graphs. 
    Returns a dictionary of nodes (wordforms and POS tags indexed 
    by line numbers) together with a graph of the dependencies encoded 
    as adjacency lists of (node_key, relation_label, direction[up or down]) tuples."""
    graph = {}
    nodes = {}
    for line in record.splitlines():
        if line.startswith('#'):
            continue
        fields = line.strip('\n').split('\t')
        key = fields[0]
        # Ignore compound surface keys for aux, du, etc.
        # Ignore hidden additional nodes for orphan handling
        if '-' in key or '.' in key:
            continue
        wordform = fields[1] 
        pos = fields[3]
        parent = fields[6]
        relation = fields[7]
        nodes[key] = {
            'wordform': wordform,
            'pos': pos,
            'relation': relation,
            'parent': parent
        }
        if key not in graph:
            graph[key] = []
        if parent not in graph:
            graph[parent] = []
        graph[key].append((parent, relation, 'up'))
        graph[parent].append((key, relation, 'down'))
    return (nodes, graph)

In [3]:
def get_node_depth(node, graph):
    """A BFS-based implementation."""
    cur_depth = 0
    q = Queue()
    q.put(('0',0))
    visited = set()
    visited.add('0')
    while not q.empty():
        current_node, current_depth = q.get()
        for neighbour, *_ in graph[current_node]:
            if neighbour == node:
                return current_depth+1
            elif neighbour not in visited:
                q.put((neighbour, current_depth+1))
            visited.add(neighbour)
    raise IndexError("Target node unreachable")

In [4]:
def highest_or_none(indices, graph):
    if indices[0] == 'X':
        return None
    min_depth = 1000
    argmin = None
    for i in indices:
        key = str(i)
        depth = get_node_depth(key, graph)
        if depth < min_depth:
            min_depth = depth
            argmin = key
    assert argmin is not None
    return argmin

In [5]:
def get_path(node1, node2, graph):
    if node1 == node2:
        return []
    
    # BFS with edge labels for paths
    q = Queue()
    # Remembers where we came from and the edge label
    sources = {}
    
    q.put(node1)
    visited = set()
    visited.add(node1)
    
    while not q.empty():
        current = q.get()
        for neighbour, relation, direction in graph[current]:
            if neighbour == node2:
                path = [relation+'_'+direction]
                source = current
                while source != node1:
                    prev_node, prev_relation, prev_direction = sources[source]
                    path.append(prev_relation+'_'+prev_direction)
                    source = prev_node
                return list(reversed(path))
            elif neighbour not in visited:
                sources[neighbour] = (current, relation, direction)
                q.put(neighbour)
            visited.add(neighbour)
            
    raise ValueError("UD graph is not connected.")

In [32]:
from math import log2
def mutual_information(counter):
    """Returns MI and NMI based on (Kvalseth 1987)"""
    total = sum(counter.values())
    joint_probabilities = {
        pair: count / total for pair, count in counter.items()
    }
    X_marginal_counts = Counter()
    Y_marginal_counts = Counter()
    for (head, tail), count in counter.items():
        X_marginal_counts[head] += count
        Y_marginal_counts[tail] += count
    X_marginals = {
        x: count / total for x, count in X_marginal_counts.items()
    }
    Y_marginals = {
        y: count / total for y, count in Y_marginal_counts.items()
    }
    MI = 0
    # X and Y values that don't occur together
    # contribute 0 to MI and by convention can be
    # ignored.
    for pair, count in counter.items():
        head, tail = pair
        MI += joint_probabilities[pair] * log2( 
            joint_probabilities[pair] / (X_marginals[head] * Y_marginals[tail])
        )
    # Normalise by dividing by the maximum marginal entropy
    X_marginal_entropy = sum(
        -1 * p * log2(p) for p in X_marginals.values()
    )
    Y_marginal_entropy = sum(
        -1 * p * log2(p) for p in Y_marginals.values()
    )
    NMI = MI / max(X_marginal_entropy, Y_marginal_entropy)
    return (MI, NMI)

In [7]:
strip_direction = lambda x: x.split('_')[0]

In [11]:
processed_indices = []

In [16]:
def get_pos_edge_pair_counts(en, ru, alignments):
    global processed_indices # For debugging
    processed_indices = []
    pos_pairs = Counter()
    edge_pairs = Counter()
    for i in range(len(en)):
        processed_indices.append(i)
        en_n, en_g = conll2graph(en[i])
        ko_n, ko_g = conll2graph(ru[i])
        alignment = alignments[i]
        # Simplify the alignment to a set of one-to-one pairs
        one_to_one = []
        for k, v in alignment.items():
            if k == 'X':
                # Do not analyse stuff added on the Ko side for now
                continue
            head = k
            tail = str(highest_or_none(v, ko_g))
            one_to_one.append((head, tail))
        # POS joint distribution
        for pair in one_to_one:
            head, tail = pair
            # Skip technical additional nodes
            if '.' in head:
                continue
            try:
                en_pos = en_n[head]['pos']
            except KeyError:
                print(i, en[i])
                continue
            if tail == 'None':
                ko_pos = 'None'
            else:
                ko_pos = ko_n[tail]['pos']
            pos_pair = (en_pos, ko_pos)
            pos_pairs[pos_pair] += 1
        # Edge label joint distribution
        for pair in combs(one_to_one, 2):
            (en_head, ko_head), (en_tail, ko_tail) = pair
            # Skip technical additional nodes
            if '.' in head:
                continue
            en_path_arr = get_path(en_head, en_tail, en_g)
            if len(en_path_arr) > 1:
                continue
            en_path = strip_direction(en_path_arr[0])
            if ko_head == ko_tail:
                ko_path = 'Nodes collapsed'
            elif ko_head == 'None' and ko_tail == 'None':
                ko_path = 'Both endpoints unaligned'
            elif ko_head == 'None' or ko_tail == 'None':
                ko_path = 'One endpoint unaligned'
            else:
                ko_path_arr = get_path(ko_head, ko_tail, ko_g)
                ko_path = '->'.join(
                    list(map(strip_direction, ko_path_arr))
                )
            path_pair = (en_path, ko_path)
            edge_pairs[path_pair] += 1
    return (pos_pairs, edge_pairs)

In [9]:
conn = sqlite3.connect('pud_current.db')
cursor = conn.cursor()

In [10]:
def get_data(lang_code):
    en = []
    ko = []
    alignments = []
    for en_, ko_, alignment_str in cursor.execute(
        f'SELECT `en`, `ru`, `alignment` FROM `en-{lang_code}` WHERE `verified` = 1'
    ):
        en.append(en_)
        ko.append(ko_)
        alignments.append(json.loads(alignment_str))
    return (en, ko, alignments)

## En-Ru

In [17]:
en, ru, alignments = get_data('ru')

In [18]:
pos_pairs, edge_pairs = get_pos_edge_pair_counts(en, ru, alignments)

In [19]:
pos_pairs.most_common(5)

[(('NOUN', 'NOUN'), 3265),
 (('VERB', 'VERB'), 1630),
 (('PROPN', 'PROPN'), 1316),
 (('ADJ', 'ADJ'), 1080),
 (('ADV', 'ADV'), 404)]

In [20]:
edge_pairs.most_common(5)

[(('amod', 'amod'), 927),
 (('nsubj', 'nsubj'), 731),
 (('nmod', 'nmod'), 604),
 (('obl', 'obl'), 555),
 (('obj', 'obj'), 422)]

In [33]:
mutual_information(pos_pairs)

(1.5175937518573266, 0.5215017877509472)

In [34]:
mutual_information(edge_pairs)

(2.317625354988861, 0.4407009173264683)

## En-Fr

In [23]:
en_fr, fr, alignments_fr = get_data('fr')
pos_pairs_fr, edge_pairs_fr = get_pos_edge_pair_counts(en_fr, fr, alignments_fr)

In [35]:
mutual_information(pos_pairs_fr)

(1.7749031148353012, 0.599195428397317)

In [36]:
mutual_information(edge_pairs_fr)

(2.617543767359971, 0.5031209518792668)

## En-Zh

In [26]:
en_zh, zh, alignments_zh = get_data('zh')
pos_pairs_zh, edge_pairs_zh = get_pos_edge_pair_counts(en_zh, zh, alignments_zh)

In [37]:
mutual_information(pos_pairs_zh)

(1.3029629775764606, 0.4280096086028898)

In [38]:
mutual_information(edge_pairs_zh)

(2.1020000349804215, 0.34542813568916303)

## En-Ko

In [29]:
en_ko, ko, alignments_ko = get_data('ko')
pos_pairs_ko, edge_pairs_ko = get_pos_edge_pair_counts(en_ko, ko, alignments_ko)

In [39]:
mutual_information(pos_pairs_ko)

(0.9059525985449439, 0.3171634407444342)

In [40]:
mutual_information(edge_pairs_ko)

(1.8197131777245381, 0.316649169136819)