In [1]:
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
import torch
import torch.nn.functional as F

from german_parser.util.const import CONSTS
from german_parser.util.dataloader import TigerDatasetGenerator
import dill as pickle
from string import punctuation
from german_parser.model import TigerModel
from german_parser.util import BatchUnionFind

from german_parser.util.c_and_d import Dependency, DependencyTree, ConstituentTree, Terminal
from collections import defaultdict
from typing import Sequence

In [2]:
(train_dataloader, train_new_words), (dev_dataloader, dev_new_words), (test_dataloader, test_new_words), character_set, character_flag_generators, inverse_word_dict, inverse_sym_dict = pickle.load(open("required_vars.pkl", "rb"))
model: TigerModel = pickle.load(open("./models/epoch_25_cpu_eval.pickle", "rb"))
dl_en = enumerate(test_dataloader)

# Playing Around

In [11]:
_, (x, l, target_ex, target_lab, target_att) = next(dl_en)

In [12]:
best_edges, labels_best_edges, attachment_orders_best_edges, (edges, joint_logits) = model.find_tree((x, l), test_new_words)

In [14]:
s_num = 3

words = x[s_num, :l[s_num]].to("cpu")

ex = best_edges[s_num, :l[s_num]].to("cpu")
lab = labels_best_edges[s_num, :l[s_num]].to("cpu")
att = attachment_orders_best_edges[s_num, :l[s_num]].to("cpu")

t_ex = target_ex[s_num, :l[s_num]].to("cpu")
t_lab = target_lab[s_num, :l[s_num]].to("cpu")
t_att = target_att[s_num, :l[s_num]].to("cpu")

In [15]:
edges = []

for i, (i_parent, i_lab, i_att) in enumerate(zip(ex, lab, att), 1):
    edges.append(f"{{{i_parent}->{i}, \"{inverse_sym_dict[i_lab.item()]}#{i_att}\"}}")

print(f"TreePlot[{{{', '.join(edges)}}}, Top, 0, VertexLabels -> Automatic, DirectedEdges -> True]")

TreePlot[{{5->1, "S#1"}, {1->2, "VP#1"}, {2->3, "PP#1"}, {2->4, "PP#1"}, {0->5, "DROOT#1"}, {8->6, "PP#1"}, {8->7, "PP#1"}, {10->8, "AP#1"}, {10->9, "AP#1"}, {11->10, "NP#1"}, {5->11, "S#1"}, {13->12, "NP#1"}, {11->13, "NP#1"}, {13->14, "NP#1"}, {14->15, "CNP#1"}, {14->16, "CNP#1"}, {5->17, "VROOT#2"}, {24->18, "S#1"}, {20->19, "NP#1"}, {23->20, "VP#1"}, {23->21, "VP#1"}, {21->22, "PP#1"}, {24->23, "S#1"}, {11->24, "NP#1"}, {5->25, "VROOT#2"}, {5->26, "S#1"}, {26->27, "PP#1"}, {26->28, "PP#1"}, {26->29, "PP#1"}, {29->30, "PP#1"}, {5->31, "VROOT#2"}}, Top, 0, VertexLabels -> Automatic, DirectedEdges -> True]


In [16]:
target_edges = []

for i, (i_parent, i_lab, i_att) in enumerate(zip(t_ex, t_lab, t_att), 1):
    target_edges.append(f"{{{i_parent}->{i}, \"{inverse_sym_dict[i_lab.item()]}#{i_att}\"}}")

print(f"TreePlot[{{{', '.join(target_edges)}}}, Top, 0, VertexLabels -> Automatic, DirectedEdges -> True]")

TreePlot[{{5->1, "S#1"}, {1->2, "AP#1"}, {2->3, "PP#1"}, {2->4, "PP#1"}, {0->5, "DROOT#1"}, {8->6, "PP#1"}, {8->7, "PP#1"}, {5->8, "S#1"}, {10->9, "AP#1"}, {11->10, "NP#1"}, {5->11, "S#1"}, {13->12, "NP#1"}, {11->13, "NP#1"}, {13->14, "NP#1"}, {14->15, "CNP#1"}, {14->16, "CNP#1"}, {5->17, "VROOT#2"}, {24->18, "S#1"}, {20->19, "NP#1"}, {23->20, "VP#1"}, {23->21, "VP#1"}, {21->22, "PP#1"}, {24->23, "S#1"}, {13->24, "NP#1"}, {5->25, "VROOT#2"}, {5->26, "S#1"}, {26->27, "PP#1"}, {26->28, "PP#1"}, {26->29, "PP#1"}, {29->30, "PP#1"}, {5->31, "VROOT#2"}}, Top, 0, VertexLabels -> Automatic, DirectedEdges -> True]


In [24]:
the_sentence = [inverse_word_dict[w.item()] if w > 0 else test_new_words[-w.item()] for w in words]
c_tree = ConstituentTree.from_collection(heads=ex, syms=[inverse_sym_dict[l.item()] for l in lab], orders=att, words=the_sentence)
t_c_tree = ConstituentTree.from_collection(heads=t_ex, syms=[inverse_sym_dict[l.item()] for l in t_lab], orders=t_att, words=the_sentence)

# Evaluation

In [1]:
from discodop import eval
from discodop import tree

In [40]:
c_tree_tree, c_tree_sent = tree.brackettree(c_tree.get_bracket(zero_indexed=True), detectdisc=True)
t_c_tree_tree, t_c_tree_sent = tree.brackettree(t_c_tree.get_bracket(zero_indexed=True), detectdisc=True)

In [77]:
params = {'LABELED': {},
          'DELETE_LABEL': set(),
          'DISC_ONLY': False,


          'DELETE_WORD': set(),
          'EQ_LABEL': {}, 'EQ_WORD': {},
          'DELETE_ROOT_PRETERMS': 0,
          'DELETE_LABEL_FOR_LENGTH': {}, 

          # do not calculate these unneeded metrics
          'LA': False,
          'TED': False,
          'DEP': False
          }