In [1]:
# todo create pointing network!

In [15]:
import torch

In [16]:
import transformers

In [42]:
from transformers.models.bert.modeling_bert import BertLayer, BertConfig, BertAttention
from transformers import BertForTokenClassification

In [18]:
pointer_config = BertConfig(num_attention_heads=1, hidden_size=768, num_hidden_layers=2, intermediate_size=768)

In [19]:
# pointer network
tag_embedding = torch.nn.Embedding(12,5)
pos_embedding = torch.nn.Embedding(100,5)

# Linear SWISH GELU
linear = torch.nn.Linear(5,5)
swish = torch.nn.SiLU()
gelu = torch.nn.GELU()

# 2x encoder
encoder = BertLayer(pointer_config)

# Attention Layer! (single head)
attention = BertAttention(pointer_config)


In [23]:
import pandas as pd
from transformers import AutoTokenizer
import datasets

In [191]:
data_train = pd.read_csv("train_with_pointing.csv")
data_train = datasets.Dataset.from_pandas(data_train)

In [192]:
def create_label_dict(max_mask=3, use_pointing=True):
    label_map = {'PAD': 0, 'SWAP': 1, 'KEEP': 2, 'DELETE': 3}
    # Create Insert 1 MASK to insertion N MASKS.
    for i in range(1, max_mask+1):
        label_map[f'KEEP|{i}'] = len(label_map)
    if not use_pointing:
        label_map[f'DELETE|{i}'] = len(label_map)
    return label_map

In [193]:
label_dict = create_label_dict(25)

In [194]:
from mbee import compute_edits_and_insertions
from insert_convert import InsertionConverter, get_number_of_masks
from trying import PointingConverter

In [195]:
point_converter = PointingConverter({}, False)

In [196]:
def tokenize_function_src_tgt(examples, tokenizer, src="informal", tgt="formal"):
    returned_dict = {f"{src}_{i}": j for i,j in tokenizer(examples[src]).items()}
    returned_dict.update({f"{tgt}_{i}": j for i,j in tokenizer(examples[tgt]).items()})
    return returned_dict

In [197]:
data_train = data_train.map(
    tokenize_function_src_tgt,
    batched=True,
    fn_kwargs={
        "tokenizer": AutoTokenizer.from_pretrained("indolem/indobert-base-uncased")
    },
)

Map: 100%|██████████| 1922/1922 [00:00<00:00, 9236.74 examples/s]


In [198]:
data_train

Dataset({
    features: ['informal', 'formal', 'point_indexes', 'label', 'informal_input_ids', 'informal_token_type_ids', 'informal_attention_mask', 'formal_input_ids', 'formal_token_type_ids', 'formal_attention_mask'],
    num_rows: 1922
})

In [199]:
def create_pointer_labels(points, label_map):
    labels = [t.added_phrase for t in points]
    point_indexes = [t.point_index for t in points]
    point_indexes_set = set(point_indexes)
    new_labels = []
    for i, added_phrase in enumerate(labels):
        if i not in point_indexes_set:
            new_labels.append(label_map["DELETE"])
        elif not added_phrase:
            new_labels.append(label_map["KEEP"])
        else:
            new_labels.append(label_map["KEEP|" + str(len(added_phrase.split()))])
    return new_labels

In [203]:
def generate_tokenized(examples, tokenizer, label_dict, point_converter, src="informal", tgt="formal"):
    src_tokenized = tokenizer.tokenize(examples[src], add_special_tokens=True)
    tgt_tokenized = tokenizer.tokenize(examples[tgt], add_special_tokens=True)
    points = point_converter.compute_points(src_tokenized, ' '.join(tgt_tokenized))
    label = create_pointer_labels(points, label_dict)
    point_indexes = [t.point_index for t in points] 
    # change them to torch tensors
    label = label
    point_indexes = point_indexes
    return {f"tag_labels": label, f"point_labels": point_indexes}

In [204]:
tokenizer = AutoTokenizer.from_pretrained("indolem/indobert-base-uncased")

In [205]:
data_train = data_train.map(
    generate_tokenized,
    batched=False,
    fn_kwargs={
        "tokenizer": tokenizer,
        "label_dict": label_dict,
        "point_converter": point_converter
    },
)

Map: 100%|██████████| 1922/1922 [00:02<00:00, 929.97 examples/s] 


In [211]:
data_train

Dataset({
    features: ['informal', 'formal', 'point_indexes', 'label', 'informal_input_ids', 'informal_token_type_ids', 'informal_attention_mask', 'formal_input_ids', 'formal_token_type_ids', 'formal_attention_mask', 'tag_labels', 'point_labels'],
    num_rows: 1922
})

