In [3]:
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
import torch
import torch.nn.functional as F

from german_parser.util.const import CONSTS
from german_parser.util.dataloader import TigerDatasetGenerator
import dill as pickle
from string import punctuation
from german_parser.model import TigerModel
from german_parser.util import BatchUnionFind

In [4]:
(train_dataloader, train_new_words), (dev_dataloader, dev_new_words), (test_dataloader, test_new_words), character_set, character_flag_generators, inverse_word_dict, inverse_sym_dict = pickle.load(open("required_vars.pkl", "rb"))

In [5]:
model: TigerModel = pickle.load(open("./models/epoch_25_cpu_eval.pickle", "rb"))

In [6]:
dl_en = enumerate(test_dataloader)

In [7]:
_, (x, l, target_ex, target_lab, target_att) = next(dl_en)

In [26]:
def find_tree(model: TigerModel, input: tuple[torch.Tensor, torch.Tensor], new_words_dict: dict[int, str] | None):
    """NOTE: indices are 1-indexed. index 0 corresponds to virtual root of D-tree (which is different from virtual root of C-tree)

    Args:
        model (TigerModel): _description_
        input (tuple[torch.Tensor, torch.Tensor]): _description_
        new_words_dict (dict[int, str] | None): _description_
    """    
    x, lengths = input
    self_attention, constituent_labels, attachment_orders, indices = model.forward(input, new_words_dict)
    # self_attention has size (B, T, T + 1)

    B, T, *_ = self_attention.shape
    uf = BatchUnionFind(B, model.beam_size, N=T + 1, device=model.dummy_param.device)

    joint_logits = torch.zeros(B, model.beam_size, device=model.dummy_param.device) # (B, K)

    edges = torch.zeros(B, model.beam_size, T, dtype=torch.long, device=model.dummy_param.device) # (B, K, T); m[b, k, t - 1] is the 1-indexed parent of 1-indexed node t, in batch b, beam k

    same_as_beams = torch.zeros(B, model.beam_size) # m[b, k] == 1 if beam k in batch b is equal to beam 1 in batch b. allows us to keep track of duplicates. duplicates are when m[b, k] != k
    
    def beams_are_unique():
        return same_as_beams == torch.arange(model.beam_size).repeat(B, 1) # return where m[b, k] == same_as_beams[b, k] != k


    for t in range(T):
        relevant_batches = t < lengths # used for updating the final arcs. otherwise, we get nans in batch b if t >= sentence_length[b]

        arc_probs = self_attention[:, t, :].log_softmax(dim=-1) # (B, T + 1)

        candidate_joint_logits = joint_logits[:, :, None] + arc_probs[:, None, :] # (B, K, T + 1); the heuristic we would like to maximise

        children = torch.tensor(t + 1, device=model.dummy_param.device) # all batches and beams share the same child index (1-indexed)
        parents = torch.arange(0, T + 1, 1, device=model.dummy_param.device, dtype=torch.long).repeat(B, model.beam_size, 1) # (B, K, T + 1), where m[b, k, t] = t to represent the index of each parent 

        no_cycles_mask = uf.is_same_set(children.expand_as(parents), parents) # (B, K, T + 1); m[b, k, s + 1] is true if in batch b, beam k, joining child (t + 1) and parent (s + 1) would lead to a cycle
        no_cycles_mask[:, :, 1:] |= ~indices.unsqueeze(1).to(device=model.dummy_param.device) # avoid setting words as head that are beyond the end of the sentence
        no_cycles_mask |= ~beams_are_unique()[:, :, None]


        candidate_joint_logits[no_cycles_mask] = -torch.inf # mask cycles out so they can never be a maximiser
        flattened_top_candidate_idx = candidate_joint_logits.flatten(-2, -1).topk(k=model.beam_size, dim=-1).indices # (B, K) in range [0, (K * T + 1)); for each batch, find top 10 best performing parent-beam combinations
        
        top_parents = parents.flatten(-2, -1).gather(index=flattened_top_candidate_idx, dim=-1) # (B, K); for each batch and beam, get the 1-indexed id of the parent we want to attach

        batch_names = torch.arange(model.beam_size, device=model.dummy_param.device, dtype=torch.long).view(1, -1, 1).expand(B, -1, T + 1).flatten(-2, -1) # (B, K, T + 1); m[b, k, s] = k for all b, s, k
        used_batches = batch_names.gather(index=flattened_top_candidate_idx, dim=-1) # (B, K) where each element is in the range [0, K). m[b, k] tells you what the kth new beam should be

        same_as_beams = used_batches # we have copied over these beams, so for now, these beams must be equal

        new_data = uf.data.gather(index=used_batches.unsqueeze(-1).expand_as(uf.data), dim=1)
        new_rank = uf.rank.gather(index=used_batches.unsqueeze(-1).expand_as(uf.rank), dim=1)
        new_edges = edges.gather(index=used_batches.unsqueeze(-1).expand_as(edges), dim=1)
        new_joint_logits = candidate_joint_logits.flatten(-2, -1).gather(index=flattened_top_candidate_idx, dim=-1)
        

        uf.data[relevant_batches] = new_data[relevant_batches]
        uf.rank[relevant_batches] = new_rank[relevant_batches]
        edges[relevant_batches] = new_edges[relevant_batches]
        joint_logits[relevant_batches] = new_joint_logits[relevant_batches]

        uf.union(children.expand_as(top_parents), top_parents)

        edges[:, :, t] = top_parents
        new_unique_beams = top_parents.gather(index=same_as_beams, dim=-1) != top_parents  # indicates the beams that have BECOME unqiue: given any batch b, suppose beam j was a copy of beam k. suppose the new parent for beam j is different to the new parent of beam k. then m[b, j] = True
        same_as_beams[new_unique_beams] = torch.arange(model.beam_size).repeat(B, 1)[new_unique_beams]

        pass

    num_labels = constituent_labels.shape[-1]
    num_attachment_orders = attachment_orders.shape[-1]

    best_edges = edges[torch.arange(edges.shape[0]), joint_logits.argmax(dim=-1)] # (B, T) containing elements in range [0, T + 1), where m[b, t - 1] denotes the 1-indexed parent of 1-indexed node t

    label_logits_best_edges = constituent_labels.gather(index=best_edges[:, :, None, None].expand(-1, -1, -1, num_labels), dim=2).squeeze(2)
    attachment_logits_best_edges = attachment_orders.gather(index=best_edges[:, :, None, None].expand(-1, -1, -1, num_attachment_orders), dim=2).squeeze(2)

    labels_best_edges = label_logits_best_edges.argmax(-1)
    attachment_orders_best_edges = attachment_logits_best_edges.argmax(-1)

    return best_edges, labels_best_edges, attachment_orders_best_edges, (edges, joint_logits)

