In [22]:
from transformers import DistilBertForTokenClassification, DistilBertTokenizerFast

import torch
import torch.nn.functional as F
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

In [4]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [3]:
model = DistilBertForTokenClassification.from_pretrained('weights/ner_distil_bert_indo')
model.to(device)
model.eval()

tokenizer = DistilBertTokenizerFast.from_pretrained('cahya/distilbert-base-indonesian')

In [7]:
df = pd.read_csv('data/formatted_train.csv')
_, eval_df = train_test_split(df, test_size=0.2, random_state=2021)

In [9]:
def tokenize_text(text, max_length):
    return tokenizer.encode_plus(
        [text],
        is_split_into_words=True,
        add_special_tokens=True,
        truncation=True,
        max_length=max_length,
        padding='max_length'
    )

In [38]:
eval_df.iloc[3]

text           pastori gereja baptis indonesia citarum
label                                              POI
text_length                                         39
Name: 213502, dtype: object

In [39]:
inputs = tokenize_text(eval_df.text.iloc[3], 150)
input_ids = torch.tensor(inputs['input_ids'], dtype=torch.long).to(device)
attention_mask = torch.tensor(inputs['attention_mask'], dtype=torch.long).to(device)

In [40]:
with torch.no_grad():
    outputs = model(input_ids.unsqueeze_(0), attention_mask=attention_mask.unsqueeze_(0))