In [19]:
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
import torch
import torch.nn.functional as F

from german_parser.util.const import CONSTS
from german_parser.util.dataloader import TigerDatasetGenerator
import dill as pickle
from string import punctuation
from german_parser.model import TigerModel
from german_parser.util import BatchUnionFind

from german_parser.util.c_and_d import Dependency, DependencyTree, ConstituentTree, Terminal
from collections import defaultdict
from typing import Sequence

In [20]:
(train_dataloader, train_new_words), (dev_dataloader, dev_new_words), (test_dataloader, test_new_words), character_set, character_flag_generators, inverse_word_dict, inverse_sym_dict = pickle.load(open("required_vars.pkl", "rb"))
model: TigerModel = pickle.load(open("./models/tiger_model_2023_09_05-03_04_35_PM_epoch_6.pickle", "rb"))
dl_en = enumerate(test_dataloader)

In [21]:
model = model.eval().to("cpu")

# Define Methods

In [22]:
def find_tree(model: TigerModel, input: tuple[torch.Tensor, torch.Tensor], new_words_dict: dict[int, str] | None):
    """NOTE: indices are 1-indexed. index 0 corresponds to virtual root of D-tree (which is different from virtual root of C-tree)

    Args:
        model (TigerModel): _description_
        input (tuple[torch.Tensor, torch.Tensor]): _description_
        new_words_dict (dict[int, str] | None): _description_
    """    
    x, lengths = input
    self_attention, constituent_labels, attachment_orders, indices = model.forward(input, new_words_dict)
    # self_attention has size (B, T, T + 1)

    B, T, *_ = self_attention.shape
    uf = BatchUnionFind(B, model.beam_size, N=T + 1, device=model.dummy_param.device)

    # initialise beams by finding top-K most probable root nodes
    best_roots = self_attention[:, :, 0].topk(k=model.beam_size, dim=-1)

    current_root_indices = best_roots.indices # (B, K)
    joint_logits = best_roots.values # (B, K)

    edges = torch.zeros(B, model.beam_size, T, dtype=torch.long, device=model.dummy_param.device) # (B, K, T); m[b, k, t - 1] is the 1-indexed parent of 1-indexed node t, in batch b, beam k

    same_as_beams = torch.zeros(B, model.beam_size) # m[b, k] == 1 if beam k in batch b is equal to beam 1 in batch b. allows us to keep track of duplicates. duplicates are when m[b, k] != k
    
    def beams_are_unique():
        return same_as_beams == torch.arange(model.beam_size).repeat(B, 1) # return where m[b, k] == same_as_beams[b, k] != k


    for t in range(T):
        relevant_batches = t < lengths # used for updating the final arcs. otherwise, we get nans in batch b if t >= sentence_length[b]

        arc_probs = self_attention[:, t, :].log_softmax(dim=-1) # (B, T + 1)

        candidate_joint_logits = joint_logits[:, :, None] + arc_probs[:, None, :] # (B, K, T + 1); the heuristic we would like to maximise

        children = torch.tensor(t + 1, device=model.dummy_param.device) # all batches and beams share the same child index (1-indexed)
        parents = torch.arange(0, T + 1, 1, device=model.dummy_param.device, dtype=torch.long).repeat(B, model.beam_size, 1) # (B, K, T + 1), where m[b, k, t] = t to represent the index of each parent 

        # prevent cycles
        disable_mask = uf.is_same_set(children.expand_as(parents), parents) # (B, K, T + 1); m[b, k, s + 1] is true if in batch b, beam k, joining child (t + 1) and parent (s + 1) would lead to a cycle
        # avoid setting words as head that are beyond the end of the sentence
        disable_mask[:, :, 1:] |= ~indices.unsqueeze(1).to(device=model.dummy_param.device) 
        # force beams to be unique
        disable_mask |= ~beams_are_unique()[:, :, None]
        # force root indices to be enabled
        disable_mask[:, :, 1:][current_root_indices == t] = True # for the batches and beams where t would be a root node, t's parent must be 0 (cannot be 1:)
        # prevent other indices from becoming root
        disable_mask[:, :, 0][current_root_indices != t] = True
        candidate_joint_logits[disable_mask] = -torch.inf # these indices can never be a maximiser

        flattened_top_candidate_idx = candidate_joint_logits.flatten(-2, -1).topk(k=model.beam_size, dim=-1).indices # (B, K) in range [0, (K * T + 1)); for each batch, find top 10 best performing parent-beam combinations
        
        top_parents = parents.flatten(-2, -1).gather(index=flattened_top_candidate_idx, dim=-1) # (B, K); for each batch and beam, get the 1-indexed id of the parent we want to attach

        batch_names = torch.arange(model.beam_size, device=model.dummy_param.device, dtype=torch.long).view(1, -1, 1).expand(B, -1, T + 1).flatten(-2, -1) # (B, K, T + 1); m[b, k, s] = k for all b, s, k
        used_batches = batch_names.gather(index=flattened_top_candidate_idx, dim=-1) # (B, K) where each element is in the range [0, K). m[b, k] tells you what the kth new beam should be

        same_as_beams = used_batches # we have copied over these beams, so for now, these beams must be equal

        new_data = uf.data.gather(index=used_batches.unsqueeze(-1).expand_as(uf.data), dim=1)
        new_rank = uf.rank.gather(index=used_batches.unsqueeze(-1).expand_as(uf.rank), dim=1)
        new_edges = edges.gather(index=used_batches.unsqueeze(-1).expand_as(edges), dim=1)
        new_joint_logits = candidate_joint_logits.flatten(-2, -1).gather(index=flattened_top_candidate_idx, dim=-1)
        new_roots = current_root_indices.gather(index=used_batches, dim=-1)
        

        uf.data[relevant_batches] = new_data[relevant_batches]
        uf.rank[relevant_batches] = new_rank[relevant_batches]
        edges[relevant_batches] = new_edges[relevant_batches]
        joint_logits[relevant_batches] = new_joint_logits[relevant_batches]
        current_root_indices[relevant_batches] = new_roots[relevant_batches]

        uf.union(children.expand_as(top_parents), top_parents)

        edges[:, :, t] = top_parents
        new_unique_beams = top_parents.gather(index=same_as_beams, dim=-1) != top_parents  # indicates the beams that have BECOME unqiue: given any batch b, suppose beam j was a copy of beam k. suppose the new parent for beam j is different to the new parent of beam k. then m[b, j] = True
        same_as_beams[new_unique_beams] = torch.arange(model.beam_size).repeat(B, 1)[new_unique_beams]

        pass

    num_labels = constituent_labels.shape[-1]
    num_attachment_orders = attachment_orders.shape[-1]

    best_edges = edges[torch.arange(edges.shape[0]), joint_logits.argmax(dim=-1)] # (B, T) containing elements in range [0, T + 1), where m[b, t - 1] denotes the 1-indexed parent of 1-indexed node t

    label_logits_best_edges = constituent_labels.gather(index=best_edges[:, :, None, None].expand(-1, -1, -1, num_labels), dim=2).squeeze(2)
    attachment_logits_best_edges = attachment_orders.gather(index=best_edges[:, :, None, None].expand(-1, -1, -1, num_attachment_orders), dim=2).squeeze(2)

    labels_best_edges = label_logits_best_edges.argmax(-1)
    attachment_orders_best_edges = attachment_logits_best_edges.argmax(-1)

    return best_edges, labels_best_edges, attachment_orders_best_edges, (edges, joint_logits)