In [27]:
best_edges, labels_best_edges, attachment_orders_best_edges, (edges, joint_logits) = find_tree(model, (x, l), test_new_words)

In [31]:
# for n, (x, l, *_) in enumerate(test_dataloader):
#     best_edges, labels_best_edges, attachment_orders_best_edges, (edges, joint_logits) = find_tree(model, (x, l), test_new_words)
#     print(f"{n} of {len(test_dataloader)}", end="\r")
#     for i in range(len(x)):
#         if not joint_logits[i].allclose(joint_logits[i,0]):
#             print("YES!")

In [16]:
s_num = 0

words = x[s_num, :l[s_num]].to("cpu")

ex = best_edges[s_num, :l[s_num]].to("cpu")
lab = labels_best_edges[s_num, :l[s_num]].to("cpu")
att = attachment_orders_best_edges[s_num, :l[s_num]].to("cpu")

t_ex = target_ex[s_num, :l[s_num]].to("cpu")
t_lab = target_lab[s_num, :l[s_num]].to("cpu")
t_att = target_att[s_num, :l[s_num]].to("cpu")

In [17]:
edges = []

for i, (i_parent, i_lab, i_att) in enumerate(zip(ex, lab, att), 1):
    edges.append(f"{{{i_parent}->{i}, \"{inverse_sym_dict[i_lab.item()]}#{i_att}\"}}")

