In [1]:
import importlib

from german_parser.model.words import WordCNN
from german_parser.model.words import WordEmbedding

from torch.utils.data import DataLoader, default_collate

import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

import pickle
import torch

In [2]:
from german_parser.util.const import CONSTS

In [3]:
from german_parser.util.dataloader import TigerDatasetGenerator

In [4]:
g = TigerDatasetGenerator(f"{CONSTS['data_dir']}/tiger/tiger_2.2_utf8.xml", (0.2, 0.1), prop_of_tiger_to_use=0.1)

INFO:model:Parsing dataset from '/home/james/programming/ml/german/german_parser/util/../../data/tiger/tiger_2.2_utf8.xml'...
INFO:model:Parsed 5047 sentences.
INFO:model:Generating trees...
INFO:model:4346 (86.11%) trees generated.
INFO:model:Dataset split into 3042 training, 869 dev, and 435 test trees.


In [5]:
train_dataset = g.get_training_dataset()

In [6]:
def my_fn(x):
    res = default_collate(x)
    sentence_lenghts: torch.Tensor = res[1]
    arg_sort = sentence_lenghts.argsort(descending=True)

    return [r[arg_sort] for r in res]

In [7]:
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=my_fn)

In [8]:
import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn_utils
import torch.nn.functional as F


class LSTM(nn.Module): # batch_first = True
    def __init__(self, input_size, hidden_size, bidirectional=True, num_layers=1, dropout=0.2):
        """_summary_

        Args:
            input_size (_type_): _description_
            hidden_size (_type_): _description_
            bidirectional (bool, optional): _description_. Defaults to True.
            num_layers (int, optional): _description_. Defaults to 1.
            dropout (float, optional): _description_. Defaults to 0.2. Only applies to non-final layers (so no effect if num_layers = 1)
        """
        super().__init__()

        self.input_size = input_size   # E_in
        self.hidden_size = hidden_size # H_in
        self.num_layers = num_layers   # H_cell
        self.bidirectional = bidirectional
        self.dropout = dropout if num_layers > 1 else 0        # dropout rate

        self.lstm = nn.LSTM(
            input_size=self.input_size, 
            hidden_size=self.hidden_size, 
            num_layers=self.num_layers, 
            dropout=self.dropout,
            bidirectional=self.bidirectional,
            batch_first=True)
        
        self.dummy_param = nn.Parameter(torch.zeros(1), requires_grad=False)

    def forward(self, *args):
        """compute BiLSTM

        Args:
            x (*Any): input directly to pytorch LSTM

        Returns:
            _type_: _description_
        """
        # h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(device)
        # c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(device)

        out, (h, c) = self.lstm(*args) # output of size (B, T, D * hidden_size), where D = 1 if not bidirectional else 2; h and c are of size (num_layers * D, B, hidden_size)

        return out, (h, c)

In [27]:
from typing import Literal
from collections.abc import Callable
from pydantic import BaseModel, Field

