In [1]:
import torch
import pandas as pd

from tqdm.notebook import tqdm
from collections import Counter
from torchtext.vocab import vocab
from sklearn.metrics import accuracy_score
from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer




In [2]:
train_set = pd.read_csv('./train/train.tsv', sep='\t', header=None, names=['labels', 'text'])

val_set = pd.read_csv('./dev-0/expected.tsv', sep='\t', header=None, names=['labels'])
val_set['text'] = pd.read_csv('./dev-0/in_dev.tsv', sep='\t', header=None, names=['text'])

test_set = pd.read_csv('./test-A/in.tsv', sep='\t', header=None, names=['text'])

In [13]:
tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")

ner = pipeline("ner", model=model, tokenizer=tokenizer)

Some weights of the model checkpoint at dslim/bert-base-NER were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [9]:
train_set["text"][0]

'EU rejects German call to boycott British lamb . </S> Peter Blackburn </S> BRUSSELS 1996-08-22 </S> The European Commission said on Thursday it disagreed with German advice to consumers to shun British lamb until scientists determine whether mad cow disease can be transmitted to sheep . </S> Germany \'s representative to the European Union \'s veterinary committee Werner Zwingmann said on Wednesday consumers should buy sheepmeat from countries other than Britain until the scientific advice was clearer . </S> " We do n\'t support any such recommendation because we do n\'t see any grounds for it , " the Commission \'s chief spokesman Nikolaus van der Pas told a news briefing . </S> He said further scientific study was required and if it was found that action was needed it should be taken by the European Union . </S> He said a proposal last month by EU Farm Commissioner Franz Fischler to ban sheep brains , spleens and spinal cords from the human and animal food chains was a highly specif

In [55]:
train_set['labels'][0]


'B-ORG O B-MISC O O O B-MISC O O O B-PER I-PER O B-LOC O O O B-ORG I-ORG O O O O O O B-MISC O O O O O B-MISC O O O O O O O O O O O O O O O B-LOC O O O O B-ORG I-ORG O O O B-PER I-PER O O O O O O O O O O O B-LOC O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-ORG O O O B-PER I-PER I-PER I-PER O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-ORG I-ORG O O O O O O O O O B-ORG O O B-PER I-PER O O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-PER O B-MISC O O O O B-LOC O B-LOC O O O O O O O B-MISC I-MISC I-MISC O B-MISC O O O O O O O O B-PER O O O O O O O B-ORG O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-MISC O O B-PER I-PER I-PER O O O B-PER O O B-ORG O O O O O O O O O O O O O O O O O O B-LOC O B-LOC O B-PER O O O O O B-ORG O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-MISC O O O O O O O O O O O O O O O O B-MISC O O O O O O O O O O O O O O O O O O O B-MISC O O O O O O B-MISC O O O O O B-LOC O O O O O O O O O O O O

In [14]:
example = train_set['text'][0].split(" ")
ner_results = ner(example)

result = []

for i in ner_results:
    if len(i) > 0:
        result.append(i[0]['entity'])
    else:
        result.append('O')

" ".join(result)

'B-ORG O B-MISC O O O B-MISC O O O B-PER B-ORG O B-MISC O O O B-MISC B-ORG O O O O O O B-MISC O O O O O B-MISC O O O O O O O O O O O O O O O B-LOC O O O O B-MISC B-ORG O O O B-PER B-PER O O O O O O O O O O O B-LOC O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-ORG O O O B-PER B-PER O B-ORG O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-MISC B-ORG O O O O O O O O O B-ORG O O B-PER B-PER O O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-PER O B-MISC O O O O B-LOC O B-LOC O O O O O O O B-PER B-MISC B-MISC O B-ORG O O O O O O O O B-PER O O O O O O O B-ORG O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-MISC O O B-ORG O B-PER O O O B-PER O O B-ORG O O O O O O O O O O O O O O O O O O B-LOC O B-LOC O B-PER O O O O O B-ORG O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-ORG O O O O O O O O O O O O O O O O B-MISC O O O O O O O O O O O O O O O O O O O B-MISC O O O O O O B-MISC O O O O O B-LOC O O O O O O O O O O O O O O O 

In [4]:
def filter_output(row):
    return " ".join([out[0]['entity'] if len(out) > 0 else 'O' for out in row ])

In [15]:
results = []
for row in tqdm(val_set['text']):
    pred = ner(row.split(" "))
    results.append(filter_output(pred))

  0%|          | 0/215 [00:00<?, ?it/s]

In [16]:
results[0]

'O O O O O B-ORG O O O O O O B-LOC O O B-LOC B-MISC O B-PER B-PER O O O O O O O B-ORG O B-ORG O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-ORG O B-ORG O B-ORG O O O O O O B-ORG O O O O O O O O O O B-ORG O O O O B-ORG O O O O O O O O B-PER B-LOC O B-ORG O O O O O O O O O O O O O O B-LOC O B-PER B-PER O O O O O O O O O O B-ORG O O O O O O O O O B-PER O O O O O O O O O O B-ORG O O O O O O O O O O O B-PER B-PER O B-PER O O O O O O O O O O B-ORG O B-LOC O O B-PER O O O O B-LOC O O O O O O O O O O O O O O O O B-ORG O O O O O O O O O O O O O O O B-ORG O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-LOC O B-ORG O B-PER B-PER O O O O O B-LOC O O O O O O O O O O O O O O O O O O O O O O O O O B-ORG O O O O O O O B-LOC O B-PER O O O O O B-ORG O O O O O O O O O O O O B-ORG O O O O O O O O O O O O O B-ORG O O O O O O O O O O O O O O O O O O O O B-MISC B-PER B-ORG O O O O O B-PER B-PER O O O O B-PER B-PER O O O O B-ORG O O O O O O O O O O O O O O

In [17]:
val_set['labels'][0]

'O O B-ORG O O O O O O O O O B-LOC O O B-MISC I-MISC O B-PER I-PER O O O O O O O B-ORG O B-ORG O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-ORG O B-ORG O B-ORG O O O O O O B-ORG O O O O O O O O O O B-ORG O O O O B-ORG O O O O O O O O B-LOC I-LOC O B-ORG O O O O O O O O O O O O O O B-LOC O B-PER I-PER O O O O O O O O O O B-ORG O O O O O O O O O B-PER O O O O O O O O O O B-ORG O O O O O O O O O O O B-PER I-PER O B-PER I-PER O O O O O O O O O B-ORG O B-LOC O O B-PER O O O O B-LOC O O O O O O O O O O O O O O O O B-ORG O O O O O O O O O O O O O O O B-ORG O O O O O O O O O B-PER O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-LOC O B-ORG O B-PER I-PER O O O O O B-LOC O O O O O O O O O O O O O O O O O O O O O O O O O B-ORG O O O O O O O B-LOC O B-PER I-PER O O O O B-ORG O O O O O O O O O O O O B-ORG O O O O O O O O O O O O O B-ORG O O O O O O O O O O O O O O O O O O O O B-MISC B-PER I-PER O O O O O B-PER I-PER O O O O B-PER I-PER O O O O B-ORG O O O O O O O 

In [None]:
results_test = []
for row in tqdm(test_set['text']):
    pred = ner(row.split(" "))
    results_test.append(filter_output(pred))

In [18]:
def save_prediction(test_pred, file_name):
    with open(file_name, 'w') as f:
        for i in range(len(test_pred)):
            f.write(f'{test_pred[i]}\n')

In [20]:
save_prediction(results, 'dev-0/out.tsv')
save_prediction(results_test, 'test-A/out.tsv')