diff --git a/bert/__init__.py b/bert/__init__.py index 63f7062..d0e3123 100644 --- a/bert/__init__.py +++ b/bert/__init__.py @@ -8,3 +8,4 @@ from .layer import Layer from .model import BertModelLayer +from .loader import StockBertConfig, load_stock_weights diff --git a/bert/transformer.py b/bert/transformer.py index 97942d8..52322b1 100644 --- a/bert/transformer.py +++ b/bert/transformer.py @@ -163,7 +163,8 @@ class TransformerEncoderLayer(Layer): """ class Params(SingleTransformerEncoderLayer.Params): - num_layers = None + num_layers = None + out_layer_ndxs = None # [-1] def _construct(self, params: Params): self.encoder_layers = [] @@ -195,8 +196,14 @@ def call(self, inputs, mask=None, training=None): layer_output = encoder_layer(layer_input, mask=mask, training=training) layer_outputs.append(layer_output) - # return the final layer only - final_output = layer_output + if self.params.out_layer_ndxs is None: + # return the final layer only + final_output = layer_output + else: + final_output = [] + for ndx in self.params.out_layer_ndxs: + final_output.append(layer_outputs[ndx]) + final_output = tuple(final_output) return final_output