In [23]:
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
import torch
import torch.nn.functional as F

from german_parser.util.const import CONSTS
from german_parser.util.dataloader import TigerDatasetGenerator
import dill as pickle
from string import punctuation
from german_parser.model import TigerModel
from german_parser.util import BatchUnionFind

In [24]:
(train_dataloader, train_new_words), (dev_dataloader, dev_new_words), (test_dataloader, test_new_words), character_set, character_flag_generators, inverse_word_dict, inverse_sym_dict = pickle.load(open("required_vars.pkl", "rb"))

In [25]:
model: TigerModel = pickle.load(open("./models/tiger_model_2023_09_05-03_04_35_PM_epoch_5.pickle", "rb"))

In [26]:
model.eval()

TigerModel(
  (word_embedding): WordEmbedding(
    (word_cnn): WordCNN(
      (embeddings): Embedding(104, 100, padding_idx=1)
      (conv): Conv1d(105, 100, kernel_size=(3,), stride=(1,))
    )
    (embeddings): Embedding(25019, 100, padding_idx=1)
  )
  (enc_lstm): LSTMSkip(
    (chain): ModuleList(
      (0): LSTM(200, 512, batch_first=True, bidirectional=True)
      (1-2): 2 x LSTM(1024, 512, batch_first=True, bidirectional=True)
    )
    (dropout): Dropout(p=0.2, inplace=False)
  )
  (dec_lstm): LSTMSkip(
    (chain): ModuleList(
      (0): LSTM(1024, 512, batch_first=True)
    )
    (dropout): Dropout(p=0, inplace=False)
  )
  (enc_final_cell_to_dec_init_cell): Linear(in_features=1024, out_features=512, bias=True)
  (enc_attention_mlp): Linear(in_features=1024, out_features=512, bias=True)
  (dec_attention_mlp): Linear(in_features=512, out_features=512, bias=True)
  (enc_label_mlp): Linear(in_features=1024, out_features=128, bias=True)
  (dec_label_mlp): Linear(in_features=512, 

In [27]:
dl_en = enumerate(test_dataloader)

In [28]:
_, (x, l, target_ex, target_lab, target_att) = next(dl_en)

In [29]:
def find_tree(model: TigerModel, input: tuple[torch.Tensor, torch.Tensor], new_words_dict: dict[int, str] | None):
    """NOTE: indices are 1-indexed. index 0 corresponds to virtual root of D-tree (which is different from virtual root of C-tree)

    Args:
        model (TigerModel): _description_
        input (tuple[torch.Tensor, torch.Tensor]): _description_
        new_words_dict (dict[int, str] | None): _description_
    """    
    x, lengths = input
    self_attention, constituent_labels, attachment_orders, indices = model.forward(input, new_words_dict)
    # self_attention has size (B, T, T + 1)

    B, T, *_ = self_attention.shape
    uf = BatchUnionFind(B, model.beam_size, N=T + 1, device=model.dummy_param.device)

    joint_logits = torch.zeros(B, model.beam_size, device=model.dummy_param.device) # (B, K)

    edges = torch.zeros(B, model.beam_size, T, dtype=torch.long, device=model.dummy_param.device) # (B, K, T); m[b, k, t - 1] is the 1-indexed parent of 1-indexed node t, in batch b, beam k

    for t in range(T):
        relevant_batches = t < lengths # used for updating the final arcs. otherwise, we get nans in batch b if t >= sentence_length[b]

        arc_probs = self_attention[:, t, :].log_softmax(dim=-1) # (B, T + 1)

        candidate_joint_logits = joint_logits[:, :, None] + arc_probs[:, None, :] # (B, K, T + 1); the heuristic we would like to maximise

        children = torch.tensor(t + 1, device=model.dummy_param.device) # all batches and beams share the same child index (1-indexed)
        parents = torch.arange(0, T + 1, 1, device=model.dummy_param.device, dtype=torch.long).repeat(B, model.beam_size, 1) # (B, K, T + 1), where m[b, k, t] = t to represent the index of each parent 

        no_cycles_mask = uf.is_same_set(children.expand_as(parents), parents) # (B, K, T + 1); m[b, k, s + 1] is true if in batch b, beam k, joining child (t + 1) and parent (s + 1) would lead to a cycle
        no_cycles_mask[:, :, 1:] |= ~indices.unsqueeze(1).to(device=model.dummy_param.device)

        candidate_joint_logits[no_cycles_mask] = -torch.inf # mask cycles out so they can never be a maximiser
        flattened_top_candidate_idx = candidate_joint_logits.flatten(-2, -1).topk(k=model.beam_size, dim=-1).indices # (B, K) in range [0, (K * T + 1)); for each batch, find top 10 best performing parent-beam combinations
        
        top_parents = parents.flatten(-2, -1).gather(index=flattened_top_candidate_idx, dim=-1) # (B, K); for each batch and beam, get the 1-indexed id of the parent we want to attach

        batch_names = torch.arange(model.beam_size, device=model.dummy_param.device, dtype=torch.long).view(1, -1, 1).expand(B, -1, T + 1).flatten(-2, -1) # (B, K, T + 1); m[b, k, s] = k for all b, s, k
        used_batches = batch_names.gather(index=flattened_top_candidate_idx, dim=-1) # (B, K) where each element is in the range [0, K). m[b, k] tells you what the kth new beam should be

        new_data = uf.data.gather(index=used_batches.unsqueeze(-1).expand_as(uf.data), dim=1)
        new_rank = uf.rank.gather(index=used_batches.unsqueeze(-1).expand_as(uf.rank), dim=1)
        new_edges = edges.gather(index=used_batches.unsqueeze(-1).expand_as(edges), dim=1)
        new_joint_logits = candidate_joint_logits.flatten(-2, -1).gather(index=flattened_top_candidate_idx, dim=-1)

        uf.data[relevant_batches] = new_data[relevant_batches]
        uf.rank[relevant_batches] = new_rank[relevant_batches]
        edges[relevant_batches] = new_edges[relevant_batches]
        joint_logits[relevant_batches] = new_joint_logits[relevant_batches]

        uf.union(children.expand_as(top_parents), top_parents)

        edges[:, :, t] = top_parents

        pass

    num_labels = constituent_labels.shape[-1]
    num_attachment_orders = attachment_orders.shape[-1]

    best_edges = edges[torch.arange(edges.shape[0]), joint_logits.argmax(dim=-1)] # (B, T) containing elements in range [0, T + 1), where m[b, t - 1] denotes the 1-indexed parent of 1-indexed node t

    label_logits_best_edges = constituent_labels.gather(index=best_edges[:, :, None, None].expand(-1, -1, -1, num_labels), dim=2).squeeze(2)
    attachment_logits_best_edges = attachment_orders.gather(index=best_edges[:, :, None, None].expand(-1, -1, -1, num_attachment_orders), dim=2).squeeze(2)

    labels_best_edges = label_logits_best_edges.argmax(-1)
    attachment_orders_best_edges = attachment_logits_best_edges.argmax(-1)

    return best_edges, labels_best_edges, attachment_orders_best_edges, (edges, joint_logits)

In [30]:
best_edges, labels_best_edges, attachment_orders_best_edges, (edges, joint_logits) = find_tree(model, (x, l), test_new_words)

In [31]:
# for n, (x, l, *_) in enumerate(test_dataloader):
#     best_edges, labels_best_edges, attachment_orders_best_edges, (edges, joint_logits) = find_tree(model, (x, l), test_new_words)
#     print(f"{n} of {len(test_dataloader)}", end="\r")
#     for i in range(len(x)):
#         if not joint_logits[i].allclose(joint_logits[i,0]):
#             print("YES!")

In [32]:
s_num = 0

words = x[s_num, :l[s_num]].to("cpu")

ex = best_edges[s_num, :l[s_num]].to("cpu")
lab = labels_best_edges[s_num, :l[s_num]].to("cpu")
att = attachment_orders_best_edges[s_num, :l[s_num]].to("cpu")

t_ex = target_ex[s_num, :l[s_num]].to("cpu")
t_lab = target_lab[s_num, :l[s_num]].to("cpu")
t_att = target_att[s_num, :l[s_num]].to("cpu")

In [42]:
edges = []

for i, (i_parent, i_lab, i_att) in enumerate(zip(ex, lab, att), 1):
    edges.append(f"{{{i_parent}->{i}, \"{inverse_sym_dict[i_lab.item()]}#{i_att}\"}}")

print(f"TreePlot[{{{', '.join(edges)}}}, Top, 0, VertexLabels -> Automatic, DirectedEdges -> True]")

TreePlot[{{3->1, "NP#1"}, {3->2, "NP#1"}, {4->3, "S#1"}, {0->4, "DROOT#1"}, {4->5, "S#1"}, {5->6, "CNP#2"}, {8->7, "NP#1"}, {5->8, "CNP#2"}, {10->9, "NP#1"}, {8->10, "NP#1"}, {16->11, "NP#1"}, {4->12, "VROOT#2"}, {16->13, "S#1"}, {13->14, "PP#1"}, {16->15, "NP#1"}, {4->16, "S#1"}, {23->17, "VROOT#2"}, {34->18, "S#1"}, {33->19, "VP#1"}, {19->20, "PP#1"}, {19->21, "PP#1"}, {23->22, "VROOT#2"}, {0->23, "DROOT#1"}, {26->24, "NP#1"}, {26->25, "NP#1"}, {41->26, "S#1"}, {41->27, "S#1"}, {30->28, "PP#1"}, {30->29, "PP#1"}, {41->30, "S#1"}, {30->31, "PP#1"}, {30->32, "PP#1"}, {34->33, "S#1"}, {23->34, "S#1"}, {23->35, "VROOT#2"}, {41->36, "S#1"}, {41->37, "S#1"}, {39->38, "NP#1"}, {41->39, "S#1"}, {41->40, "S#1"}, {34->41, "S#1"}, {4->42, "VROOT#2"}}, Top, 0, VertexLabels -> Automatic, DirectedEdges -> True]


In [43]:
target_edges = []

for i, (i_parent, i_lab, i_att) in enumerate(zip(t_ex, t_lab, t_att), 1):
    target_edges.append(f"{{{i_parent}->{i}, \"{inverse_sym_dict[i_lab.item()]}#{i_att}\"}}")

print(f"TreePlot[{{{', '.join(target_edges)}}}, Top, 0, VertexLabels -> Automatic, DirectedEdges -> True]")

TreePlot[{{3->1, "NP#1"}, {3->2, "NP#1"}, {4->3, "S#1"}, {0->4, "DROOT#1"}, {4->5, "S#1"}, {5->6, "CNP#1"}, {8->7, "NP#1"}, {5->8, "CNP#1"}, {10->9, "NP#1"}, {4->10, "S#1"}, {4->11, "S#1"}, {4->12, "VROOT#3"}, {33->13, "VP#1"}, {13->14, "PP#1"}, {16->15, "NP#1"}, {13->16, "PP#1"}, {4->17, "VROOT#3"}, {13->18, "CO#2"}, {18->19, "AP#1"}, {19->20, "PP#1"}, {19->21, "PP#1"}, {4->22, "VROOT#3"}, {4->23, "CS#2"}, {26->24, "NP#1"}, {26->25, "NP#1"}, {23->26, "S#1"}, {33->27, "VP#1"}, {27->28, "PP#1"}, {27->29, "PP#1"}, {33->30, "VP#1"}, {30->31, "PP#1"}, {30->32, "PP#1"}, {34->33, "VP#1"}, {23->34, "S#1"}, {4->35, "VROOT#3"}, {41->36, "S#1"}, {41->37, "S#1"}, {39->38, "NP#1"}, {41->39, "S#1"}, {41->40, "S#1"}, {33->41, "VP#1"}, {4->42, "VROOT#3"}}, Top, 0, VertexLabels -> Automatic, DirectedEdges -> True]


In [35]:
from german_parser.util.c_and_d import Dependency, DependencyTree, ConstituentTree, Terminal
from collections import defaultdict

def create_d_tree(heads: torch.Tensor, labels: torch.Tensor, orders: torch.Tensor, word_ids: torch.Tensor, word_dict: dict[int, str]):
    num_words = len(heads)
    modifiers: dict[int, dict[int, list[Dependency]]] = defaultdict(lambda: defaultdict(lambda x=None: []))
    terminals: dict[int, Terminal] = {}

    for child, (head, label, order, word) in enumerate(zip(heads, labels, orders, word_ids), start=1):
        head = int(head.item())
        label = int(label.item())
        order = int(order.item())
        word = int(word.item())

        terminals[child] = Terminal(
            idx=child,
            word=word_dict[word]
        )

        # child with head of 0 is the root node of the d-tree; don't add it to modifiers dict
        if head == 0:
            continue
        
        modifiers[head][order].append(Dependency(head=head, modifier=child, sym=inverse_sym_dict[label]))

    return DependencyTree(modifiers=modifiers, terminals=terminals, num_words=num_words)

def create_c_tree(heads: torch.Tensor, labels: torch.Tensor, orders: torch.Tensor, word_ids: torch.Tensor, word_dict: dict[int, str]):
    d_tree = create_d_tree(heads, labels, orders, word_ids, word_dict)
    return ConstituentTree.from_d_tree(d_tree)

In [36]:
c_tree = create_c_tree(heads=ex, labels=lab, orders=att, word_ids=words, word_dict={**inverse_word_dict, **{-i: w for i, w in test_new_words.items()}})
t_c_tree = create_c_tree(heads=t_ex, labels=t_lab, orders=t_att, word_ids=words, word_dict={**inverse_word_dict, **{-i: w for i, w in test_new_words.items()}})

AssertionError: 

In [37]:
def get_bracket(c_tree: ConstituentTree, node: int|None=None, ignore_pre_terminal_sym: bool=False, ignore_non_terminal_sym: bool=False, ignore_all_syms: bool=False):
    if node is None:
        return get_bracket(c_tree,
                           node=c_tree.root,
                           ignore_pre_terminal_sym=ignore_pre_terminal_sym,
                           ignore_non_terminal_sym=ignore_non_terminal_sym,
                           ignore_all_syms=ignore_all_syms
                           )
    
    c = c_tree.constituents[node]
    if c.is_pre_terminal:
        return f"({'?' if ignore_pre_terminal_sym or ignore_all_syms else c.sym} {c.id}={c_tree.terminals[node].word})"
    
    res = f"({'?' if ignore_non_terminal_sym or ignore_all_syms else c.sym} "
    for child in c.children:
        res += get_bracket(c_tree,
                           node=child,
                           ignore_pre_terminal_sym=ignore_pre_terminal_sym,
                           ignore_non_terminal_sym=ignore_non_terminal_sym,
                           ignore_all_syms=ignore_all_syms
                           )
        
    res += ")"
    return res

In [40]:
get_bracket(c_tree, ignore_all_syms=False)

'(VZ (ADV (PPOSS (? 4=rÃ¤umten)(? 2=zweijÃ¤hrige)(NP (? 3=Bedenkzeit)(? 1=Eine))(CAC (? 5=Major)(? 6=und)(NP (? 8=Commonwealth-Kollegen)(? 7=seine)))(NP (? 10=Nigerianern)(? 9=den))(APPO (? 13=nach)(? 14=Ablauf)(NP (? 16=Frist)(? 11=ein)(? 15=dieser))))(VVINF (? 18=pÃ¼nktlich)(APPO (? 19=zum)(? 20=nÃ¤chsten)(? 21=Commonwealth-Gipfel)))(PPOSS (? 23=soll)(? 25=afrikanische)(NP (? 26=Staat)(? 24=der))(APPO (? 27=in)(? 28=aller))(? 29=Form)(VVINF (? 34=werden)(VVINF (? 33=ausgeschlossen)(APPO (? 30=aus)(? 31=der)(? 32=Gemeinschaft)(PPOSS (? 41=nachkommt)(? 36=wenn)(? 37=er)(NP (? 39=Demokratisierungs-Forderungen)(? 38=den))(? 40=nicht)))))))(? 12=-)(? 17=,)(? 22=,)(? 35=,)(? 42=.))'

In [41]:
get_bracket(t_c_tree, ignore_all_syms=False)

'(VROOT (CS (S (? 4=rÃ¤umten)(NP (? 3=Bedenkzeit)(? 1=Eine)(? 2=zweijÃ¤hrige))(CNP (? 5=Major)(? 6=und)(NP (? 8=Commonwealth-Kollegen)(? 7=seine)))(NP (? 10=Nigerianern)(? 9=den))(? 11=ein))(S (? 23=soll)(NP (? 26=Staat)(? 24=der)(? 25=afrikanische))(VP (? 34=werden)(VP (? 33=ausgeschlossen)(CO (PP (? 13=nach)(? 14=Ablauf)(NP (? 16=Frist)(? 15=dieser)))(AP (? 18=pÃ¼nktlich)(PP (? 19=zum)(? 20=nÃ¤chsten)(? 21=Commonwealth-Gipfel))))(PP (? 27=in)(? 28=aller)(? 29=Form))(PP (? 30=aus)(? 31=der)(? 32=Gemeinschaft))(S (? 41=nachkommt)(? 36=wenn)(? 37=er)(NP (? 39=Demokratisierungs-Forderungen)(? 38=den))(? 40=nicht))))))(? 12=-)(? 17=,)(? 22=,)(? 35=,)(? 42=.))'

In [18]:
class B(object):
    pass

In [19]:
a = B()

In [21]:
b = B()

In [22]:
a & b

TypeError: unsupported operand type(s) for &: 'B' and 'B'