# Environment

In [2]:
import torch
import joblib
import torch.nn as nn

from torchcrf import CRF
from collections import OrderedDict
from IPython.display import display, HTML
from transformers import AutoModel, XLMRobertaTokenizer

  from .autonotebook import tqdm as notebook_tqdm


# Corpus

In [3]:
class EntityDataset:
    def __init__(self, texts, tags, enc_tag, char_vocab):
        self.texts = texts
        self.tags = tags
        self.enc_tag=enc_tag
        self.char_vocab = char_vocab
        self.tokenizer = XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, item):
        MAX_LEN = 256
        CHAR_MAX_LEN = 32
        
        text = self.texts[item]
        tags = self.tags[item]

        ids = []
        char_ids = []
        target_tag =[]

        for i, s in enumerate(text):
            inputs = self.tokenizer.encode(
                str(s),
                add_special_tokens=False
            )
            tokens = self.tokenizer.tokenize(s)
            input_len = len(inputs)
            ids.extend(inputs)

            # Tag
            target_tag.extend([tags[i]] * input_len)

            # Char
            char, input_char_ids = [], []
            for token in tokens:
                if token == self.tokenizer.unk_token:
                    char.append(token)
                    input_char_ids.append([self.char_vocab.get(token, 0)])
                else:
                    character = [char for char in token]
                    character_ids = [self.char_vocab[i] for i in token]

                    char.append(character)
                    input_char_ids.append(character_ids)
            char_ids.extend(input_char_ids)
            
        ids = ids[:MAX_LEN - 2]
        char_ids = char_ids[:MAX_LEN - 2]
        target_tag = target_tag[:MAX_LEN - 2]

        # Add special token: <s> and </s>
        CLS_ID = self.tokenizer.cls_token_id
        SEP_ID = self.tokenizer.sep_token_id
        ids = [CLS_ID] + ids + [SEP_ID]

        char_ids = [[0]] + char_ids + [[0]]
        
        o_tag=self.enc_tag.transform(["O"])[0]
        target_tag = [o_tag] + target_tag + [o_tag]

        # Masking
        mask = [1] * len(ids)
        token_type_ids = [0] * len(ids) # Not used in XLM-R, just for compatibility

        padding_len = MAX_LEN - len(ids)

        ids = ids + ([self.tokenizer.pad_token_id] * padding_len)
        mask = mask + ([0] * padding_len)
        token_type_ids = token_type_ids + ([0] * padding_len)
        target_tag = target_tag + ([0] * padding_len)

        char_ids = [i[:CHAR_MAX_LEN] for i in char_ids]
        char_ids = [i + [0] * (CHAR_MAX_LEN - len(i)) for i in char_ids] + \
            [[0] * CHAR_MAX_LEN] * (MAX_LEN - len(char_ids))

        return {
            "ids": torch.tensor(ids, dtype=torch.long),
            "mask": torch.tensor(mask, dtype=torch.long),
            "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
            "target_tag": torch.tensor(target_tag, dtype=torch.long),
            "chars": torch.tensor(char_ids, dtype=torch.long),
        }

# CNN

