In [1]:
%cd ..
%load_ext autoreload
%autoreload 2

/home/haryoaw/documents/courses/nlp802/project/texteditalay


In [2]:
import fire
from transformers import AutoTokenizer, BertForTokenClassification, BertConfig, BertForMaskedLM
from neo_stif.components.utils import create_label_map
import pandas as pd
from neo_stif.components.train_data_preparation import prepare_data_tagging_and_pointer
import datasets
from neo_stif.lit import LitTaggerOrInsertion
from torch.utils.data import DataLoader
from neo_stif.components.collator import FelixCollator, FelixInsertionCollator
from lightning import Trainer
from lightning.pytorch.callbacks import RichProgressBar, ModelCheckpoint, EarlyStopping
from neo_stif.components.utils import compute_class_weights
from datasets import load_from_disk


MAX_MASK = 30
USE_POINTING = True


model_dict = {"koto": "indolem/indobert-base-uncased"}


LR_TAGGER = 5e-5 # due to the pre-trained nature
LR_POINTER = 1e-5 # no pre-trained
LR_INSERTION = 2e-5 # due to the pre-trained nature
VAL_CHECK_INTERVAL = 20

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
model_path_or_name = model_dict["koto"]

In [4]:
tokenizer = AutoTokenizer.from_pretrained("indolem/indobert-base-uncased")
label_dict = create_label_map(MAX_MASK, USE_POINTING)

# Callback for trainer

df_train = pd.read_csv("data/stif_indo/test_with_pointing.csv")
data_train = datasets.Dataset.from_pandas(df_train)
data_train, label_dict = prepare_data_tagging_and_pointer(
    data_train, tokenizer, label_dict
)
model_path_or_name = model_dict["koto"]

Map: 100%|██████████| 363/363 [00:00<00:00, 5050.04 examples/s]
Map: 100%|██████████| 363/363 [00:00<00:00, 1621.07 examples/s]


In [5]:
pre_trained_bert = BertForTokenClassification.from_pretrained(
        model_path_or_name, num_labels=len(label_dict)
    )



Some weights of BertForTokenClassification were not initialized from the model checkpoint at indolem/indobert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
class_weights = (
        compute_class_weights(df_train.label.apply(eval), num_classes=len(label_dict))
        if True
        else None
    )

In [12]:
pointer_network_config = BertConfig(
    vocab_size=len(label_dict) + 1,
    num_hidden_layers=2,
    hidden_size=256,
    num_attention_heads=1,
    pad_token_id=len(label_dict),
)  # + 1 as the pad token

lit_tagger = LitTaggerOrInsertion.load_from_checkpoint(
    "outputs/stif-i-f/test.ckpt",
    model=pre_trained_bert,
    lr=2e-5,
    num_classes=len(label_dict),
    class_weight=None,
    tokenizer=tokenizer,
    label_dict=label_dict,
    use_pointer=USE_POINTING,
    pointer_config=pointer_network_config,
)

In [82]:
lit_tagger = lit_tagger.cpu()
tokenizer_vocab_reverse = {v: k for k, v in tokenizer.vocab.items()}
label_dict

# reverese the label dict
label_dict_reverse = {v: k for k, v in label_dict.items()}

In [83]:
lit_tagger.eval()

