In [1]:
%cd ..
%load_ext autoreload
%autoreload 2

/home/haryoaw/documents/courses/nlp802/project/texteditalay


In [6]:
from neo_stif.components.train_data_preparation import prepare_data_tagging_and_pointer
import datasets
from torch.utils.data import DataLoader
from neo_stif.components.collator import FelixCollator
from neo_stif.components.utils import create_label_map
from transformers import BertForMaskedLM

In [8]:
import pandas as pd
from transformers import AutoTokenizer, BertForTokenClassification, BertConfig

import neo_stif

MAX_MASK = 30
USE_POINTING = True
tokenizer = AutoTokenizer.from_pretrained("indolem/indobert-base-uncased")
label_dict = create_label_map(MAX_MASK, USE_POINTING)

df_train = pd.read_csv("data/stif_indo/train_with_pointing.csv")
data_train = datasets.Dataset.from_pandas(df_train)
data_train, label_dict = prepare_data_tagging_and_pointer(data_train, tokenizer, label_dict)

Map: 100%|██████████| 1922/1922 [00:00<00:00, 5907.86 examples/s]
Map: 100%|██████████| 1922/1922 [00:01<00:00, 1866.01 examples/s]


In [26]:
loader = DataLoader(data_train, batch_size=2, shuffle=True, collate_fn=FelixCollator(tokenizer, pad_label_as_input=len(label_dict)))

In [20]:
bert_koto = BertForTokenClassification.from_pretrained("indolem/indobert-base-uncased", num_labels=12)

Some weights of BertForTokenClassification were not initialized from the model checkpoint at indolem/indobert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [27]:
from neo_stif.components.models import PointerNetwork


pointer_network_config = BertConfig(
    vocab_size=len(label_dict) + 1,
    num_hidden_layers=2,
    num_attention_heads=1,
    pad_token_id=len(label_dict),
)  # + 1 as the pad token

pointer_network = PointerNetwork(pointer_network_config)

In [22]:
import torch

In [23]:
optimizer_for_koto = torch.optim.AdamW(bert_koto.parameters(), lr=5e-5)

In [29]:
optimizer_for_pointer = torch.optim.AdamW(pointer_network.parameters(), lr=1e-3)

In [24]:
# TEST TAGGER
for current_batch in loader:
    input_to_koto = {k: v for k, v in current_batch.items() if k in ['input_ids', 'attention_mask', 'token_type_ids']}
    
    # Tagger
    tag_pred = bert_koto(**input_to_koto, labels=current_batch['tag_labels'])
    loss = tag_pred.loss
    loss.backward()
    optimizer_for_koto.step()
    optimizer_for_koto.zero_grad()
    break

In [54]:
# TEST POINTER
for current_batch in loader:    
    # Pointer
    input_to_pointer_real = {
        k: v for k, v in current_batch.items() if k in ['tag_labels_input', 'attention_mask', 'token_type_ids']
    }
    input_to_pointer_real['input_ids'] = input_to_pointer_real.pop('tag_labels_input')
    
    point_pred = pointer_network(**input_to_pointer_real, labels=current_batch['point_labels'])

    loss, att = point_pred
    loss.backward()
    print(loss)
    optimizer_for_pointer.step()
    optimizer_for_pointer.zero_grad()
    break

tensor(2.1499, grad_fn=<NllLossBackward0>)


## Insertion Training

tokenizer

In [10]:
from neo_stif.components.extract_insertion import create_masked_source

In [11]:
df_train.point_indexes.iloc[0][2]

','

In [12]:
train_processed, test_processed = create_masked_source(df_train.informal, df_train.formal, df_train )

TypeError: create_masked_source() missing 2 required positional arguments: 'target_tokens' and 'label_map'

In [80]:
df_train = pd.read_csv("data/stif_indo/train_with_pointing.csv")
data_train_insert = datasets.Dataset.from_pandas(df_train)

In [13]:
df_dev = pd.read_csv("data/stif_indo/dev_with_pointing.csv")
data_dev_insert = datasets.Dataset.from_pandas(df_dev)

