In [1]:

!pip install pip
!pip install trax

! pip install jax
! pip install jaxlib

Collecting trax
  Downloading trax-1.3.9-py2.py3-none-any.whl (629 kB)
[K     |████████████████████████████████| 629 kB 5.6 MB/s 
Collecting tensorflow-text
  Downloading tensorflow_text-2.6.0-cp37-cp37m-manylinux1_x86_64.whl (4.4 MB)
[K     |████████████████████████████████| 4.4 MB 44.6 MB/s 
Collecting funcsigs
  Downloading funcsigs-1.0.2-py2.py3-none-any.whl (17 kB)
Collecting t5
  Downloading t5-0.9.2-py3-none-any.whl (152 kB)
[K     |████████████████████████████████| 152 kB 53.8 MB/s 
Collecting transformers>=2.7.0
  Downloading transformers-4.10.2-py3-none-any.whl (2.8 MB)
[K     |████████████████████████████████| 2.8 MB 42.1 MB/s 
Collecting mesh-tensorflow[transformer]>=0.1.13
  Downloading mesh_tensorflow-0.1.19-py3-none-any.whl (366 kB)
[K     |████████████████████████████████| 366 kB 43.9 MB/s 
Collecting rouge-score
  Downloading rouge_score-0.0.4-py2.py3-none-any.whl (22 kB)
Collecting tfds-nightly
  Downloading tfds_nightly-4.4.0.dev202109220107-py3-none-any.whl (4.

In [2]:

import random
import numpy as np
import trax
from trax import layers as tl
from trax.fastmath import numpy as fastnp
from trax.supervised import training
from jax import jit
import jax.numpy as jnp

!pip list | grep trax

trax                          1.3.9


In [3]:
# we will use data set from 'opus'which is available at 'tensorflow Datasets(TFDS)'. Using trax.data.TFDS, we will get python generator function.

In [4]:
train_stream_fn = trax.data.TFDS('opus/medical' ,data_dir='./data/', keys=('en', 'de'), eval_holdout_size=0.01,train=True)

  "jax.host_id has been renamed to jax.process_index. This alias "
  "jax.host_count has been renamed to jax.process_count. This alias "


[1mDownloading and preparing dataset 34.29 MiB (download: 34.29 MiB, generated: 188.85 MiB, total: 223.13 MiB) to ./data/opus/medical/0.1.0...[0m


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

Generating splits...:   0%|          | 0/1 [00:00<?, ? splits/s]

Generating train examples...:   0%|          | 0/1108752 [00:00<?, ? examples/s]

Shuffling opus-train.tfrecord...:   0%|          | 0/1108752 [00:00<?, ? examples/s]

[1mDataset opus downloaded and prepared to ./data/opus/medical/0.1.0. Subsequent calls will reuse this data.[0m


In [5]:
eval_stream_fn = trax.data.TFDS('opus/medical', data_dir='./data/', keys=('en', 'de'), eval_holdout_size=0.01, train=False)

  "jax.host_id has been renamed to jax.process_index. This alias "
  "jax.host_count has been renamed to jax.process_count. This alias "


In [6]:
# In this we use subword representation so that instead of storing ("fear", "some", "fearsome"), we store only ("fear","some")

In [7]:
train_stream = train_stream_fn()
eval_stream = eval_stream_fn()

VOCAB_FILE = 'ende_32k.subword'
VOCAB_DIR = 'data/'

tokenized_train_stream = trax.data.Tokenize(vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)(train_stream)
tokenized_eval_stream = trax.data.Tokenize(vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)(eval_stream)

# to append EOF at end of every sentence
EOF = 1
def add_eof(function):
  for (inputs,targets) in function:
      inputs_= list(inputs) + [EOF]
      targets_ = list(targets) + [EOF] 
      yield np.array(inputs_), np.array(targets_)
tokenized_train_stream = add_eof(tokenized_train_stream)
tokenized_eval_stream = add_eof(tokenized_eval_stream)


# New section

In [8]:
# Using filter by length to restrict number of tokens in a sentence by 256.
tokenized_train_stream = trax.data.FilterByLength(max_length=256,length_keys=[0,1])(tokenized_train_stream)
tokenized_eval_stream = trax.data.FilterByLength(max_length=256,length_keys=[0,1])(tokenized_eval_stream)

In [9]:
def tokenize(string, vocab_file=None, vocab_dir=None):
    EOF = 1
    inputs =  next(trax.data.tokenize(iter([string]),vocab_file=vocab_file, vocab_dir=vocab_dir))
    inputs = list(inputs) + [EOF]
    batch = np.reshape(np.array(inputs), [1, -1])
    return batch

print("tokenize-->treatment: ", tokenize('treatment', vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR))

tokenize-->treatment:  [[2248    1]]


In [10]:
def detokenize(tokens, vocab_file=None, vocab_dir=None):
  EOF=1
  index = tokens.index(EOF)
  if EOF in tokens:
    tokens = tokens[:index]
  detokenize_string = trax.data.detokenize(tokens, vocab_file=vocab_file, vocab_dir=vocab_dir)
  return detokenize_string
print('str([17332 140 1]):', detokenize([2248, 1], vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR))

str([17332 140 1]): treatment


In [11]:
# Bucketing --> Suppose you have batch of sentences and if one sentence length is min and other sentence length is max so by padding,
# we just waste lot of memory. So, the way to cope with this problem is bucketing where we arrange a batch of sentences in such a way  
# similar length sentences are batched together
boundaries = [8, 16, 32, 128, 256, 512]
batch_size = [256, 128, 64, 32, 16, 8, 4, 2]
tokenized_train_stream = trax.data.BucketByLength(boundaries, batch_size, [0,1])(tokenized_train_stream)
tokenized_eval_stream =trax.data.BucketByLength(boundaries, batch_size, [0,1])(tokenized_eval_stream)

# Adding mask
tokenized_train_stream= trax.data.AddLossWeights(0)(tokenized_train_stream)
tokenized_eval_stream= trax.data.AddLossWeights(0)(tokenized_eval_stream)

In [12]:
input,target,mask= next(tokenized_train_stream)
print("input(English) shape-->",str(input.shape), "  input(English) datatype-->",str(type(input)))
print("target(German) shape-->",str(target.shape), "  target(German) datatype-->",str(type(target)))
print("mask shape-->",str(mask.shape), "  mask datatype-->",str(type(mask)))

input(English) shape--> (32, 128)   input(English) datatype--> <class 'numpy.ndarray'>
target(German) shape--> (32, 128)   target(German) datatype--> <class 'numpy.ndarray'>
mask shape--> (32, 128)   mask datatype--> <class 'numpy.ndarray'>


In [13]:
# input encoder function which take input tokens and pass it through embedding layer and feed the resulting output to 'n' LSTM layer
def input_encoder_fn(input, doe, n):
  encoder = tl.Serial(tl.Embedding(input,doe), [tl.LSTM(doe) for i in range(n)])
  return encoder

In [14]:
# Pre decoder layer which take target token as input and pass it through 'ShiftRight',e.g. [8, 34, 12] shifted right is [0, 8, 34, 12]
# and during training, this shift also allows the target tokens to be passed as input to do teacher forcing.this then passsed to the 
# embedding layer which takes no. of word in a vocab(input) and no. of elements in word embedding(doe). and then resulting output passed through 
# LSTM layer of size depth of embedding(doe)    
def pre_attention_decoder_fn(mode, input, doe):
  decoder = tl.Serial(tl.ShiftRight(mode=mode),tl.Embedding(input, doe), tl.LSTM(doe))
  return decoder

In [15]:
# Preparing attention input which take activations of input encoder and pre attention decoder and padded input token(pit)

def prepare_attention_input(encoder_activations, decoder_activations, pit):
  keys = encoder_activations
  values = encoder_activations
  queries = decoder_activations
  mask = jnp.where(pit!=0, x=1, y=0)
  mask = fastnp.reshape(mask, (mask.shape[0], 1, 1, mask.shape[1]))
  mask = mask + fastnp.zeros((1, 1, decoder_activations.shape[1], 1))
  return queries, keys, values, mask  

In [16]:
# Implement sequence to sequence model with attention. In this we create serial network which consist of --> 1-create copies as[input token, 
# target token, input token, target token], 2-parallel branch so that we can give input token to input_encoder_fn and target token to 
# pre_attention_decoder_fn, 3- prepare_attention_input for generating queries, keys, values and mask, 4- then giving to AttentionQKV layer
# to compute scaled dot product attention and outputs the attention weights and mask, 5-inser residual layer to  to add the output of 
# AttentionQKV with the queries input, 6-drop mask, 7- feed the attention weighted output to the LSTM decoder and determine the probabilities
# of each subword in the vocabulary using dense layer and logsoftmax layer.
def NMTA(input_vocab_size=33000, target_vocab_size=33000,doe=1024,n_encoders=2,n_decoders=2,n_heads=1, dropout=0, mode='train'):
  encoder= input_encoder_fn(input_vocab_size,doe,n_encoders)
  pre_attention_decoder= pre_attention_decoder_fn(mode, target_vocab_size, doe)
  model= tl.Serial(tl.Select([1,0,1,0]), tl.Parallel(encoder, pre_attention_decoder),
                   tl.Fn('prepare_attention_input',prepare_attention_input, n_out=4),
                   tl.Residual(tl.AttentionQKV(doe, n_heads=n_heads, dropout=dropout, mode=mode)),
                   tl.Select([0, 2]),
                   [tl.LSTM(doe) for _ in range(n_decoders)],
                   tl.Dense(target_vocab_size),
                   tl.LogSoftmax() )
  return model

In [17]:
print(NMTA())

Serial_in2_out2[
  Select[1,0,1,0]_in2_out4
  Parallel_in2_out2[
    Serial[
      Embedding_33000_1024
      LSTM_1024
      LSTM_1024
    ]
    Serial[
      Serial[
        ShiftRight(1)
      ]
      Embedding_33000_1024
      LSTM_1024
    ]
  ]
  prepare_attention_input_in3_out4
  Serial_in4_out2[
    Branch_in4_out3[
      None
      Serial_in4_out2[
        _in4_out4
        Serial_in4_out2[
          Parallel_in3_out3[
            Dense_1024
            Dense_1024
            Dense_1024
          ]
          PureAttention_in4_out2
          Dense_1024
        ]
        _in2_out2
      ]
    ]
    Add_in2
  ]
  Select[0,2]_in3_out2
  LSTM_1024
  LSTM_1024
  Dense_33000
  LogSoftmax
]


In [18]:
# Training: consist of three part, first is TrainTask which is used to train model from tokenized_train_stream
# Second is EvalTask which is used to train model from tokenized_eval_stream, Third is Loop which is used to run epochs
train = training.TrainTask(labeled_data= tokenized_train_stream, loss_layer= tl.CrossEntropyLoss(),optimizer= trax.optimizers.Adam(.01),
                           lr_schedule= trax.lr.warmup_and_rsqrt_decay(1000, .01),n_steps_per_checkpoint= 10,)

eval = training.EvalTask(labeled_data=tokenized_eval_stream,metrics=[tl.CrossEntropyLoss(), tl.Accuracy()],)
output_dir = 'output/'

# remove old model if it exists. restarts training.
!rm -f ~/output_dir/model.pkl.gz  

training_loop = training.Loop(NMTA(mode='train'), train, eval_tasks=[eval], output_dir=output_dir)
training_loop.run(20)

  "jax.host_id has been renamed to jax.process_index. This alias "
  "jax.host_count has been renamed to jax.process_count. This alias "



Step      1: Total number of trainable weights: 147570920
Step      1: Ran 1 train steps in 147.30 secs
Step      1: train CrossEntropyLoss |  10.28426838
Step      1: eval  CrossEntropyLoss |  10.36134434
Step      1: eval          Accuracy |  0.00000000

Step     10: Ran 9 train steps in 585.46 secs
Step     10: train CrossEntropyLoss |  10.16161346
Step     10: eval  CrossEntropyLoss |  9.73571396
Step     10: eval          Accuracy |  0.14457831

Step     20: Ran 10 train steps in 588.25 secs
Step     20: train CrossEntropyLoss |  9.23410988
Step     20: eval  CrossEntropyLoss |  8.40290737
Step     20: eval          Accuracy |  0.13624841


In [26]:
model = NMTA(mode='eval')
model.init_from_file("/content/output/model.pkl.gz", weights_only=True)
model = tl.Accelerate(model)
print(model)

Accelerate_in2_out2[
  Serial_in2_out2[
    Select[1,0,1,0]_in2_out4
    Parallel_in2_out2[
      Serial[
        Embedding_33000_1024
        LSTM_1024
        LSTM_1024
      ]
      Serial[
        Serial[
          ShiftRight(1)
        ]
        Embedding_33000_1024
        LSTM_1024
      ]
    ]
    prepare_attention_input_in3_out4
    Serial_in4_out2[
      Branch_in4_out3[
        None
        Serial_in4_out2[
          _in4_out4
          Serial_in4_out2[
            Parallel_in3_out3[
              Dense_1024
              Dense_1024
              Dense_1024
            ]
            PureAttention_in4_out2
            Dense_1024
          ]
          _in2_out2
        ]
      ]
      Add_in2
    ]
    Select[0,2]_in3_out2
    LSTM_1024
    LSTM_1024
    Dense_33000
    LogSoftmax
  ]
]


In [27]:
# next symbol: input is model(NMTA),input_tokens(np.ndarray 1 x n_tokens) which is tokenized representation of the input sentence, 
# cur_output_tokens (list) which is tokenized representation of previously translated words, temperature for sampling from distribution
def next_symbol(NMTA,input_tokens, cur_output_tokens, temperature):
  token_length = len(cur_output_tokens)
  padded_length = 2**int(np.ceil(np.log2(token_length + 1))) 
  padded = cur_output_tokens + [0] * (padded_length - token_length)
  padded_with_batch = np.expand_dims(padded, axis=0)
  output, _ = NMTA((input_tokens, padded_with_batch))
  log_probs = output[0, token_length, :]
  symbol = int(tl.logsoftmax_sample(log_probs, temperature))
  return symbol, float(log_probs[symbol])

In [35]:
def sampling_decode(input_sentence, NMTA = None, temperature=0.0, vocab_file=None, vocab_dir=None):
  input_tokens = tokenize(input_sentence,vocab_file,vocab_dir)
  cur_output_tokens = []
  cur_output = 0
  EOF = 1
  while cur_output != EOF:
    cur_output, log_prob = next_symbol(NMTA, input_tokens, cur_output_tokens, temperature)
    cur_output_tokens.append(cur_output)
  sentence = detokenize(cur_output_tokens, vocab_file, vocab_dir)
  return cur_output_tokens, log_prob, sentence

In [36]:
def greedy_decode_test(sentence, NMTA=None, vocab_file=None, vocab_dir=None):
  _,_, translated_sentence = sampling_decode(sentence, NMTA, vocab_file=vocab_file, vocab_dir=vocab_dir)
  print("English: ", sentence)
  print("German: ", translated_sentence)
  return translated_sentence

In [None]:
your_sentence = 'This is my assignment.'
greedy_decode_test(your_sentence, model, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)