# BERT for Patents

Copyright 2020 Google Inc.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

In [None]:
import collections
import math
import random
import sys
import time
from typing import Dict, List, Tuple

# Use Tensorflow 2.0
import tensorflow as tf
import numpy as np

In [None]:
# Set BigQuery application credentials
from google.cloud import bigquery
import os
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "path/to/file.json"

project_id = "your_bq_project_id"
bq_client = bigquery.Client(project=project_id)

In [None]:
# You will have to clone the BERT repo
!test -d bert_repo || git clone https://github.com/google-research/bert bert_repo
if not 'bert_repo' in sys.path:
  sys.path += ['bert_repo']

The BERT repo uses Tensorflow 1 and thus a few of the functions have been moved/changed/renamed in Tensorflow 2. In order for the BERT tokenizer to be used, one of the lines in the repo that was just cloned needs to be modified to comply with Tensorflow 2. Line 125 in the BERT tokenization.py file must be changed as follows:

From => `with tf.gfile.GFile(vocab_file, "r") as reader:`

To => `with tf.io.gfile.GFile(vocab_file, "r") as reader:`

Once that is complete and the file is saved, the tokenization library can be imported.

In [None]:
import tokenization

## Set Up BERT and Some Helpers

The BERT exported here has been trained on >100 million patent documents and was trained on all parts of a patent (abstract, claims, description).

The BERT model exported here comes in two formats:

* [SavedModel](https://storage.googleapis.com/patents-public-data-github/saved_model.zip)

* [Checkpoint](https://storage.googleapis.com/patents-public-data-github/checkpoint.zip)

**NOTE: This notebook uses the saved model format.**

The models can also be loaded and saved in another format or just the weights can be saved.

The BERT model has been trained on >100 million patent documents and was trained on all parts of a patent (abstract, claims, description). It has a similar configuration to the BERT-Large model, with a couple of important notes:

* The maximum input sequence length is 512 tokens and maximum masked words for a sequence is 45.
* The vocabulary has approximately 9000 added words from the standard BERT vocabulary. These represent frequently used patent terms.
* The vocabulary includes "context" tokens indicating what part of a patent the text is from (abstract, claims, summary, invention). Providing context tokens in the examples is optional.

The full BERT vocabulary can be downloaded [here](https://storage.googleapis.com/patents-public-data-github/bert_for_patents_vocab_39k.txt). The vocabulary also contains 1000 unused tokens so that more tokens can be added.

The exact configuration for the BERT model is as follows (and downloaded [here](https://storage.googleapis.com/patents-public-data-github/bert_for_patents_large_config.json)):

* attention_probs_dropout_prob: 0.1
* hidden_act: gelu
* hidden_dropout_prob: 0.1
* hidden_size: 1024
* initializer_range: 0.02
* intermediate_size: 4096
* max_position_embeddings: 512
* num_attention_heads: 16
* num_hidden_layers: 24
* vocab_size: 39859

The model has requires the following input signatures:
1. `input_ids`
2. `input_mask`
3. `segment_ids`
4. `mlm_ids`

And the BERT model contains output signatures for:
1. `cls_token`
2. `encoder_layer` is the contextualized word embeddings from the last encoder layer.
3. `mlm_logits` is the predictions for any masked tokens provided to the model.

In [None]:
# The functions in this block are also found in the bert cloned repo in the 
# `run_classifier.py` file, however those also have some compatibility issues 
# and thus the functions needed are just copied here.

class InputFeatures(object):
  """A single set of features of data."""

  def __init__(self, input_ids, input_mask, segment_ids, label_id,
               is_real_example=True):
    self.input_ids = input_ids
    self.input_mask = input_mask
    self.segment_ids = segment_ids
    self.label_id = label_id
    self.is_real_example = is_real_example

class InputExample(object):
  """A single training/test example for simple sequence classification."""

  def __init__(self, guid, text_a, text_b=None, label=None):
    """Constructs a InputExample."""
    self.guid = guid
    self.text_a = text_a
    self.text_b = text_b
    self.label = label

def _truncate_seq_pair(tokens_a, tokens_b, max_length):
  """Truncates a sequence pair in place to the maximum length."""
  while True:
    total_length = len(tokens_a) + len(tokens_b)
    if total_length <= max_length:
      break
    if len(tokens_a) > len(tokens_b):
      tokens_a.pop()
    else:
      tokens_b.pop()

def convert_examples_to_features(examples, label_list, max_seq_length,
                                 tokenizer):
  """Convert a set of `InputExample`s to a list of `InputFeatures`."""

  features = []
  for (ex_index, example) in enumerate(examples):
    feature = convert_single_example(ex_index, example, label_list,
                                     max_seq_length, tokenizer)
    features.append(feature)
  return features

def convert_single_example(ex_index, example, label_list, max_seq_length,
                           tokenizer):
  """Converts a single `InputExample` into a single `InputFeatures`."""

  label_map = {}
  for (i, label) in enumerate(label_list):
    label_map[label] = i

  tokens_a = tokenizer.tokenize(example.text_a)
  tokens_b = None
  if example.text_b:
    tokens_b = tokenizer.tokenize(example.text_b)

  if tokens_b:
    # Modifies `tokens_a` and `tokens_b` in place so that the total
    # length is less than the specified length.
    # Account for [CLS], [SEP], [SEP] with "- 3"
    _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
  else:
    # Account for [CLS] and [SEP] with "- 2"
    if len(tokens_a) > max_seq_length - 2:
      tokens_a = tokens_a[0:(max_seq_length - 2)]

  tokens = []
  segment_ids = []
  tokens.append("[CLS]")
  segment_ids.append(0)
  for token in tokens_a:
    tokens.append(token)
    segment_ids.append(0)
  tokens.append("[SEP]")
  segment_ids.append(0)

  if tokens_b:
    for token in tokens_b:
      tokens.append(token)
      segment_ids.append(1)
    tokens.append("[SEP]")
    segment_ids.append(1)

  input_ids = tokenizer.convert_tokens_to_ids(tokens)

  # The mask has 1 for real tokens and 0 for padding tokens. Only real
  # tokens are attended to.
  input_mask = [1] * len(input_ids)

  # Zero-pad up to the sequence length.
  while len(input_ids) < max_seq_length:
    input_ids.append(0)
    input_mask.append(0)
    segment_ids.append(0)

  assert len(input_ids) == max_seq_length
  assert len(input_mask) == max_seq_length
  assert len(segment_ids) == max_seq_length

  label_id = label_map[example.label]

  feature = InputFeatures(
      input_ids=input_ids,
      input_mask=input_mask,
      segment_ids=segment_ids,
      label_id=label_id,
      is_real_example=True)
  return feature

In [None]:
def get_tokenized_input(
    texts: List[str], tokenizer: tokenization.FullTokenizer) -> List[List[int]]:
  """Returns list of tokenized text segments."""

  return [tokenizer.tokenize(text) for text in texts]


class BertPredictor():

  def __init__(
      self, 
      model_name: str, 
      text_tokenizer: tokenization.FullTokenizer, 
      max_seq_length: int,
      max_preds_per_seq: int,
      has_context: bool = False):
    """Initializes a BertPredictor object."""

    self.tokenizer = text_tokenizer
    self.max_seq_length = max_seq_length
    self.max_preds_per_seq = max_preds_per_seq
    self.mask_token_id = 4
    # If you want to add context tokens to the input, set value to True.
    self.context = has_context

    model = tf.compat.v2.saved_model.load(export_dir=model_name, tags=['serve'])
    self.model = model.signatures['serving_default']

  def get_features_from_texts(self, texts: List[str]) -> Dict[str, int]:
    """Uses tokenizer to convert raw text into features for prediction."""

    #examples = [run_classifier.InputExample(0, t, label='') for t in texts]
    #features = run_classifier.convert_examples_to_features(
    #    examples, [''], self.max_seq_length, self.tokenizer)
    examples = [InputExample(0, t, label='') for t in texts]
    features = convert_examples_to_features(
        examples, [''], self.max_seq_length, self.tokenizer)
    return dict(
        input_ids=[f.input_ids for f in features],
        input_mask=[f.input_mask for f in features],
        segment_ids=[f.segment_ids for f in features]
    )

  def insert_token(self, input: List[int], token: int) -> List[int]:
    """Adds token to input."""

    return input[:1] + [token] + input[1:-1]

  def add_input_context(
      self, inputs: Dict[str, int], context_tokens: List[str]
  ) -> Dict[str, int]:
    """Adds context token to input features."""

    context_token_ids = self.tokenizer.convert_tokens_to_ids(context_tokens)
    segment_token_id = 0
    mask_token_id = 1

    for i, context_token_id in enumerate(context_token_ids):
      inputs['input_ids'][i] = self.insert_token(
          inputs['input_ids'][i], context_token_id)

      inputs['segment_ids'][i] = self.insert_token(
          inputs['segment_ids'][i], segment_token_id)

      inputs['input_mask'][i] = self.insert_token(
          inputs['input_mask'][i], mask_token_id)
    return inputs

  def create_mlm_mask(
      self, inputs: Dict[str, int], mlm_ids: List[List[int]]
  ) -> Tuple[Dict[str, List[List[int]]], List[List[str]]]:
    """Creates masked language model mask."""

    masked_text_tokens = []
    mlm_positions = []

    if not mlm_ids:
      inputs['mlm_ids'] = mlm_positions
      return inputs, masked_text_tokens

    for i, _ in enumerate(mlm_ids):

      masked_text = []

      # Pad mlm positions to max seqeuence length.
      mlm_positions.append(
          mlm_ids[i] + [0] * (self.max_preds_per_seq - len(mlm_ids[i])))

      for pos in mlm_ids[i]:
        # Retrieve the masked token.
        masked_text.extend(
            self.tokenizer.convert_ids_to_tokens([inputs['input_ids'][i][pos]]))
        # Replace the mask positions with the mask token.
        inputs['input_ids'][i][pos] = self.mask_token_id
  
      masked_text_tokens.append(masked_text)

    inputs['mlm_ids'] = mlm_positions
    return inputs, masked_text_tokens

  def predict(
      self, texts: List[str], mlm_ids: List[List[int]] = None, 
      context_tokens: List[str] = None
  ) -> Tuple[Dict[str, tf.Tensor], Dict[str, List[List[int]]], List[List[str]]]:
    """Gets BERT predictions for provided text and masks.
    
    Args:
      texts: List of texts to get BERT predictions.
      mlm_ids: List of lists corresponding to the mask positions for each input
        in `texts`.
      context_token: List of string contexts to prepend to input texts.

    Returns:
      response: BERT model response.
      inputs: Tokenized and modified input to BERT model.
      masked_text: Raw strings of the masked tokens.
    """

    if mlm_ids:
      assert len(mlm_ids) == len(texts), ('If mask ids provided, they must be '
          'equal to the length of the input text.')

    if self.context:
      # If model uses context, but none provided, use 'UNK' token for context.
      if not context_tokens:
        context_tokens = ['[UNK]' for _ in range(len(texts))]
      assert len(context_tokens) == len(texts), ('If context tokens provided, '
          'they must be equal to the length of the input text.')
    
    inputs = self.get_features_from_texts(texts)

    # If using a BERT model with context, add corresponding tokens.
    if self.context:
      inputs = self.add_input_context(inputs, context_tokens)

    inputs, masked_text = self.create_mlm_mask(inputs, mlm_ids)

    response = self.model(
      segment_ids=tf.convert_to_tensor(inputs['segment_ids'], dtype=tf.int64),
      input_mask=tf.convert_to_tensor(inputs['input_mask'], dtype=tf.int64),
      input_ids=tf.convert_to_tensor(inputs['input_ids'], dtype=tf.int64),
      mlm_positions=tf.convert_to_tensor(inputs['mlm_ids'], dtype=tf.int64),
      )
    
    if mlm_ids:
      # Do a reshape of the mlm logits (batch size, num predictions, vocab).
      new_shape = (len(texts), self.max_preds_per_seq, -1)
      response['mlm_logits'] = tf.reshape(
          response['mlm_logits'], shape=new_shape)
    
    return response, inputs, masked_text 


In [None]:
# Some helper functions.

def get_mlm_ids_by_token(
    mask_token: str, tokenized_text: List[List[str]], 
    has_context: bool = False, first_occurence: bool = True
) -> List[List[int]]:
  """Returns position ids for masking a specified token."""

  pos_add = 2 if has_context else 1
  mlm_ids = []
  for i, tokens in enumerate(tokenized_text):
    pub_mlm_ids = []
    for j, token in enumerate(tokens):
      if token == mask_token:
        pub_mlm_ids.append(j + pos_add)
        if first_occurence:
          break
    mlm_ids.append(pub_mlm_ids)

  return mlm_ids


def bert_topk_predictions(
    mlm_logits: tf.Tensor, mlm_ids: List[List[int]], top_k: int = 5
) -> Tuple[List[int], List[str]]:
  """Returns BERT predicted token ids and terms for masked ids.
  
  Args:
    mlm_logits: The BERT masked language logits.
    mlm_ids: The masked ids.
    top_k: Number of predictions to return for each mask.

  Returns:
    token_preds: Token predictions for each mask position.
    term_preds: Term predictions for each mask position.
  """

  token_preds = []
  term_preds = []

  # Tradeoff between single call for all (including non masked) and then gather
  # vs. calling math top_k over and over

  for i, ids in enumerate(mlm_ids):
    current_token_preds = []
    current_term_preds = []
    for j, id in enumerate(ids):
      preds = tf.math.top_k(mlm_logits[i][j], top_k).indices.numpy().tolist()
      current_token_preds.append(preds)
      current_term_preds.append(tokenizer.convert_ids_to_tokens(preds))
    token_preds.append(current_token_preds)
    term_preds.append(current_term_preds)

  return token_preds, term_preds


def find_rankings(
    words: List[str], word_ids: List[int], mlm_logits: tf.Tensor, 
    mlm_ids: List[List[str]]
) -> Dict[str, float]:
  """Return the rankings in the bert predictions for the provided words."""
  
  word_positions = []

  # Iterate through all predictions.
  for i, _ in enumerate(mlm_ids):
    for j, _ in enumerate(mlm_ids[i]):
      logits = tf.argsort(mlm_logits[i][j], direction='DESCENDING')
      positions = tf.reshape(tf.where(tf.equal(
          tf.expand_dims(word_ids, axis=-1), logits))[:,-1], [1, -1])
      word_positions.extend(list(positions.numpy()))

  transposed = np.array(word_positions).T
  word_dict = dict()

  for i, word in enumerate(words):
    total = sum(transposed[i])
    word_dict[word] = {
        'average': transposed[i].mean(),
        'max': transposed[i].max(),
        'min': transposed[i].min(),
        'std': transposed[i].std(),
    }

  return word_dict

# Load BERT

In [None]:
MAX_SEQ_LENGTH = 512
MAX_PREDS_PER_SEQUENCE = 45
MODEL_DIR = 'path/to/bert/model/'
VOCAB = 'path/to/vocab.txt'

tokenizer = tokenization.FullTokenizer(VOCAB, do_lower_case=True)

bert_predictor = BertPredictor(
    model_name=MODEL_DIR,
    text_tokenizer=tokenizer,
    max_seq_length=MAX_SEQ_LENGTH,
    max_preds_per_seq=MAX_PREDS_PER_SEQUENCE,
    has_context=False)

## Masked Term Example from Patent Abstracts

Here we do a simple query from the BigQuery patents data to collect the abstract for 3 different patent abstracts that use the word "eye" and print our their predictions to see how the synonyms change for the same word as the patent changes.

In [None]:
test_pubs = ('US-8000000-B2', 'US-2007186831-A1', 'US-2009030261-A1')

query = r"""
  SELECT publication_number, abstract, url
  FROM `patents-public-data.google_patents_research.publications` 
  WHERE publication_number in {}
""".format(test_pubs)

df = bq_client.query(query).to_dataframe()

In [None]:
tokenized_inputs = get_tokenized_input(df.abstract.to_list(), tokenizer)
mlm_ids = get_mlm_ids_by_token('eye', tokenized_inputs)

In [None]:
response, inputs, masked_text = bert_predictor.predict(
    df.abstract.to_list(), mlm_ids)

token_preds, term_preds = bert_topk_predictions(response['mlm_logits'], mlm_ids)

In [None]:
for row, terms in zip(df.values.tolist(), term_preds):
  out = 'Patent: {}. ({})\nAbstract: {}\nPredictions of term eye \n\t{}\n'
  print(out.format(row[0], row[2], row[1][:100]+'...', terms))

Patent: US-2007186831-A1. (https://patents.google.com/patent/US20070186831A1)
Abstract: A sewing machine includes a thread take-up, a thread take-up driving mechanism driving the thread ta...
Predictions of term eye 
	[['hole', 'point', 'drop', 'eye', 'tip']]

Patent: US-8000000-B2. (https://patents.google.com/patent/US8000000B2)
Abstract: A visual prosthesis apparatus and a method for limiting power consumption in a visual prosthesis app...
Predictions of term eye 
	[['eye', 'retina', 'eyes', 'brain', 'eyeball']]

Patent: US-2009030261-A1. (https://patents.google.com/patent/US20090030261A1)
Abstract: Currently, no efficient, non-invasive methods exist for delivering drugs and/or other therapeutic ag...
Predictions of term eye 
	[['eye', 'eyeball', 'eyes', 'body', 'cornea']]



## Generating Synonyms for a CPC 

Building on the above we can query for patents containing certain terms across CPC codes and examine how the predicted synonyms change in each of those CPC codes.

In [None]:
search_token = 'priming'
words = ['priming', 'cleaning', 'maintenance',  'bonding', 'subbing', 'anchor']
word_ids = tokenizer.convert_tokens_to_ids(words)

query = r"""
  SELECT publication_number, abstract, url
  FROM `patents-public-data.google_patents_research.publications`,
    UNNEST(cpc) as cpc
  WHERE 
    cpc.code = '{}' AND
    cpc.first = True AND
    abstract like '% {} %'
  LIMIT 100
"""

In [None]:
cpc = 'B41J2/165'

df = bq_client.query(query.format(cpc, search_token)).to_dataframe()

tokenized_inputs = get_tokenized_input(df.abstract.to_list(), tokenizer)
mlm_ids = get_mlm_ids_by_token('priming', tokenized_inputs)

response, inputs, masked_text = bert_predictor.predict(
    df.abstract.to_list(), mlm_ids)

token_preds, term_preds = bert_topk_predictions(
    response['mlm_logits'], mlm_ids, top_k=10)

word_dict = find_rankings(words, word_ids, response['mlm_logits'], mlm_ids)

print('Word positions for our term list:')
for k, v in word_dict.items():
  print(k, v)

prediction_list = [x[0] for x in term_preds]
all_predictions = [item for sublist in prediction_list for item in sublist]

all_counts = collections.Counter(all_predictions)
top_10 = collections.Counter(all_predictions).most_common(10)

print('\nMost common words predicted:')
for t, _ in top_10:
  print(t)

Word positions for our term list:
priming {'average': 0.75, 'max': 6, 'min': 0, 'std': 1.984313483298443}
cleaning {'average': 2.625, 'max': 3, 'min': 0, 'std': 0.9921567416492215}
maintenance {'average': 22.875, 'max': 57, 'min': 3, 'std': 20.55138377336183}
bonding {'average': 133.75, 'max': 260, 'min': 69, 'std': 69.24729236583912}
subbing {'average': 1577.5, 'max': 1996, 'min': 1309, 'std': 242.1414875646055}
anchor {'average': 4977.875, 'max': 15669, 'min': 1834, 'std': 4156.430633292825}

Most common words predicted:
cleaning
capping
priming
sealing
filling
pumping
flushing
purging
servicing
maintenance


In [None]:
cpc = 'F04D9/041'

df = bq_client.query(query.format(cpc, search_token)).to_dataframe()

tokenized_inputs = get_tokenized_input(df.abstract.to_list(), tokenizer)
mlm_ids = get_mlm_ids_by_token('priming', tokenized_inputs)

response, inputs, masked_text = bert_predictor.predict(
    df.abstract.to_list(), mlm_ids)

token_preds, term_preds = bert_topk_predictions(
    response['mlm_logits'], mlm_ids, top_k=10)

word_dict = find_rankings(words, word_ids, response['mlm_logits'], mlm_ids)

print('Word positions for our term list:')
for k, v in word_dict.items():
  print(k, v)

prediction_list = [x[0] for x in term_preds]
all_predictions = [item for sublist in prediction_list for item in sublist]

all_counts = collections.Counter(all_predictions)
top_10 = collections.Counter(all_predictions).most_common(10)

print('\nMost common words predicted:')
for t, _ in top_10:
  print(t)

Word positions for our term list:
priming {'average': 7.7560975609756095, 'max': 109, 'min': 0, 'std': 23.986561774180547}
cleaning {'average': 91.39024390243902, 'max': 483, 'min': 1, 'std': 133.53101001836941}
maintenance {'average': 1291.0731707317073, 'max': 5662, 'min': 4, 'std': 1455.0359879925174}
bonding {'average': 2474.0, 'max': 11616, 'min': 683, 'std': 2299.0895016474165}
subbing {'average': 5893.609756097561, 'max': 20773, 'min': 1136, 'std': 4222.55737374153}
anchor {'average': 6398.975609756098, 'max': 21392, 'min': 497, 'std': 4246.727684205259}

Most common words predicted:
priming
starting
pumping
suction
prime
vacuum
centrifugal
flushing
contained
cleaning


## Extending BERT - CPC Classifier

A lot more can be done with the BERT trained model beyond synonym prediction. We can take the BERT outputs to do things such as:
- Build classifiers for CPC codes (or anything else)
- Tune a model on top of BERT ouputs to perform autocomplete
- Perform semantic simialrity by training some type of siamese network on the BERT outputs

Below we take the BERT outputs for 100 patents and build a tiny classifier to predict the first letter of the CPC code for a patent.

In [None]:
query = r'''
  #standardSQL
  SELECT DISTINCT
    substr(cpc.code, 0, 1) as cpc_class,
    res.abstract
  FROM `patents-public-data.google_patents_research.publications` res,
    UNNEST(cpc) as cpc
    INNER JOIN `patents-public-data.patents.publications` pub ON 
      res.publication_number = pub.publication_number
  WHERE 
    pub.publication_date >= 20000101 AND
    res.country = 'United States' AND
    cpc.first = True AND
    RAND() < 0.1
  LIMIT {}
'''.format(200)

df = bq_client.query(query).to_dataframe()
df = df.sample(frac=1).reset_index(drop=True)

cpc_classes = {
    'A': 0, 'B': 1, 'C': 2, 'D': 3, 'E': 4, 'F': 5, 'G': 6, 'H': 7, 'Y': 8}

texts = df.abstract.tolist()
classes = [cpc_classes[x] for x in df.cpc_class.tolist()]

In [None]:
response, inputs, masked_text = bert_predictor.predict(texts)

train_inputs = response['cls_token']
train_labels = tf.convert_to_tensor(classes)

In [None]:
num_classes = len(cpc_classes)

model = tf.keras.Sequential([
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(num_classes)
])

model.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy'])

history = model.fit(
    x=train_inputs, 
    y=train_labels, 
    epochs=10, 
    validation_split=0.1)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