class BiAffine(nn.Module):
    def __init__(self, num_classes: int, enc_input_size: int, dec_input_size: int, include_attention=False, check_accuracy: bool=False):
        super().__init__()

        self.num_classes = num_classes
        self.enc_input_size = enc_input_size
        self.dec_input_size = dec_input_size

        self.Z = nn.Parameter(torch.zeros(self.num_classes, self.dec_input_size, self.enc_input_size), requires_grad=True)
        self.b = nn.Parameter(torch.zeros(self.num_classes), requires_grad=True)
        self.U_enc = nn.Parameter(torch.zeros(self.num_classes, self.enc_input_size), requires_grad=True)
        self.U_dec = nn.Parameter(torch.zeros(self.num_classes, self.dec_input_size), requires_grad=True)

        self.include_attention = include_attention
        self.w = None
        if self.include_attention:
            self.w = nn.Parameter(torch.zeros(self.num_classes), requires_grad=True)

        self.check_accuracy = check_accuracy
        self._reset_parameters()

    def forward(self, enc: torch.Tensor, dec: torch.Tensor):
        """calculate biaffine score between enc and dec

        Args:
            enc (torch.Tensor): tensor of size (B, T, enc_input_size)
            dec (torch.Tensor): tensor of size (B, T, dec_input_size)
        """

        dec_brd = dec[:, :, None, None, :, None]         # (B, T, 1,     1,           dec_input_size, 1)
        Z = self.Z[None, None, None, :, :, :]            # (1, 1, 1,     num_classes, dec_input_size, enc_input_size)
        enc_brd = enc[:, None, :, None, :, None]         # (B, 1, T + 1, 1,           enc_input_size, 1)

        interaction_score = (dec_brd.transpose(-1, -2) @ Z @ enc_brd).squeeze(-1, -2) # (B, T, T + 1, num_classes) index via [batch_number, DECoder_index, ENCoder_index]

        dec_brd = dec_brd.squeeze(3)
        enc_brd = enc_brd.squeeze(3)

        enc_score = (self.U_enc @ enc_brd).squeeze(-1) # (B, 1, T + 1, num_classes)
        dec_score = (self.U_dec @ dec_brd).squeeze(-1) # (B, T, 1, num_classes)

        bias = self.b[None, None, None, :]             # (1, 1, 1, num_classes)

        res = interaction_score + enc_score + dec_score + bias

        # check correctness
        if self.check_accuracy:
            n_batches = enc.shape[0]
            seq_length = dec.shape[1]
            assert enc.shape[1] == seq_length + 1, "Encoder output must have one more item than decoder, as the first item denotes ROOT"

            for batch_num in range(n_batches):
                for c in range(self.num_classes):
                    for i in range(seq_length + 1): # encoder index
                        for j in range(seq_length): # decoder index
                            res_val = res[batch_num, j, i, c]

                            true_val = dec[batch_num, j] @ self.Z[c] @ enc[batch_num, i] + self.U_enc[c] @ enc[batch_num, i] + self.U_dec[c] @ dec[batch_num, j] + self.b[c]
                        
                            print((res_val.item() - true_val.item()) / true_val.item())

        if self.include_attention:
            res = self.w @ res.tanh().transpose(-1, -2) # (B, T, T + 1)

        return res


    def _reset_parameters(self):
        with torch.no_grad():
            Zb_bound = ((self.enc_input_size ** 0.5) * (self.dec_input_size ** 0.5)) ** 0.5
            self.Z.uniform_(-Zb_bound, Zb_bound)
            self.b.uniform_(-Zb_bound, Zb_bound)

            U_enc_bound = self.enc_input_size ** 0.5
            self.U_enc.uniform_(-U_enc_bound, U_enc_bound)
            U_dec_bound = self.dec_input_size ** 0.5
            self.U_dec.uniform_(-U_dec_bound, U_dec_bound)

            if self.include_attention:
                w_bound = self.num_classes ** 0.5
                self.w.uniform_(-w_bound, w_bound)

