<a href="https://colab.research.google.com/github/lclazx/nlp_learning/blob/master/extract_data_with_bert_hub.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 步骤一：引用所需资源

In [1]:
!pip install bert-tensorflow
import tensorflow as tf
import tensorflow_hub as hub
import bert
from bert import optimization
from bert import tokenization

Collecting bert-tensorflow
[?25l  Downloading https://files.pythonhosted.org/packages/a6/66/7eb4e8b6ea35b7cc54c322c816f976167a43019750279a8473d355800a93/bert_tensorflow-1.0.1-py2.py3-none-any.whl (67kB)
[K     |████▉                           | 10kB 18.7MB/s eta 0:00:01[K     |█████████▊                      | 20kB 3.0MB/s eta 0:00:01[K     |██████████████▋                 | 30kB 4.4MB/s eta 0:00:01[K     |███████████████████▍            | 40kB 2.9MB/s eta 0:00:01[K     |████████████████████████▎       | 51kB 3.5MB/s eta 0:00:01[K     |█████████████████████████████▏  | 61kB 4.2MB/s eta 0:00:01[K     |████████████████████████████████| 71kB 2.8MB/s 
Installing collected packages: bert-tensorflow
Successfully installed bert-tensorflow-1.0.1





步骤二：准备数据

In [0]:
class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self,
                 input_ids,
                 input_mask,
                 segment_ids,
                 token_label_ids,
                 predicate_label_id,
                 is_real_example=True):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.token_label_ids = token_label_ids
        self.predicate_label_id = predicate_label_id
        self.is_real_example = is_real_example

In [0]:
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
  """Truncates a sequence pair in place to the maximum length."""

  # This is a simple heuristic which will always truncate the longer sequence
  # one token at a time. This makes more sense than truncating an equal percent
  # of tokens from each, since if one sequence is very short then each token
  # that's truncated likely contains more information than a longer sequence.
  while True:
    total_length = len(tokens_a) + len(tokens_b)
    if total_length <= max_length:
      break
    if len(tokens_a) > len(tokens_b):
      tokens_a.pop()
    else:
      tokens_b.pop()

In [0]:
def convert_single_example(ex_index, example, token_label_list, predicate_label_list, max_seq_length, tokenizer):
  if isinstance(example, PaddingInputExample):
    return InputFeatures(
        input_ids=[0]*max_seq_length,
        input_mask=[0]*max_seq_length,
        segment_ids=[0]*max_seq_length,
        token_label_ids=[0]*max_seq_length,
        predicate_label_id=[0],
        is_real_example=False
    )

  token_label_map={}
  for(i, label) in enumerate(token_label_list):
    token_label_map[label]=i
  predicate_label_map={}
  for(i, label) in enumerate(predicate_label_list):
    predicate_label_map[label] = i
  
  # text_token = example.text_token.split['\t'][0].split(' ')
  # if example.token_label is not None:
  #   token_label = example.token_label.split('\t')[0].split(' ')
  # else:
  #   token_label = ["O"]*len(text_token)
  # assert len(text_token) == len(token_label)

  # text_predicate = example.text_token.split('\t')[1]
  # if example.token_label is not None:
  #   token_predicate = example.token_label.split('\t')[1]
  # else:
  #   token_predicate = text_predicate
  text_token = example.text_token
  token_label = text_token.pop()
  text_predicate = example.token_label
  token_predicate = text_predicate.pop()
  assert token_predicate == text_predicate

  tokens_b = [text_predicate]*len(text_token)
  predicate_id = predicate_label_map[text_predicate]

  _truncate_seq_pair(text_token, tokens_b, max_seq_length-3)

  tokens = []
  token_label_ids = []
  segment_ids = []
  tokens.append('[CLS]')
  token_label_ids.append(token_label_map["[CLS]"])
  segment_ids.append(0)

  for token, label in zip(text_token, token_label):
    tokens.append(text_token)
    token_label_ids.append(token_label)
    segment_ids.append(0)

  tokens.append('[SEP]')
  token_label_ids.append(token_label_map['[SEP]'])
  segment_ids.append(0)

  input_ids = tokenizer.convert_tokens_to_ids(tokens)
  bias = 1
  for token in tokens_b:
    input_ids.append(predicate_id + bias)
    segment_ids.append(1)
    token_label_ids.append(token_label_map['[category]'])

  input_ids.append(tokenizer.convert_tokens_to_ids(['[SEP]'])[0])
  segment_ids.append(1)
  token_label_ids.append(token_label_map('[SEP]'))

  input_mask = [1]*len(input_ids)
  while len(input_ids) < max_seq_length:
    input_ids.append(0)
    input_mask.append(0)
    segment_ids.append(0)
    token_label_ids.append(0)
    tokens.append('[Padding]')

  assert len(input_ids) == max_seq_length
  assert len(input_mask) == max_seq_length
  assert len(segment_ids) == max_seq_length
  assert len(token_label_ids) == max_seq_length

  if ex_index<5:
    tf.logging.info("*** Example ***")
    tf.logging.info("guid:%s" % (example.guid))
    tf.logging.info("tokens: %s" % ''.join([tokenization.printable_text(x) for x in tokens]))
    tf.logging.info('input_ids: %s' % ' '.join([str(x) for x in input_ids]))
    tf.logging.info('input_mask: %s' % ' '.join([str(x) for x in input_mask]))
    tf.logging.info("segment_ids: %s" % ' '.join([str(x) for x in segment_ids]))
    tf.logging.info('token_label_ids: %s' % ' '.join([str(x) for x in token_label_ids]))
    tf.logging.info('predicate_id: %s' % str(predicate_id))

  feature = InputFeatures(
      input_ids = input_ids,
      input_mask = input_mask,
      segment_ids = segment_ids,
      token_label_ids = token_label_ids,
      predicate_label_id = [predicate_id],
      is_real_example=True
  )

  return feature
  

In [0]:
class InputExample(object):
  def __init__(self, guid, text_token, token_labels):
    self.guid=guid
    self.text_token=text_token
    self.token_label=token_labels

class PaddingInputExample(object):
  """Truncates a sequence pair in place to the maximum length."""


In [0]:
from google.colab import auth
auth.authenticate_user()
from googleapiclient.discovery import build
gcs_service = build('storage', 'v1')

OUTPUT_DIR = 'extract_training_output' #@param
DO_DELETE =  False#@param
USE_BUCKET = True #@param
BUCKET = 'bert_classification'#@param

if USE_BUCKET:
  OUTPUT_DIR = 'gs://{}/{}'.format(BUCKET, OUTPUT_DIR)
  
if DO_DELETE:
  try:
    tf.gfile.DeleteRecursively(OUTPUT_DIR)
  except:
    pass
tf.gfile.MakeDirs(OUTPUT_DIR)
  

In [0]:
from apiclient.http import MediaIoBaseDownload
def download_file(output_dir, source_file):
  with open(output_dir, 'wb') as f:
    request = gcs_service.objects().get_media(bucket=BUCKET, object=source_file)
    media = MediaIoBaseDownload(f, request)
    done = False
    while not done:
      _, done = media.next_chunk()

download_file(output_dir='/tmp/train_data.json', source_file='raw_data/train_data.json/train_data.json')
download_file(output_dir='/tmp/test_data.json', source_file='raw_data/test_data_postag.json/test_data_postag.json')
download_file(output_dir='/tmp/dev_data.json', source_file='raw_data/dev_data.json/dev_data.json')

print('Downloaded')

In [9]:
BERT_MODEL_HUB = 'https://tfhub.dev/google/bert_chinese_L-12_H-768_A-12/1'
TRAINABLE = True

def create_tokenizer_from_hub_module():
  with tf.Graph().as_default():
    bert_module = hub.Module(BERT_MODEL_HUB)
    tokenization_info = bert_module(signature='tokenization_info', as_dict=True)
    with tf.Session() as sess:
      vocab_file, do_lower_case = sess.run([tokenization_info['vocab_file'], tokenization_info['do_lower_case']])
    return bert.tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case)

tokenizer = create_tokenizer_from_hub_module()

INFO:tensorflow:Saver not created because there are no variables in the graph to restore


INFO:tensorflow:Saver not created because there are no variables in the graph to restore








In [0]:
def find_bio_in_text_token(so_object_text, text_token):
  for i, token in enumerate(text_token):
    trimmed_token = token.replace('##', '')
    if so_object_text.startswith(trimmed_token):
      for j in range(i+1,len(text_token)):
        concatenated_token = ''.join([x.replace("##",'') for x in text_token[i:j]])
        if so_object_text == concatenated_token:
          return i, j
        elif so_object_text.startswith(concatenated_token):
          continue
        else:
          break
  return -1,-1

In [0]:
def build_SPO_Example(data):
  text = data['text']
  spo_list = data['spo_list']
  text_token = tokenizer.tokenize(text)
  examples = []
  for spo in spo_list:
    subject = spo['subject']
    obj = spo['object']
    predicate = spo['predicate']
    subject_start_index, subject_end_index = find_bio_in_text_token(subject, text_token)
    obj_start_index, obj_end_index = find_bio_in_text_token(obj, text_token)
    text_token_with_prediction=[]
    for token in text_token:
      text_token_with_prediction.append(token)
    token_labels = ['O']*len(text_token_with_prediction)
    text_token_with_prediction.append(predicate)
    token_labels.append(predicate)

    if subject_start_index>-1:
      token_labels[subject_start_index] = 'B-SUB'
      for i in range(subject_start_index+1, subject_end_index):
        token_labels[i] = 'I-SUB'
    
    if obj_start_index > -1:
      token_labels[obj_start_index] = 'B-OBJ'
      for i in range(obj_start_index+1, obj_end_index):
        token_labels[i] = 'I-OBJ'

    #examples.append({'text_token': text_token_with_prediction, 'token_labels': token_labels})
    example = InputExample(guid=None, text_token=text_token_with_prediction, token_labels=token_labels)
    examples.append(example)
  return examples
      


In [23]:
test_data = {'text': '查尔斯·阿兰基斯（Charles Aránguiz），1989年4月17日出生于智利圣地亚哥，智利职业足球运动员，司职中场，效力于德国足球甲级联赛勒沃库森足球俱乐部',
             'spo_list':[
                       {
                          'predicate':'出生地',
                          'object':'圣地亚哥',
                          'subject':'查尔斯·阿兰基斯'
                       },
                       {
                          'predicate':'出生日期',
                          'object':'1989年4月17日',
                          'subject':'查尔斯·阿兰基斯'

                       }
                  ]
        }

examples = build_SPO_Example(test_data)
examples

[<__main__.InputExample at 0x7f92a6558438>,
 <__main__.InputExample at 0x7f92a65584a8>]

In [0]:
def get_predicate_labels():
  return ['丈夫', '上映时间', '专业代码', '主持人', '主演', '主角', '人口数量', '作曲', '作者', '作词', '修业年限', '出品公司', '出版社', '出生地', '出生日期', '创始人', '制片人', '占地面积', '号', '嘉宾', '国籍', '妻子', '字', '官方语言', '导演', '总部地点', '成立日期', '所在城市', '所属专辑', '改编自', '朝代', '歌手', '母亲', '毕业院校', '民族', '气候', '注册资本', '海拔', '父亲', '目', '祖籍', '简称', '编剧', '董事长', '身高', '连载网站', '邮政编码', '面积', '首都']
def get_token_labels():
  BIO_token_labels = ["[Padding]", "[category]", "[##WordPiece]", "[CLS]", "[SEP]", "B-SUB", "I-SUB", "B-OBJ", "I-OBJ", "O"]  #id 0 --> [Paddding]
  return BIO_token_labels


In [26]:
examples[0].text_token

['查',
 '尔',
 '斯',
 '·',
 '阿',
 '兰',
 '基',
 '斯',
 '（',
 '[UNK]',
 '[UNK]',
 '）',
 '，',
 '1989',
 '年',
 '4',
 '月',
 '17',
 '日',
 '出',
 '生',
 '于',
 '智',
 '利',
 '圣',
 '地',
 '亚',
 '哥',
 '，',
 '智',
 '利',
 '职',
 '业',
 '足',
 '球',
 '运',
 '动',
 '员',
 '，',
 '司',
 '职',
 '中',
 '场',
 '，',
 '效',
 '力',
 '于',
 '德',
 '国',
 '足',
 '球',
 '甲',
 '级',
 '联',
 '赛',
 '勒',
 '沃',
 '库',
 '森',
 '足',
 '球',
 '俱',
 '乐',
 '部']

In [25]:
convert_single_example(0, examples[0], get_token_labels(), get_predicate_labels(), 256, tokenizer)

AssertionError: ignored