In [218]:
len(data_train['point_labels'][0])

31

In [216]:
len(data_train['formal_input_ids'][0])

29

In [207]:
data_train

Dataset({
    features: ['informal', 'formal', 'point_indexes', 'label', 'informal_input_ids', 'informal_token_type_ids', 'informal_attention_mask', 'formal_input_ids', 'formal_token_type_ids', 'formal_attention_mask', 'tag_labels', 'point_labels'],
    num_rows: 1922
})

In [46]:
bert_koto = BertForTokenClassification.from_pretrained("indolem/indobert-base-uncased", num_labels=12)

Some weights of BertForTokenClassification were not initialized from the model checkpoint at indolem/indobert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [228]:
# collator


class FelixCollator:
    def __init__(self, tokenizer, pad_label=-100):
        self.tokenizer = tokenizer
        self.pad_label = pad_label

    def __call__(self, batch):
        # batch is a list of dicts
        output_dict = {}
        informal_input_ids, informal_attention_mask = [
            [i[col] for i in batch]
            for col in [
                "informal_input_ids",
                "informal_attention_mask",
            ]
        ]
        formal_input_ids = [i["formal_input_ids"] for i in batch]

        tag_label = [i["tag_labels"] for i in batch]

        tokenized_output = self.tokenizer.pad(
            {
                "input_ids": informal_input_ids,
                "attention_mask": informal_attention_mask,
            },
            return_tensors="pt",
        )
        tokenized_output["token_type_ids"] = torch.zeros_like(
            tokenized_output["input_ids"]
        )

        output_label = self.tokenizer.pad(
            {
                "input_ids": formal_input_ids,
            },
            return_tensors="pt",
        )
        out_label = output_label["input_ids"]
        # change pad token to -100
        out_label[out_label == self.tokenizer.pad_token_id] = self.pad_label
        print(out_label)
        output_dict.update(tokenized_output)
        output_dict['labels'] = out_label
        print(output_dict)
        # add tag_label to output_dict
        # each tag_label is a list of labels (list) with different length
        # pad first
        max_len = max([len(i) for i in tag_label])
        tag_label = [i + [self.pad_label] * (max_len - len(i)) for i in tag_label]
        tag_label = torch.tensor(tag_label)

        output_dict["tag_labels"] = tag_label

        # add point_label to output_dict
        # same as above
        point_label = [i["point_labels"] for i in batch]
        max_len = max([len(i) for i in point_label])
        point_label = [i + [self.pad_label] * (max_len - len(i)) for i in point_label]
        point_label = torch.tensor(point_label)

        output_dict["point_labels"] = point_label
        return output_dict

In [229]:
from torch.utils.data import DataLoader

loader = DataLoader(data_train, batch_size=2, collate_fn=FelixCollator(tokenizer))

In [232]:

for current_batch in loader:
    print(current_batch.keys())
    input_to_koto = {k: v for k, v in current_batch.items() if k in ['input_ids', 'attention_mask', 'token_type_ids']}
    print(current_batch['labels'].shape)
    for x in ['token_type_ids', 'input_ids', 'attention_mask']:
        print(input_to_koto[x].shape, x)
    tag_pred = bert_koto(**input_to_koto, labels=current_batch['tag_labels'])
    break