In [23]:
def create_d_tree(heads: Sequence[int] | torch.Tensor, labels: Sequence[int] | torch.Tensor, orders: Sequence[int] | torch.Tensor, words: list[str]):
    assert len(labels) == len(heads) and len(orders) == len(labels) and len(words) == len(orders)
    num_words = len(heads)

    modifiers: dict[int, dict[int, list[Dependency]]] = defaultdict(lambda: defaultdict(lambda x=None: []))
    terminals: dict[int, Terminal] = {}

    for child, (head, label, order, word) in enumerate(zip(heads, labels, orders, words), start=1):
        if isinstance(head, torch.Tensor):
            head = int(head.item())
        if isinstance(label, torch.Tensor):
            label = int(label.item())
        if isinstance(order, torch.Tensor):
            order = int(order.item())

        terminals[child] = Terminal(
            idx=child,
            word=word
        )

        # child with head of 0 is the root node of the d-tree; don't add it to modifiers dict
        if head == 0:
            continue
        
        modifiers[head][order].append(Dependency(head=head, modifier=child, sym=inverse_sym_dict[label]))

    # with an ordering, all arcs must have the same symbol between dependency and head
    for h in modifiers:
        for o in modifiers[h]:
            dependencies = modifiers[h][o]
            closest_dependency_to_head = max(dependencies, key=lambda d: abs(d.head - d.modifier))
            for d in dependencies:
                if d.sym != closest_dependency_to_head.sym:
                    print(f"Warning: arc symbol from {d.modifier} to {d.head} ({d.sym}) does not match arc from closest dependency {closest_dependency_to_head.modifier} ({closest_dependency_to_head.sym}). Setting this to ({closest_dependency_to_head.sym})")
                    d.sym = closest_dependency_to_head.sym
            pass

    return DependencyTree(modifiers=modifiers, terminals=terminals, num_words=num_words)

def create_c_tree(heads: Sequence[int] | torch.Tensor, labels: Sequence[int] | torch.Tensor, orders: Sequence[int] | torch.Tensor, words: list[str]):
    d_tree = create_d_tree(heads, labels, orders, words)
    return ConstituentTree.from_d_tree(d_tree)

# Playing Around

In [24]:
_, (x, l, target_ex, target_lab, target_att) = next(dl_en)

In [25]:
best_edges, labels_best_edges, attachment_orders_best_edges, (edges, joint_logits) = find_tree(model, (x, l), test_new_words)

In [26]:
joint_logits[0]

tensor([-16.7786, -16.8741, -16.9794, -17.0701, -17.1065, -17.1906, -17.2020,
        -17.2709, -18.1423, -18.2378], grad_fn=<SelectBackward0>)