In [14]:
data_dev_insert, label_dict = prepare_data_tagging_and_pointer(data_train, tokenizer, label_dict)

Map: 100%|██████████| 1922/1922 [00:00<00:00, 11039.53 examples/s]
Map: 100%|██████████| 1922/1922 [00:01<00:00, 1842.86 examples/s]


In [86]:
df_train = pd.read_csv("data/stif_indo/train_with_pointing.csv")
data_train = datasets.Dataset.from_pandas(df_train)
data_train, label_dict = prepare_data(data_train, tokenizer, label_dict, max_mask=35)

Map: 100%|██████████| 1922/1922 [00:00<00:00, 13603.94 examples/s]
Map: 100%|██████████| 1922/1922 [00:00<00:00, 2072.35 examples/s]


In [93]:
data_train

Dataset({
    features: ['informal', 'formal', 'point_indexes', 'label', 'len_label', 'informal_input_ids', 'informal_token_type_ids', 'informal_attention_mask', 'formal_input_ids', 'formal_token_type_ids', 'formal_attention_mask', 'tag_labels', 'point_labels'],
    num_rows: 1922
})

In [111]:
informal = tokenizer.tokenize(data_train['informal'][0], add_special_tokens=True)
formal = tokenizer.tokenize(data_train['formal'][0], add_special_tokens=True)
point_indexes = data_train['point_labels'][0]
tag_label = data_train['tag_labels'][0]

In [15]:
def process_masked_source(x, tokenizer, label_map):
    dict_return = {}
    informal = tokenizer.tokenize(x["informal"], add_special_tokens=True)
    formal = tokenizer.tokenize(x["formal"], add_special_tokens=True)
    point_indexes = x["point_labels"]
    tag_label = x["tag_labels"]
    masked_tokens, target_tokens = create_masked_source(
        informal, tag_label, point_indexes, formal, label_map
    )
    masked_tokens_ids = [tokenizer.vocab[i] for i in masked_tokens]
    target_tokens_ids = [tokenizer.vocab[i] for i in target_tokens]
    attention_mask = [1] * len(masked_tokens_ids)
    token_type_ids = [0] * len(masked_tokens_ids)

    dict_return["input_ids"] = torch.LongTensor(masked_tokens_ids)
    dict_return["attention_mask"] = torch.LongTensor(attention_mask)
    dict_return["token_type_ids"] = torch.LongTensor(token_type_ids)
    dict_return["labels"] = torch.LongTensor(target_tokens_ids)
    return dict_return

In [156]:
masked_tokens, target_tokens = create_masked_source(
    informal, tag_label, point_indexes, formal, label_dict
)

In [157]:
data_train

Dataset({
    features: ['informal', 'formal', 'point_indexes', 'label', 'len_label', 'informal_input_ids', 'informal_token_type_ids', 'informal_attention_mask', 'formal_input_ids', 'formal_token_type_ids', 'formal_attention_mask', 'tag_labels', 'point_labels'],
    num_rows: 1922
})

In [None]:
data_dev_insertion = data_dev_insertion.map(
    process_masked_source,
    batched=False,
    fn_kwargs={"tokenizer": tokenizer, "label_map": label_dict},
)

In [158]:
data_train_insertion = data_train.map(
    process_masked_source,
    batched=False,
    fn_kwargs={"tokenizer": tokenizer, "label_map": label_dict},
)

Map:   0%|          | 0/1922 [00:00<?, ? examples/s]

Map: 100%|██████████| 1922/1922 [18:16<00:00,  1.75 examples/s]


In [159]:
# save data_train_insertion parquet
data_train_insertion.save_to_disk("data/stif_indo/train_insertion")

Saving the dataset (1/1 shards): 100%|██████████| 1922/1922 [00:00<00:00, 89321.59 examples/s]


In [173]:
tokenizer.decode(data_train_insertion['labels'][14])

'[CLS] tag [UNK] nag [PAD]ih hutang [UNK] utang [PAD] ke teman saat dia terlihat sedang [UNK] temen pas doi keliatan lagi [PAD] kaya. [SEP]'