In [4]:
class CharCNN(nn.Module):
    def __init__(self, 
                 char_emb_dim,
                 char_input_dim,
                 char_emb_dropout,
                 char_cnn_filter_num,
                 char_cnn_kernel_size,
                 char_cnn_dropout
                ):
        super(CharCNN, self).__init__()

        # Character Embedding
        self.char_pad_idx = 0
        self.char_emb_dim = char_emb_dim
        self.char_emb = nn.Embedding(
            num_embeddings=char_input_dim,
            embedding_dim=char_emb_dim,
            padding_idx=self.char_pad_idx
        )
        
        # Initialize embedding for char padding as zero
        self.char_emb.weight.data[self.char_pad_idx] = torch.zeros(char_emb_dim )
        self.char_emb_dropout = nn.Dropout(char_emb_dropout)
        
        # Char CNN
        self.char_cnn = nn.Conv1d(
            in_channels=char_emb_dim,
            out_channels=char_emb_dim * char_cnn_filter_num,
            kernel_size=char_cnn_kernel_size,
            groups=char_emb_dim  # different 1d conv for each embedding dim
        )
        self.char_cnn_dropout = nn.Dropout(char_cnn_dropout)

    def forward(self, chars):
        char_emb_out = self.char_emb_dropout(self.char_emb(chars))
        batch_size, sent_len, word_len, char_emb_dim = char_emb_out.shape
        
        char_cnn_max_out = torch.zeros(batch_size, sent_len, self.char_cnn.out_channels, device=device)
        
        for sent_i in range(sent_len):
            sent_char_emb = char_emb_out[:, sent_i, :, :]
            sent_char_emb_p = sent_char_emb.permute(0, 2, 1)
            char_cnn_sent_out = self.char_cnn(sent_char_emb_p)
            char_cnn_max_out[:, sent_i, :], _ = torch.max(char_cnn_sent_out, dim=2)
        char_cnn = self.char_cnn_dropout(char_cnn_max_out)
        char_cnn_p = char_cnn

        return char_cnn_p

# Model

