In [None]:
# !pip install -q pandas sklearn ohmeow-blurr

# Prepare data

In [1]:
import pandas as pd
pd.options.mode.chained_assignment = None
df = pd.read_csv('train.csv')

In [2]:
import ast
from sklearn.model_selection import train_test_split

train_df, test_df = train_test_split(df, test_size=0.15, random_state=75)

train_df = pd.read_csv('train.csv')

# df_converters = {'tokens': ast.literal_eval, 'labels': ast.literal_eval}
# train_df = pd.read_csv('train-ner-2.csv', converters=df_converters)

len(train_df), len(test_df)

(300000, 45000)

In [3]:
from string import punctuation
import re

def clean(s):
    res = re.sub(r'(\w)(\()(\w)', '\g<1> \g<2>\g<3>', s)
    res = re.sub(r'(\w)([),.:;]+)(\w)', '\g<1>\g<2> \g<3>', res)
    res = re.sub(r'(\w)(\.\()(\w)', '\g<1>. (\g<3>', res)
    res = re.sub(r'\s+', ' ', res)
    res = res.strip()
    return res

def stripclean(arr):
    return [s.strip().strip(punctuation) for s in arr]

def dummy(x):
    # stupid workaround to deep copy array cause i couldn't get it to work properly
    return [s for s in x]

In [4]:
train_df['POI'] = train_df['POI/street'].str.split('/').str[0].apply(clean).str.split().apply(stripclean)
train_df['STR'] = train_df['POI/street'].str.split('/').str[1].apply(clean).str.split().apply(stripclean)
train_df['tokens'] = train_df['raw_address'].apply(clean).str.split()
train_df['strip_tokens'] = train_df['tokens'].apply(stripclean)
train_df['full_tokens'] = train_df['tokens'].apply(dummy)
train_df['labels'] = train_df['tokens'].apply(lambda x : ['O'] * len(x))
train_df['pos_poi'] = train_df['tokens'].apply(lambda x : [-1, -1])
train_df['pos_str'] = train_df['tokens'].apply(lambda x : [-1, -1])

In [5]:
train_df.head()

Unnamed: 0,id,raw_address,POI/street,POI,STR,tokens,strip_tokens,full_tokens,labels,pos_poi,pos_str
0,0,jl kapuk timur delta sili iii lippo cika 11 a ...,/jl kapuk timur delta sili iii lippo cika,[],"[jl, kapuk, timur, delta, sili, iii, lippo, cika]","[jl, kapuk, timur, delta, sili, iii, lippo, ci...","[jl, kapuk, timur, delta, sili, iii, lippo, ci...","[jl, kapuk, timur, delta, sili, iii, lippo, ci...","[O, O, O, O, O, O, O, O, O, O, O, O, O]","[-1, -1]","[-1, -1]"
1,1,"aye, jati sampurna",/,[],[],"[aye,, jati, sampurna]","[aye, jati, sampurna]","[aye,, jati, sampurna]","[O, O, O]","[-1, -1]","[-1, -1]"
2,2,setu siung 119 rt 5 1 13880 cipayung,/siung,[],[siung],"[setu, siung, 119, rt, 5, 1, 13880, cipayung]","[setu, siung, 119, rt, 5, 1, 13880, cipayung]","[setu, siung, 119, rt, 5, 1, 13880, cipayung]","[O, O, O, O, O, O, O, O]","[-1, -1]","[-1, -1]"
3,3,"toko dita, kertosono",toko dita/,"[toko, dita]",[],"[toko, dita,, kertosono]","[toko, dita, kertosono]","[toko, dita,, kertosono]","[O, O, O]","[-1, -1]","[-1, -1]"
4,4,jl. orde baru,/jl. orde baru,[],"[jl, orde, baru]","[jl., orde, baru]","[jl, orde, baru]","[jl., orde, baru]","[O, O, O]","[-1, -1]","[-1, -1]"


In [6]:
test_df['tokens'] = test_df['raw_address'].apply(clean).str.split()

In [7]:
test_df.head()

Unnamed: 0,id,raw_address,POI/street,tokens
90142,90142,lom 88 asrikaton,/,"[lom, 88, asrikaton]"
163531,163531,"varia usaha ungaran, peri kem pudakpayung",/,"[varia, usaha, ungaran,, peri, kem, pudakpayung]"
233950,233950,hutan gar no 7 20371 percut sei tuan,/gar,"[hutan, gar, no, 7, 20371, percut, sei, tuan]"
126157,126157,"wardah gor srik ton,",wardah gorden/srik ton,"[wardah, gor, srik, ton,]"
96808,96808,green puri 7 cengkareng,/green puri 7,"[green, puri, 7, cengkareng]"


# Build word list and token labels

In [8]:
wordlist_raw = {}
POI_ERR_IDX = []
STR_ERR_IDX = []
SHORTEN_IDX = []
OVERLAP_IDX = set()

In [9]:
from tqdm import tqdm

