<a href="https://colab.research.google.com/github/gecco-evojax/evojax/blob/main/notebooks/Seq2SeqTask.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!nvidia-smi

Fri Jan  7 21:00:37 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.46       Driver Version: 495.46       CUDA Version: 11.5     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   37C    P0    36W / 300W |      0MiB / 16160MiB |      1%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

In [2]:
# @title Install Packages

from IPython.display import clear_output

!pip install git+https://github.com/gecco-evojax/evojax.git@main

clear_output()

In [3]:
# @title Import Libraries

import time
import numpy as np
from IPython.display import Image

import jax
import jax.numpy as jnp
from jax import random

from evojax import Trainer
from evojax.algo import PGPE
from evojax.policy import Seq2seqPolicy
from evojax.task.seq2seq import Seq2seqTask
from evojax.util import create_logger

log_dir = '/tmp/seq2seq_task'
logger = create_logger(name='Seq2SeqDemo', log_dir=log_dir)
logger.info('jax.devices(): {}'.format(jax.devices()))

absl: 2022-01-07 21:00:45,844 [INFO] Starting the local TPU driver.
absl: 2022-01-07 21:00:45,845 [INFO] Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://
absl: 2022-01-07 21:00:46,702 [INFO] Unable to initialize backend 'tpu': Invalid argument: TpuPlatform is not available.
Seq2SeqDemo: 2022-01-07 21:00:46,703 [INFO] jax.devices(): [GpuDevice(id=0, process_index=0)]


# Description