In [27]:
s_num = 3

words = x[s_num, :l[s_num]].to("cpu")

ex = best_edges[s_num, :l[s_num]].to("cpu")
lab = labels_best_edges[s_num, :l[s_num]].to("cpu")
att = attachment_orders_best_edges[s_num, :l[s_num]].to("cpu")

t_ex = target_ex[s_num, :l[s_num]].to("cpu")
t_lab = target_lab[s_num, :l[s_num]].to("cpu")
t_att = target_att[s_num, :l[s_num]].to("cpu")

In [28]:
edges = []

for i, (i_parent, i_lab, i_att) in enumerate(zip(ex, lab, att), 1):
    edges.append(f"{{{i_parent}->{i}, \"{inverse_sym_dict[i_lab.item()]}#{i_att}\"}}")

print(f"TreePlot[{{{', '.join(edges)}}}, Top, 0, VertexLabels -> Automatic, DirectedEdges -> True]")

TreePlot[{{2->1, "S#1"}, {0->2, "DROOT#1"}, {5->3, "NP#1"}, {5->4, "NP#1"}, {7->5, "VP#1"}, {7->6, "VP#1"}, {2->7, "S#1"}, {2->8, "VROOT#2"}, {20->9, "S#1"}, {11->10, "NP#1"}, {20->11, "S#1"}, {14->12, "NP#1"}, {14->13, "NP#1"}, {20->14, "S#1"}, {17->15, "NP#1"}, {17->16, "NP#1"}, {20->17, "S#1"}, {20->18, "S#1"}, {18->19, "PP#1"}, {2->20, "S#1"}, {2->21, "VROOT#2"}, {26->22, "S#1"}, {26->23, "S#1"}, {26->24, "S#1"}, {26->25, "VP#1"}, {20->26, "CS#1"}, {2->27, "VROOT#2"}}, Top, 0, VertexLabels -> Automatic, DirectedEdges -> True]


In [29]:
target_edges = []

for i, (i_parent, i_lab, i_att) in enumerate(zip(t_ex, t_lab, t_att), 1):
    target_edges.append(f"{{{i_parent}->{i}, \"{inverse_sym_dict[i_lab.item()]}#{i_att}\"}}")

print(f"TreePlot[{{{', '.join(target_edges)}}}, Top, 0, VertexLabels -> Automatic, DirectedEdges -> True]")

TreePlot[{{2->1, "S#1"}, {0->2, "DROOT#1"}, {5->3, "NP#1"}, {5->4, "NP#1"}, {2->5, "S#1"}, {7->6, "VP#1"}, {2->7, "S#1"}, {2->8, "VROOT#2"}, {20->9, "S#1"}, {11->10, "NP#1"}, {20->11, "S#1"}, {14->12, "NP#1"}, {14->13, "NP#1"}, {11->14, "NP#1"}, {17->15, "NP#1"}, {17->16, "NP#1"}, {20->17, "S#1"}, {20->18, "S#1"}, {18->19, "PP#1"}, {6->20, "PP#1"}, {2->21, "VROOT#2"}, {26->22, "S#1"}, {26->23, "S#1"}, {26->24, "S#1"}, {22->25, "NP#1"}, {17->26, "NP#1"}, {2->27, "VROOT#2"}}, Top, 0, VertexLabels -> Automatic, DirectedEdges -> True]


In [30]:
the_sentence = [inverse_word_dict[w.item()] if w > 0 else test_new_words[-w.item()] for w in words]
c_tree = create_c_tree(heads=ex, labels=lab, orders=att, words=the_sentence)
t_c_tree = create_c_tree(heads=t_ex, labels=t_lab, orders=t_att, words=the_sentence)



In [31]:
c_tree.draw()

(<toyplot.canvas.Canvas at 0x16af6ea90>,
 <toyplot.coordinates.Cartesian at 0x12b1c3490>,
 <toytree.Render.ToytreeMark at 0x31be85ed0>)

In [32]:
t_c_tree.draw()

(<toyplot.canvas.Canvas at 0x16ae3fd90>,
 <toyplot.coordinates.Cartesian at 0x16ae2f690>,
 <toytree.Render.ToytreeMark at 0x16ae10750>)

# Evaluation

In [33]:
from discodop import eval
from discodop import tree

In [40]:
c_tree_tree, c_tree_sent = tree.brackettree(c_tree.get_bracket(zero_indexed=True), detectdisc=True)
t_c_tree_tree, t_c_tree_sent = tree.brackettree(t_c_tree.get_bracket(zero_indexed=True), detectdisc=True)

In [77]:
params = {'LABELED': {},
          'DELETE_LABEL': set(),
          'DISC_ONLY': False,


          'DELETE_WORD': set(),
          'EQ_LABEL': {}, 'EQ_WORD': {},
          'DELETE_ROOT_PRETERMS': 0,
          'DELETE_LABEL_FOR_LENGTH': {}, 

          # do not calculate these unneeded metrics
          'LA': False,
          'TED': False,
          'DEP': False
          }