for idx in tqdm(range(len(train_df))):
    row = train_df.iloc[idx]
    found_poi, found_str, shorten = False, False, False
    for i in range(len(row['strip_tokens'])):
        if row['strip_tokens'][i] == '': continue
        if len(row['POI']) > 0 and row['POI'][0].startswith(row['strip_tokens'][i]):
            ok = True
            for j in range(len(row['POI'])):
                if i + j >= len(row['strip_tokens']) or not row['POI'][j].startswith(row['strip_tokens'][i + j]):
                    ok = False
                    break
            if ok:
                found_poi = True
                row['pos_poi'][0] = i
                row['pos_poi'][1] = i + len(row['POI']) - 1
                for j in range(len(row['POI'])):
                    #assert row['labels'][i + j] == 'O'
                    if row['labels'][i + j] != 'O':
                        OVERLAP_IDX.add(row['id'])
                    if len(row['POI']) == 1:       row['labels'][i + j] = 'S-POI'
                    elif j == 0:                   row['labels'][i + j] = 'B-POI'
                    elif j == len(row['POI']) - 1: row['labels'][i + j] = 'E-POI'
                    else:                          row['labels'][i + j] = 'I-POI'
                    if row['strip_tokens'][i + j] != row['POI'][j]:
                        row['full_tokens'][i + j] = row['full_tokens'][i + j].replace(row['strip_tokens'][i + j], row['POI'][j])
                        row['labels'][i + j] += '-SHORT'
                        shorten = True
                        if not row['strip_tokens'][i + j] in wordlist_raw: wordlist_raw[row['strip_tokens'][i + j]] = {}
                        if not row['POI'][j] in wordlist_raw[row['strip_tokens'][i + j]]: wordlist_raw[row['strip_tokens'][i + j]][row['POI'][j]] = 0
                        wordlist_raw[row['strip_tokens'][i + j]][row['POI'][j]] += 1
        
        if len(row['STR']) > 0 and row['STR'][0].startswith(row['strip_tokens'][i]):
            ok = True
            for j in range(len(row['STR'])):
                if i + j >= len(row['strip_tokens']) or not row['STR'][j].startswith(row['strip_tokens'][i + j]):
                    ok = False
                    break
            if ok:
                found_str = True
                row['pos_str'][0] = i
                row['pos_str'][1] = i + len(row['STR']) - 1
                for j in range(len(row['STR'])):
                    #assert row['labels'][i + j] == 'O'
                    if row['labels'][i + j] != 'O':
                        OVERLAP_IDX.add(row['id'])
                    if len(row['STR']) == 1:       row['labels'][i + j] = 'S-STR'
                    elif j == 0:                   row['labels'][i + j] = 'B-STR'
                    elif j == len(row['STR']) - 1: row['labels'][i + j] = 'E-STR'
                    else:                          row['labels'][i + j] = 'I-STR'
                    if row['strip_tokens'][i + j] != row['STR'][j]:
                        row['full_tokens'][i + j] = row['full_tokens'][i + j].replace(row['strip_tokens'][i + j], row['STR'][j])
                        row['labels'][i + j] += '-SHORT'
                        shorten = True
                        if not row['strip_tokens'][i + j] in wordlist_raw: wordlist_raw[row['strip_tokens'][i + j]] = {}
                        if not row['STR'][j] in wordlist_raw[row['strip_tokens'][i + j]]: wordlist_raw[row['strip_tokens'][i + j]][row['STR'][j]] = 0
                        wordlist_raw[row['strip_tokens'][i + j]][row['STR'][j]] += 1
    
    if len(row['POI']) > 0 and not found_poi:
        POI_ERR_IDX.append(row['id'])
    if len(row['STR']) > 0 and not found_str:
        STR_ERR_IDX.append(row['id'])
    if shorten:
        SHORTEN_IDX.append(row['id'])

100%|██████████| 300000/300000 [02:26<00:00, 2044.00it/s]


In [10]:
len(wordlist_raw), len(POI_ERR_IDX), len(STR_ERR_IDX), len(SHORTEN_IDX), len(OVERLAP_IDX)

(11825, 196, 96, 59011, 919)

In [11]:
# sanity check
train_df[train_df['id'].isin(SHORTEN_IDX[:10])]

Unnamed: 0,id,raw_address,POI/street,POI,STR,tokens,strip_tokens,full_tokens,labels,pos_poi,pos_str
10,10,"cikahuripan sd neg boj 02 klap boj, no 5 16877",sd negeri bojong 02/klap boj,"[sd, negeri, bojong, 02]","[klap, boj]","[cikahuripan, sd, neg, boj, 02, klap, boj,, no...","[cikahuripan, sd, neg, boj, 02, klap, boj, no,...","[cikahuripan, sd, negeri, bojong, 02, klap, bo...","[O, B-POI, I-POI-SHORT, I-POI-SHORT, E-POI, B-...","[1, 4]","[5, 6]"
11,11,"yaya atohar,",yayasan atohariyah/,"[yayasan, atohariyah]",[],"[yaya, atohar,]","[yaya, atohar]","[yayasan, atohariyah,]","[B-POI-SHORT, E-POI-SHORT]","[0, 1]","[-1, -1]"
20,20,"toko bang ajs,",toko bangunan ajs/,"[toko, bangunan, ajs]",[],"[toko, bang, ajs,]","[toko, bang, ajs]","[toko, bangunan, ajs,]","[B-POI, I-POI-SHORT, E-POI]","[0, 2]","[-1, -1]"
40,40,mar tabl metro iringmulyo metro timur,markaz tabligh metro/,"[markaz, tabligh, metro]",[],"[mar, tabl, metro, iringmulyo, metro, timur]","[mar, tabl, metro, iringmulyo, metro, timur]","[markaz, tabligh, metro, iringmulyo, metro, ti...","[B-POI-SHORT, I-POI-SHORT, E-POI, O, O, O]","[0, 2]","[-1, -1]"
44,44,sd neg 12 anggrek,sd negeri 12 anggrek/,"[sd, negeri, 12, anggrek]",[],"[sd, neg, 12, anggrek]","[sd, neg, 12, anggrek]","[sd, negeri, 12, anggrek]","[B-POI, I-POI-SHORT, I-POI, E-POI]","[0, 3]","[-1, -1]"
48,48,"rumah makan pela, raya jomb,",rumah makan pelangi/raya jomb,"[rumah, makan, pelangi]","[raya, jomb]","[rumah, makan, pela,, raya, jomb,]","[rumah, makan, pela, raya, jomb]","[rumah, makan, pelangi,, raya, jomb,]","[B-POI, I-POI, E-POI-SHORT, B-STR, E-STR]","[0, 2]","[3, 4]"
69,69,cak 11 nagasari karawang barat,/cakrad,[],[cakrad],"[cak, 11, nagasari, karawang, barat]","[cak, 11, nagasari, karawang, barat]","[cakrad, 11, nagasari, karawang, barat]","[S-STR-SHORT, O, O, O, O]","[-1, -1]","[0, 0]"
74,74,"rnd prin, gang pinak, sukarame",rnd printing/gang pinak,"[rnd, printing]","[gang, pinak]","[rnd, prin,, gang, pinak,, sukarame]","[rnd, prin, gang, pinak, sukarame]","[rnd, printing,, gang, pinak,, sukarame]","[B-POI, E-POI-SHORT, B-STR, E-STR, O]","[0, 1]","[2, 3]"
76,76,"pp minhajutt, kh abdul manan, sumberberas muncar",pp minhajutthollab/kh abdul manan,"[pp, minhajutthollab]","[kh, abdul, manan]","[pp, minhajutt,, kh, abdul, manan,, sumberbera...","[pp, minhajutt, kh, abdul, manan, sumberberas,...","[pp, minhajutthollab,, kh, abdul, manan,, sumb...","[B-POI, E-POI-SHORT, B-STR, I-STR, E-STR, O, O]","[0, 1]","[2, 4]"
77,77,"tk islam daruss,",tk islam darussalam/,"[tk, islam, darussalam]",[],"[tk, islam, daruss,]","[tk, islam, daruss]","[tk, islam, darussalam,]","[B-POI, I-POI, E-POI-SHORT]","[0, 2]","[-1, -1]"


