Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nick dory patch 1 #1781

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions trax/models/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Constructing T2T Models.

This directory contains T2T models, their hyperparameters, and a number
of common layers and hyperparameter settings to help construct new models.
Common building blocks are in `common_layers.py` and `common_attention.py`.
Common hyperparameters are in `common_hparams.py`. Models are imported in
`__init__.py`.

## Adding a new model.

To add a model to the built-in set, create a new file (see, e.g.,
`neural_gpu.py`) and write your model class inheriting from `T2TModel` there and
decorate it with `registry.register_model`. Import it in `__init__.py`.

It is now available to use with the trainer binary (`t2t-trainer`) using the
`--model=model_name` flag.
153 changes: 80 additions & 73 deletions trax/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2022 The Trax Authors.
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -13,80 +13,87 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Models defined in trax."""
import gin
"""Models defined in T2T. Imports here force registration."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from trax.models import atari_cnn
from trax.models import mlp
from trax.models import neural_gpu
from trax.models import resnet
from trax.models import rl
from trax.models import rnn
from trax.models import transformer
from trax.models.reformer import reformer
from trax.models.research import bert
from trax.models.research import configurable_transformer
from trax.models.research import hourglass
from trax.models.research import layerdrop_transformer
from trax.models.research import rezero
from trax.models.research import rse
from trax.models.research import terraformer
from trax.models.research import transformer2
import six

# pylint: disable=unused-import

# Ginify
def model_configure(*args, **kwargs):
kwargs['module'] = 'trax.models'
return gin.external_configurable(*args, **kwargs)
from tensor2tensor.layers import modalities # pylint: disable=g-import-not-at-top
from tensor2tensor.models import basic
from tensor2tensor.models import bytenet
from tensor2tensor.models import distillation
from tensor2tensor.models import evolved_transformer
from tensor2tensor.models import image_transformer
from tensor2tensor.models import image_transformer_2d
from tensor2tensor.models import lstm
from tensor2tensor.models import neural_assistant
from tensor2tensor.models import neural_gpu
from tensor2tensor.models import resnet
from tensor2tensor.models import revnet
from tensor2tensor.models import shake_shake
from tensor2tensor.models import slicenet
from tensor2tensor.models import text_cnn
from tensor2tensor.models import transformer
from tensor2tensor.models import vanilla_gan
from tensor2tensor.models import xception
from tensor2tensor.models.neural_architecture_search import nas_model
from tensor2tensor.models.research import adafactor_experiments
from tensor2tensor.models.research import aligned
from tensor2tensor.models.research import autoencoders
from tensor2tensor.models.research import cycle_gan
from tensor2tensor.models.research import gene_expression
from tensor2tensor.models.research import neural_stack
from tensor2tensor.models.research import residual_shuffle_exchange
from tensor2tensor.models.research import rl
from tensor2tensor.models.research import shuffle_network
from tensor2tensor.models.research import similarity_transformer
from tensor2tensor.models.research import super_lm
from tensor2tensor.models.research import transformer_moe
from tensor2tensor.models.research import transformer_nat
from tensor2tensor.models.research import transformer_parallel
from tensor2tensor.models.research import transformer_revnet
from tensor2tensor.models.research import transformer_seq2edits
from tensor2tensor.models.research import transformer_sketch
from tensor2tensor.models.research import transformer_symshard
from tensor2tensor.models.research import transformer_vae
from tensor2tensor.models.research import universal_transformer
from tensor2tensor.models.video import basic_deterministic
from tensor2tensor.models.video import basic_recurrent
from tensor2tensor.models.video import basic_stochastic
from tensor2tensor.models.video import emily
from tensor2tensor.models.video import savp
from tensor2tensor.models.video import sv2p
from tensor2tensor.utils import contrib
from tensor2tensor.utils import registry

# The following models can't be imported under TF2
if not contrib.is_tf2:
# pylint: disable=g-import-not-at-top
from tensor2tensor.models.research import attention_lm
from tensor2tensor.models.research import attention_lm_moe
from tensor2tensor.models.research import glow
from tensor2tensor.models.research import lm_experiments
from tensor2tensor.models.research import moe_experiments
from tensor2tensor.models.research import multiquery_paper
from tensor2tensor.models import mtf_image_transformer
from tensor2tensor.models import mtf_resnet
from tensor2tensor.models import mtf_transformer
from tensor2tensor.models import mtf_transformer2
from tensor2tensor.models.research import vqa_attention
from tensor2tensor.models.research import vqa_recurrent_self_attention
from tensor2tensor.models.research import vqa_self_attention
from tensor2tensor.models.video import epva
from tensor2tensor.models.video import next_frame_glow
# pylint: enable=g-import-not-at-top

