In [1]:
from const import CONSTS
from util import get_int_after_underscore, get_str_after_underscore, is_pairwise_disjoint

In [2]:
import xml.etree.ElementTree as ET
from pydantic import BaseModel, root_validator, Field

In [3]:
document = ET.parse(open(f"{CONSTS['data_dir']}/tiger/tiger_2.2_utf8.xml", "r", encoding="utf-8"))

In [5]:
class Terminal(BaseModel):
    word: str = Field()
    lemma: str = Field()
    pos: str = Field()
    morph: str = Field()
    case: str = Field()
    number: str = Field()
    gender: str = Field()
    person: str = Field()
    degree: str = Field()
    tense: str = Field()
    mood: str = Field()
    
    idx: int = Field()
        

In [6]:
from typing import Any


class Constituent(BaseModel):
    id: int   = Field()    # node's id within the sentence
    head: int | None = Field()    # node's head word TODO: remove None!!
    yld: set[int] = Field()# node's yield
    sym: str     = Field() # node's symbol

    edge_label: str | None = Field(default=None) # the label for the edge from this constituent to its parent
    parent: int | None = Field(default=None) # the id of the parent node

    is_pre_terminal: bool = False
    children: list[int] = Field(defualt=[])

    @root_validator()
    @classmethod
    def _check_children_types(cls, field_values: dict[str, Any]):
        # if "children" not in field_values.keys():
        #     return field_values

        # for child in field_values["children"]:
        #     if field_values["is_pre_terminal"]:
        #         assert isinstance(child, Terminal), "Expected all children to be of type Terminal"
        #     else:
        #         assert isinstance(child, cls), "Expected all children to be of type Constituent"
        return field_values

