# Autoregressive Abstracter: Card Sorting

This is a second development version of `autoregressive_abstracter_card_sorting.ipynb`. In the first version of the notebook, the transformer successfully completed the task, but the $\text[{encoder}] \to [\text{abstracter}] \to [\text{decoder}]$ models failed. Here, we try a couple more configurations of the abstracter models. In particular,

1) Standard encoder with self-attention. Abstracter with self-attention and relational cross-attention ($Q=E, K=E, V=A$). Decoder with causal self-attention and relational cross-attention with ($Q=A, K=A, V=E$).

2) Standard encoder with self-attention. Abstracter with self-attention and relational or symbolic cross-attention. Decoder with causal self-attention and standard cross-attention with ($Q=D, K=A, V=A$). (so far, as in the first version of the notebook). Then, final prediction is generated via $t_j \sim \text{Multinomial}(\text{Softmax}(D_j^\top a_1, ..., D_j^\top a_L))$

## Set Up

In [1]:
import pydealer
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import sklearn.metrics

from hand2hand import Cards
import utils

2023-01-29 15:33:06.678075: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
%env "WANDB_NOTEBOOK_NAME" "autoregressive_abstracter_hand_sorting_new_ideas.ipynb"

import wandb
wandb.login()

env: "WANDB_NOTEBOOK_NAME"="autoregressive_abstracter_hand_sorting_new_ideas.ipynb"


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mawni00[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
def create_callbacks(monitor='loss', log_gradients=False, save_model=True, log_weights=True,
                     train_ds=None, val_ds=None, ):
    callbacks = [
#         tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, mode='auto', restore_best_weights=True),
#         tf.keras.callbacks.ReduceLROnPlateau( monitor='val_loss', factor=0.1, patience=5, verbose=1, mode='auto'),
        wandb.keras.WandbMetricsLogger(log_freq='epoch'),
        wandb.keras,WandbModelCheckpoint(monitor=monitor, mode='auto', save_freq='epoch')
#         wandb.keras.WandbCallback(
#             monitor=monitor, log_weights=log_weights, log_gradients=log_gradients, save_model=save_model, save_graph=True,
#             training_data=train_ds, validation_data=val_ds,
#             labels=class_names, predictions=64, compute_flops=True)
        ]
    return callbacks

# metrics = [
#         tf.keras.metrics.BinaryAccuracy(name='acc'),
#         tf.keras.metrics.Precision(class_id=1, name='precision'),
#         tf.keras.metrics.Recall(class_id=1, name='recall'),
#         tf.keras.metrics.AUC(curve='ROC', multi_label=True, name='auc')
#         ]

# loss = tf.keras.losses.CategoricalCrossentropy(from_logits=False)
# opt = tf.keras.optimizers.Adam()

In [4]:
import tensorflow as tf
import seq2seq_transformer
from sklearn.model_selection import train_test_split
from tensorflow.keras import layers, Model

## Define New Layers (New Variants of Cross-Attention for Decoder)

Define New variant of Decoder with $Q=A, K=A, V=E$

In [30]:
# implement a general cross-attention layer which supports all the different kinds of attention

from seq2seq_transformer import BaseAttention

class ContextualCrossAttention(BaseAttention):
    """A general layer for implementing cross-attention, with configurable queries, keys, and values"""

    def __init__(self, cross_attention_type='standard', **kwargs):
        
        super(ContextualCrossAttention, self).__init__(**kwargs)

        if cross_attention_type in ('std_encoder_decoder', 'symbolic', 'relational'):
            self.cross_attention_type = cross_attention_type
        else:
            raise ValueError(f'`cross_attention_type` {cross_attention_type} is invalid')

    def call(self, input_seq, context_seq):

        if self.cross_attention_type == 'std_encoder_decoder':
            # standard encoder-decoder cross-attention of transformers
            attn_output, attn_scores = self.mha(
                query=input_seq,
                key=context_seq,
                value=context_seq,
                return_attention_scores=True)
            
            x = self.add([input_seq, attn_output])

            x = self.layernorm(x)

        elif self.cross_attention_type == 'symbolic':
            # 'symbolic' cross-attention.
            #  input_seq is learned input-independent symbols
            attn_output, attn_scores = self.mha(
                query=input_seq,
                key=context_seq,
                value=input_seq,
                return_attention_scores=True)
            
            x = self.add([input_seq, attn_output]) # TODO: think about this. should we keep this skip connection?

            x = self.layernorm(x)
        
        elif self.cross_attention_type == 'relational':
            # 'relational' cross-attention. 
            # queries and keys both come from the context sequence, thus their inner product computes relations
            attn_output, attn_scores = self.mha(
                query=context_seq,
                key=context_seq,
                value=input_seq,
                return_attention_scores=True)
            
            x = self.add([input_seq, attn_output]) # TODO: think about this. should we keep this skip connection?

            x = self.layernorm(x)

        else:
            raise ValueError('unexpected `cross_attention_type`')

        # Cache the attention scores for plotting later.
        self.last_attn_scores = attn_scores


        return x

In [41]:
from seq2seq_transformer import CausalSelfAttention, FeedForward

class ContextDecoderLayer(tf.keras.layers.Layer):
  def __init__(self,
               d_model,
               num_heads,
               dff,
               cross_attention_type='std_encoder_decoder',
               dropout_rate=0.1):
    
    super(ContextDecoderLayer, self).__init__()
    
    self.cross_attention_type = cross_attention_type

    self.causal_self_attention = CausalSelfAttention(
        num_heads=num_heads,
        key_dim=d_model,
        dropout=dropout_rate)

    self.cross_attention = ContextualCrossAttention(
        num_heads=num_heads,
        key_dim=d_model,
        dropout=dropout_rate,
        cross_attention_type=self.cross_attention_type)

    self.ffn = FeedForward(d_model, dff)

  def call(self, input_seq, context_seq):
    x = self.causal_self_attention(x=input_seq)
    x = self.cross_attention(input_seq=x, context_seq=context_seq)

    # Cache the last attention scores for plotting later
    self.last_attn_scores = self.cross_attention.last_attn_scores

    x = self.ffn(x)  # Shape `(batch_size, seq_len, d_model)`.

    return x


class ContextDecoder(tf.keras.layers.Layer):
    def __init__(
        self, 
        num_layers, 
        num_heads,
        dff,
        cross_attention_type='std_encoder_decoder',
        dropout_rate=0.1,
        name='decoder'):

        super(ContextDecoder, self).__init__(name=name)

        self.cross_attention_type = cross_attention_type
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.dff = dff
        self.dropout_rate = dropout_rate

    def build(self, input_shape):

        _, self.sequence_length, self.d_model = input_shape

        self.dropout = tf.keras.layers.Dropout(self.dropout_rate)

        self.dec_layers = [
            ContextDecoderLayer(
                d_model=self.d_model, num_heads=self.num_heads,
                dff=self.dff, dropout_rate=self.dropout_rate, 
                cross_attention_type=self.cross_attention_type)
            for _ in range(self.num_layers)]

        self.last_attn_scores = None

    def call(self, input_seq, context_seq):

        x = self.dropout(input_seq)

        for i in range(self.num_layers):
            x = self.dec_layers[i](input_seq=x, context_seq=context_seq)

#             self.last_attn_scores = self.dec_layers[-1].last_attn_scores

        return x

In [42]:
# # define standard decoder with above implementation of Cross-Attention Layer
# class Decoder(tf.keras.layers.Layer):
#     def __init__(self, num_layers, num_heads, dff,
#                dropout_rate=0.1, name='decoder'):
#         super(Decoder, self).__init__(name=name)

#         self.num_layers = num_layers
#         self.num_heads = num_heads
#         self.dff = dff
#         self.dropout_rate = dropout_rate

#     def build(self, input_shape):

#         _, self.sequence_length, self.d_model = input_shape

#         self.dropout = tf.keras.layers.Dropout(self.dropout_rate)

#         self.dec_layers = [
#             ContextDecoderLayer(
#                 d_model=self.d_model, num_heads=self.num_heads,
#                 dff=self.dff, dropout_rate=self.dropout_rate,
#                 cross_attention_type='std_encoder_decoder')
#             for _ in range(self.num_layers)]

#         self.last_attn_scores = None

#     def call(self, input_seq, context_seq):

#         x = self.dropout(input_seq)

#         for i in range(self.num_layers):
#             x = self.dec_layers[i](input_seq=x, context_seq=context_seq)

# #             self.last_attn_scores = self.dec_layers[-1].last_attn_scores

#         return x


## Dataset

In [43]:
hand_size = 7

deck = Cards()
pydeck = pydealer.Deck()
pydeck.shuffle()

In [44]:
n = 10000
BEGIN_HAND = 52 # token for 'beginning of hand'
END_HAND = 53 # token for 'end of hand'

hands = np.array(n*(hand_size+2)*[0]).reshape(n, hand_size+2)
hands_sorted = np.array(n*(hand_size+2)*[0]).reshape(n, hand_size+2)

for i in np.arange(n):
    hand = pydeck.deal(hand_size)
    if len(hand) < hand_size:
        #print('shuffling deck')
        pydeck = pydealer.Deck()
        pydeck.shuffle()
        hand = pydeck.deal(hand_size)
    source = list(deck.index_pyhand(hand))
    source.insert(0,BEGIN_HAND)
    source.append(END_HAND)
    hands[i] = np.array(source)
    deck.sort_pyhand(hand)
    target = list(deck.index_pyhand(hand))
    target.insert(0,BEGIN_HAND)
    target.append(END_HAND)
    hands_sorted[i] = np.array(target)


In [45]:
hands_train, hands_test, sorted_train, sorted_test = train_test_split(hands, hands_sorted, test_size=0.25)

source_train = hands_train
target_train = sorted_train[:,:-1]
labels_train = sorted_train[:,1:]

source_test = hands_test
target_test = sorted_test[:,:-1]
labels_test = sorted_test[:,1:]

In [46]:
def evaluate_seq2seq_model(model):
    n = len(source_test)
    output = np.zeros(n*(hand_size+2), dtype=int).reshape(n,hand_size+2)
    output[:,0] = BEGIN_HAND
    for i in range(hand_size+1):
        predictions = model((source_test, output[:, :-1]), training=False)
        predictions = predictions[:, i, :]
        predicted_id = tf.argmax(predictions, axis=-1)
        output[:,i+1] = predicted_id

    acc = (np.sum(output[:,1:] == labels_test))/np.prod(labels_test.shape)
    print('per-card accuracy: %.2f%%' % (100*acc))
    
    return acc

## Standard Transformer

In [47]:
# from seq2seq_transformer import Encoder

# inputs = layers.Input(shape=(9,), name='token_input')
# target = layers.Input(shape=(8,), name='token_target')

# token_embedder = layers.Embedding(54, 128, name='vector_embedding')
# pos_embedding_adder_input = AddPositionalEmbedding(name='add_pos_embedding_input')
# pos_embedding_adder_target = AddPositionalEmbedding(name='add_pos_embedding_target')
# encoder = Encoder(num_layers=3, num_heads=2, dff=64, dropout_rate=0.1, name='transformer_encoder')

# decoder = Decoder(num_layers=3, num_heads=2, dff=64, dropout_rate=0.1, name='transformer_decoder')

# x = token_embedder(inputs)
# x = pos_embedding_adder_input(x)

# encoder_context = encoder(x)

# target_embedding = token_embedder(target)
# target_embedding = pos_embedding_adder_target(target_embedding)

# x = decoder(target_embedding, encoder_context)

# x = layers.Dense(54)(x)

# transformer = Model(inputs=[inputs, target], outputs=x)

In [48]:
from seq2seq_transformer import Encoder, AddPositionalEmbedding

class Transformer(tf.keras.Model):
    def __init__(self, num_layers, num_heads, dff,
            input_vocab_size, target_vocab_size, embedding_dim,
            dropout_rate=0.1, name='transformer'):
        super().__init__(name=name)
        
        self.token_embedder = layers.Embedding(input_vocab_size, embedding_dim, name='vector_embedding')
        
        self.pos_embedding_adder_input = AddPositionalEmbedding(name='add_pos_embedding_input')
        self.pos_embedding_adder_target = AddPositionalEmbedding(name='add_pos_embedding_target')

        self.encoder = Encoder(num_layers=num_layers, num_heads=num_heads, dff=dff, dropout_rate=dropout_rate, name='encoder')
        self.decoder = ContextDecoder(num_layers=num_layers, num_heads=num_heads, dff=dff,
          dropout_rate=dropout_rate, cross_attention_type='std_encoder_decoder', name='decoder')
        self.final_layer = layers.Dense(target_vocab_size, name='final_layer')


    def call(self, inputs):
        # To use a Keras model with `.fit` you must pass all your inputs in the
        # first argument.
        source, target  = inputs
        
        x = self.token_embedder(source)
        x = self.pos_embedding_adder_input(x)

        encoder_context = self.encoder(x)

        target_embedding = self.token_embedder(target)
        target_embedding = self.pos_embedding_adder_target(target_embedding)

        x = self.decoder(input_seq=target_embedding, context_seq=encoder_context)

        # Final linear layer output.
        logits = self.final_layer(x)  # (batch_size, target_len, target_vocab_size)

        try:
          # Drop the keras mask, so it doesn't scale the losses/metrics.
          # b/250038731
          del logits._keras_mask
        except AttributeError:
          pass

        # Return the final output and the attention weights.
        return logits

In [49]:
transformer = Transformer(num_layers=2, num_heads=2, dff=64, 
    input_vocab_size=54, target_vocab_size=54, embedding_dim=128)

In [50]:
from seq2seq_transformer import masked_loss, masked_accuracy

# opt.build(transformer.trainable_variables)
transformer.compile(loss=masked_loss, optimizer=tf.keras.optimizers.Adam(), metrics=masked_accuracy)
transformer((source_train, target_train))

transformer.summary()

Model: "transformer"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 vector_embedding (Embedding  multiple                 6912      
 )                                                               
                                                                 
 add_pos_embedding_input (Ad  multiple                 0         
 dPositionalEmbedding)                                           
                                                                 
 add_pos_embedding_target (A  multiple                 0         
 ddPositionalEmbedding)                                          
                                                                 
 encoder (Encoder)           multiple                  298112    
                                                                 
 decoder (ContextDecoder)    multiple                  562560    
                                                       

In [56]:
tf.config.list_physical_devices()

[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'),
 PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [51]:
transformer.fit((source_train, target_train), labels_train, epochs=10, batch_size=64, verbose=1)

Epoch 1/10


2023-01-29 15:38:49.983086: I tensorflow/compiler/xla/service/service.cc:173] XLA service 0x2ac644012810 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2023-01-29 15:38:49.983174: I tensorflow/compiler/xla/service/service.cc:181]   StreamExecutor device (0): Tesla P100-PCIE-16GB, Compute Capability 6.0
2023-01-29 15:38:50.266978: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2023-01-29 15:38:50.346269: W tensorflow/compiler/xla/service/gpu/nvptx_helper.cc:56] Can't find libdevice directory ${CUDA_DIR}/nvvm/libdevice. This may result in compilation or runtime failures, if the program we try to run uses routines from libdevice.
Searched for CUDA in the following directories:
  ./cuda_sdk_lib
  /usr/local/cuda-11.2
  /usr/local/cuda
  .
