In [1]:
from transformers import DistilBertForTokenClassification, DistilBertTokenizerFast

import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

In [2]:
def tokenize_text(text, max_length):
    return tokenizer.encode_plus(
        [text],
        is_split_into_words=True,
        add_special_tokens=True,
        truncation=True,
        max_length=max_length,
        padding='max_length'
    )

In [3]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [5]:
class IndoDistilBertExtractor(nn.Module):
    def __init__(self):
        super(IndoDistilBertExtractor, self).__init__()
        self.encoder = DistilBertModel.from_pretrained('cahya/distilbert-base-indonesian')
        self.dropout = nn.Dropout(0.3)
        self.classifier = nn.Linear(768, 5)
        
    def forward(self, ids, attn_mask):
        outputs = self.encoder(ids, attention_mask=attn_mask)
        outputs = self.dropout(outputs[0])
        outputs = self.classifier(outputs)
        
        return outputs

In [7]:
model = torch.load('weights/ner_distil_indo.pt')
model.to(device)
model.eval()

tokenizer = DistilBertTokenizerFast.from_pretrained('cahya/distilbert-base-indonesian')

In [8]:
df = pd.read_csv('data/test.csv')
targets = ['B-STR', 'I-STR', 'B-POI', 'I-POI', 'O']
tags = dict(zip(range(len(targets)), targets))

In [10]:
submission = []
for index, data in df.iterrows():
    steet_tokens = []
    poi_tokens = []
    
    text = data['raw_address']
    tokens = text.split(' ')
    tokenized_text = tokenize_text(text, 100)
    input_ids, attn_mask = tokenized_text['input_ids'], tokenized_text['attention_mask']
    input_ids = torch.tensor(input_ids, dtype=torch.long).to(device).unsqueeze_(0)
    attn_mask = torch.tensor(attn_mask, dtype=torch.long).to(device).unsqueeze_(0)
    
    # run inference 
    with torch.no_grad():
        outputs = model(input_ids, attn_mask)
        
    logits = outputs.view(-1, 5)
    preds = logits.detach().cpu().numpy()
    preds = preds.argmax(axis=1)
    results = []
    for index, token in enumerate(tokens):
        results.append(tags[preds[index]])
    
    for index, result in enumerate(results):
        if 'STR' in result:
            steet_tokens.append(tokens[index])

        if 'POI' in result:
            poi_tokens.append(tokens[index])
            
    submission.append({'POI/street': '{}/{}'.format(' '.join(poi_tokens), ' '.join(steet_tokens))})

In [11]:
submit_df = pd.DataFrame(submission)

In [12]:
submit_df = submit_df.reset_index()
submit_df.columns = ['id', 'POI/street']

In [14]:
submit_df.to_csv('output/submmission.csv', index=False)