In [12]:
ERR_IDX = set(POI_ERR_IDX + STR_ERR_IDX + list(OVERLAP_IDX))
len(ERR_IDX)

1211

In [13]:
train_df = train_df[~train_df['id'].isin(ERR_IDX)]

In [14]:
def cleanshort(arr):
    return [s.replace('-SHORT', '') for s in arr]

new_train_df = train_df[train_df['id'].isin(SHORTEN_IDX)].copy(deep=True)
new_train_df['tokens'] = new_train_df['full_tokens'].apply(dummy)
new_train_df['labels'] = new_train_df['labels'].apply(cleanshort)

In [15]:
new_train_df.head()

Unnamed: 0,id,raw_address,POI/street,POI,STR,tokens,strip_tokens,full_tokens,labels,pos_poi,pos_str
10,10,"cikahuripan sd neg boj 02 klap boj, no 5 16877",sd negeri bojong 02/klap boj,"[sd, negeri, bojong, 02]","[klap, boj]","[cikahuripan, sd, negeri, bojong, 02, klap, bo...","[cikahuripan, sd, neg, boj, 02, klap, boj, no,...","[cikahuripan, sd, negeri, bojong, 02, klap, bo...","[O, B-POI, I-POI, I-POI, E-POI, B-STR, E-STR, ...","[1, 4]","[5, 6]"
11,11,"yaya atohar,",yayasan atohariyah/,"[yayasan, atohariyah]",[],"[yayasan, atohariyah,]","[yaya, atohar]","[yayasan, atohariyah,]","[B-POI, E-POI]","[0, 1]","[-1, -1]"
20,20,"toko bang ajs,",toko bangunan ajs/,"[toko, bangunan, ajs]",[],"[toko, bangunan, ajs,]","[toko, bang, ajs]","[toko, bangunan, ajs,]","[B-POI, I-POI, E-POI]","[0, 2]","[-1, -1]"
40,40,mar tabl metro iringmulyo metro timur,markaz tabligh metro/,"[markaz, tabligh, metro]",[],"[markaz, tabligh, metro, iringmulyo, metro, ti...","[mar, tabl, metro, iringmulyo, metro, timur]","[markaz, tabligh, metro, iringmulyo, metro, ti...","[B-POI, I-POI, E-POI, O, O, O]","[0, 2]","[-1, -1]"
44,44,sd neg 12 anggrek,sd negeri 12 anggrek/,"[sd, negeri, 12, anggrek]",[],"[sd, negeri, 12, anggrek]","[sd, neg, 12, anggrek]","[sd, negeri, 12, anggrek]","[B-POI, I-POI, I-POI, E-POI]","[0, 3]","[-1, -1]"


In [16]:
train_df = train_df.append(new_train_df, ignore_index=True)

In [17]:
from tqdm import tqdm

swap_parts = []
swap_tokens = []
swap_labels = []

In [18]:
for idx in tqdm(range(len(train_df))):
    old_row = train_df.iloc[idx]
    if old_row['pos_poi'][0] == -1 or old_row['pos_str'][0] == -1: continue
    
    start_poi, end_poi = old_row['pos_poi']
    start_str, end_str = old_row['pos_str']
    if end_poi < start_str:
        swap_tokens.append(old_row['tokens'][:start_poi] + \
                           old_row['tokens'][start_str:end_str + 1] + \
                           old_row['tokens'][end_poi + 1:start_str] + \
                           old_row['tokens'][start_poi:end_poi + 1] + \
                           old_row['tokens'][end_str + 1:])
        swap_labels.append(old_row['labels'][:start_poi] + \
                           old_row['labels'][start_str:end_str + 1] + \
                           old_row['labels'][end_poi + 1:start_str] + \
                           old_row['labels'][start_poi:end_poi + 1] + \
                           old_row['labels'][end_str + 1:])
        swap_parts.append((0, \
                           old_row['tokens'][:start_poi], \
                           old_row['tokens'][start_str:end_str + 1], \
                           old_row['tokens'][end_poi + 1:start_str], \
                           old_row['tokens'][start_poi:end_poi + 1], \
                           old_row['tokens'][end_str + 1:], \
                           old_row['labels'][:start_poi], \
                           old_row['labels'][start_str:end_str + 1], \
                           old_row['labels'][end_poi + 1:start_str], \
                           old_row['labels'][start_poi:end_poi + 1], \
                           old_row['labels'][end_str + 1:]))
    else:
        swap_tokens.append(old_row['tokens'][:start_str] + \
                           old_row['tokens'][start_poi:end_poi + 1] + \
                           old_row['tokens'][end_str + 1:start_poi] + \
                           old_row['tokens'][start_str:end_str + 1] + \
                           old_row['tokens'][end_poi + 1:])
        swap_labels.append(old_row['labels'][:start_str] + \
                           old_row['labels'][start_poi:end_poi + 1] + \
                           old_row['labels'][end_str + 1:start_poi] + \
                           old_row['labels'][start_str:end_str + 1] + \
                           old_row['labels'][end_poi + 1:])
        swap_parts.append((1, \
                           old_row['tokens'][:start_str], \
                           old_row['tokens'][start_poi:end_poi + 1], \
                           old_row['tokens'][end_str + 1:start_poi], \
                           old_row['tokens'][start_str:end_str + 1], \
                           old_row['tokens'][end_poi + 1:], \
                           old_row['labels'][:start_str], \
                           old_row['labels'][start_poi:end_poi + 1], \
                           old_row['labels'][end_str + 1:start_poi], \
                           old_row['labels'][start_str:end_str + 1], \
                           old_row['labels'][end_poi + 1:]))

100%|██████████| 357219/357219 [00:59<00:00, 6029.65it/s]


In [19]:
import random
swap_idx = list(range(len(swap_parts)))
random.Random(4).shuffle(swap_idx)

In [20]:
for i in tqdm(range(len(swap_parts))):
    if i == swap_idx[i]: continue
    j = swap_idx[i]
    
    swap_tokens.append(swap_parts[i][1] + swap_parts[j][2] + swap_parts[i][3] + swap_parts[j][4] + swap_parts[i][5])
    swap_labels.append(swap_parts[i][6] + swap_parts[j][7] + swap_parts[i][8] + swap_parts[j][9] + swap_parts[i][10])