In [39]:
class Sentence:
    def __init__(self, tree_element: ET.Element):
        self.head_rules = { 'S': [('s','HD',[])],
                              'VP': [('s','HD',[])],
                              'VZ': [('s','HD',[])],
                              'AVP':[('s','HD',[]),('s','PH',[]),('r','AVC',['ADV']),('l','AVC',['FM'])],
                              'AP': [('s','HD',[]),('s','PH',[])],
                              'DL': [('s','DH',[])],
                              'AA': [('s','HD',[])],
                              'ISU':[('l','UC',[])],
                              'PN': [('r','PNC',['NE','NN','FM','TRUNC','APPR','APPRART','CARD','VVFIN','VAFIN','ADJA','ADJD','XY'])],
                              'MPN': [('r','PNC',['NE','NN','FM','TRUNC','APPR','APPRART','CARD','VVFIN','VAFIN','ADJA','ADJD','XY'])],
                              'NM': [('r','NMC',['NN','CARD'])],
                              'MTA':[('r','ADC',['ADJA'])],
                              'PP': [('r','HD',['APPRART','APPR','APPO','PROAV','NE','APZR','PWAV','TRUNC']),('r','AC',['APPRART','APPR','APPO','PROAV','NE','APZR','PWAV','TRUNC']),('r','PH',['APPRART','APPR','APPO','PROAV','NE','APZR','PWAV','TRUNC']),('l','NK',['PROAV'])],
                              'CH': [('s','HD',[]),('l','UC',['FM','NE','XY','CARD','ITJ'])],
                              'NP': [('l','HD',['NN']),('l','NK',['NN']),('r','HD',['NE','PPER','PIS','PDS','PRELS','PRF','PWS','PPOSS','FM','TRUNC','ADJA','CARD','PIAT','PWAV','PROAV','ADJD','ADV','APPRART','PDAT']),('r','NK',['NE','PPER','PIS','PDS','PRELS','PRF','PWS','PPOSS','FM','TRUNC','ADJA','CARD','PIAT','PWAV','PROAV','ADJD','ADV','APPRART','PDAT']),('r','PH',['NN','NE','PPER','PIS','PDS','PRELS','PRF','PWS','PPOSS','FM','TRUNC','ADJA','CARD','PIAT','PWAV','PROAV','ADJD','ADV','APPRART','PDAT'])],              
                              'CAC':[('l','CJ',[])],
                              'CAP':[('l','CJ',[])],
                             'CAVP':[('l','CJ',[])],
                              'CCP':[('l','CJ',[])],
                              'CNP':[('l','CJ',[])],
                              'CO': [('l','CJ',[])],
                              'CPP':[('l','CJ',[])],
                              'CS': [('l','CJ',[])],
                              'CVP':[('l','CJ',[])],
                              'CVZ':[('l','CJ',[])]
                            }
        self.verb_phrase_reattach_symbols = ["S", "VP"]
        self.constituents_to_ignore_heads = ["VROOT"]
        """
        headrules is a dict[str, rules]
        where rules is list[rule]
        where rule is tuple[direction, edge_label, list[str]]
            edge_label dictates the edge to follow
            list[str] is a list of parts of speech to find along that edge
            direction is one of 'l' or 'r' for left or right: if left, follow the edge to constituent that matches the first pos in the list, if right, follow the edge to the rightmost constituent that matches the pos
        """

        # find elements in xml tree
        self.tree_graph = tree_element.find("graph")
        self.tree_graph_root_name = get_str_after_underscore(self.tree_graph.attrib["root"])
        self.tree_terminals = self.tree_graph.find("terminals")
        self.tree_nonterminals = self.tree_graph.find("nonterminals")

        self.is_discontinuous = ("discontinuous" in self.tree_graph.attrib) and (self.tree_graph.attrib["discontinuous"] == "true")

        self.terminals: dict[int, Terminal] = {} # use dict for 1-based indexing
        self.constituents: dict[int, Constituent] = {}

        # parse terminals
        for idx, t in enumerate(self.tree_terminals.iter("t")):
            assert idx + 1 == get_int_after_underscore(t.attrib["id"]), f"Terminal index '{idx + 1}' does not match index implied by its id '{t.attrib['id']}'"
            term = Terminal(**t.attrib, idx=idx + 1)
            self.terminals.update({
                idx + 1: term
            })

        # create pre-terminals from terminals: the pre-terminals are constituents with id equal to the terminal's index within the sentence
        self.constituents.update({
                t.idx: Constituent(
                    id=t.idx,
                    head=t.idx,
                    yld={t.idx},
                    sym=t.pos,
                    is_pre_terminal=True,
                    children=[t.idx]
                )
            for t in self.terminals.values()})
        
        self.integize_generator = self.integize()
        integize_dict: dict[str, int] = {}

        # create non-terminal constituents
        for nt in self.tree_nonterminals.iter("nt"):
            nt_id = get_int_after_underscore(nt.attrib["id"])
            if nt_id is None:
                nt_id = next(self.integize_generator)
                assert nt_id not in integize_dict.values()
                integize_dict[get_str_after_underscore(nt.attrib["id"])] = nt_id
            assert nt_id is not None

            nt_sym = nt.attrib["cat"]
            nt_head = None # default to None

            nt_children: list[tuple[str, int]] = [] # list of tuple[edge_label, child_id]

            nt_edges = nt.iter("edge")
            for edge in nt_edges:
                edge_id_ref = get_int_after_underscore(edge.attrib["idref"], integize_dict.get(get_str_after_underscore(edge.attrib["idref"]), None))
                edge_label = edge.attrib["label"]
                assert edge_id_ref is not None
                assert edge_id_ref in self.constituents

                nt_children.append((edge_label, edge_id_ref))
                if edge.attrib["label"] == "HD":
                    nt_head = edge_id_ref

                # add edge label and parent id to child constituent
                assert self.constituents[edge_id_ref].parent is None, f"Constituent '{edge_id_ref}' already has a parent"
                assert self.constituents[edge_id_ref].edge_label is None, f"Constituent '{edge_id_ref}' already has an edge label"
                self.constituents[edge_id_ref].parent = nt_id
                self.constituents[edge_id_ref].edge_label = edge_label

            nt_children_ylds = [self.constituents[c_id].yld for _, c_id in nt_children]
            assert is_pairwise_disjoint(*nt_children_ylds), "Yields of children nodes must be pairwise disjoint"

            self.constituents.update({
                nt_id: Constituent(
                    id=nt_id,
                    sym=nt_sym,
                    head=nt_head,
                    is_pre_terminal=False,
                    yld=set.union(*nt_children_ylds),
                    children=[c_id for _, c_id in nt_children]
                )
            })
        
        # find root
        try:
            self.root = int(self.tree_graph_root_name)
        except Exception:
            self.root = integize_dict[self.tree_graph_root_name]
    
        assert self.constituents[self.root].yld == set(range(1, len(self.terminals) + 1)), f"Root constituent must yield entire sentence"

        # attempt to find heads for all phrases that do not have a head already defined by the markup
        empty_constituent_generator = self.create_next_empty_constituent()
        constituents_without_heads = [c for c in self.constituents.values() if c.head is None and c.sym not in self.constituents_to_ignore_heads]
        for c in constituents_without_heads:
            if c.head is None:
                c.head = self.find_head(c.id)

                # reattach VP or S
                if c.head is None and c.sym in self.verb_phrase_reattach_symbols:
                    c.head = next(empty_constituent_generator)

                print(f"Found head word {c.head} for constituent {c.id}")        

    def find_head_candidates(self, children: list[int], edge_label: str, pos_list: list[str]):
        children_matching_edge = [c for c in children if self.constituents[c].edge_label == edge_label]
        assert set([v for c in children_matching_edge for v in self.constituents[c].yld]).issubset(set(self.terminals.keys())), "Head candidates must be terminals"

        if len(children_matching_edge) > 1:
            print(f"Warning: {len(children_matching_edge)} candidate edges exist")

        if not pos_list: # if pos_list is empty, then we don't care about the POS of the head
            yield [v for c in children_matching_edge for v in self.constituents[c].yld]
        
        for pos in pos_list:
            yield [v for c in children_matching_edge for v in self.constituents[c].yld if self.constituents[v].sym == pos]

    def integize(self):
        num_integers = 1048576
        while True:
            num_integers += 1
            yield num_integers

    def create_next_empty_constituent(self):
        """
        Creates an empty constituent and adds this to the sentence's constituent dict, yielding its id. Empty constituent ids are negative integers.
        """
        num_empty_constituents = 0 # counter to keep empty constituent ids unique
        while True:
            num_empty_constituents -= 1
            assert num_empty_constituents not in self.constituents, f"Empty constituent id '{num_empty_constituents}' already exists"

            self.constituents.update({
                num_empty_constituents: Constituent(
                    id=num_empty_constituents,
                    sym="<EMPTY>",
                    head=None,
                    is_pre_terminal=True,
                    yld=set(),
                    children=[]
                )
            })

            yield num_empty_constituents

    def find_head(self, idx: int) -> int | None:
        assert idx in self.constituents, f"Constituent '{idx}' does not exist"
        assert self.constituents[idx].sym in self.head_rules, f"Constituent symbol '{self.constituents[idx].sym}' has no head rules"

        for rule in self.head_rules.values():
            for dir, label, pos_list in rule:
                for head_candidates in self.find_head_candidates(self.constituents[idx].children, label, pos_list):
                    if dir == 's' and head_candidates:
                        assert len(head_candidates) == 1, "'s' direction requires unique head candidate"
                        return head_candidates[0]
                    elif dir == 'l' and head_candidates:
                        return min(head_candidates)
                    elif dir == 'r' and head_candidates:
                        return max(head_candidates)
                    
        return None

