In [1]:
# !wget https://raw.githubusercontent.com/huseinzol05/malay-dataset/master/parsing/constituency/train.txt
# !wget https://raw.githubusercontent.com/huseinzol05/malay-dataset/master/parsing/constituency/test.txt

In [2]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('mesolitica/nanot5-base-malaysian-cased')

In [3]:
import pyximport
import numpy as np
import torch
import torch.nn as nn
from malaya.function import trees_newline as trees
pyximport.install(setup_args={"include_dirs": np.get_include()})

import chart_helper

2023-09-20 01:03:10.066690: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-09-20 01:03:10.151429: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-09-20 01:03:10.609476: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-09-20 01:03:10.609664: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not l

In [4]:
from transformers import T5Config
from malaya.torch_model.t5 import T5Constituency

In [5]:
BERT_TOKEN_MAPPING = {
    "-LRB-": "(",
    "-RRB-": ")",
    "-LCB-": "{",
    "-RCB-": "}",
    "-LSB-": "[",
    "-RSB-": "]",
    "``": '"',
    "''": '"',
    "`": "'",
    '«': '"',
    '»': '"',
    '‘': "'",
    '’': "'",
    '“': '"',
    '”': '"',
    '„': '"',
    '‹': "'",
    '›': "'",
    "\u2013": "--", # en dash
    "\u2014": "--", # em dash
    }

def process_word(word):
    word = word.replace('\\/', '/').replace('\\*', '*')
    # Mid-token punctuation occurs in biomedical text
    word = word.replace('-LSB-', '[').replace('-RSB-', ']')
    word = word.replace('-LRB-', '(').replace('-RRB-', ')')
    if word == "n't" and cleaned_words:
        cleaned_words[-1] = cleaned_words[-1] + "n"
        word = "'t"
    return word

In [6]:
import collections

class Vocabulary(object):
    def __init__(self):
        self.frozen = False
        self.values = []
        self.indices = {}
        self.counts = collections.defaultdict(int)

    @property
    def size(self):
        return len(self.values)

    def value(self, index):
        assert 0 <= index < len(self.values)
        return self.values[index]

    def index(self, value):
        if not self.frozen:
            self.counts[value] += 1

        if value in self.indices:
            return self.indices[value]

        elif not self.frozen:
            self.values.append(value)
            self.indices[value] = len(self.values) - 1
            return self.indices[value]

        else:
            raise ValueError("Unknown value: {}".format(value))

    def index_or_unk(self, value, unk_value):
        assert self.frozen
        if value in self.indices:
            return self.indices[value]
        else:
            return self.indices[unk_value]

    def count(self, value):
        return self.counts[value]

    def freeze(self):
        self.frozen = True


In [7]:
train_treebank = trees.load_trees('train.txt')
train_parse = [tree.convert() for tree in train_treebank]

In [8]:
label_vocab = Vocabulary()
label_vocab.index(())

tag_vocab = Vocabulary()
START = tokenizer.bos_token
STOP = tokenizer.eos_token
UNK = tokenizer.unk_token
TAG_UNK = "UNK"
tag_vocab.index(START)
tag_vocab.index(STOP)
tag_vocab.index(UNK)

for tree in train_parse:
    nodes = [tree]
    while nodes:
        node = nodes.pop()
        if isinstance(node, trees.InternalParseNode):
            label_vocab.index(node.label)
            nodes.extend(reversed(node.children))
        else:
            tag_vocab.index(node.tag)
            
tag_vocab.freeze()
label_vocab.freeze()

In [9]:
config = T5Config.from_pretrained('mesolitica/nanot5-base-malaysian-cased')
config.num_labels = label_vocab.size
config.num_tags = tag_vocab.size
config.tag_loss_scale = 5.0
config.label_vocab = {str(k): v for k, v in label_vocab.indices.items()}
config.tag_vocab = tag_vocab.indices

In [10]:
model = T5Constituency(config)

In [19]:
trainable_parameters = [param for param in model.parameters() if param.requires_grad]

In [20]:
trainer = torch.optim.AdamW(trainable_parameters, lr = 2e-5)

In [11]:
class BatchIndices:
    """
    Batch indices container class (used to implement packed batches)
    """
    def __init__(self, batch_idxs_np):
        self.batch_idxs_np = batch_idxs_np
        # Note that the torch copy will be on GPU if use_cuda is set
        self.batch_idxs_torch = torch.from_numpy(batch_idxs_np)

        self.batch_size = int(1 + np.max(batch_idxs_np))

        batch_idxs_np_extra = np.concatenate([[-1], batch_idxs_np, [-1]])
        self.boundaries_np = np.nonzero(batch_idxs_np_extra[1:] != batch_idxs_np_extra[:-1])[0]
        self.seq_lens_np = self.boundaries_np[1:] - self.boundaries_np[:-1]
        assert len(self.seq_lens_np) == self.batch_size
        self.max_len = int(np.max(self.boundaries_np[1:] - self.boundaries_np[:-1]))