In [164]:
## MUST LOAD!
testing = load_from_disk('data/stif_indo/train_insertion')

In [123]:
tokenizer.vocab['[MASK]']

2

In [9]:
from neo_stif.components.collator import FelixInsertionCollator

In [176]:
bert_for_insertion = BertForMaskedLM.from_pretrained("indolem/indobert-base-uncased")
insert_collator = FelixInsertionCollator(tokenizer)
loader = DataLoader(data_train_insertion, batch_size=2, shuffle=True, collate_fn=insert_collator)
current_batch = next(iter(loader))

In [185]:
optimizer_for_insertion = torch.optim.AdamW(pointer_network.parameters(), lr=5e-5)

In [216]:
# TEST INSERTION
for current_batch in loader:    
    # Pointer
    output = bert_for_insertion(**current_batch)
    loss = output.loss
    loss.backward()
    optimizer_for_insertion.step()
    optimizer_for_insertion.zero_grad()
    print(loss)
    break

tensor(6.6082, grad_fn=<NllLossBackward0>)


In [192]:
current_batch

{'input_ids': tensor([[    3,  1731,  1798,  9986, 10896, 18909,     2,     1,  3316,    16,
              0,     4,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0],
         [    3,     2, 13234,     2,     2,  1975,     1,  3353,  1522,  1975,
              0,    35,     2,     1,  5311,     0,  1684,  8168,  4143,  1959,
           2268,     2,     4]]),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]),
 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]),
 'labels': tensor([[ -100,  -100,  -100,  -100,  -100,  -100,    18,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100],
         [ -100,  2811,  -100,  2447, 2

In [121]:
[tokenizer.vocab[i] for i in masked_tokens]

[3,
 11450,
 2,
 1,
 1862,
 932,
 945,
 0,
 10121,
 66,
 8014,
 1604,
 962,
 1843,
 2587,
 4207,
 933,
 2,
 2,
 1,
 17849,
 17104,
 21463,
 0,
 12411,
 1,
 1476,
 0,
 16,
 2,
 2,
 1,
 14099,
 17849,
 0,
 18,
 2,
 2,
 1,
 21140,
 0,
 18070,
 6359,
 962,
 10155,
 2,
 4]

In [118]:
list(zip(masked_tokens, target_tokens))

[('[CLS]', '[CLS]'),
 ('alhamdulillah', 'alhamdulillah'),
 ('[MASK]', 'setelah'),
 ('[UNK]', '[UNK]'),
 ('st', 'st'),
 ('##l', '##l'),
 ('##h', '##h'),
 ('[PAD]', '[PAD]'),
 ('libur', 'libur'),
 ('x', 'x'),
 ('##num', '##num'),
 ('##ber', '##ber'),
 ('##x', '##x'),
 ('hari', 'hari'),
 ('on', 'on'),
 ('##bi', '##bi'),
 ('##d', '##d'),
 ('[MASK]', 'langsung'),
 ('[MASK]', 'diberi'),
 ('[UNK]', '[UNK]'),
 ('lg', 'lg'),
 ('##sg', '##sg'),
 ('dikasih', 'dikasih'),
 ('[PAD]', '[PAD]'),
 ('order', 'order'),
 ('[UNK]', '[UNK]'),
 ('##an', '##an'),
 ('[PAD]', '[PAD]'),
 (',', ','),
 ('[MASK]', 'makanan'),
 ('[MASK]', 'lagi'),
 ('[UNK]', '[UNK]'),
 ('food', 'food'),
 ('lg', 'lg'),
 ('[PAD]', '[PAD]'),
 ('.', '.'),
 ('[MASK]', 'terima'),
 ('[MASK]', 'kasih'),
 ('[UNK]', '[UNK]'),
 ('thanks', 'thanks'),
 ('[PAD]', '[PAD]'),
 ('xu', 'xu'),
 ('##ser', '##ser'),
 ('##x', '##x'),
 ('cc', 'cc'),
 ('[MASK]', '.'),
 ('[SEP]', '[SEP]')]