In [1]:
%cd ..
%load_ext autoreload
%autoreload 2

/home/haryoaw/documents/courses/nlp802/project/texteditalay


In [3]:
import fire
from transformers import AutoTokenizer, BertForTokenClassification, BertConfig, BertForMaskedLM
from neo_stif.components.utils import create_label_map
import pandas as pd
from neo_stif.components.train_data_preparation import prepare_data_tagging_and_pointer
import datasets
from neo_stif.lit import LitTaggerOrInsertion
from torch.utils.data import DataLoader
from neo_stif.components.collator import FelixCollator, FelixInsertionCollator
from lightning import Trainer
from lightning.pytorch.callbacks import RichProgressBar, ModelCheckpoint, EarlyStopping
from neo_stif.components.utils import compute_class_weights
from datasets import load_from_disk


MAX_MASK = 30
USE_POINTING = True


model_dict = {"koto": "indolem/indobert-base-uncased"}


LR_TAGGER = 5e-5 # due to the pre-trained nature
LR_POINTER = 1e-5 # no pre-trained
LR_INSERTION = 2e-5 # due to the pre-trained nature
VAL_CHECK_INTERVAL = 20

In [4]:
model_path_or_name = model_dict["koto"]

In [5]:
tokenizer = AutoTokenizer.from_pretrained("indolem/indobert-base-uncased")
label_dict = create_label_map(MAX_MASK, USE_POINTING)

# Callback for trainer

df_train = pd.read_csv("data/stif_indo/test_with_pointing.csv")
data_train = datasets.Dataset.from_pandas(df_train)
data_train, label_dict = prepare_data_tagging_and_pointer(
    data_train, tokenizer, label_dict
)
model_path_or_name = model_dict["koto"]

Map: 100%|██████████| 363/363 [00:00<00:00, 5068.42 examples/s]
Map: 100%|██████████| 363/363 [00:00<00:00, 1916.54 examples/s]


In [6]:
pre_trained_bert = BertForTokenClassification.from_pretrained(
        model_path_or_name, num_labels=len(label_dict)
    )



Some weights of BertForTokenClassification were not initialized from the model checkpoint at indolem/indobert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
class_weights = (
        compute_class_weights(df_train.label.apply(eval), num_classes=len(label_dict))
        if True
        else None
    )

In [8]:
pointer_network_config = BertConfig(
    vocab_size=len(label_dict) + 1,
    num_hidden_layers=2,
    hidden_size=100,
    num_attention_heads=1,
    pad_token_id=len(label_dict),
)  # + 1 as the pad token

lit_tagger = LitTaggerOrInsertion(
    # "outputs/stif-i-f/last.ckpt",
    model=pre_trained_bert,
    lr=2e-5,
    num_classes=len(label_dict),
    class_weight=class_weights,
    tokenizer=tokenizer,
    label_dict=label_dict,
    use_pointer=USE_POINTING,
    pointer_config=pointer_network_config,
)

In [10]:
data_0 = data_train[0]

In [11]:
inp_to_model = tokenizer(data_0['informal'], return_tensors="pt")

In [12]:
lit_tagger = lit_tagger.cpu()

In [13]:
out_logits = lit_tagger.forward(**inp_to_model, output_hidden_states=True)

In [14]:
tokenizer_vocab_reverse = {v: k for k, v in tokenizer.vocab.items()}

In [15]:
label_dict

# reverese the label dict
label_dict_reverse = {v: k for k, v in label_dict.items()}

In [16]:
decoded_seq = [tokenizer_vocab_reverse[x.item()] for x in inp_to_model['input_ids'][0]]
decoded_label = [label_dict_reverse[x.item()] for x in out_logits.logits.argmax(-1)[0]]

In [17]:
list(zip(decoded_seq, decoded_label))

[('[CLS]', 'KEEP|10'),
 ('belum', 'KEEP|10'),
 ('ada', 'KEEP|10'),
 ('konfirmasi', 'KEEP|10'),
 ('lagi', 'KEEP|6'),
 ('kah', 'KEEP|30'),
 ('min', 'SWAP'),
 ('?', 'KEEP|10'),
 ('[SEP]', 'KEEP|10')]

In [18]:
data_0

{'informal': 'belum ada konfirmasi lagi kah min ?',
 'formal': 'belum ada konfirmasi lagikah , admin ?',
 'point_indexes': '[1, 2, 3, 4, 7, 0, 0, 8, 0]',
 'label': '[2, 2, 2, 2, 6, 3, 3, 2, 2]',
 'informal_input_ids': [3, 2077, 1684, 13243, 1975, 8478, 2118, 35, 4],
 'informal_token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0],
 'informal_attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1],
 'formal_input_ids': [3, 2077, 1684, 13243, 1975, 3251, 16, 4374, 35, 4],
 'formal_token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 'formal_attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 'tag_labels': [2, 2, 2, 2, 6, 3, 3, 2, 2],
 'point_labels': [1, 2, 3, 4, 7, 0, 0, 8, 0]}

In [20]:
import torch

In [22]:
inp_tag = torch.LongTensor([data_0['tag_labels']])

In [26]:
lit_tagger.forward_pointer(
    input_ids=inp_tag,
    attention_mask=inp_to_model["attention_mask"],
    token_type_ids=inp_to_model["token_type_ids"],
    previous_last_hidden=out_logits.hidden_states[-1],
)

torch.Size([1, 9, 100])
torch.Size([1, 9, 100])


(None,
 tensor([[[[0.1087, 0.1092, 0.1196, 0.1210, 0.1034, 0.1153, 0.1059, 0.1110,
            0.1061],
           [0.1094, 0.1159, 0.1187, 0.1196, 0.0992, 0.1115, 0.1058, 0.1119,
            0.1079],
           [0.1057, 0.1021, 0.1194, 0.1207, 0.1048, 0.1110, 0.1175, 0.1106,
            0.1081],
           [0.1104, 0.1099, 0.1126, 0.1176, 0.1064, 0.1075, 0.1091, 0.1140,
            0.1124],
           [0.1033, 0.1113, 0.1220, 0.1183, 0.1162, 0.1024, 0.1074, 0.1128,
            0.1063],
           [0.1116, 0.1023, 0.1193, 0.1157, 0.1017, 0.1054, 0.1101, 0.1217,
            0.1121],
           [0.1159, 0.1061, 0.1164, 0.1121, 0.1069, 0.1162, 0.0958, 0.1122,
            0.1183],
           [0.1100, 0.0982, 0.1195, 0.1150, 0.1072, 0.1090, 0.1119, 0.1160,
            0.1132],
           [0.1133, 0.1087, 0.1178, 0.1178, 0.1034, 0.1105, 0.1033, 0.1119,
            0.1130]]]], grad_fn=<SoftmaxBackward0>))