In [40]:
sent = Sentence(document.find('.//*[@id="s31"]')) # type:ignore

Found head word 1 for constituent 500
Found head word 7 for constituent 503
Found head word -1 for constituent 504


In [41]:
all_s = document.findall(".//s")

In [82]:
%%capture

errors = 0
sents = []
for idx, s in enumerate(all_s):
    print(f"Parsing sentence {idx + 1}; errors {errors}")
    try:
        sents.append(Sentence(s))
    except Exception:
        errors += 1

In [83]:
errors

298

In [79]:
sents[0].constituents

{1: Constituent(id=1, head=1, yld={1}, sym='$(', edge_label='--', parent=1048577, is_pre_terminal=True, children=[1]),
 2: Constituent(id=2, head=2, yld={2}, sym='NE', edge_label='PNC', parent=500, is_pre_terminal=True, children=[2]),
 3: Constituent(id=3, head=3, yld={3}, sym='NE', edge_label='PNC', parent=500, is_pre_terminal=True, children=[3]),
 4: Constituent(id=4, head=4, yld={4}, sym='VAFIN', edge_label='HD', parent=502, is_pre_terminal=True, children=[4]),
 5: Constituent(id=5, head=5, yld={5}, sym='ADV', edge_label='MO', parent=502, is_pre_terminal=True, children=[5]),
 6: Constituent(id=6, head=6, yld={6}, sym='ART', edge_label='NK', parent=501, is_pre_terminal=True, children=[6]),
 7: Constituent(id=7, head=7, yld={7}, sym='ADJA', edge_label='NK', parent=501, is_pre_terminal=True, children=[7]),
 8: Constituent(id=8, head=8, yld={8}, sym='NN', edge_label='NK', parent=501, is_pre_terminal=True, children=[8]),
 9: Constituent(id=9, head=9, yld={9}, sym='$(', edge_label='--', p

In [80]:
sents_lengths = [v for v in map(lambda v : len(v.terminals), sents)]

In [81]:
for i, l in enumerate(sents_lengths):
    if l == 130:
        print (i)

5221
34835


In [77]:
sents[3]

<__main__.Sentence at 0x7ff274995690>

In [78]:
# TODO: check behaviour if multiple candidate edges exist is expected: r means find largest index word (rightmost word)
# TODO: check head rules satisfy definition of constituent