tensor([[    3, 11450,  1818, 10121, 21474,  8014,  1604, 10518,   962,  1843,
          2587,  4207,   933,  2460,  3716, 12411,    16,  3005,  1975,    18,
          5218,  3774, 21474, 27330, 10518,   962, 10155,    18,     4],
        [    3,  4863,  6097,  4374,    18,  1731,  2882,  3888,  6505,  1881,
         10896,  4362,  1925,  2643,  6813,    16,     6,  2022,  4942,  1560,
          2289,     6,  9153, 23120,    18,  5218,  3774,    18,     4]])
{'input_ids': tensor([[    3, 11450,  1862,   932,   945, 10121, 21474,  8014,  1604, 10518,
           962,  1843,  2587,  4207,   933, 17849, 17104, 21463, 12411,  1476,
            16, 14099, 17849,    18, 21140, 21474, 27330, 10518,   962, 10155,
             4],
        [    3,  4863,  6097,  2118,    18,  1731,  2882,  3888,  6505,  1881,
         10896,  4362,    16,  1925,  2643,  6813,     6,  2022,  4942,  1560,
          2289,     6,  9153, 23120,    18,  5218,  3774,     4,     0,     0,
             0]]), 'attention_ma

In [148]:
current_batch = next(iter(loader))

input_to_koto = {k: v for k, v in current_batch.items() if k in ['input_ids', 'attention_mask', 'token_type_ids']}
input_to_koto['labels'] = current_batch['tag_label']
bert_koto(**input_to_koto)

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


RuntimeError: The size of tensor a (29) must match the size of tensor b (31) at non-singleton dimension 1

In [125]:
next(iter(loader))

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'input_ids': tensor([[    3, 11450,  1818, 10121, 21474,  8014,  1604, 10518,   962,  1843,
           2587,  4207,   933,  2460,  3716, 12411,    16,  3005,  1975,    18,
           5218,  3774, 21474, 27330, 10518,   962, 10155,    18,     4],
         [    3,  4863,  6097,  4374,    18,  1731,  2882,  3888,  6505,  1881,
          10896,  4362,  1925,  2643,  6813,    16,     6,  2022,  4942,  1560,
           2289,     6,  9153, 23120,    18,  5218,  3774,    18,     4]]),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1]]),
 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0]]),
 'labels': tensor([[    3, 11450,

In [None]:
## POINTER NETWORK INFERENCE

def _realize_beam_search(self, source_token_ids,
                           ordered_source_indexes,
                           tags,
                           source_length):
    """Returns realized prediction using indexes and tags.

    TODO: Refactor this function to share code with
    `_create_masked_source` from insertion_converter.py to reduce code
    duplication and to ensure that the insertion example creation is consistent
    between preprocessing and prediction.

    Args:
      source_token_ids: List of source token ids.
      ordered_source_indexes: The order in which the kept tokens should be
        realized.
      tags: a List of tags.
      source_length: How long is the source input (excluding padding).

    Returns:
      Realized predictions (with deleted tokens).
    """
    # Need to help type checker.
    self._inverse_label_map = cast(Mapping[int, str], self._inverse_label_map)

    source_token_ids_set = set(ordered_source_indexes)
    out_tokens = []
    out_tokens_with_deletes = []
    for j, index in enumerate(ordered_source_indexes):
      token = self._builder.tokenizer.convert_ids_to_tokens(
          [source_token_ids[index]])
      out_tokens += token
      tag = self._inverse_label_map[tags[index]]
      if self._use_open_vocab:
        out_tokens_with_deletes += token
        # Add the predicted MASK tokens.
        number_of_masks = insertion_converter.get_number_of_masks(tag)
        # Can not add phrases after last token.
        if j == len(ordered_source_indexes) - 1:
          number_of_masks = 0
        masks = [constants.MASK] * number_of_masks
        out_tokens += masks
        out_tokens_with_deletes += masks

        # Find the deleted tokens, which appear after the current token.
        deleted_tokens = []
        for i in range(index + 1, source_length):
          if i in source_token_ids_set:
            break
          deleted_tokens.append(source_token_ids[i])
        # Bracket the deleted tokens, between unused0 and unused1.
        if deleted_tokens:
          deleted_tokens = [constants.DELETE_SPAN_START] + list(
              self._builder.tokenizer.convert_ids_to_tokens(deleted_tokens)) + [
                  constants.DELETE_SPAN_END
              ]
          out_tokens_with_deletes += deleted_tokens
      # Add the predicted phrase.
      elif '|' in tag:
        pos_pipe = tag.index('|')
        added_phrase = tag[pos_pipe + 1:]
        out_tokens.append(added_phrase)

    if not self._use_open_vocab:
      out_tokens_with_deletes = out_tokens
    assert (
        out_tokens_with_deletes[0] == (constants.CLS)
    ), (f' {out_tokens_with_deletes} did not start/end with the correct tokens '
        f'{constants.CLS}, {constants.SEP}')
    return out_tokens_with_deletes


In [None]:
tag_embedding = self._tag_embedding_layer(edit_tags)
position_embedding = self._position_embedding_layer(tag_embedding)
edit_tagged_sequence_output = self._edit_tagged_sequence_output_layer(
    tf.keras.layers.concatenate(
        [bert_output, tag_embedding, position_embedding]))

intermediate_query_embeddings = edit_tagged_sequence_output
if self._bert_config.query_transformer:
    attention_mask = self._self_attention_mask_layer(
        intermediate_query_embeddings, input_mask)
    for _ in range(int(self._bert_config.query_transformer)):
    intermediate_query_embeddings = self._transformer_query_layer(
        [intermediate_query_embeddings, attention_mask])

query_embeddings = self._query_embeddings_layer(
    intermediate_query_embeddings)

key_embeddings = self._key_embeddings_layer(edit_tagged_sequence_output)

pointing_logits = self._attention_scores(query_embeddings, key_embeddings,
                                            tf.cast(input_mask, tf.float32))