You can choose the search directory by setting xla_gpu_cuda_data_dir in HloModule's DebugOptions.  For m

InternalError: Graph execution error:

Detected at node 'StatefulPartitionedCall_86' defined at (most recent call last):
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/runpy.py", line 194, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/runpy.py", line 87, in _run_code
      exec(code, run_globals)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/ipykernel_launcher.py", line 17, in <module>
      app.launch_new_instance()
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/traitlets/config/application.py", line 1041, in launch_instance
      app.start()
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/ipykernel/kernelapp.py", line 724, in start
      self.io_loop.start()
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/tornado/platform/asyncio.py", line 215, in start
      self.asyncio_loop.run_forever()
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/asyncio/base_events.py", line 570, in run_forever
      self._run_once()
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/asyncio/base_events.py", line 1859, in _run_once
      handle._run()
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/asyncio/events.py", line 81, in _run
      self._context.run(self._callback, *self._args)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 512, in dispatch_queue
      await self.process_one()
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 501, in process_one
      await dispatch(*args)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 408, in dispatch_shell
      await result
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 731, in execute_request
      reply_content = await reply_content
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/ipykernel/ipkernel.py", line 417, in do_execute
      res = shell.run_cell(
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/ipykernel/zmqshell.py", line 540, in run_cell
      return super().run_cell(*args, **kwargs)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2945, in run_cell
      result = self._run_cell(
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3000, in _run_cell
      return runner(coro)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3203, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3382, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3442, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "/tmp/tmp.YpQxWfuT6Q/ipykernel_22645/144498288.py", line 1, in <module>
      transformer.fit((source_train, target_train), labels_train, epochs=10, batch_size=64, verbose=1)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/keras/engine/training.py", line 1650, in fit
      tmp_logs = self.train_function(iterator)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/keras/engine/training.py", line 1249, in train_function
      return step_function(self, iterator)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/keras/engine/training.py", line 1233, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/keras/engine/training.py", line 1222, in run_step
      outputs = model.train_step(data)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/keras/engine/training.py", line 1027, in train_step
      self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/keras/optimizers/optimizer_experimental/optimizer.py", line 527, in minimize
      self.apply_gradients(grads_and_vars)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/keras/optimizers/optimizer_experimental/optimizer.py", line 1140, in apply_gradients
      return super().apply_gradients(grads_and_vars, name=name)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/keras/optimizers/optimizer_experimental/optimizer.py", line 634, in apply_gradients
      iteration = self._internal_apply_gradients(grads_and_vars)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/keras/optimizers/optimizer_experimental/optimizer.py", line 1166, in _internal_apply_gradients
      return tf.__internal__.distribute.interim.maybe_merge_call(
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/keras/optimizers/optimizer_experimental/optimizer.py", line 1216, in _distributed_apply_gradients_fn
      distribution.extended.update(
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/keras/optimizers/optimizer_experimental/optimizer.py", line 1211, in apply_grad_to_update_var
      return self._update_step_xla(grad, var, id(self._var_key(var)))
Node: 'StatefulPartitionedCall_86'
libdevice not found at ./libdevice.10.bc
	 [[{{node StatefulPartitionedCall_86}}]] [Op:__inference_train_function_14144]

In [17]:
evaluate_seq2seq_model(transformer);

per-card accuracy: 94.40%


## Autoregressive Abstracter with 'Symbolic' Cross-Attention $(Q=A, K=E, V=A)$

In [20]:
from symbolic_decoder import SymbolicDecoder

class AutoregressiveSymbolicAbstracter(tf.keras.Model):
    def __init__(self, num_layers, num_heads, dff,
            input_vocab_size, target_vocab_size, embedding_dim,
            dropout_rate=0.1, name='transformer'):
        super().__init__(name=name)
        
        self.token_embedder = layers.Embedding(input_vocab_size, embedding_dim, name='vector_embedding')
        
        self.pos_embedding_adder_input = AddPositionalEmbedding(name='add_pos_embedding_input')
        self.pos_embedding_adder_target = AddPositionalEmbedding(name='add_pos_embedding_target')

        self.encoder = Encoder(num_layers=num_layers, num_heads=num_heads, dff=dff, dropout_rate=dropout_rate, name='encoder')
        self.abstracter = SymbolicDecoder(num_layers=num_layers, num_heads=num_heads, dff=dff, dropout_rate=dropout_rate, name='abstracter')
        self.decoder = Decoder(num_layers=num_layers, num_heads=num_heads, dff=dff, dropout_rate=dropout_rate, name='decoder')
        self.final_layer = layers.Dense(target_vocab_size, name='final_layer')


    def call(self, inputs):
        # To use a Keras model with `.fit` you must pass all your inputs in the
        # first argument.
        source, target  = inputs
        
        x = self.token_embedder(source)
        x = self.pos_embedding_adder_input(x)

        encoder_context = self.encoder(x)

        abstracted_context = self.abstracter(encoder_context)
        
        target_embedding = self.token_embedder(target)
        target_embedding = self.pos_embedding_adder_target(target_embedding)

        x = self.decoder(target_embedding, abstracted_context)

        # Final linear layer output.
        logits = self.final_layer(x)  # (batch_size, target_len, target_vocab_size)

        try:
          # Drop the keras mask, so it doesn't scale the losses/metrics.
          # b/250038731
          del logits._keras_mask
        except AttributeError:
          pass

        # Return the final output and the attention weights.
        return logits

In [21]:
autoregressive_symbolic_abstracter = AutoregressiveSymbolicAbstracter(num_layers=2, num_heads=2, dff=64, 
    input_vocab_size=54, target_vocab_size=54, embedding_dim=128)

In [22]:
from seq2seq_transformer import masked_loss, masked_accuracy

autoregressive_symbolic_abstracter.compile(loss=masked_loss, optimizer=tf.keras.optimizers.Adam(), metrics=masked_accuracy)
autoregressive_symbolic_abstracter((source_train, target_train))

autoregressive_symbolic_abstracter.summary()

Model: "transformer"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 vector_embedding (Embedding  multiple                 6912      
 )                                                               
                                                                 
 add_pos_embedding_input (Ad  multiple                 0         
 dPositionalEmbedding)                                           
                                                                 
 add_pos_embedding_target (A  multiple                 0         
 ddPositionalEmbedding)                                          
                                                                 
 encoder (Encoder)           multiple                  298112    
                                                                 
 abstracter (SymbolicDecoder  multiple                 563712    
 )                                                     

In [23]:
autoregressive_symbolic_abstracter.fit((source_train, target_train), labels_train, epochs=10, batch_size=64, verbose=1)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x7ff0acda1c50>

In [24]:
evaluate_seq2seq_model(autoregressive_symbolic_abstracter);

per-card accuracy: 15.85%


## Autoregressive Abstracter with 'Episodic' Cross-Attention $(Q=E, K=E, V=E)$

In [25]:
from seq2seq_transformer import EpisodicDecoder

class AutoregressiveEpisodicAbstracter(tf.keras.Model):
    def __init__(self, num_layers, num_heads, dff,
            input_vocab_size, target_vocab_size, embedding_dim,
            dropout_rate=0.1, name='autoregressive_episodic_abstracter'):
        super().__init__(name=name)
        
        self.token_embedder = layers.Embedding(input_vocab_size, embedding_dim, name='vector_embedding')
        
        self.pos_embedding_adder_input = AddPositionalEmbedding(name='add_pos_embedding_input')
        self.pos_embedding_adder_target = AddPositionalEmbedding(name='add_pos_embedding_target')

        self.encoder = Encoder(num_layers=num_layers, num_heads=num_heads, dff=dff, dropout_rate=dropout_rate, name='encoder')
        self.abstracter = EpisodicDecoder(num_layers=num_layers, num_heads=num_heads, dff=dff, dropout_rate=dropout_rate, name='abstracter')
        self.decoder = Decoder(num_layers=num_layers, num_heads=num_heads, dff=dff, dropout_rate=dropout_rate, name='decoder')
        self.final_layer = layers.Dense(target_vocab_size, name='final_layer')


    def call(self, inputs):
        # To use a Keras model with `.fit` you must pass all your inputs in the
        # first argument.
        source, target  = inputs
        
        x = self.token_embedder(source)
        x = self.pos_embedding_adder_input(x)

        encoder_context = self.encoder(x)

        abstracted_context = self.abstracter(encoder_context)
        
        target_embedding = self.token_embedder(target)
        target_embedding = self.pos_embedding_adder_target(target_embedding)

        x = self.decoder(target_embedding, abstracted_context)

        # Final linear layer output.
        logits = self.final_layer(x)  # (batch_size, target_len, target_vocab_size)

        try:
          # Drop the keras mask, so it doesn't scale the losses/metrics.
          # b/250038731
          del logits._keras_mask
        except AttributeError:
          pass

        # Return the final output and the attention weights.
        return logits

In [26]:
autoregressive_episodic_abstracter = AutoregressiveEpisodicAbstracter(num_layers=2, num_heads=2, dff=64, 
    input_vocab_size=54, target_vocab_size=54, embedding_dim=128)

In [27]:
from seq2seq_transformer import masked_loss, masked_accuracy, CustomSchedule

learning_rate = CustomSchedule(d_model=128)
autoregressive_episodic_abstracter.compile(
    loss=masked_loss, optimizer=tf.keras.optimizers.Adam(learning_rate), metrics=masked_accuracy)
autoregressive_episodic_abstracter((source_train, target_train))

autoregressive_episodic_abstracter.summary()

Model: "autoregressive_episodic_abstracter"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 vector_embedding (Embedding  multiple                 6912      
 )                                                               
                                                                 
 add_pos_embedding_input (Ad  multiple                 0         
 dPositionalEmbedding)                                           
                                                                 
 add_pos_embedding_target (A  multiple                 0         
 ddPositionalEmbedding)                                          
                                                                 
 encoder (Encoder)           multiple                  298112    
                                                                 
 abstracter (EpisodicDecoder  multiple                 563712    
 )                              

In [28]:
autoregressive_episodic_abstracter.fit((source_train, target_train), labels_train,
    epochs=10, batch_size=64, verbose=1)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x7ff08d17bc50>

In [29]:
evaluate_seq2seq_model(autoregressive_episodic_abstracter);

per-card accuracy: 91.11%


## Multi-Abstracter Model

$$\text{Encoder} \to \text{Abstracter} \to \cdots \to \text{Abstracter} \to \text{Decoder}$$

...