Skip to content

Commit

Permalink
initial
Browse files Browse the repository at this point in the history
  • Loading branch information
kpe committed May 22, 2019
1 parent 5dc0992 commit 3daae2e
Show file tree
Hide file tree
Showing 23 changed files with 2,796 additions and 2 deletions.
9 changes: 9 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
.idea
.python-version
.venv
**/__pycache__/
*.egg-info/
*.pyc
*.coverage
build/
dist/
57 changes: 57 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
sudo: false
language: python
python:
- "2.7"
- "3.4"
- "3.5"
- "3.6"
dist: trusty

# Enable 3.7 without globally enabling sudo and dist: xenial for other build jobs
matrix:
include:
- python: 3.7
dist: xenial
sudo: true

env:
- PEP8_IGNORE="E221,E501,W504,W391,E241"

# command to install dependencies
install:
- pip install --upgrade pip
- pip install -r requirements.txt
- pip install coverage pep8 nose

# command to run tests
# require 100% coverage (not including test files) to pass Travis CI test
# To skip pypy: - if [[ $TRAVIS_PYTHON_VERSION != 'pypy' ]]; then DOSTUFF ; fi
script:
- export MAJOR_PYTHON_VERSION=`echo $TRAVIS_PYTHON_VERSION | cut -c 1`
- coverage run --source=params_flow $(which nosetests)
--with-doctest tests/
- if [[ $TRAVIS_PYTHON_VERSION == 3.7 ]]; then coverage report --show-missing --fail-under=10 ; fi
- if [[ $TRAVIS_PYTHON_VERSION == 3.7 ]]; then pep8 --ignore=$PEP8_IGNORE --exclude=tests,.venv -r --show-source . ; fi
# For convenience, make sure simple test commands work
- python setup.py develop
- py.test
- nosetests

# load coverage status to https://coveralls.io
after_success:
- 'echo travis-python: $TRAVIS_PYTHON_VERSION'
- if [[ $TRAVIS_PYTHON_VERSION == 3.7 ]]; then pip install coveralls; COVERALLS_REPO_TOKEN=$COVERALLS_REPO_TOKEN coveralls ; fi

notifications:
email: false

deploy:
provider: pypi
username: kpe
password:
secure: e5wV98e/MPIZV9kRfXVkmmNwvEJRzDsKAFxSlfkpOI9BqI2aBN4IBrekpn0LwCrIQ4mbUYXd8FOc5BMBqISzwdHC2azKtRk+XDQI3vc6XpHwW7m/HLXTL1mrVQ3rumJnbk+1fitplvJhGCY46zNy+D6FQzWhpThA0Q3T1F9mVsLtZblmwLi450NrnXiqLEGjM4CbK8ROvbd1G8PrlpNwHQW9/TgoMBE0PQc5vKU3TlzvbdqahVoaDg0o3cFISfXk7JKmNrA49kkKoaFBzcatqUy2DXoaOx43++4yuRnE0m9juL7tJSBJYnzGrih14zuU+kKBUCLM5ty7b5s4r2+LvQZpWy+WuTKG+9CqSPtiLlBlULHQNdH6j50qpIW5kJd+UhBRB+1KrpXC73DvWErXxpk0X9FIdAlw+sIxQMxWX7xLXyk88Gsh8m5WKxNwV/19fL3DiHHfD37R56wvvr5A0IU2VXTyLwy4HWIiCy9exwezOv3T+Gjsne0/i7vsoQXp9Nohf5T5lKANEP7EcWs+s1anFiZgYJgzCA7Wa9ppCR0WK98MWPRnaJ3kMZn5XyLkwl8+DTKbT+HfYZWBvAnlf18Kv6C2z7BYvqQNroH6VNurxscfS1mrM4+vNEcYz1yJo6yxRHWmxYqoue5e3bn1EltUxRClcyVBpr6LKxdbk8k=
on:
tags: true
branch: master
condition: "$TRAVIS_PYTHON_VERSION = 3.7"
skip_existing: true
2 changes: 0 additions & 2 deletions README.md

This file was deleted.

22 changes: 22 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
BERT for TensorFlow v2
======================


