<a href="https://colab.research.google.com/github/nishikaz/PlayGround/blob/master/LUKEonColab_EntityClassification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# HuggingFace Transformers から LUKE による Entity Classification を実行してみる

ほぼ[このノートブック](https://colab.research.google.com/github/studio-ousia/luke/blob/master/notebooks/huggingface_open_entity.ipynb#scrollTo=tLzX8LIS127b)の写経。自分の理解用。

In [1]:
# transformers をインストールする。LUKE は現時点で master branch でしか利用できない模様
!pip install git+https://github.com/huggingface/transformers.git

Collecting git+https://github.com/huggingface/transformers.git
  Cloning https://github.com/huggingface/transformers.git to /tmp/pip-req-build-hl7m1yfv
  Running command git clone -q https://github.com/huggingface/transformers.git /tmp/pip-req-build-hl7m1yfv
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/75/ee/67241dc87f266093c533a2d4d3d69438e57d7a90abb216fa076e7d475d4a/sacremoses-0.0.45-py3-none-any.whl (895kB)
[K     |████████████████████████████████| 901kB 4.3MB/s 
Collecting huggingface-hub==0.0.8
  Downloading https://files.pythonhosted.org/packages/a1/88/7b1e45720ecf59c6c6737ff332f41c955963090a18e72acbcbeac6b25e86/huggingface_hub-0.0.8-py3-none-any.whl
Collecting tokenizers<0.11,>=0.10.1
[?25l  Downloading https://files.pythonhosted.org/packages/ae/04/5b870f26a858552025a62f1649c20d29d

In [2]:
import json
import torch
from tqdm import trange
from transformers import LukeTokenizer, LukeForEntityClassification

## データセットの準備

ACL2019 で発表された ERNIE データセットに含まれる OpenEntity を用いる。
json 形式で格納されており

In [3]:
# gdown を使うと google drive から直接 colab がマウントしている領域へファイルをダウンロードできる
!gdown --id 1HlWw7Q6-dFSm9jNSCh4VaBf1PlGqt9im
!tar xzf /content/data.tar.gz

# test.json をワークスペースに配置する
!cp data/OpenEntity/test.json .

Downloading...
From: https://drive.google.com/uc?id=1HlWw7Q6-dFSm9jNSCh4VaBf1PlGqt9im
To: /content/data.tar.gz
322MB [00:02, 152MB/s]


In [4]:
def load_examples(dataset_file):
    with open(dataset_file, 'r') as f:
        data = json.load(f)
    
    examples = []
    for item in data:
        examples.append(dict(
            text         = item['sent'],                    # テキスト
            entity_spans = [(item['start'], item['end'])],  # エンティティの範囲
            label        = item['labels']                   # エンティティタイプ
        ))
    
    return examples

In [10]:
test_examples = load_examples('test.json')
test_examples[:3]

[{'entity_spans': [(3, 20)],
  'label': ['time'],
  'text': 'On late Monday night , 30th Nov 2009 , Bangladesh Police arrested Rajkhowa somewhere near Dhaka .'},
 {'entity_spans': [(111, 123)],
  'label': ['event'],
  'text': 'Leo W. Gerard , president of the steelworkers union , said he and several leaders of the AFL-CIO had organized joint events this week with the Sierra Club and the Alliance for Climate Protection .'},
 {'entity_spans': [(76, 78)],
  'label': ['person'],
  'text': 'Peace agreements will only bring further losses and push back our cause , " he added , pointing out that Abbas \'s Fatah party also maintains its own armed wing , the loosely affiliated Al - Aqsa Martyrs Brigades .'}]

## 学習済みのモデルとトーカナイザをダウンロードする

In [12]:
model = LukeForEntityClassification.from_pretrained('studio-ousia/luke-large-finetuned-open-entity')
model.eval()
model.to('cuda')

tokenizer = LukeTokenizer.from_pretrained('studio-ousia/luke-large-finetuned-open-entity')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1051.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=2239159855.0, style=ProgressStyle(descr…




Some weights of the model checkpoint at studio-ousia/luke-large-finetuned-open-entity were not used when initializing LukeForEntityClassification: ['luke.embeddings.position_ids']
- This IS expected if you are initializing LukeForEntityClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LukeForEntityClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=898822.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=15287192.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=33.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1000.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1691.0, style=ProgressStyle(description…




## テストデータを使って EntityClassification を行う

In [113]:
batch_size = 128

num_predicted = 0
num_gold = 0
num_correct = 0

all_predictions = []
all_labels = []

for batch_start_idx in trange(0, len(test_examples), batch_size):
    batch_examples = test_examples[batch_start_idx:batch_start_idx+batch_size]
    
    texts = [example['text'] for example in batch_examples]
    entity_spans = [example['entity_spans'] for example in batch_examples]
    gold_labels = [example['label'] for example in batch_examples]

    inputs = tokenizer(texts, entity_spans=entity_spans, return_tensors='pt', padding=True)
    inputs = inputs.to('cuda')

    # dict_keys(['input_ids', 'entity_ids', 'entity_position_ids', 'attention_mask', 'entity_attention_mask'])
    # input_ids             : token の id (多分 sub-words 単位)
    # entity_ids            : 多分分類対象のエンティティ id
    # entity_position_ids   : input_ids の中でエンティティを表すトークン位置を示す
    # attention_mask        : padding されている箇所を区分けするためのマスク
    # entity_attention_mask : 謎

    with torch.no_grad():
        outputs = model(**inputs)
    
    # odict_keys(['logits'])
    # logits : 各ラベルに対する尤度
    
    # バッチの中に含まれるすべてのラベルを数え上げる
    num_gold += sum(len(l) for l in gold_labels)

    for logits, labels in zip(outputs.logits, gold_labels):
        for index, logit in enumerate(logits):
            if logit > 0:
                num_predicted += 1
                predicted_label = model.config.id2label[index]
                if predicted_label in labels:
                    num_correct += 1

precision = num_correct / num_predicted
recall = num_correct / num_gold
f1 = 2 * precision * recall / (precision + recall)

print(f"\n\nprecision: {precision} recall: {recall} f1: {f1}")



  0%|          | 0/16 [00:00<?, ?it/s][A[A

  6%|▋         | 1/16 [00:06<01:30,  6.04s/it][A[A

 12%|█▎        | 2/16 [00:10<01:18,  5.62s/it][A[A

 19%|█▉        | 3/16 [00:15<01:09,  5.32s/it][A[A

 25%|██▌       | 4/16 [00:19<01:00,  5.06s/it][A[A

 31%|███▏      | 5/16 [00:25<00:56,  5.15s/it][A[A

 38%|███▊      | 6/16 [00:30<00:51,  5.11s/it][A[A

 44%|████▍     | 7/16 [00:35<00:47,  5.27s/it][A[A

 50%|█████     | 8/16 [00:39<00:39,  4.92s/it][A[A

 56%|█████▋    | 9/16 [00:45<00:35,  5.03s/it][A[A

 62%|██████▎   | 10/16 [00:49<00:29,  4.84s/it][A[A

 69%|██████▉   | 11/16 [00:54<00:24,  4.83s/it][A[A

 75%|███████▌  | 12/16 [00:58<00:18,  4.68s/it][A[A

 81%|████████▏ | 13/16 [01:02<00:13,  4.51s/it][A[A

 88%|████████▊ | 14/16 [01:08<00:09,  4.80s/it][A[A

 94%|█████████▍| 15/16 [01:13<00:04,  4.82s/it][A[A

100%|██████████| 16/16 [01:16<00:00,  4.75s/it]



precision: 0.7980295566502463 recall: 0.7657563025210085 f1: 0.781559903511123





## 手動入力のテキストからエンティティを分類する

In [119]:
text = 'I have been in N.Y. city for 2 months.'
entity_spans = [(15, 24)]

inputs = tokenizer(text, entity_spans=entity_spans, return_tensors='pt')
inputs.to('cuda')
outputs = model(**inputs)

predicted_indices = [index for index, logit in enumerate(outputs.logits[0]) if logit > 0]
print('Predicted entity type is {:}'.format([model.config.id2label[index] for index in predicted_indices]))

Predicted entity type is ['location', 'place']


In [122]:
text = 'Sakura Ayane is one of the most famous voice actress in Japan.'
entity_spans = [(0, 12)]

inputs = tokenizer(text, entity_spans=entity_spans, return_tensors='pt')
inputs.to('cuda')
outputs = model(**inputs)

predicted_indices = [index for index, logit in enumerate(outputs.logits[0]) if logit > 0]
print('Predicted entity type is {:}'.format([model.config.id2label[index] for index in predicted_indices]))

Predicted entity type is ['person']