100%|██████████| 122193/122193 [00:00<00:00, 205567.12it/s]


In [21]:
swap_train_df = pd.DataFrame(columns=['tokens', 'labels'], data={'tokens': swap_tokens, 'labels': swap_labels})
swap_train_df.head()

Unnamed: 0,tokens,labels
0,"[toko, bb, kids, 299, raya, samb, gede,]","[B-POI, I-POI, E-POI, O, B-STR, I-STR, E-STR]"
1,"[cikahuripan, klap, boj,, sd, neg, boj, 02, no...","[O, B-STR, E-STR, B-POI, I-POI-SHORT, I-POI-SH..."
2,"[br., junjungan,, jln., tirta, tawar,, ubud,, ...","[B-POI, E-POI, B-STR, I-STR, E-STR, O, O, O, O..."
3,"[komplek, borneo, lestari,, jl., amd,, blok, 2...","[B-POI, I-POI, E-POI, B-STR, E-STR, O, O, O, O]"
4,"[raya, jomb,, rumah, makan, pela,]","[B-STR, E-STR, B-POI, I-POI, E-POI-SHORT]"


In [22]:
train_df.drop(columns=['id', 'raw_address', 'POI/street', 'POI', 'STR', 'strip_tokens', 'full_tokens', 'pos_poi', 'pos_str'], inplace=True)

In [23]:
train_df = train_df.append(swap_train_df, ignore_index=True)

In [24]:
train_df.head()

Unnamed: 0,tokens,labels
0,"[jl, kapuk, timur, delta, sili, iii, lippo, ci...","[B-STR, I-STR, I-STR, I-STR, I-STR, I-STR, I-S..."
1,"[aye,, jati, sampurna]","[O, O, O]"
2,"[setu, siung, 119, rt, 5, 1, 13880, cipayung]","[O, S-STR, O, O, O, O, O, O]"
3,"[toko, dita,, kertosono]","[B-POI, E-POI, O]"
4,"[jl., orde, baru]","[B-STR, I-STR, E-STR]"


In [25]:
len(train_df)

601605

In [26]:
train_df.to_csv('train-ner-2.csv', index=False)

In [27]:
import json

with open('wordlist_raw.json', 'w') as fp:
    json.dump(wordlist_raw, fp)

In [28]:
import json

with open('wordlist_raw.json', 'r') as fp:
    wordlist_raw = json.load(fp)

In [29]:
len(wordlist_raw)

11825

In [30]:
def get_list(raw, p, lmt):
    res = {}
    for word, stats in raw.items():
        best = max(stats, key=stats.get)
        best_cnt = stats[best]
        total = sum(stats.values())
        frac = best_cnt / total
        if total >= lmt and best_cnt / total >= p: 
            res[word] = best
    return res

In [31]:
wordlist = get_list(wordlist_raw, 0, 1)

In [97]:
len(wordlist)

11825

In [33]:
import json

with open('wordlist.json', 'w') as fp:
    json.dump(wordlist, fp)

In [34]:
import json

with open('wordlist.json', 'r') as fp:
    wordlist = json.load(fp)

In [35]:
len(wordlist)

11825

# Training

In [36]:
from transformers import *
from fastai.text.all import *

from blurr.data.all import *
from blurr.modeling.all import *

SEED = 42
set_seed(SEED)

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [37]:
labels = sorted(list(set([lbls for sublist in train_df.labels.tolist() for lbls in sublist])))
print(labels)

['B-POI', 'B-POI-SHORT', 'B-STR', 'B-STR-SHORT', 'E-POI', 'E-POI-SHORT', 'E-STR', 'E-STR-SHORT', 'I-POI', 'I-POI-SHORT', 'I-STR', 'I-STR-SHORT', 'O', 'S-POI', 'S-POI-SHORT', 'S-STR', 'S-STR-SHORT']


In [47]:
task = HF_TASKS_AUTO.TokenClassification
# pretrained_model_name = 'bert-base-multilingual-uncased'
pretrained_model_name = 'indobenchmark/indobert-large-p1'
config = AutoConfig.from_pretrained(pretrained_model_name)
config.num_labels = len(labels)

hf_arch, hf_config, hf_tokenizer, hf_model = BLURR_MODEL_HELPER.get_hf_objects(pretrained_model_name, 
                                                                               task=task, 
                                                                               config=config)
hf_arch, type(hf_config), type(hf_tokenizer), type(hf_model)

('bert',
 transformers.models.bert.configuration_bert.BertConfig,
 transformers.models.bert.tokenization_bert_fast.BertTokenizerFast,
 transformers.models.bert.modeling_bert.BertForTokenClassification)

In [48]:
before_batch_tfm = HF_TokenClassBeforeBatchTransform(hf_arch, hf_config, hf_tokenizer, hf_model,
                                                     is_split_into_words=True, 
                                                     tok_kwargs={ 'return_special_tokens_mask': True })

blocks = (
    HF_TextBlock(before_batch_tfm=before_batch_tfm, input_return_type=HF_TokenClassInput), 
    HF_TokenCategoryBlock(vocab=labels)
)

def get_y(inp): return [(label, len(hf_tokenizer.tokenize(str(entity)))) for entity, label in zip(inp.tokens, inp.labels)]

In [49]:
db = DataBlock(blocks=blocks, 
               splitter=RandomSplitter(seed=42),
               get_x=ColReader('tokens'),
               get_y=get_y)

In [50]:
dls = db.dataloaders(train_df, bs=128)
dls.show_batch(dataloaders=dls)

