Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,34 +10,32 @@


def create_estimator(run_config, model_config):
# t2t expects these keys in run_config
run_config.data_parallelism = None
run_config.t2t_device_info = {"num_async_replicas": 1}

hparams = trainer_lib.create_hparams("transformer_base_single_gpu")

# SentimentIMDBCortex subclasses SentimentIMDB
problem = SentimentIMDBCortex(list(model_config["aggregates"]["reviews_vocab"]))
p_hparams = problem.get_hparams(hparams)
hparams.problem = problem
hparams.problem_hparams = p_hparams
hparams.problem_hparams = problem.get_hparams(hparams)

# metrics specific to the sentiment problem
problem.eval_metrics = lambda: [
metrics.Metrics.ACC_TOP5,
metrics.Metrics.ACC_PER_SEQ,
metrics.Metrics.NEG_LOG_PERPLEXITY,
]

# t2t expects this key
hparams.warm_start_from = None

# reduce memory load
hparams.num_hidden_layers = 2
hparams.hidden_size = 32
hparams.filter_size = 32
hparams.num_heads = 2

estimator = trainer_lib.create_estimator("transformer", hparams, run_config)
return estimator
# t2t expects these keys
hparams.warm_start_from = None
run_config.data_parallelism = None
run_config.t2t_device_info = {"num_async_replicas": 1}

return trainer_lib.create_estimator("transformer", hparams, run_config)


def transform_tensorflow(features, labels, model_config):
Expand Down
2 changes: 1 addition & 1 deletion examples/reviews/resources/apis.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@

- kind: api
name: sentiment-t2t
model_name: t2t_transformer
model_name: transformer
compute:
replicas: 1
2 changes: 1 addition & 1 deletion examples/reviews/resources/models.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
num_steps: 5000

- kind: model
name: t2t_transformer
name: transformer
type: classification
target_column: label_indexed
feature_columns:
Expand Down