print(f"TreePlot[{{{', '.join(edges)}}}, Top, 0, VertexLabels -> Automatic, DirectedEdges -> True]")

TreePlot[{{11->1, "S#1"}, {1->2, "PP#1"}, {11->3, "VROOT#2"}, {5->4, "PP#1"}, {1->5, "PP#1"}, {5->6, "PP#1"}, {5->7, "PP#1"}, {5->8, "PP#1"}, {8->9, "PP#1"}, {11->10, "VROOT#2"}, {0->11, "DROOT#1"}, {13->12, "NP#1"}, {11->13, "S#1"}, {11->14, "VROOT#2"}, {24->15, "S#1"}, {18->16, "NP#1"}, {18->17, "NP#1"}, {24->18, "S#1"}, {18->19, "NP#1"}, {19->20, "PP#1"}, {19->21, "PP#1"}, {23->22, "AP#1"}, {24->23, "S#1"}, {13->24, "NP#1"}, {11->25, "VROOT#2"}, {33->26, "S#1"}, {28->27, "NP#1"}, {33->28, "S#1"}, {32->29, "VP#1"}, {32->30, "VP#1"}, {30->31, "AVP#1"}, {33->32, "S#1"}, {22->33, "AVP#1"}, {11->34, "VROOT#2"}}, Top, 0, VertexLabels -> Automatic, DirectedEdges -> True]


In [18]:
target_edges = []

for i, (i_parent, i_lab, i_att) in enumerate(zip(t_ex, t_lab, t_att), 1):
    target_edges.append(f"{{{i_parent}->{i}, \"{inverse_sym_dict[i_lab.item()]}#{i_att}\"}}")

print(f"TreePlot[{{{', '.join(target_edges)}}}, Top, 0, VertexLabels -> Automatic, DirectedEdges -> True]")

TreePlot[{{11->1, "S#1"}, {1->2, "PP#1"}, {11->3, "VROOT#2"}, {5->4, "PP#1"}, {1->5, "PP#1"}, {5->6, "PP#1"}, {5->7, "PP#1"}, {5->8, "PP#1"}, {8->9, "PP#1"}, {11->10, "VROOT#2"}, {0->11, "DROOT#1"}, {13->12, "NP#1"}, {11->13, "S#1"}, {11->14, "VROOT#2"}, {24->15, "S#1"}, {18->16, "NP#1"}, {18->17, "NP#1"}, {24->18, "S#1"}, {18->19, "NP#1"}, {19->20, "PP#1"}, {19->21, "PP#1"}, {23->22, "AP#1"}, {24->23, "S#1"}, {13->24, "NP#1"}, {11->25, "VROOT#2"}, {33->26, "S#1"}, {28->27, "NP#1"}, {33->28, "S#1"}, {32->29, "VP#1"}, {32->30, "VP#1"}, {30->31, "AVP#1"}, {33->32, "S#1"}, {22->33, "AVP#1"}, {11->34, "VROOT#2"}}, Top, 0, VertexLabels -> Automatic, DirectedEdges -> True]


In [19]:
from german_parser.util.c_and_d import Dependency, DependencyTree, ConstituentTree, Terminal
from collections import defaultdict

def create_d_tree(heads: torch.Tensor, labels: torch.Tensor, orders: torch.Tensor, word_ids: torch.Tensor, word_dict: dict[int, str]):
    num_words = len(heads)
    modifiers: dict[int, dict[int, list[Dependency]]] = defaultdict(lambda: defaultdict(lambda x=None: []))
    terminals: dict[int, Terminal] = {}

    for child, (head, label, order, word) in enumerate(zip(heads, labels, orders, word_ids), start=1):
        head = int(head.item())
        label = int(label.item())
        order = int(order.item())
        word = int(word.item())

        terminals[child] = Terminal(
            idx=child,
            word=word_dict[word]
        )

        # child with head of 0 is the root node of the d-tree; don't add it to modifiers dict
        if head == 0:
            continue
        
        modifiers[head][order].append(Dependency(head=head, modifier=child, sym=inverse_sym_dict[label]))

    return DependencyTree(modifiers=modifiers, terminals=terminals, num_words=num_words)