Unnamed: 0,token / target label
0,"[('dusun', 'B-POI'), ('7', 'I-POI'), ('1', 'I-POI'), ('=', 'I-POI'), ('c', 'I-POI'), ('dalam', 'I-POI'), ('patokan', 'I-POI'), ('masuk', 'I-POI'), ('dari', 'I-POI'), ('pasar', 'I-POI'), ('1', 'I-POI'), ('paya', 'E-POI'), ('bakung', 'O'), ('masuk,', 'O'), ('sumber', 'O'), ('melati', 'B-STR'), ('diski', 'I-STR'), ('jl', 'I-STR'), ('imp', 'I-STR'), ('sebe', 'I-STR'), ('mush', 'I-STR'), ('bau', 'I-STR'), ('muk', 'I-STR'), ('di', 'I-STR'), ('ping', 'I-STR'), ('sun', 'I-STR'), ('kecil', 'I-STR'), ('buka', 'I-STR'), ('08.', 'E-STR'), ('00', 'O')]"
1,"[('swadaya', 'O'), ('dalam', 'O'), ('no', 'O'), (':', 'O'), ('31', 'O'), ('rt', 'O'), (':', 'O'), ('05', 'O'), ('rw', 'O'), (':', 'O'), ('06', 'O'), ('cawang', 'O'), ('kapling', 'O'), ('-', 'O'), ('tanah', 'O'), ('manisan', 'O'), ('-', 'O'), ('jakarta', 'O'), ('-', 'O'), ('timur', 'O')]"
2,"[('baba', 'B-STR'), ('sayu', 'I-STR'), ('gg', 'I-STR'), ('sete', 'I-STR'), ('pasar', 'I-STR'), ('sayung', 'I-STR'), ('dari', 'I-STR'), ('sema', 'I-STR'), ('kanan,', 'E-STR'), ('sayung', 'O'), ('toko', 'B-POI'), ('arsya', 'I-POI'), ('pak', 'I-POI'), ('suwarno', 'I-POI'), ('sayung', 'I-POI'), ('wetan', 'E-POI'), ('rt', 'O'), ('5', 'O'), ('1', 'O'), ('sayung', 'O')]"
3,"[('sate', 'B-POI'), ('ayam', 'I-POI'), ('dan', 'I-POI'), ('kambing', 'I-POI'), ('pak', 'I-POI'), ('de', 'I-POI'), ('cabang', 'I-POI'), ('gad', 'I-POI-SHORT'), ('serpong,', 'E-POI'), ('no', 'O'), ('a', 'O'), ('15', 'O'), ('boulevard', 'B-STR'), ('raya', 'I-STR'), ('gad', 'I-STR'), ('serp,', 'E-STR'), ('kelapa', 'O'), ('dua', 'O'), ('kelapa', 'O'), ('dua', 'O')]"
4,"[('pang', 'B-STR'), ('anta,', 'E-STR'), ('east', 'B-POI'), ('kalimantan', 'I-POI'), ('center', 'I-POI'), ('-', 'I-POI'), ('pusat', 'I-POI'), ('oleh', 'I-POI'), ('-', 'E-POI'), ('oleh', 'O'), ('-', 'O'), ('kaltim,', 'O'), ('lorong', 'O'), ('a', 'O'), ('no.', 'O'), ('63', 'O'), ('rt006', 'O'), ('rw09,', 'O'), ('koja,', 'O')]"
5,"[('jl.', 'B-STR'), ('mayor', 'I-STR'), ('salim', 'I-STR'), ('batubara', 'E-STR'), ('d.', 'B-POI'), ('i,', 'E-POI'), ('no.', 'O'), ('6452,', 'O'), ('20', 'O'), ('ilir', 'O'), ('d', 'O'), ('ii,', 'O'), ('kec.', 'O'), ('ilir', 'O'), ('tim.', 'O'), ('i,', 'O'), ('kota', 'O'), ('palembang,', 'O'), ('su', 'O')]"
6,"[('ycab', 'O'), ('foundation,', 'O'), ('kedai', 'B-POI'), ('pak', 'I-POI'), ('waji', 'I-POI'), ('dari', 'I-POI'), ('jam', 'I-POI'), ('7', 'I-POI'), ('-', 'E-POI'), ('10', 'B-STR'), ('malam', 'I-STR'), ('jl.', 'I-STR'), ('r', 'I-STR'), ('panj', 'I-STR'), ('oku', 'I-STR'), ('dekat', 'E-STR'), ('sd', 'O'), ('016', 'O'), ('no', 'O')]"
7,"[('rawa', 'O'), ('buaya', 'O'), ('ruko', 'O'), ('inter', 'O'), ('kota', 'O'), ('blok', 'O'), (':', 'O'), ('no', 'O'), ('b', 'O'), ('8', 'O'), ('rt', 'O'), ('rw', 'O'), ('07', 'O'), ('09', 'O'), ('sebrang', 'B-POI'), ('gedung', 'E-POI'), ('ot', 'O'), ('samping', 'O'), ('charllie', 'O')]"
8,"[('rental', 'B-POI'), ('mobil', 'I-POI'), ('sutikno', 'I-POI'), ('pekarungan', 'E-POI'), ('pekarungan', 'O'), ('nan', 'B-STR'), ('iii,', 'I-STR'), ('gg', 'I-STR'), ('tam', 'I-STR'), ('sari', 'I-STR'), ('suko', 'I-STR'), ('sido', 'I-STR'), ('jawa', 'I-STR'), ('timur', 'I-STR'), ('indo', 'E-STR'), ('rt', 'O'), ('11', 'O'), ('4', 'O'), ('sukodono', 'O')]"


In [51]:
@delegates()
class TokenCrossEntropyLossFlat(BaseLoss):
    "Same as `CrossEntropyLossFlat`, but for mutiple tokens output"
    y_int = True
    @use_kwargs_dict(keep=True, weight=None, ignore_index=-100, reduction='mean')
    def __init__(self, *args, axis=-1, **kwargs): super().__init__(nn.CrossEntropyLoss, *args, axis=axis, **kwargs)
    def decodes(self, x):    return L([ i.argmax(dim=self.axis) for i in x ])
    def activation(self, x): return L([ F.softmax(i, dim=self.axis) for i in x ])

In [52]:
model = HF_BaseModelWrapper(hf_model)
loss_func = TokenCrossEntropyLossFlat()
opt_func = partial(Adam)
learn_cbs = [HF_BaseModelCallback]
fit_cbs = [HF_TokenClassMetricsCallback()]
splitter = hf_splitter

In [53]:
# learn1 = Learner(dls, model, loss_func=loss_func, opt_func=opt_func, splitter=splitter, cbs=learn_cbs).to_fp16()
learn2 = Learner(dls, model, loss_func=loss_func, opt_func=opt_func, splitter=splitter, cbs=learn_cbs).to_fp16()

In [54]:
# learn1.create_opt()
learn2.create_opt()

In [55]:
# learn1.unfreeze()
learn2.unfreeze()

In [44]:
learn1.fit_one_cycle(5, 1e-4, moms=(0.8, 0.7, 0.8), cbs=fit_cbs)