class TigerModel(nn.Module):
    class WordEmbeddingParams(BaseModel):
        char_set: dict[str, int]
        char_flag_generators: list[Callable[[str], Literal[1, 0]]]
        char_internal_embedding_dim: int
        char_part_embedding_dim: int
        word_part_embedding_dim: int
        char_internal_window_size: int
        word_dict: dict[int, str]

    class LSTMParams(BaseModel):
        hidden_size: int
        bidirectional: bool = Field(default=False)
        num_layers: int = Field(default=1)
        dropout: float = Field(default=0.2)

    def __init__(self, word_embedding_params: WordEmbeddingParams, enc_lstm_params: LSTMParams, dec_lstm_params: LSTMParams, num_biaffine_attention_classes=2, num_constituent_labels=10):
        super().__init__()
        self.dummy_param = nn.Parameter(torch.zeros(1), requires_grad=False) # to get self device
        
        # create word embeddor
        self.word_embedding_params = word_embedding_params
        self.word_embedding = WordEmbedding(
            char_set=self.word_embedding_params.char_set,
            char_flag_generators=self.word_embedding_params.char_flag_generators,
            char_internal_embedding_dim=self.word_embedding_params.char_internal_embedding_dim,
            char_part_embedding_dim=self.word_embedding_params.char_part_embedding_dim,
            word_part_embedding_dim=self.word_embedding_params.word_part_embedding_dim,
            char_internal_window_size=self.word_embedding_params.char_internal_window_size,
            word_dict=self.word_embedding_params.word_dict
        )

        # define encoder
        self.enc_lstm_params = enc_lstm_params
        assert enc_lstm_params.bidirectional == True, "Encoder must be bidirectional"
        self.enc_lstm = LSTM(
            input_size=word_embedding_params.char_part_embedding_dim + word_embedding_params.word_part_embedding_dim,
            hidden_size=self.enc_lstm_params.hidden_size,
            num_layers=self.enc_lstm_params.num_layers,
            bidirectional=self.enc_lstm_params.bidirectional,
            dropout=self.enc_lstm_params.dropout
        )

        # define decoder
        self.dec_lstm_params = dec_lstm_params
        assert self.dec_lstm_params.bidirectional == False, "Decoder must not be bidirectional"
        self.dec_lstm = LSTM(
            input_size=2 * self.enc_lstm_params.hidden_size,
            hidden_size=self.dec_lstm_params.hidden_size,
            num_layers=self.dec_lstm_params.num_layers,
            bidirectional=self.dec_lstm_params.bidirectional,
            dropout=self.dec_lstm_params.dropout
        )

        # define initial encoder state
        self.enc_init_state = nn.Parameter(
            torch.zeros(2 * self.enc_lstm_params.num_layers, 1, self.enc_lstm_params.hidden_size),
            requires_grad=True
        )

        # define dense layer to convert encoder final cell state into decoder initial cell state
        self.enc_final_cell_to_dec_init_cell = nn.Linear(
            2 * self.enc_lstm_params.hidden_size,
            self.dec_lstm_params.hidden_size
        )

        # define biaffine layer for attention
        self.biaffine_attention = BiAffine(
            num_classes=num_biaffine_attention_classes,
            enc_input_size=2 * self.enc_lstm_params.hidden_size,
            dec_input_size=self.dec_lstm_params.hidden_size,
            include_attention=True
        )

        # define biaffine layer for classification of constituent labels
        self.biaffine_constituent_classifier = BiAffine(
            num_classes=num_constituent_labels,
            enc_input_size=2 * self.enc_lstm_params.hidden_size,
            dec_input_size=self.dec_lstm_params.hidden_size,
            include_attention=False
        )

    def _get_final_concatenated_enc_hidden_state(self, c: torch.Tensor):
        """takes final two layers of final cell state of encoder and returns a tensor of size (B, 1, 2 * enc_hidden_size) to initialise the decoder (or encoder)

        Args:
            c (torch.Tensor): final cell state or hidden state, of size (enc_num_layers * 2, B, enc_hidden_size)

        Returns:
            torch.Tensor: last two layers of the state concatenated together. size (B, 1, 2 * enc_hidden_size)
        """
        _, B, _ = c.shape
        res = c[-2:] # take the last two layers (2, B, enc_hidden_size)
        res = res.transpose(0, 1).contiguous() # (B, 2, enc_hidden_size)
        res = res.view(B, 1, 2 * self.enc_lstm_params.hidden_size) # (B, 1, 2 * enc_hidden_size)
        return res

    def _get_decoder_init_state(self, encoder_final_hc: tuple[torch.Tensor, torch.Tensor]):
        """convert final encoder hidden state into a value to initialise hidden state for the decoder

        Args:
            encoder_final_hc (tuple[torch.Tensor, torch.Tensor]): tuple of tensors, each with size (enc_num_layers * 2, B, enc_hidden_size)

        returns tuple[torch.Tensor, torch.Tensor]: tuple of initial decoder state tensors, each with size (dec_num_layers, B, dec_hidden_size). First is initial hidden state, second is inital cell state for the decoder
        """
        h, c = encoder_final_hc
        _, B, _ = h.shape

        c = self._get_final_concatenated_enc_hidden_state(c)
        c = c.transpose(0, 1) # (1, B, 2 * enc_hidden_size)
        
        c_dec: torch.Tensor = self.enc_final_cell_to_dec_init_cell(c) # (1, B, dec_hidden_size)
        if self.dec_lstm_params.num_layers > 1:
            c_dec = torch.cat([c_dec, c_dec.new_zeros((self.dec_lstm_params.num_layers - 1, B, self.dec_lstm.hidden_size))], dim=0)
        
        h_dec = c_dec.tanh()

        return (h_dec, c_dec)
        

    def forward(self, input: tuple[torch.Tensor, torch.Tensor], new_words_dict: dict[int, str] | None):
        """forward

        Args:
            input (tuple[torch.Tensor, torch.Tensor]): tuple of (data, sentence_lengths), where data is a tensor of size (B, T) and sentence_lengths is a tensor of size (B,). B is batch size, T is max(sentence_length) across all batches. The input must be sorted in descending order of sentence length
            new_words_dict (dict[int, str] | None): dictionary of new words. positive indices in new_words_dict correspond to negative indices in input[0] (data). If None, then all unknown words must be coded as 0

        Returns:
            _type_: _description_
        """

        # transfer to current device. avoid making a copy if possible
        x, lengths = input     
        x = torch.as_tensor(x, device=self.dummy_param.device)

        B = len(lengths)

        # create packed embedding sequences
        x_embedded = self.word_embedding(x, new_words_dict) # (B, T, E) where B is batch_size, T is max(sentence_length), E is embedding dimension (char_part_embedding_dim + word_part_embedding_dim)
        x_embedded_packed = rnn_utils.pack_padded_sequence(x_embedded, lengths, batch_first=True, enforce_sorted=True)

        # henceforth, T refers to max(sentence_length) within the batch, rather than across all batches

        # define initial encoder state
        c_init = self.enc_init_state.repeat(1, B, 1) # (enc_num_layers * 2, B, enc_hidden_size)
        h_init = c_init.tanh()

        # feed through encoder
        enc_out, enc_final_state = self.enc_lstm(x_embedded_packed, (h_init, c_init)) # enc_out has size (B, T, 2 * enc_hidden_size)  
        enc_out_pad, _ = rnn_utils.pad_packed_sequence(enc_out, batch_first=True)

        # feed through decoder
        dec_init_state = self._get_decoder_init_state(enc_final_state) # tuple of tensors, each with size (dec_num_layers, B, dec_hidden_size)
        dec_out, _ = self.dec_lstm(enc_out, (dec_init_state[0], dec_init_state[1])) # (B, T, hidden_size)

        # TODO: apply dropout to enc_out

        # unpad encoder output (B, T + 1, 2 * enc_hidden_size)
        # concatenate final layer of initial encoder state with the output of the encoder
        h_init_res = self._get_final_concatenated_enc_hidden_state(h_init) # (B, 1, 2 * enc_hidden_size)
        enc_out_pad = torch.cat((h_init_res, enc_out_pad), dim=1) # (B, T + 1, 2 * enc_hidden_size)

        # unpad decoder output (B, T, hidden_size)
        dec_out_pad, _ = rnn_utils.pad_packed_sequence(dec_out, batch_first=True)

        # henceforth, indices are 0-indexed in the comments. Effectively, head indices are 1-indexed (0 indicates root), and dependency indices are 0-indexed
        # TASK 1: predict HEAD words
        # for batch b and word index j, argmax(self_attention[b, j]) gives a pointer i to HEAD of word j
        self_attention = self.biaffine_attention(enc_out_pad, dec_out_pad) # size (B, T, T + 1). index by (batch_num, decoder_index, encoder_index + 1), which represents (batch_num, dependency_index, head_index + 1)

        # TASK 2: predict ATTACHMENT labels
        # for batch b and dependency index j and head index i, constituent_lables[b, j, i] gives logits to classify the label of the dependency from word j to HEAD word i
        constituent_labels = self.biaffine_constituent_classifier(enc_out_pad, dec_out_pad) # size (B, T, T + 1, num_constituent_labels). index by (batch_num, decoder_index, encoder_index + 1, label_index), which represents (batch_num, dependency_index, head_index + 1, label_index)

        # TASK 3: predict attachment ORDER

        self._mask_out(self_attention, lengths)
        self._mask_out(constituent_labels, lengths)

        # TODO: TASK 4: predict DEPENDENCY labels (according to GM 2022, this will improve overall performance in a multitask setting)

        indices = self._get_batch_indices(lengths)

        return self_attention, constituent_labels, indices

    def _mask_out(self, out: torch.Tensor, lengths: torch.Tensor):
        """mask out unneeded output elements IN PLACE, given sentence lengths

        Args:
            out (torch.Tensor): size (B, T, T + 1)
            lengths (torch.Tensor): sorted tensor in descending order where len(lengths) = B and lengths[0] = T
        """
        B, T, *_ = out.shape
        dependency_index_mask = torch.triu(torch.full((T + 1, T), True))[lengths].unsqueeze(-1).repeat(1, 1, T + 1) # (B, T, T + 1)
        out[dependency_index_mask] = -torch.inf

        head_index_mask = torch.triu(torch.full((T + 2, T + 1), True))[lengths + 1].unsqueeze(-2).repeat(1, T, 1) # (B, T, T + 1)
        out[head_index_mask] = -torch.inf

    def _get_batch_indices(self, lengths: torch.Tensor):
        """Get indices for a given set of sentence lengths.
           Suppose x is a tensor of size (B, T, *), which has been masked out
           Some elements of x are not needed, as they correspond to padding. This function returns the indices of the elements that are needed
           x[indices] will give you a tensor of size (N, *)

        Args:
            lengths (torch.Tensor): sorted tensor in descending order where len(lengths) = B and lengths[0] = T, the longest sentence _within the batch_

        Returns:
            torch.Tensor: indices of size (N, *) where each row is non-masked
        """

        T = lengths[0].item() # max sentence length within batch
        indices = ~torch.triu(torch.full((T + 1, T), True))[lengths] # type: ignore
        return indices
        

In [28]:
m = TigerModel(TigerModel.WordEmbeddingParams(char_set=g.character_set, char_flag_generators=g.character_flag_generators, char_internal_embedding_dim=10, char_part_embedding_dim=10, word_part_embedding_dim=10, char_internal_window_size=3, word_dict=g.inverse_word_dict), TigerModel.LSTMParams(hidden_size=10, bidirectional=True), TigerModel.LSTMParams(hidden_size=10, bidirectional=False))

In [99]:
input = next(iter(train_dataloader))

In [100]:
self_attention, labels, indices = m((input[0], input[1]), train_dataset.get_new_words_dict())

In [101]:
T = input[1][0]

In [102]:
input[1]

tensor([40, 34, 34, 33, 29, 29, 27, 25, 23, 23, 22, 22, 21, 20, 19, 18, 17, 16,
        16, 15, 15, 13, 12, 11, 11, 10, 10,  9,  8,  6,  3,  2])

In [103]:
F.cross_entropy(self_attention[indices], input[2][:,0:T][indices]).backward()