This repo contains a TensorFlow v2 Keras implementation of `google-research/bert`_,
with support for load the original `pre-trained weights`_,
and producing numerically identical activations.



Resources
---------

- `BERT`_ - BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding
- `google-research/bert`_ - the original BERT implementation
- `kpe/params-flow`_ - utilities for reducing keras boilerplate code in custom layers

.. _`pre-trained weights`: https://github.com/google-research/bert#pre-trained-models
.. _`google-research/bert`: https://github.com/google-research/bert
.. _`BERT`: https://arxiv.org/abs/1810.04805
.. _`kpe/params-flow`: https://github.com/kpe/params-flow

12 changes: 12 additions & 0 deletions bert/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# coding=utf-8
#
# created by kpe on 15.Mar.2019 at 15:28
#
from __future__ import division, absolute_import, print_function

from .layer import Layer

from .attention import AttentionLayer
from .bert import BertModelLayer

__version__ = '0.0.1'
145 changes: 145 additions & 0 deletions bert/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# coding=utf-8
#
# created by kpe on 15.Mar.2019 at 12:52
#

from __future__ import absolute_import, division, print_function

import tensorflow as tf
from tensorflow.python import keras
from tensorflow.python.keras import backend as K

import bert


class AttentionLayer(bert.Layer):
class Params(bert.Layer.Params):
num_heads = None
size_per_head = None
initializer_range = 0.02
query_activation = None
key_activation = None
value_activation = None
attention_dropout = 0.1
negative_infinity = -10000.0 # used for attention scores before softmax

@staticmethod
def create_attention_mask(from_shape, input_mask):
"""
Creates 3D attention.
:param from_shape: [batch_size, from_seq_len, ...]
:param input_mask: [batch_size, seq_len]
:return: [batch_size, from_seq_len, seq_len]
"""

mask = tf.cast(tf.expand_dims(input_mask, axis=1), tf.float32) # [B, 1, T]
ones = tf.expand_dims(tf.ones(shape=from_shape[:2], dtype=tf.float32), axis=-1) # [B, F, 1]
mask = ones * mask # broadcast along two dimensions

return mask # [B, F, T]

def _construct(self, params: Params):

self.query_activation = params.query_activation
self.key_activation = params.key_activation
self.value_activation = params.value_activation

self.query_layer = None
self.key_layer = None
self.value_layer = None

self.supports_masking = True

# noinspection PyAttributeOutsideInit
def build(self, input_shape):
self.input_spec = keras.layers.InputSpec(shape=input_shape)

dense_units = self.params.num_heads * self.params.size_per_head # N*H
#
# B, F, T, N, H - batch, from_seq_len, to_seq_len, num_heads, size_per_head
#
self.query_layer = keras.layers.Dense(units=dense_units, activation=self.query_activation,
kernel_initializer=self.create_initializer(),
name="query")
self.key_layer = keras.layers.Dense(units=dense_units, activation=self.key_activation,
kernel_initializer=self.create_initializer(),
name="key")
self.value_layer = keras.layers.Dense(units=dense_units, activation=self.value_activation,
kernel_initializer=self.create_initializer(),
name="value")
self.dropout_layer = keras.layers.Dropout(self.params.attention_dropout)

super(AttentionLayer, self).build(input_shape)

def compute_output_shape(self, input_shape):
from_shape = input_shape

# from_shape # [B, F, W] [batch_size, from_seq_length, from_width]
# input_mask_shape # [B, F]

output_shape = [from_shape[0], from_shape[1], self.params.num_heads * self.params.size_per_head]

return output_shape # [B, F, N*H]

# noinspection PyUnusedLocal
def call(self, inputs, mask=None, training=None, **kwargs):
from_tensor = inputs
to_tensor = inputs
if mask is None:
sh = self.get_shape_list(from_tensor)
mask = tf.ones(sh[:2], dtype=tf.int32)
attention_mask = AttentionLayer.create_attention_mask(tf.shape(from_tensor), mask)

# from_tensor shape - [batch_size, from_seq_length, from_width]
input_shape = tf.shape(from_tensor)
batch_size, from_seq_len, from_width = input_shape[0], input_shape[1], input_shape[2]
to_seq_len = from_seq_len