epoch,train_loss,valid_loss,accuracy,precision,recall,f1,time
0,0.34703,0.363953,0.894492,0.779272,0.714002,0.745211,12:20
1,0.25975,0.255239,0.929252,0.844856,0.8348,0.839798,12:29
2,0.213813,0.225695,0.938303,0.870825,0.855771,0.863233,12:20
3,0.167972,0.216481,0.943253,0.88442,0.868227,0.876249,12:19
4,0.1484,0.223286,0.944264,0.886923,0.87176,0.879276,12:21


  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
learn2.fit_one_cycle(5, 1e-4, moms=(0.8, 0.7, 0.8), cbs=fit_cbs)

epoch,train_loss,valid_loss,accuracy,precision,recall,f1,time
0,0.310868,0.309791,0.916442,0.814347,0.797192,0.805678,22:56
1,0.235226,0.239664,0.935117,0.866598,0.846518,0.85644,22:58
2,0.20466,0.211123,0.943851,0.884705,0.86918,0.876874,22:49
3,0.157571,0.204892,0.94787,0.899703,0.880835,0.890169,22:33


In [None]:
learn1.recorder.plot_loss()
learn2.recorder.plot_loss()

In [None]:
print(learn1.token_classification_report)
print(learn2.token_classification_report)

In [None]:
learn1.save('bert-multi-2')
learn2.save('bert-indo-2')

# Evaluation

In [56]:
learn1.load('bert-multi-2')
learn2.load('bert-indo-2')

<fastai.learner.Learner at 0x7f4dec2aecd0>

In [57]:
@patch
def blurr_predict(self:Learner, items, rm_type_tfms=None):
    hf_before_batch_tfm = get_blurr_tfm(self.dls.before_batch)
    is_split_str = hf_before_batch_tfm.is_split_into_words and isinstance(items[0], str)
    is_df = isinstance(items, pd.DataFrame)
    if (not is_df and (is_split_str or not is_listy(items))): items = [items]
    dl = self.dls.test_dl(items, rm_type_tfms=rm_type_tfms, num_workers=0)
    with self.no_bar(): probs, _, decoded_preds = self.get_preds(dl=dl, with_input=False, with_decoded=True)
    trg_tfms = self.dls.tfms[self.dls.n_inp:]
    outs = []
    probs, decoded_preds = L(probs), L(decoded_preds)
    for i in range(len(items)):
        item_probs = [probs[i]]
        item_dec_preds = [decoded_preds[i]]
        item_dec_labels = tuplify([tfm.decode(item_dec_preds[tfm_idx]) for tfm_idx, tfm in enumerate(trg_tfms)])
        outs.append((item_dec_labels, item_dec_preds, item_probs))
    return outs

In [71]:
from string import punctuation

def reconstruct(num, pred, raw_tokens, raw_address):
    def complete_word(x):
        y = x.strip().strip(punctuation)
        if y in wordlist:
            x = x.replace(y, wordlist[y])
        return x
    
    def normalize_bracket(x):
        if '(' in x and ')' not in x:
            x = x + ')'
        elif ')' in x and '(' not in x:
            x = '(' + x
        return x
    
    ans = ['/'] * num
    for idx in range(num):
        res = pred[idx]
        start_poi, end_poi = -1, -1
        start_str, end_str = -1, -1
        for i in range(len(res[0])):
            if 'POI' in res[1][i]:
                if start_poi == -1: start_poi = i
                end_poi = i
            if 'STR' in res[1][i]:
                if start_str == -1: start_str = i
                end_str = i
        
        if start_poi != -1:
            txt1 = raw_address[idx]
            for i in range(start_poi):
                txt1 = txt1[len(raw_tokens[idx][i]):].strip()
            for i in range(len(raw_tokens[idx]) - 1, end_poi, -1):
                txt1 = txt1[:-len(raw_tokens[idx][i])].strip()
            
            txt1_check = ''.join(raw_tokens[idx][start_poi:end_poi + 1]).replace(' ', '')
            assert txt1.replace(' ', '') == txt1_check
            
            last = len(txt1)
            for i in range(end_poi, start_poi - 1, -1):
                while last > 0 and txt1[last - 1] == ' ':
                    last -= 1
                assert last >= len(raw_tokens[idx][i])
                last -= len(raw_tokens[idx][i])
                if 'SHORT' in res[1][i]:
                    txt1 = txt1[:last] + complete_word(raw_tokens[idx][i]) + txt1[last + len(raw_tokens[idx][i]):]
        else:
            txt1 = ''
        
        if start_str != -1:
            txt2 = raw_address[idx]
            for i in range(start_str):
                txt2 = txt2[len(raw_tokens[idx][i]):].strip()
            for i in range(len(raw_tokens[idx]) - 1, end_str, -1):
                txt2 = txt2[:-len(raw_tokens[idx][i])].strip()
            
            txt2_check = ''.join(raw_tokens[idx][start_str:end_str + 1]).replace(' ', '')
            assert txt2.replace(' ', '') == txt2_check
            
            last = len(txt2)
            for i in range(end_str, start_str - 1, -1):
                while last > 0 and txt2[last - 1] == ' ':
                    last -= 1
                assert last >= len(raw_tokens[idx][i])
                last -= len(raw_tokens[idx][i])
                if 'SHORT' in res[1][i]:
                    txt2 = txt2[:last] + complete_word(raw_tokens[idx][i]) + txt2[last + len(raw_tokens[idx][i]):]
        else:
            txt2 = ''
        
        txt1 = txt1.strip(punctuation)
        txt2 = txt2.strip(punctuation)
        txt1 = normalize_bracket(txt1)
        txt2 = normalize_bracket(txt2)
        
        ans[idx] = (txt1 + '/' + txt2)
    
    return ans

In [72]:
def show_diff(df):
    MAX_ROWS = 50
    CNT = 0
    for idx in range(len(df)):
        if CNT == MAX_ROWS: break
        row = df.iloc[idx]
        if row['POI/street'] != row['pred']:
            CNT += 1
            print(idx, row['id'], row['POI/street'], 'vs', row['pred'])

In [73]:
def calc_acc(df):
    return df.loc[test_df['pred'] == df['POI/street'], 'id'].count() / len(df)

In [61]:
raw_tokens = list(test_df['tokens'])
raw_address = list(test_df['raw_address'])

In [62]:
raw_pred1 = learn1.blurr_predict_tokens(raw_tokens)
raw_pred2 = learn2.blurr_predict_tokens(raw_tokens)

In [63]:
prob1 = [i[-1] for i in raw_pred1]
prob2 = [i[-1] for i in raw_pred2]
prob = [sum(x) for x in zip(prob1, prob2)]
raw_preds = [dls.vocab[x.argmax(dim=1)] for x in prob]