In [5]:
class EntityModel(nn.Module):
    def __init__(self, 
                 num_tag,
                 char_emb_dim=37,
                 char_input_dim=0,
                 char_emb_dropout=0.25,
                 char_cnn_filter_num=4,
                 char_cnn_kernel_size=3,
                 char_cnn_dropout=0.25,
                 input_dim=916,
                 lstm_hidden_dim=64,
                 lstm_layers=1,
                 attn_heads=4,
                 attn_dropout=0.25
                ):
        super(EntityModel, self).__init__()
        self.num_tag = num_tag

        # XLM-RoBERTa - Word Embedding
        self.xlm_roberta = AutoModel.from_pretrained('xlm-roberta-base')

        # CNN - Character Embedding
        self.char_cnn = CharCNN(
            char_emb_dim=char_emb_dim,
            char_input_dim=char_input_dim,
            char_emb_dropout=char_emb_dropout,
            char_cnn_filter_num=char_cnn_filter_num,
            char_cnn_kernel_size=char_cnn_kernel_size,
            char_cnn_dropout=char_cnn_dropout
        )

        # BiLSTM
        self.bilstm= nn.LSTM(
            input_size=input_dim,
            hidden_size=lstm_hidden_dim, 
            num_layers=lstm_layers,
            bidirectional=True, 
            batch_first=True
        )

        # Multihead Attention
        self.attn = nn.MultiheadAttention(
            embed_dim=lstm_hidden_dim * 2,
            num_heads=attn_heads,
            dropout=attn_dropout
        )
        self.attn_layer_norm = nn.LayerNorm(lstm_hidden_dim * 2)

        # CRF
        self.dropout_tag = nn.Dropout(0.3)
        self.hidden2tag_tag = nn.Linear(lstm_hidden_dim*2, self.num_tag)
        self.crf_tag = CRF(self.num_tag, batch_first=True)

    def forward(self, ids, mask, token_type_ids, target_tag, chars, show_output=None):
        # XLM-RoBERTa - Word Embedding
        x = self.xlm_roberta(ids, attention_mask=mask)
        encoded_layers = x.last_hidden_state

        # CNN - Character Embedding
        char_cnn_p = self.char_cnn(chars)

        # Concat XLM-RoBERTa & CNN
        word_features = torch.cat((encoded_layers, char_cnn_p), dim=2)

        # BiLSTM
        h, _ = self.bilstm(word_features)

        ### BEGIN MODIFIED SECTION: ATTENTION ###

        h_ = h.permute(1, 0, 2)
        key_padding_mask = (mask == 0)
        attn_out, attn_weight = self.attn(h_, h_, h_, key_padding_mask=key_padding_mask)
        attn_out_ = attn_out.permute(1, 0, 2)

        h_plus_attn = h + attn_out_
        normed_h_plus_attn = self.attn_layer_norm(h_plus_attn)
        
        ### END MODIFIED SECTION: ATTENTION ###

        # CRF
        o_tag = self.dropout_tag(normed_h_plus_attn)
        tag = self.hidden2tag_tag(o_tag)        
        mask = torch.where(mask==1, True, False)
        pred_tag = self.crf_tag.decode(tag, mask=mask)
        loss_tag = - self.crf_tag(tag, target_tag, mask=mask, reduction='token_mean')
        loss=loss_tag

        if show_output is not None:
            print('Input / Output Data')
            if 1 in show_output:
                print('Word Embedding BERT')
                print(f'Shape: {encoded_layers.shape}')
                print(f'Output: {encoded_layers}\n')
            if 2 in show_output:
                print('Char Embedding CNN')
                print(f'Shape: {char_cnn_p.shape}')
                print(f'Output: {char_cnn_p}\n')
            if 3 in show_output:
                print('Char & Word Embedding')
                print(f'Shape: {word_features.shape}')
                print(f'Output: {word_features}\n')
            if 4 in show_output:
                print('BiLSTM')
                print(f'Shape: {h.shape}')
                print(f'Output: {h}\n')
            if 5 in show_output:
                print('Multihead Attention')
                print(f'Shape: {normed_h_plus_attn.shape}')
                print(f'Output: {normed_h_plus_attn}\n')
            if 6 in show_output:
                print('CRF')
                print(f'Output: {loss}\n')

        return loss.unsqueeze(0)

    def encode (self, ids, mask, token_type_ids, target_tag, chars, show_output=None):
        # XLM-RoBERTa - Word Embedding
        x = self.xlm_roberta(ids, attention_mask=mask)
        encoded_layers = x.last_hidden_state

        # CNN - Character Embedding
        char_cnn_p = self.char_cnn(chars)

        # Concat XLM-RoBERTa & CNN
        word_features = torch.cat((encoded_layers, char_cnn_p), dim=2)

        # BiLSTM
        h, _ = self.bilstm(word_features)

        ### BEGIN MODIFIED SECTION: ATTENTION ###

        h_ = h.permute(1, 0, 2)
        key_padding_mask = (mask == 0)
        attn_out, attn_weight = self.attn(h_, h_, h_, key_padding_mask=key_padding_mask)
        attn_out_ = attn_out.permute(1, 0, 2)

        h_plus_attn = h + attn_out_
        normed_h_plus_attn = self.attn_layer_norm(h_plus_attn)
        
        ### END MODIFIED SECTION: ATTENTION ###

        # CRF
        o_tag = self.dropout_tag(normed_h_plus_attn)
        tag = self.hidden2tag_tag(o_tag)        
        mask = torch.where(mask==1, True, False)
        
        tag = self.crf_tag.decode(tag, mask=mask)

        if show_output is not None:
            print('Input / Output Data')
            if 1 in show_output:
                print('Word Embedding XLM-RoBERTa')
                print(f'Shape: {encoded_layers.shape}')
                print(f'Output: {encoded_layers}\n')
            if 2 in show_output:
                print('Char Embedding CNN')
                print(f'Shape: {char_cnn_p.shape}')
                print(f'Output: {char_cnn_p}\n')
            if 3 in show_output:
                print('Char & Word Embedding')
                print(f'Shape: {word_features.shape}')
                print(f'Output: {word_features}\n')
            if 4 in show_output:
                print('BiLSTM')
                print(f'Shape: {h.shape}')
                print(f'Output: {h}\n')
            if 5 in show_output:
                print('Multihead Attention')
                print(f'Shape: {normed_h_plus_attn.shape}')
                print(f'Output: {normed_h_plus_attn}\n')
            if 6 in show_output:
                print('CRF')
                print(f'Output: {tag}\n')

        return tag

# Inference

