In [1]:
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
import torch
import torch.nn.functional as F

from german_parser.util.const import CONSTS
from german_parser.util.dataloader import TigerDatasetGenerator
import dill as pickle
from string import punctuation
from german_parser.model import TigerModel
from german_parser.util import BatchUnionFind

In [2]:
(train_dataloader, train_new_words), (dev_dataloader, dev_new_words), (test_dataloader, test_new_words), character_set, character_flag_generators, inverse_word_dict, inverse_sym_dict = pickle.load(open("required_vars.pkl", "rb"))

In [3]:
model = TigerModel(
    word_embedding_params=TigerModel.WordEmbeddingParams(char_set=character_set, char_flag_generators=character_flag_generators, char_internal_embedding_dim=100,
                                   char_part_embedding_dim=100, 
                                   word_part_embedding_dim=100, 
                                   char_internal_window_size=3,
                                   word_dict=inverse_word_dict),
    enc_lstm_params=TigerModel.LSTMParams(
        hidden_size=512,
        bidirectional=True,
        num_layers=3),
    dec_lstm_params=TigerModel.LSTMParams(
        hidden_size=512,
        bidirectional=False,
        num_layers=1
        ),
        enc_attention_mlp_dim=512,
        dec_attention_mlp_dim=512,
        enc_label_mlp_dim=128,
        dec_label_mlp_dim=128,
        num_biaffine_attention_classes=2,
        num_constituent_labels=len(inverse_sym_dict),
        enc_attachment_mlp_dim=128,
        dec_attachment_mlp_dim=64,
        max_attachment_order=train_dataloader.dataset.attachment_orders.max() + 1
    )
# model = model.to(device="cuda") # type: ignore

In [4]:
_, (x, l, *_) = next(enumerate(train_dataloader))

In [10]:
def find_tree(model: TigerModel, input: tuple[torch.Tensor, torch.Tensor], new_words_dict: dict[int, str] | None):
    """NOTE: indices are 1-indexed. index 0 corresponds to virtual root of D-tree (which is different from virtual root of C-tree)

    Args:
        model (TigerModel): _description_
        input (tuple[torch.Tensor, torch.Tensor]): _description_
        new_words_dict (dict[int, str] | None): _description_
    """    
    x, lengths = input
    self_attention, constituent_labels, attachment_orders, indices = model.forward(input, new_words_dict)
    # self_attention has size (B, T, T + 1)

    B, T, *_ = self_attention.shape
    uf = BatchUnionFind(B, model.beam_size, N=T + 1, device=model.dummy_param.device)

    joint_logits = torch.zeros(B, model.beam_size) # (B, K)

    for t in range(T):
        arc_probs = self_attention[:, t, :].log_softmax(dim=-1)

        candidate_joint_logits = joint_logits[:, :, None] + arc_probs[:, None, :] # (B, K, T + 1)

        child_idx = torch.fill(torch.empty(B, model.beam_size, T + 1, device="cpu", dtype=torch.long), t)

        # arc_probs_argsort = arc_probs.argsort(dim=-1, descending=True)
        # candidate_head_idx = torch.zeros()
        pass

In [11]:
find_tree(model, (x, l), train_new_words)