In [64]:
raw_txts = [i[0] for i in raw_pred1]

In [74]:
pred1 = reconstruct(len(test_df), raw_pred1, raw_tokens, raw_address)
pred2 = reconstruct(len(test_df), raw_pred2, raw_tokens, raw_address)

In [75]:
from string import punctuation

def reconstruct(num, txt, pred, raw_tokens, raw_address):
    def complete_word(x):
        y = x.strip().strip(punctuation)
        if y in wordlist:
            x = x.replace(y, wordlist[y])
        return x
    
    def normalize_bracket(x):
        if '(' in x and ')' not in x:
            x = x + ')'
        elif ')' in x and '(' not in x:
            x = '(' + x
        return x
    
    ans = ['/'] * num
    for idx in range(num):
        res = pred[idx]
        start_poi, end_poi = -1, -1
        start_str, end_str = -1, -1
        for i in range(len(txt[idx])):
            if 'POI' in res[i]:
                if start_poi == -1: start_poi = i
                end_poi = i
            if 'STR' in res[i]:
                if start_str == -1: start_str = i
                end_str = i
        
        if start_poi != -1:
            txt1 = raw_address[idx]
            for i in range(start_poi):
                txt1 = txt1[len(raw_tokens[idx][i]):].strip()
            for i in range(len(raw_tokens[idx]) - 1, end_poi, -1):
                txt1 = txt1[:-len(raw_tokens[idx][i])].strip()
            
            txt1_check = ''.join(raw_tokens[idx][start_poi:end_poi + 1]).replace(' ', '')
            assert txt1.replace(' ', '') == txt1_check
            
            last = len(txt1)
            for i in range(end_poi, start_poi - 1, -1):
                while last > 0 and txt1[last - 1] == ' ':
                    last -= 1
                assert last >= len(raw_tokens[idx][i])
                last -= len(raw_tokens[idx][i])
                if 'SHORT' in res[i]:
                    txt1 = txt1[:last] + complete_word(raw_tokens[idx][i]) + txt1[last + len(raw_tokens[idx][i]):]
        else:
            txt1 = ''
        
        if start_str != -1:
            txt2 = raw_address[idx]
            for i in range(start_str):
                txt2 = txt2[len(raw_tokens[idx][i]):].strip()
            for i in range(len(raw_tokens[idx]) - 1, end_str, -1):
                txt2 = txt2[:-len(raw_tokens[idx][i])].strip()
            
            txt2_check = ''.join(raw_tokens[idx][start_str:end_str + 1]).replace(' ', '')
            assert txt2.replace(' ', '') == txt2_check
            
            last = len(txt2)
            for i in range(end_str, start_str - 1, -1):
                while last > 0 and txt2[last - 1] == ' ':
                    last -= 1
                assert last >= len(raw_tokens[idx][i])
                last -= len(raw_tokens[idx][i])
                if 'SHORT' in res[i]:
                    txt2 = txt2[:last] + complete_word(raw_tokens[idx][i]) + txt2[last + len(raw_tokens[idx][i]):]
        else:
            txt2 = ''
        
        txt1 = txt1.strip(punctuation)
        txt2 = txt2.strip(punctuation)
        txt1 = normalize_bracket(txt1)
        txt2 = normalize_bracket(txt2)
        
        ans[idx] = (txt1 + '/' + txt2)
    
    return ans

In [76]:
pred = reconstruct(len(test_df), raw_txts, raw_preds, raw_tokens, raw_address)

In [81]:
test_df['pred'] = pred # pred1 pred2 
test_df.head()

Unnamed: 0,id,raw_address,POI/street,tokens,pred
90142,90142,lom 88 asrikaton,/,"[lom, 88, asrikaton]",/lom
163531,163531,"varia usaha ungaran, peri kem pudakpayung",/,"[varia, usaha, ungaran,, peri, kem, pudakpayung]",varia usaha/peri kem
233950,233950,hutan gar no 7 20371 percut sei tuan,/gar,"[hutan, gar, no, 7, 20371, percut, sei, tuan]",/gar
126157,126157,"wardah gor srik ton,",wardah gorden/srik ton,"[wardah, gor, srik, ton,]",wardah goreng/srik ton
96808,96808,green puri 7 cengkareng,/green puri 7,"[green, puri, 7, cengkareng]",/green puri 7


In [78]:
calc_acc(test_df)

0.7422666666666666

In [80]:
calc_acc(test_df)

0.7648888888888888

In [82]:
calc_acc(test_df)

0.76

In [None]:
show_diff(test_df)

# Final inference

In [83]:
real_test_df = pd.read_csv('test.csv')
real_test_df['tokens'] = real_test_df['raw_address'].apply(clean).str.split()
real_test_df.head()

Unnamed: 0,id,raw_address,tokens
0,0,s. par 53 sidanegara 4 cilacap tengah,"[s., par, 53, sidanegara, 4, cilacap, tengah]"
1,1,"angg per, baloi indah kel. lubuk baja","[angg, per,, baloi, indah, kel., lubuk, baja]"
2,2,"asma laun, mand imog,","[asma, laun,, mand, imog,]"
3,3,"ud agung rej, raya nga sri wedari karanganyar","[ud, agung, rej,, raya, nga, sri, wedari, karanganyar]"
4,4,"cut mutia, 35 baiturrahman","[cut, mutia,, 35, baiturrahman]"


In [84]:
raw_tokens = list(real_test_df['tokens'])
raw_address = list(real_test_df['raw_address'])

In [85]:
raw_pred1 = learn1.blurr_predict_tokens(raw_tokens)
raw_pred2 = learn2.blurr_predict_tokens(raw_tokens)

prob1 = [i[-1] for i in raw_pred1]
prob2 = [i[-1] for i in raw_pred2]
prob = [sum(x) for x in zip(prob1, prob2)]
raw_preds = [dls.vocab[x.argmax(dim=1)] for x in prob]

raw_txts = [i[0] for i in raw_pred1]

In [88]:
from string import punctuation

