Skip to content

Commit

Permalink
Params.out_layer_ndxs added to allow outputing internal layer activat…
Browse files Browse the repository at this point in the history
…ions
  • Loading branch information
kpe committed Jun 27, 2019
1 parent 7c6976c commit 3afe346
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
1 change: 1 addition & 0 deletions bert/__init__.py
Expand Up @@ -8,3 +8,4 @@
from .layer import Layer
from .model import BertModelLayer

from .loader import StockBertConfig, load_stock_weights
13 changes: 10 additions & 3 deletions bert/transformer.py
Expand Up @@ -163,7 +163,8 @@ class TransformerEncoderLayer(Layer):
"""

class Params(SingleTransformerEncoderLayer.Params):
num_layers = None
num_layers = None
out_layer_ndxs = None # [-1]

def _construct(self, params: Params):
self.encoder_layers = []
Expand Down Expand Up @@ -195,8 +196,14 @@ def call(self, inputs, mask=None, training=None):
layer_output = encoder_layer(layer_input, mask=mask, training=training)
layer_outputs.append(layer_output)

# return the final layer only
final_output = layer_output
if self.params.out_layer_ndxs is None:
# return the final layer only
final_output = layer_output
else:
final_output = []
for ndx in self.params.out_layer_ndxs:
final_output.append(layer_outputs[ndx])
final_output = tuple(final_output)

return final_output

Expand Down

0 comments on commit 3afe346

Please sign in to comment.