# [B, F, N*H] -> [B, N, F, H]
def transpose_for_scores(input_tensor, seq_len):
output_shape = [batch_size, seq_len,
self.params.num_heads, self.params.size_per_head]
output_tensor = K.reshape(input_tensor, output_shape)
return tf.transpose(output_tensor, [0, 2, 1, 3]) # [B,N,F,H]

query = self.query_layer(from_tensor) # [B,F, N*H] [batch_size, from_seq_len, N*H]
key = self.key_layer(to_tensor) # [B,T, N*H]
value = self.value_layer(to_tensor) # [B,T, N*H]

query = transpose_for_scores(query, from_seq_len) # [B, N, F, H]
key = transpose_for_scores(key, to_seq_len) # [B, N, T, H]

attention_scores = tf.matmul(query, key, transpose_b=True) # [B, N, F, T]
attention_scores = attention_scores / tf.sqrt(float(self.params.size_per_head))

if attention_mask is not None:
attention_mask = tf.expand_dims(attention_mask, axis=1) # [B, 1, F, T]
# {1, 0} -> {0.0, -inf}
adder = (1.0 - tf.cast(attention_mask, tf.float32)) * self.params.negative_infinity
attention_scores += adder # adding to softmax -> its like removing them entirely

# scores to probabilities
attention_probs = tf.nn.softmax(attention_scores) # [B, N, F, T]

# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout_layer(attention_probs,
training=training) # [B, N, F, T]

# [B,T,N,H]
value = tf.reshape(value, [batch_size, to_seq_len,
self.params.num_heads, self.params.size_per_head])
value = tf.transpose(value, [0, 2, 1, 3]) # [B, N, T, H]

context_layer = tf.matmul(attention_probs, value) # [B, N, F, H]
context_layer = tf.transpose(context_layer, [0, 2, 1, 3]) # [B, F, N, H]

output_shape = [batch_size, from_seq_len,
self.params.num_heads * self.params.size_per_head]
context_layer = tf.reshape(context_layer, output_shape)
return context_layer # [B, F, N*H]

# noinspection PyUnusedLocal
def compute_mask(self, inputs, mask=None):
return mask # [B, F]

63 changes: 63 additions & 0 deletions bert/bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# coding=utf-8
#
# created by kpe on 28.Mar.2019 at 12:33
#

from __future__ import absolute_import, division, print_function

from tensorflow.python import keras

import bert
from bert.embeddings import BertEmbeddingsLayer
from bert.transformer import TransformerEncoderLayer


class BertModelLayer(bert.Layer):
"""
BERT Model (arXiv:1810.04805).
See: https://arxiv.org/pdf/1810.04805.pdf
"""
class Params(BertEmbeddingsLayer.Params,
TransformerEncoderLayer.Params):
pass

# noinspection PyUnusedLocal
def _construct(self, params: Params):
self.embeddings_layer = None
self.encoders_layer = None

self.support_masking = True

def build(self, input_shape):
if isinstance(input_shape, list):
assert len(input_shape) == 2
input_ids_shape, token_type_ids_shape = input_shape
self.input_spec = [keras.layers.InputSpec(shape=input_ids_shape),
keras.layers.InputSpec(shape=token_type_ids_shape)]
else:
input_ids_shape = input_shape
self.input_spec = keras.layers.InputSpec(shape=input_ids_shape)

self.embeddings_layer = BertEmbeddingsLayer.from_params(
self.params,
name="embeddings"
)

# create all transformer encoder sub-layers
self.encoders_layer = TransformerEncoderLayer.from_params(
self.params,
name="encoder"
)

super(BertModelLayer, self).build(input_shape)

def call(self, inputs, mask=None, training=None):
if mask is None:
mask = self.embeddings_layer.compute_mask(inputs)

embedding_output = self.embeddings_layer(inputs, mask=mask, training=training)
output = self.encoders_layer(embedding_output, mask=mask, training=training)
return output # [B, seq_len, hidden_size]

Loading

0 comments on commit 3daae2e

Please sign in to comment.