def reconstruct(num, pred, raw_tokens, raw_address):
    def complete_word(x):
        y = x.strip().strip(punctuation)
        if y != '' and y in wordlist:
            x = x.replace(y, wordlist[y])
        return x
    
    def normalize_bracket(x):
        if '(' in x and ')' not in x:
            x = x + ')'
        elif ')' in x and '(' not in x:
            x = '(' + x
        return x
    
    ans = ['/'] * num
    for idx in range(num):
        res = pred[idx]
        start_poi, end_poi = -1, -1
        start_str, end_str = -1, -1
        for i in range(len(res[0])):
            if 'POI' in res[1][i]:
                if start_poi == -1: start_poi = i
                end_poi = i
            if 'STR' in res[1][i]:
                if start_str == -1: start_str = i
                end_str = i
        
        if start_poi != -1:
            txt1 = raw_address[idx]
            for i in range(start_poi):
                txt1 = txt1[len(raw_tokens[idx][i]):].strip()
            for i in range(len(raw_tokens[idx]) - 1, end_poi, -1):
                txt1 = txt1[:-len(raw_tokens[idx][i])].strip()
            
            txt1_check = ''.join(raw_tokens[idx][start_poi:end_poi + 1]).replace(' ', '')
            assert txt1.replace(' ', '') == txt1_check
            
            last = len(txt1)
            for i in range(end_poi, start_poi - 1, -1):
                while last > 0 and txt1[last - 1] == ' ':
                    last -= 1
                assert last >= len(raw_tokens[idx][i])
                last -= len(raw_tokens[idx][i])
                if 'SHORT' in res[1][i]:
                    txt1 = txt1[:last] + complete_word(raw_tokens[idx][i]) + txt1[last + len(raw_tokens[idx][i]):]
        else:
            txt1 = ''
        
        if start_str != -1:
            txt2 = raw_address[idx]
            for i in range(start_str):
                txt2 = txt2[len(raw_tokens[idx][i]):].strip()
            for i in range(len(raw_tokens[idx]) - 1, end_str, -1):
                txt2 = txt2[:-len(raw_tokens[idx][i])].strip()
            
            txt2_check = ''.join(raw_tokens[idx][start_str:end_str + 1]).replace(' ', '')
            assert txt2.replace(' ', '') == txt2_check
            
            last = len(txt2)
            for i in range(end_str, start_str - 1, -1):
                while last > 0 and txt2[last - 1] == ' ':
                    last -= 1
                assert last >= len(raw_tokens[idx][i])
                last -= len(raw_tokens[idx][i])
                if 'SHORT' in res[1][i]:
                    txt2 = txt2[:last] + complete_word(raw_tokens[idx][i]) + txt2[last + len(raw_tokens[idx][i]):]
        else:
            txt2 = ''
        
        txt1 = txt1.strip(punctuation)
        txt2 = txt2.strip(punctuation)
        txt1 = normalize_bracket(txt1)
        txt2 = normalize_bracket(txt2)
        
        ans[idx] = (txt1 + '/' + txt2)
    
    return ans

In [89]:
pred1 = reconstruct(len(real_test_df), raw_pred1, raw_tokens, raw_address)
pred2 = reconstruct(len(real_test_df), raw_pred2, raw_tokens, raw_address)

In [90]:
from string import punctuation

def reconstruct(num, txt, pred, raw_tokens, raw_address):
    def complete_word(x):
        y = x.strip().strip(punctuation)
        if y in wordlist:
            x = x.replace(y, wordlist[y])
        return x
    
    def normalize_bracket(x):
        if '(' in x and ')' not in x:
            x = x + ')'
        elif ')' in x and '(' not in x:
            x = '(' + x
        return x
    
    ans = ['/'] * num
    for idx in range(num):
        res = pred[idx]
        start_poi, end_poi = -1, -1
        start_str, end_str = -1, -1
        for i in range(len(txt[idx])):
            if 'POI' in res[i]:
                if start_poi == -1: start_poi = i
                end_poi = i
            if 'STR' in res[i]:
                if start_str == -1: start_str = i
                end_str = i
        
        if start_poi != -1:
            txt1 = raw_address[idx]
            for i in range(start_poi):
                txt1 = txt1[len(raw_tokens[idx][i]):].strip()
            for i in range(len(raw_tokens[idx]) - 1, end_poi, -1):
                txt1 = txt1[:-len(raw_tokens[idx][i])].strip()
            
            txt1_check = ''.join(raw_tokens[idx][start_poi:end_poi + 1]).replace(' ', '')
            assert txt1.replace(' ', '') == txt1_check
            
            last = len(txt1)
            for i in range(end_poi, start_poi - 1, -1):
                while last > 0 and txt1[last - 1] == ' ':
                    last -= 1
                assert last >= len(raw_tokens[idx][i])
                last -= len(raw_tokens[idx][i])
                if 'SHORT' in res[i]:
                    txt1 = txt1[:last] + complete_word(raw_tokens[idx][i]) + txt1[last + len(raw_tokens[idx][i]):]
        else:
            txt1 = ''
        
        if start_str != -1:
            txt2 = raw_address[idx]
            for i in range(start_str):
                txt2 = txt2[len(raw_tokens[idx][i]):].strip()
            for i in range(len(raw_tokens[idx]) - 1, end_str, -1):
                txt2 = txt2[:-len(raw_tokens[idx][i])].strip()
            
            txt2_check = ''.join(raw_tokens[idx][start_str:end_str + 1]).replace(' ', '')
            assert txt2.replace(' ', '') == txt2_check
            
            last = len(txt2)
            for i in range(end_str, start_str - 1, -1):
                while last > 0 and txt2[last - 1] == ' ':
                    last -= 1
                assert last >= len(raw_tokens[idx][i])
                last -= len(raw_tokens[idx][i])
                if 'SHORT' in res[i]:
                    txt2 = txt2[:last] + complete_word(raw_tokens[idx][i]) + txt2[last + len(raw_tokens[idx][i]):]
        else:
            txt2 = ''
        
        txt1 = txt1.strip(punctuation)
        txt2 = txt2.strip(punctuation)
        txt1 = normalize_bracket(txt1)
        txt2 = normalize_bracket(txt2)
        
        ans[idx] = (txt1 + '/' + txt2)
    
    return ans

In [93]:
pred = reconstruct(len(real_test_df), raw_txts, raw_preds, raw_tokens, raw_address)

In [102]:
real_test_df['POI/street'] = pred2 # pred1 pred2

In [103]:
# real_test_df.drop(columns=['raw_address', 'tokens'], inplace=True)
real_test_df.head()

Unnamed: 0,id,POI/street
0,0,/s. par
1,1,/angg per
2,2,asma laundry/mand imog
3,3,ud agung rejeki/raya nga
4,4,/cut mutia


In [104]:
real_test_df.to_csv('bert-indo-final.csv', index=False)