Derived from the docs' tutorial <a href="https://trax-ml.readthedocs.io/en/latest/notebooks/trax_intro.html">here</a>.

In [1]:
#@title
# Copyright 2020 Google LLC.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [2]:
import trax

In [3]:
mod = trax.models.Transformer(input_vocab_size=33300, 
                              d_model=512, 
                              d_ff=2048, 
                              n_heads=8, 
                              n_encoder_layers=6, 
                              n_decoder_layers=6, 
                              max_len=2048, 
                              mode='predict')

In [4]:
mod.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',
                   weights_only=True)



In [5]:
sent = ('How much wood would a woodchuck chuck if a woodchuck could chuck '
        'wood?')
tokenized = list(
    trax.data.tokenize(iter([sent]), 
                       vocab_dir='gs://trax-ml/vocabs', 
                       vocab_file='ende_32k.subword')
)[0]

In [6]:
tokenized

array([ 1670,   276,  7610,    98,    13, 13081,  7851,  1574,  7851,
        1574,   175,    13, 13081,  7851,  1574,   252,  7851,  1574,
        7610,   102])

In [7]:
tokenized = tokenized[None, :] # add batch dim
tokenized_translation = trax.supervised.decoding.autoregressive_sample(
    mod, tokenized, temperature=0.) # higher temp -> more diverse results

In [8]:
# Detokenize
# remove batch and EOS
tokenized_translation = tokenized_translation[0][:-1]
translation = trax.data.detokenize(tokenized_translation, 
                                   vocab_dir='gs://trax-ml/vocabs/', 
                                   vocab_file='ende_32k.subword')
translation

'Wie viel Holz würde ein Waldchuck springen, wenn ein Waldschock Holz schütteln könnte.'