In [None]:
# in this example we will look at 
# 1. the training task 
# 2. the evaluation task 
# 3. the training loop 
!pip install trax
import trax
from trax.supervised import training 
import trax.layers as tl
import os 


In [None]:
# create the streams from the tensorflow datasets 
train_stream = trax.data.TFDS("imdb_reviews", keys=("text", "label"), train=True)()
eval_stream = trax.data.TFDS("imdb_reviews", keys=("text", "label"), train=False)()

# building up the pipeline 
data_pipeline=trax.data.Serial(
    trax.data.Tokenize(vocab_file="en_8k.subword", keys=[0]),
    trax.data.Shuffle(), 
    trax.data.FilterByLength(max_length=2048, length_keys=[0]), 
    trax.data.Shuffle(),
    trax.data.FilterByLength(max_length=2048, length_keys=[0]),
    trax.data.BucketByLength(boundaries=[32,128,512,2048],
                             batch_sizes=[512,128,32,8,1],
                             length_keys=[0]),
    trax.data.AddLossWeights()
)



train_batches_stream=data_pipeline(train_stream)

eval_batches_stream=data_pipeline(eval_stream)

In [None]:
import trax.data as data
sentiment_analysis_model=tl.Serial(
    tl.Embedding(data.vocab_size(vocab_file="en_8k.subword"), d_feature=256),
    tl.Mean(axis=1),
    tl.Dense(2), #classifies 2 classes 
    tl.LogSoftmax()
)
print(sentiment_analysis_model)

In [None]:

# create an output directory for weights and checkpoints 
# plug in our model training/evaluation 
# use a loop to iterate over each instances eval_batches_stream= data_pipeline(eval_stream)
# training task 
train_task = training.TrainTask(
    labeled_data=train_batches_stream, 
    loss_layer=tl.CrossEntropyLoss(), 
    optimizer=trax.optimizers.Adam(0.01), 
    n_steps_per_checkpoint=200, 
)

# evaluation 
eval_task=training.EvalTask(
    labeled_data = eval_batches_stream,
    metrics=[tl.CrossEntropyLoss(), tl.Accuracy()], 
    n_eval_batches=20
)

# create checkpoints 
output_dir = os.path.expanduser("~/output_dir/")
!rm -rf {output_dir}
training_loop=training.Loop(sentiment_analysis_model,
                            train_task,
                            eval_tasks=[eval_task],
                            output_dir=output_dir)
# set how many loops to run 
training_loop.run(2000)

In [None]:
import numpy 
example_input= "this was a decent film that I enjoyed watching. It passed some spare time nicely"

# tokenize 
input_iter = iter([example_input])
input_tokens = data.tokenize(input_iter, vocab_file="en_8k.subword")
tokenized_input = list(input_tokens)[0]
#placeholder branch input 
tokenized_with_batch = tokenized_input[None, :]
# extract log probabilites 
sentiment_prob_logs = sentiment_analysis_model(tokenized_with_batch)
# normalise the logs 
norm_log_prob=numpy.exp(sentiment_prob_logs)
# extract sentiment polarity 
sentiment=numpy.argmax(norm_log_prob[0])
print('Input review:\n"{}"\nThe sentiment is: {}'.format(example_input, "Positive" if sentiment else "Negative"))

In [None]:
# loading from a checkpoint 
training_loop.load_checkpoint(directory="~/output_dir/", filename='model.pkl.gz')
# take up training from a given location (2000) and run another 200 times 
training_loop.run(200)

  "jax.host_id has been renamed to jax.process_index. This alias "



Step   2200: Ran 200 train steps in 11.42 secs
Step   2200: train CrossEntropyLoss |  0.28224170
Step   2200: eval  CrossEntropyLoss |  0.38194130
Step   2200: eval          Accuracy |  0.84218750


In [None]:
# loading a pretrained model 
new_model = tl.Serial(
    tl.Embedding(data.vocab_size(vocab_file='en_8k.subword'), d_feature=256),
    tl.Mean(axis=1), 
    tl.Dense(2),
    tl.LogSoftmax()
)

import numpy
def parse_sentiment(text, new_model):
  input_iter = iter([text])
  input_tokens=data.tokenize(input_iter, vocab_file="en_8k.subword")
  tokenize_input=list(input_tokens)[0]
  tokenize_with_batch = tokenize_input[None, :]
  sentiment_log_probs=new_model(tokenize_with_batch)
  norm_log_probs=numpy.exp(sentiment_log_probs)
  sentiment=numpy.argmax(norm_log_probs[0])
  return sentiment
# initialise this new model with weights from the old on e
new_model.init_from_file(file_name="/root/output_dir/model.pkl.gz", weights_only=True)
print("the sentiemt is : ", parse_sentiment("this was a decent film that I enjoyed watching. It passed some spare time nicely", new_model))

the sentiemt is :  1