# pylint: disable=invalid-name
AtariCnn = model_configure(atari_cnn.AtariCnn)
AtariCnnBody = model_configure(atari_cnn.AtariCnnBody)
FrameStackMLP = model_configure(atari_cnn.FrameStackMLP)
BERT = model_configure(bert.BERT)
BERTClassifierHead = model_configure(bert.BERTClassifierHead)
BERTRegressionHead = model_configure(bert.BERTRegressionHead)
ConfigurableTerraformer = model_configure(terraformer.ConfigurableTerraformer)
ConfigurableTransformer = model_configure(
configurable_transformer.ConfigurableTransformer)
ConfigurableTransformerEncoder = model_configure(
configurable_transformer.ConfigurableTransformerEncoder)
ConfigurableTransformerLM = model_configure(
configurable_transformer.ConfigurableTransformerLM)
MLP = model_configure(mlp.MLP)
NeuralGPU = model_configure(neural_gpu.NeuralGPU)
Reformer = model_configure(reformer.Reformer)
ReformerLM = model_configure(reformer.ReformerLM)
ReformerShortenLM = model_configure(reformer.ReformerShortenLM)
Resnet50 = model_configure(resnet.Resnet50)
ReZeroTransformer = model_configure(
rezero.ReZeroTransformer)
ReZeroTransformerDecoder = model_configure(
rezero.ReZeroTransformerDecoder)
ReZeroTransformerEncoder = model_configure(
rezero.ReZeroTransformerEncoder)
ReZeroTransformerLM = model_configure(
rezero.ReZeroTransformerLM)
SkippingTransformerLM = model_configure(
layerdrop_transformer.SkippingTransformerLM)
LayerDropTransformerLM = model_configure(
layerdrop_transformer.LayerDropTransformerLM)
EveryOtherLayerDropTransformerLM = model_configure(
layerdrop_transformer.EveryOtherLayerDropTransformerLM)
Transformer = model_configure(transformer.Transformer)
TransformerDecoder = model_configure(transformer.TransformerDecoder)
TransformerEncoder = model_configure(transformer.TransformerEncoder)
TransformerLM = model_configure(transformer.TransformerLM)
Transformer2 = model_configure(
transformer2.Transformer2)
WideResnet = model_configure(resnet.WideResnet)
Policy = model_configure(rl.Policy)
PolicyAndValue = model_configure(rl.PolicyAndValue)
Value = model_configure(rl.Value)
Quality = model_configure(rl.Quality)
RNNLM = model_configure(rnn.RNNLM)
GRULM = model_configure(rnn.GRULM)
LSTMSeq2SeqAttn = model_configure(rnn.LSTMSeq2SeqAttn)
ResidualShuffleExchange = model_configure(rse.ResidualShuffleExchange)
HourglassLM = model_configure(hourglass.HourglassLM)
# pylint: disable=unused-import

# pylint: enable=unused-import


def model(name):
return registry.model(name)
58 changes: 58 additions & 0 deletions trax/models/basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Basic models for testing simple tasks."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensor2tensor.layers import common_hparams
from tensor2tensor.layers import common_layers
from tensor2tensor.utils import registry
from tensor2tensor.utils import t2t_model

import tensorflow.compat.v1 as tf


@registry.register_model
class BasicFcRelu(t2t_model.T2TModel):
"""Basic fully-connected + ReLU model."""

def body(self, features):
hparams = self.hparams
x = features["inputs"]
shape = common_layers.shape_list(x)
x = tf.reshape(x, [-1, shape[1] * shape[2] * shape[3]])
for i in range(hparams.num_hidden_layers):
x = tf.layers.dense(x, hparams.hidden_size, name="layer_%d" % i)
x = tf.nn.dropout(x, keep_prob=1.0 - hparams.dropout)
x = tf.nn.relu(x)
return tf.expand_dims(tf.expand_dims(x, axis=1), axis=1) # 4D For T2T.


@registry.register_hparams
def basic_fc_small():
"""Small fully connected model."""
hparams = common_hparams.basic_params1()
hparams.learning_rate = 0.1
hparams.batch_size = 128
hparams.hidden_size = 256
hparams.num_hidden_layers = 2
hparams.initializer = "uniform_unit_scaling"
hparams.initializer_gain = 1.0
hparams.weight_decay = 0.0
hparams.dropout = 0.0
return hparams
51 changes: 51 additions & 0 deletions trax/models/basic_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Basic nets tests."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np

from tensor2tensor.data_generators import mnist # pylint: disable=unused-import
from tensor2tensor.models import basic
from tensor2tensor.utils import trainer_lib

import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator


class BasicTest(tf.test.TestCase):

def testBasicFcRelu(self):
x = np.random.randint(256, size=(1, 28, 28, 1))
y = np.random.randint(10, size=(1, 1))
hparams = trainer_lib.create_hparams(
"basic_fc_small", problem_name="image_mnist", data_dir=".")
with self.test_session() as session:
features = {
"inputs": tf.constant(x, dtype=tf.int32),
"targets": tf.constant(y, dtype=tf.int32),
}
model = basic.BasicFcRelu(hparams, tf_estimator.ModeKeys.TRAIN)
logits, _ = model(features)
session.run(tf.global_variables_initializer())
res = session.run(logits)
self.assertEqual(res.shape, (1, 1, 1, 1, 10))


if __name__ == "__main__":
tf.test.main()
Loading