LitTaggerOrInsertion(
  (model): BertForTokenClassification(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(31923, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-11): 12 x BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=7

In [85]:
from pprint import pprint

In [100]:
data_0 = data_train[12]

In [101]:
import torch
with torch.no_grad():
    inp_to_model = tokenizer(data_0['informal'], return_tensors="pt")
    out_logits = lit_tagger.forward(**inp_to_model, output_hidden_states=True)
    decoded_seq = [tokenizer_vocab_reverse[x.item()] for x in inp_to_model['input_ids'][0]]
    decoded_label = [label_dict_reverse[x.item()] for x in out_logits.logits.argmax(-1)[0]]
    inp_tag = torch.LongTensor([data_0['tag_labels']])
    _, out_att = lit_tagger.forward_pointer(
        input_ids=inp_tag,
        attention_mask=inp_to_model["attention_mask"],
        token_type_ids=inp_to_model["token_type_ids"],
        previous_last_hidden=out_logits.hidden_states[-1],
    )
    att_output = out_att.argmax(-1)
    pprint(list(zip(list(range(len(decoded_seq))), decoded_seq, decoded_label, att_output[0][0].numpy(), data_0['point_labels'])))

[(0, '[CLS]', 'KEEP', 1, 1),
 (1, 'hal', 'KEEP', 2, 2),
 (2, 'apa', 'KEEP', 3, 3),
 (3, 'yang', 'KEEP', 4, 4),
 (4, 'lebih', 'KEEP|1', 7, 7),
 (5, 'cep', 'DELETE', 0, 0),
 (6, '##et', 'DELETE', 0, 0),
 (7, 'dari', 'KEEP', 8, 8),
 (8, 'gund', 'KEEP', 9, 9),
 (9, '##ala', 'KEEP', 10, 10),
 (10, '?', 'KEEP', 11, 11),
 (11, 'ketika', 'KEEP', 7, 13),
 (12, 'driver', 'KEEP', 0, 0),
 (13, 'dan', 'KEEP', 14, 15),
 (14, 'cs', 'KEEP', 0, 0),
 (15, 'sama', 'KEEP', 8, 16),
 (16, '-', 'KEEP', 17, 17),
 (17, 'sama', 'KEEP', 4, 21),
 (18, 'nge', 'DELETE', 0, 0),
 (19, '##cha', 'DELETE', 0, 0),
 (20, '##t', 'DELETE', 0, 0),
 (21, '"', 'KEEP', 24, 24),
 (22, 'oke', 'KEEP|1', 0, 0),
 (23, '"', 'KEEP|1', 0, 0),
 (24, '[SEP]', 'KEEP', 0, 0)]


In [69]:
list(zip(decoded_seq, decoded_label))

[('[CLS]', 'KEEP'),
 ('hah', 'KEEP'),
 ('##aha', 'KEEP'),
 ('##ha', 'DELETE'),
 ('nih', 'DELETE'),
 ('solo', 'KEEP'),
 ('x', 'KEEP'),
 ('##num', 'KEEP'),
 ('##ber', 'KEEP'),
 ('##x', 'KEEP'),
 ('x', 'KEEP'),
 ('##num', 'KEEP'),
 ('##ber', 'KEEP'),
 ('##x', 'KEEP'),
 ('k', 'DELETE'),
 ('jadi', 'KEEP'),
 ('sekian', 'KEEP'),
 ('.', 'KEEP'),
 ('[SEP]', 'KEEP')]

In [75]:
import torch

[(0, '[CLS]', 'KEEP', 1),
 (1, 'hah', 'KEEP', 2),
 (2, '##aha', 'KEEP', 5),
 (3, '##ha', 'DELETE', 0),
 (4, 'nih', 'DELETE', 0),
 (5, 'solo', 'KEEP', 6),
 (6, 'x', 'KEEP', 7),
 (7, '##num', 'KEEP', 8),
 (8, '##ber', 'KEEP', 9),
 (9, '##x', 'KEEP', 10),
 (10, 'x', 'KEEP', 11),
 (11, '##num', 'KEEP', 12),
 (12, '##ber', 'KEEP', 13),
 (13, '##x', 'KEEP', 14),
 (14, 'k', 'DELETE', 15),
 (15, 'jadi', 'KEEP', 16),
 (16, 'sekian', 'KEEP', 17),
 (17, '.', 'KEEP', 18),
 (18, '[SEP]', 'KEEP', 0)]

[(0, '[CLS]', 'KEEP', 1),
 (1, 'belum', 'KEEP', 2),
 (2, 'ada', 'KEEP', 3),
 (3, 'konfirmasi', 'KEEP', 4),
 (4, 'lagi', 'KEEP|1', 7),
 (5, 'kah', 'DELETE', 0),
 (6, 'min', 'DELETE', 0),
 (7, '?', 'KEEP', 8),
 (8, '[SEP]', 'KEEP', 0)]