def create_c_tree(heads: torch.Tensor, labels: torch.Tensor, orders: torch.Tensor, word_ids: torch.Tensor, word_dict: dict[int, str]):
    d_tree = create_d_tree(heads, labels, orders, word_ids, word_dict)
    return ConstituentTree.from_d_tree(d_tree)

In [20]:
c_tree = create_c_tree(heads=ex, labels=lab, orders=att, word_ids=words, word_dict={**inverse_word_dict, **{-i: w for i, w in test_new_words.items()}})
t_c_tree = create_c_tree(heads=t_ex, labels=t_lab, orders=t_att, word_ids=words, word_dict={**inverse_word_dict, **{-i: w for i, w in test_new_words.items()}})

In [21]:
def get_bracket(c_tree: ConstituentTree, node: int|None=None, ignore_pre_terminal_sym: bool=False, ignore_non_terminal_sym: bool=False, ignore_all_syms: bool=False):
    if node is None:
        return get_bracket(c_tree,
                           node=c_tree.root,
                           ignore_pre_terminal_sym=ignore_pre_terminal_sym,
                           ignore_non_terminal_sym=ignore_non_terminal_sym,
                           ignore_all_syms=ignore_all_syms
                           )
    
    c = c_tree.constituents[node]
    if c.is_pre_terminal:
        return f"({'?' if ignore_pre_terminal_sym or ignore_all_syms else c.sym} {c.id}={c_tree.terminals[node].word})"
    
    res = f"({'?' if ignore_non_terminal_sym or ignore_all_syms else c.sym} "
    for child in c.children:
        res += get_bracket(c_tree,
                           node=child,
                           ignore_pre_terminal_sym=ignore_pre_terminal_sym,
                           ignore_non_terminal_sym=ignore_non_terminal_sym,
                           ignore_all_syms=ignore_all_syms
                           )
        
    res += ")"
    return res

In [22]:
get_bracket(c_tree, ignore_all_syms=False)

'(VROOT (S (? 11=bestehe)(PP (? 1=Bei)(? 2=Großversuchen)(PP (? 5=bei)(? 4=wie)(? 6=der)(? 7=AOK)(PP (? 8=in)(? 9=Sachsen))))(NP (? 13=Gefahr)(? 12=die)(S (? 24=werde)(? 15=daß)(NP (? 18=Druck)(? 16=der)(? 17=soziale)(PP (? 19=auf)(? 20=die)(? 21=Patienten)))(AP (? 23=groß)(AVP (? 22=so)(S (? 33=sei)(? 26=daß)(NP (? 28=Freiwilligkeit)(? 27=die))(VP (? 32=gewährleistet)(? 29=faktisch)(AVP (? 30=nicht)(? 31=mehr)))))))))(? 3=,)(? 10=,)(? 14=,)(? 25=,)(? 34=.))'

In [23]:
get_bracket(t_c_tree, ignore_all_syms=False)

'(VROOT (S (? 11=bestehe)(PP (? 1=Bei)(? 2=Großversuchen)(PP (? 5=bei)(? 4=wie)(? 6=der)(? 7=AOK)(PP (? 8=in)(? 9=Sachsen))))(NP (? 13=Gefahr)(? 12=die)(S (? 24=werde)(? 15=daß)(NP (? 18=Druck)(? 16=der)(? 17=soziale)(PP (? 19=auf)(? 20=die)(? 21=Patienten)))(AP (? 23=groß)(AVP (? 22=so)(S (? 33=sei)(? 26=daß)(NP (? 28=Freiwilligkeit)(? 27=die))(VP (? 32=gewährleistet)(? 29=faktisch)(AVP (? 30=nicht)(? 31=mehr)))))))))(? 3=,)(? 10=,)(? 14=,)(? 25=,)(? 34=.))'