In [2]:
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
import torch
import torch.nn.functional as F

from german_parser.util.const import CONSTS
from german_parser.util.dataloader import TigerDatasetGenerator
import dill as pickle
from string import punctuation
from german_parser.model import TigerModel
from german_parser.util import BatchUnionFind

In [3]:
(train_dataloader, train_new_words), (dev_dataloader, dev_new_words), (test_dataloader, test_new_words), character_set, character_flag_generators, inverse_word_dict, inverse_sym_dict = pickle.load(open("required_vars.pkl", "rb"))

In [9]:
# model = TigerModel(
#     word_embedding_params=TigerModel.WordEmbeddingParams(char_set=character_set, char_flag_generators=character_flag_generators, char_internal_embedding_dim=100,
#                                    char_part_embedding_dim=100, 
#                                    word_part_embedding_dim=100, 
#                                    char_internal_window_size=3,
#                                    word_dict=inverse_word_dict),
#     enc_lstm_params=TigerModel.LSTMParams(
#         hidden_size=512,
#         bidirectional=True,
#         num_layers=3),
#     dec_lstm_params=TigerModel.LSTMParams(
#         hidden_size=512,
#         bidirectional=False,
#         num_layers=1
#         ),
#         enc_attention_mlp_dim=512,
#         dec_attention_mlp_dim=512,
#         enc_label_mlp_dim=128,
#         dec_label_mlp_dim=128,
#         num_biaffine_attention_classes=2,
#         num_constituent_labels=len(inverse_sym_dict),
#         enc_attachment_mlp_dim=128,
#         dec_attachment_mlp_dim=64,
#         max_attachment_order=train_dataloader.dataset.attachment_orders.max() + 1
#     )
# model = model.to(device="cuda") # type: ignore

model: TigerModel = pickle.load(open("./models/tiger_model_2023_09_04-02_35_02_PM_epoch_24.pickle", "rb"))

In [78]:
_, (x, l, *_) = next(enumerate(test_dataloader))

In [74]:
def find_tree(model: TigerModel, input: tuple[torch.Tensor, torch.Tensor], new_words_dict: dict[int, str] | None):
    """NOTE: indices are 1-indexed. index 0 corresponds to virtual root of D-tree (which is different from virtual root of C-tree)

    Args:
        model (TigerModel): _description_
        input (tuple[torch.Tensor, torch.Tensor]): _description_
        new_words_dict (dict[int, str] | None): _description_
    """    
    x, lengths = input
    self_attention, constituent_labels, attachment_orders, indices = model.forward(input, new_words_dict)
    # self_attention has size (B, T, T + 1)

    B, T, *_ = self_attention.shape
    uf = BatchUnionFind(B, model.beam_size, N=T + 1, device=model.dummy_param.device)

    joint_logits = torch.zeros(B, model.beam_size, device=model.dummy_param.device) # (B, K)

    edges = torch.zeros(B, model.beam_size, T, dtype=torch.long, device=model.dummy_param.device) # (B, K, T); m[b, k, t - 1] is the 1-indexed parent of 1-indexed node t, in batch b, beam k

    for t in range(T):
        relevant_batches = t < lengths # used for updating the final arcs. otherwise, we get nans in batch b if t >= sentence_length[b]

        arc_probs = self_attention[:, t, :].log_softmax(dim=-1) # (B, T + 1)

        candidate_joint_logits = joint_logits[:, :, None] + arc_probs[:, None, :] # (B, K, T + 1); the heuristic we would like to maximise

        children = torch.tensor(t + 1, device=model.dummy_param.device) # all batches and beams share the same child index (1-indexed)
        parents = torch.arange(0, T + 1, 1, device=model.dummy_param.device, dtype=torch.long).repeat(B, model.beam_size, 1) # (B, K, T + 1), where m[b, k, t] = t to represent the index of each parent 

        no_cycles_mask = uf.is_same_set(children.expand_as(parents), parents) # (B, K, T + 1); m[b, k, s + 1] is true if in batch b, beam k, joining child (t + 1) and parent (s + 1) would lead to a cycle
        no_cycles_mask[:, :, 1:] |= ~indices.unsqueeze(1).to(device=model.dummy_param.device)

        candidate_joint_logits[no_cycles_mask] = -torch.inf # mask cycles out so they can never be a maximiser
        flattened_top_candidate_idx = candidate_joint_logits.flatten(-2, -1).argsort(dim=-1, descending=True)[:, :model.beam_size] # (B, K) in range [0, (K * T + 1)); for each batch, find top 10 best performing parent-beam combinations
        
        top_parents = parents.flatten(-2, -1).gather(index=flattened_top_candidate_idx, dim=-1) # (B, K); for each batch and beam, get the 1-indexed id of the parent we want to attach

        batch_names = torch.arange(model.beam_size, device=model.dummy_param.device, dtype=torch.long).view(1, -1, 1).expand(B, -1, T + 1).flatten(-2, -1) # (B, K, T + 1); m[b, k, s] = k for all b, s, k
        used_batches = batch_names.gather(index=flattened_top_candidate_idx, dim=-1) # (B, K) where each element is in the range [0, K). m[b, k] tells you what the kth new beam should be

        new_data = uf.data.gather(index=used_batches.unsqueeze(-1).expand_as(uf.data), dim=1)
        new_rank = uf.rank.gather(index=used_batches.unsqueeze(-1).expand_as(uf.rank), dim=1)
        new_edges = edges.gather(index=used_batches.unsqueeze(-1).expand_as(edges), dim=1)
        new_joint_logits = candidate_joint_logits.flatten(-2, -1).gather(index=flattened_top_candidate_idx, dim=-1)

        uf.data[relevant_batches] = new_data[relevant_batches]
        uf.rank[relevant_batches] = new_rank[relevant_batches]
        edges[relevant_batches] = new_edges[relevant_batches]
        joint_logits[relevant_batches] = new_joint_logits[relevant_batches]

        uf.union(children.expand_as(top_parents), top_parents)

        edges[:, :, t] = top_parents

        pass

    return edges, joint_logits

In [79]:
edges, joint_logits = find_tree(model, (x, l), test_new_words)

  result = _VF.lstm(input, batch_sizes, hx, self._flat_weights, self.bias,


In [82]:
edges[31, :]

tensor([[0, 3, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 3, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
         2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [0, 3, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
         3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
        [0, 3, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 3, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
         2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [0, 3, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
         3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
        [0, 3, 1, 0, 0, 0, 0, 0,

In [91]:
for n, (x, l, *_) in enumerate(test_dataloader):
    edges, joint_logits = find_tree(model, (x, l), test_new_words)
    print(f"{n} of {len(test_dataloader)}", end="\r")
    for i in range(len(x)):
        if not joint_logits[i].allclose(joint_logits[i,0]):
            print("YES!")

  result = _VF.lstm(input, batch_sizes, hx, self._flat_weights, self.bias,


132 of 133

In [None]:
_, (x, l, *_) = next(enumerate(test_dataloader))