In [6]:
from IPython.display import display, HTML

colors = {
    "PERSON": "#ffff00",
    "PERSONCOREF": "#9932cc",
    "ROLE": "#ff00ff",
    "AFFILIATION": "#00ff7f",
    "CUE": "#ff6347",
    "CUECOREF": "#00bfff",
    "STATEMENT": "#ffa500",
    "ISSUE": "#7fffd4",
    "DATETIME": "#ffdab9",
    "LOCATION": "#adff2f",
    "EVENT": "#d2b48c"
}
entity_types = set(colors.keys())

def strip_prefix(tag):
    """Hapus prefix B-, I-, L-, U- jika ada."""
    for prefix in ['B-', 'I-', 'L-', 'U-']:
        if tag.startswith(prefix):
            return tag[len(prefix):]
    return tag

def highlight_entities(tokens_with_tags):
    html = ""
    buffer = []
    current_tag = None

    for token, tag in tokens_with_tags:
        tag_clean = strip_prefix(tag)

        if tag_clean in entity_types:
            if tag_clean == current_tag:
                buffer.append(token)
            else:
                if buffer:
                    entity_text = " ".join(buffer)
                    color = colors[current_tag]
                    html += (
                        f'<mark style="background-color: {color}; padding:2px 4px; border-radius:3px; margin-right:4px;">'
                        f'{entity_text} <strong style="color:black;">[{current_tag}]</strong>'
                        f'</mark> '
                    )
                buffer = [token]
                current_tag = tag_clean
        else:
            if buffer:
                entity_text = " ".join(buffer)
                color = colors[current_tag]
                html += (
                    f'<mark style="background-color: {color}; padding:2px 4px; border-radius:3px; margin-right:4px;">'
                    f'{entity_text} <strong style="color:black;">[{current_tag}]</strong>'
                    f'</mark> '
                )
                buffer = []
                current_tag = None
            html += token + " "

    # Akhiri sisa buffer
    if buffer and current_tag:
        entity_text = " ".join(buffer)
        color = colors[current_tag]
        html += (
            f'<mark style="background-color: {color}; padding:2px 4px; border-radius:3px; margin-right:4px;">'
            f'{entity_text} <strong style="color:black;">[{current_tag}]</strong>'
            f'</mark> '
        )

    return html.strip()

def predict_sentence(model, sentence, enc_tag, chars, show_output=None):
    sentence = sentence.split()
    test_dataset = EntityDataset(
        texts=[sentence],
        tags=[[0] * len(sentence)],
        enc_tag=enc_tag,
        char_vocab=chars
    )

    with torch.no_grad():
        data = test_dataset[0]
        for k, v in data.items():
            data[k] = v.to(device).unsqueeze(0)

        tag = model.encode(**data, show_output=show_output)
        tag = enc_tag.inverse_transform(tag[0])

    return tag

def reverse_tokenize(ids, tags, tokenizer):
    tokens = []
    tags_list = []

    prev_tag = None
    for token_id, tag in zip(ids, tags):
        token = tokenizer.convert_ids_to_tokens([token_id])[0]

        # Lewati token khusus <s> dan </s>
        if token in ['<s>', '</s>']:
            continue

        # Token baru yang diawali dengan ▁ menandakan token baru
        if token.startswith("▁"):
            token = token.replace("▁", "")
            tokens.append(token)
            tags_list.append(tag)
            prev_tag = tag
        else:
            # Subword: gabungkan dengan token sebelumnya, dan gunakan tag sebelumnya
            if tokens:
                tokens[-1] += token
                tags_list[-1] = prev_tag

    return list(zip(tokens, tags_list))