**About the Task**  
This task is based on the seq2seq example from this [repo](https://github.com/google/flax/tree/fb4f2a455d4cae383f11d77707bf0d0f18a45e70/examples/seq2seq). In this task, the agent sees an embedding query that stands for a simple addition, and is required to output an embedding that represents the result.
For example, if the agent sees "012+345", it should output "357". Given the maximum number of digits, the queries are randomly sampled during training and tests.  

**About the Policy**  
The policy network is a sequence-to-sequence model with 2 LSTMs as the encoder and the decoder. During training, we do NOT use teacher forcing and therefore the decoder LSTM has to depend on its previous output for future predictions.

# Learn to solve the task

In [4]:
# @title Set hyper-parameters

max_num_digits = 3  # @param
pop_size = 1024  # @param
batch_size = 128  # @param
hidden_size = 128  # #param
center_lr = 0.01  # @param
stdev_lr = 0.03  # @param
init_stdev = 0.05  # @param
max_iters = 50000  # @param
seed = 42  # @param

In [5]:
# @title Training
# @markdown In training the scores are cross-entropy losses, and in tests the scores are accuracies.

policy = Seq2seqPolicy(hidden_size=hidden_size, logger=logger)

train_task = Seq2seqTask(
    max_len_query_digit=max_num_digits, batch_size=batch_size, test=False)
test_task = Seq2seqTask(
    max_len_query_digit=max_num_digits, batch_size=batch_size, test=True)

solver = PGPE(
    pop_size=pop_size,
    param_size=policy.num_params,
    optimizer='adam',
    center_learning_rate=center_lr,
    stdev_learning_rate=stdev_lr,
    init_stdev=init_stdev,
    logger=logger,
    seed=seed,
)

trainer = Trainer(
    policy=policy,
    solver=solver,
    train_task=train_task,
    test_task=test_task,
    max_iter=max_iters,
    log_interval=100,
    test_interval=500,
    n_repeats=1,
    n_evaluations=1,
    seed=seed,
    log_dir=log_dir,
    logger=logger,
)
_ = trainer.run()

Seq2SeqDemo: 2022-01-07 21:00:57,082 [INFO] Seq2seqPolicy.num_params = 149391
Seq2SeqDemo: 2022-01-07 21:00:57,285 [INFO] Start to train for 50000 iterations.
Seq2SeqDemo: 2022-01-07 21:02:01,203 [INFO] Iter=100, size=1024, max=-1.3815, avg=-1.3996, min=-1.4349, std=0.0078
Seq2SeqDemo: 2022-01-07 21:02:06,950 [INFO] Iter=200, size=1024, max=-1.3683, avg=-1.3846, min=-1.4169, std=0.0070
Seq2SeqDemo: 2022-01-07 21:02:12,699 [INFO] Iter=300, size=1024, max=-1.3287, avg=-1.3544, min=-1.4088, std=0.0110
Seq2SeqDemo: 2022-01-07 21:02:18,446 [INFO] Iter=400, size=1024, max=-1.2985, avg=-1.3233, min=-1.3706, std=0.0110
Seq2SeqDemo: 2022-01-07 21:02:24,190 [INFO] Iter=500, size=1024, max=-1.2313, avg=-1.2571, min=-1.3164, std=0.0118
Seq2SeqDemo: 2022-01-07 21:03:18,340 [INFO] [TEST] Iter=500, #tests=1, max=0.0078 avg=0.0078, min=0.0078, std=0.0000
Seq2SeqDemo: 2022-01-07 21:03:24,205 [INFO] Iter=600, size=1024, max=-1.1912, avg=-1.2233, min=-1.2916, std=0.0129
Seq2SeqDemo: 2022-01-07 21:03:29,9

In [6]:
# @title Visualize the trained policy

act_fn = jax.jit(policy.get_actions)

state = test_task.reset(random.PRNGKey(0)[None, :])
logits = act_fn(
    state.obs,
    jnp.repeat(solver.best_params[None, :], state.obs.shape[0], axis=0)
)
preds = jnp.argmax(logits, axis=-1)
preds = jax.nn.one_hot(preds, logits.shape[-1])

decoded_preds = test_task.decode_embeddings(preds[0])
decoded_problems = test_task.decode_embeddings(state.obs[0])
decoded_answers = test_task.decode_embeddings(state.labels[0])
for i in range(batch_size):
  correct_pred = decoded_preds[i] == decoded_answers[i][1:]
  print('{} ={} ({}) {}'.format(
      decoded_problems[i], decoded_preds[i],
      'CORRECT' if correct_pred else 'INCORRECT',
      '' if correct_pred else decoded_answers[i]))

43+033 =076 (CORRECT) 
00+579 =579 (CORRECT) 
39+694 =733 (CORRECT) 
22+230 =252 (CORRECT) 
01+428 =429 (CORRECT) 
35+116 =151 (CORRECT) 
33+283 =316 (CORRECT) 
30+772 =802 (CORRECT) 
89+845 =934 (CORRECT) 
22+150 =172 (CORRECT) 
63+387 =450 (CORRECT) 
16+128 =144 (CORRECT) 
13+044 =057 (CORRECT) 
85+208 =293 (CORRECT) 
26+660 =686 (CORRECT) 
94+766 =860 (CORRECT) 
66+705 =771 (CORRECT) 
86+744 =830 (CORRECT) 
23+647 =670 (CORRECT) 
12+155 =167 (CORRECT) 
36+615 =651 (CORRECT) 
98+141 =239 (CORRECT) 
71+097 =168 (CORRECT) 
05+698 =703 (CORRECT) 
63+353 =416 (CORRECT) 
15+665 =680 (CORRECT) 
13+910 =923 (CORRECT) 
28+760 =788 (CORRECT) 
25+018 =043 (CORRECT) 
49+244 =293 (CORRECT) 
08+610 =618 (CORRECT) 
70+509 =579 (CORRECT) 
68+403 =471 (CORRECT) 
18+342 =360 (CORRECT) 
26+450 =476 (CORRECT) 
34+667 =701 (CORRECT) 
22+723 =745 (CORRECT) 
78+976 =054 (CORRECT) 
32+057 =089 (CORRECT) 
01+524 =525 (CORRECT) 
41+711 =752 (CORRECT) 
33+180 =213 (CORRECT) 
78+493 =571 (CORRECT) 
71+617 =688