In [23]:
def split_batch(sentences, golds, subbatch_max_tokens=3000):
    lens = [
        len(tokenizer.tokenize(' '.join([word for (_, word) in sentence]))) + 2
        for sentence in sentences
    ]

    lens = np.asarray(lens, dtype=int)
    lens_argsort = np.argsort(lens).tolist()

    num_subbatches = 0
    subbatch_size = 1
    while lens_argsort:
        if (subbatch_size == len(lens_argsort)) or (subbatch_size * lens[lens_argsort[subbatch_size]] > subbatch_max_tokens):
            yield [sentences[i] for i in lens_argsort[:subbatch_size]], [golds[i] for i in lens_argsort[:subbatch_size]]
            lens_argsort = lens_argsort[subbatch_size:]
            num_subbatches += 1
            subbatch_size = 1
        else:
            subbatch_size += 1
            
def pad_sentence_batch(sentence_batch, pad_int):
    padded_seqs = []
    seq_lens = []
    max_sentence_len = max([len(sentence) for sentence in sentence_batch])
    for sentence in sentence_batch:
        padded_seqs.append(
            sentence + [pad_int] * (max_sentence_len - len(sentence))
        )
        seq_lens.append(len(sentence))
    return padded_seqs, seq_lens

In [26]:
start_index = 0
batch_size = 4

trainer.zero_grad()
batch_loss_value = 0.0
batch_trees = train_parse[start_index:start_index + batch_size]
batch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in batch_trees]
batch_num_tokens = sum(len(sentence) for sentence in batch_sentences)

In [27]:
for subbatch_sentences, subbatch_trees in split_batch(batch_sentences, batch_trees):
    sentences = subbatch_sentences
    golds = subbatch_trees
    packed_len = sum([(len(sentence) + 2) for sentence in sentences])
    i = 0
    tag_idxs = np.zeros(packed_len, dtype=int)
    batch_idxs = np.zeros(packed_len, dtype=int)
    for snum, sentence in enumerate(sentences):
        for (tag, word) in [(START, START)] + sentence + [(STOP, STOP)]:
            tag_idxs[i] = tag_vocab.index_or_unk(tag, TAG_UNK)
            batch_idxs[i] = snum
            i += 1
    
    batch_idxs = BatchIndices(batch_idxs)
    emb_idxs_map = {
        'tags': tag_idxs,
    }
    emb_idxs = [
        torch.from_numpy(v)
        for k, v in emb_idxs_map.items()
    ]
    gold_tag_idxs = torch.from_numpy(emb_idxs_map['tags'])
    all_input_ids = []
    all_word_start_mask = []
    all_word_end_mask = []

    for snum, sentence in enumerate(sentences):

        tokens = []
        word_start_mask = []
        word_end_mask = []
        tokens.append(START)
        word_start_mask.append(1)
        word_end_mask.append(1)

        cleaned_words = []
        for _, word in sentence:
            cleaned_words.append(process_word(word))

        for word in cleaned_words:
            word_tokens = tokenizer.tokenize(word)
            for _ in range(len(word_tokens)):
                word_start_mask.append(0)
                word_end_mask.append(0)
            word_start_mask[len(tokens)] = 1
            word_end_mask[-1] = 1
            tokens.extend(word_tokens)
        tokens.append(STOP)
        word_start_mask.append(1)
        word_end_mask.append(1)

        input_ids = tokenizer.convert_tokens_to_ids(tokens)
        all_input_ids.append(input_ids)
        all_word_start_mask.append(word_start_mask)
        all_word_end_mask.append(word_end_mask)
    
    padded = tokenizer.pad({
        'input_ids': all_input_ids,
    }, return_tensors = 'pt')

    all_word_start_mask = torch.from_numpy(np.array(pad_sentence_batch(all_word_start_mask, 0)[0]))
    all_word_end_mask = torch.from_numpy(np.array(pad_sentence_batch(all_word_end_mask, 0)[0]))
    
    padded['sentences'] = sentences
    padded['batch_idxs'] = batch_idxs
    padded['all_word_start_mask'] = all_word_start_mask
    padded['all_word_end_mask'] = all_word_end_mask
    padded['gold_tag_idxs'] = gold_tag_idxs
    padded['golds'] = golds
    
    loss, tag_loss = model(**padded)
    loss = tag_loss / len(subbatch_sentences) + loss / batch_num_tokens
    loss_value = float(loss.data.cpu().numpy())
    batch_loss_value += loss_value
    if loss_value > 0:
        loss.backward()
        
    padded = tokenizer.pad({
        'input_ids': all_input_ids,
    }, return_tensors = 'pt')
    padded['sentences'] = sentences
    padded['batch_idxs'] = batch_idxs
    padded['all_word_start_mask'] = all_word_start_mask
    padded['all_word_end_mask'] = all_word_end_mask
    trees, scores = model(**padded)
    
grad_norm = torch.nn.utils.clip_grad_norm_(trainable_parameters, 1.0)
trainer.step()
batch_loss_value

193.48486328125