In [7]:
# Load Tokenizer
tokenizer = XLMRobertaTokenizer.from_pretrained(
    "xlm-roberta-base",
    cache_dir="models/transformers_cache",
    local_files_only=False
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load Model Direct
meta_direct = joblib.load("models/direct-quotes/meta_direct2.bin")
enc_tag_direct = meta_direct["enc_tag"]
chars_direct = meta_direct["chars"]

model_direct = EntityModel(num_tag=len(enc_tag_direct.classes_),char_input_dim=len(chars_direct))
state_dict_direct = torch.load(f"models/direct-quotes/direct-quotes.bin", map_location=device)
state_dict_direct = {k.replace("module.", ""): v for k, v in state_dict_direct.items()}
model_direct.load_state_dict(state_dict_direct)
model_direct.to(device)

# Load Model Indirect
meta_indirect = joblib.load("models/indirect-quotes/meta_indirect2.bin")
enc_tag_indirect = meta_indirect["enc_tag"]
chars_indirect = meta_indirect["chars"]

model_indirect = EntityModel(num_tag=len(enc_tag_indirect.classes_),char_input_dim=len(chars_indirect))
state_dict_indirect = torch.load(f"models/indirect-quotes/indirect-quotes.bin", map_location=device)
state_dict_indirect = {k.replace("module.", ""): v for k, v in state_dict_indirect.items()}
model_indirect.load_state_dict(state_dict_indirect)
model_indirect.to(device)

# Label from Indirect Model
prefer_indirect_labels = {
    "B-PERSON", "I-PERSON", "L-PERSON", "U-PERSON",
    "B-PERSONCOREF", "I-PERSONCOREF", "L-PERSONCOREF", "U-PERSONCOREF",
    "B-CUE", "I-CUE", "L-CUE", "U-CUE", 
    "B-CUECOREF", "I-CUECOREF", "L-CUECOREF", "U-CUECOREF"
    "B-STATEMENT", "I-STATEMENT", "L-STATEMENT", "U-STATEMENT"
}

sentences = [
    'JAKARTA, KOMPAS — Pemerintah serius mewujudkan transformasi energi menuju energi baru dan terbarukan, termasuk penggunaan kendaraan listrik.',
    '"Pemerintah sangat serius untuk masuk pada energi baru terbarukan, termasuk di dalamnya adalah menuju pada kendaraan listrik. Oleh sebab itu, saya sangat menghargai keberanian perusahaan-perusahaan yang tadi saya sebut para CEO-nya masuk dari hulu sampai hilir untuk memulai membangun ekosistem kendaraan listrik," ujar Presiden Joko Widodo saat memberikan sambutan pada acara Peluncuran Kolaborasi Pengembangan Ekosistem Kendaraan Listrik yang digelar di Stasiun Pengisian Bahan Bakar untuk Umum (SPBU) MT Haryono, di Jakarta, Selasa (22/2/2022).',
    'Alfitra menyatakan bahwa teradu Sophia Marlinda Djami diberhentikan secara tetap dari jabatannya sebagai Ketua KPU Kabupaten Sumba Barat sejak putusan tersebut dibacakan.',
]

# --- Gabungkan hasil prediksi ---
full_paragraph = ""
for sentence in sentences:
    # Tokenisasi (encode -> ids), untuk reverse nanti
    tokenized_input = tokenizer.encode(sentence)

    # Prediksi dari kedua model
    tags_direct = predict_sentence(model_direct, sentence, enc_tag_direct, chars_direct)
    tags_indirect = predict_sentence(model_indirect, sentence, enc_tag_indirect, chars_indirect)

    # Gunakan label dari indirect jika termasuk prefer list
    final_tags = []
    for tag_d, tag_i in zip(tags_direct, tags_indirect):
        if tag_i in prefer_indirect_labels:
            final_tags.append(tag_i)
        else:
            final_tags.append(tag_d)

    # Reverse token & tag jadi pasangan
    token_tag_pairs = reverse_tokenize(tokenized_input, final_tags, tokenizer)
    highlighted = highlight_entities(token_tag_pairs)
    full_paragraph += highlighted + " "

# Tampilkan hasil
display(HTML(f"<p>{full_paragraph.strip()}</p>"))

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
