In [1]:
import sys
import os
import time
import random
import re
import json
import pickle
from typing import List, Tuple, Dict, Callable, Optional, Any, Sequence, Mapping, NamedTuple
from attrdict import AttrDict

In [2]:
import tensorflow as tf
import numpy as np
import matplotlib as plt

In [3]:
from model.attention import SelfAttention, MultiheadAttention
from model.embedding import EmbeddingSharedWeights
from model.ffn import FeedForwardNetwork
from model.layer_utils import LayerWrapper, LayerNormalization
from model import model_utils
from datasource.sample_ds import SampleDataSource

In [4]:
tf.enable_eager_execution()

In [5]:
hparams = AttrDict()
hparams.num_layers = 6
hparams.num_units = 256
hparams.num_filter_units = 1024
hparams.num_heads = 8
hparams.dropout_rate = 0.1
hparams.max_length = 50
hparams.batch_size = 64
hparams.vocab_size = 3278
hparams.data_path = './data'

In [6]:
ds = SampleDataSource(hparams)

In [51]:
class Transformer(tf.keras.Model):
    
    def __init__(self, hparams, is_train):
        super().__init__()
        self.hparams = hparams
        self.is_train = is_train
        self.embedding_layer = EmbeddingSharedWeights(hparams['vocab_size'], hparams['num_units'])
        self.encoder_stack = EncoderStack(hparams, is_train)
        self.encoder_embedding_dropout = tf.keras.layers.Dropout(hparams['dropout_rate'])
        
        self.decoder_stack = DecoderStack(hparams, is_train)
        self.decoder_embedding_dropout = tf.keras.layers.Dropout(hparams['dropout_rate'])
    
    def call(self, inputs, targets: Optional[np.ndarray] = None):
        attention_bias = model_utils.get_padding_bias(inputs)
        encoder_outputs = self._encode(inputs, attention_bias)
        
        if targets is None:
            return self.predict(encoder_outputs, attention_bias)
        else:
            logits = self._decode(encoder_outputs, targets, attention_bias)
            return logits
    
    def loss(self, inputs, targets):
        pad = tf.to_float(tf.not_equal(targets, 0))
        onehot_targets = tf.one_hot(targets, self.hparams['vocab_size'])
        logits = self(inputs, targets)
        cross_ents = tf.losses.softmax_cross_entropy(
            onehot_labels=onehot_targets,
            logits=logits
        )
        loss = tf.reduce_sum(cross_ents * pad) / tf.reduce_sum(pad)
        return loss
    
    def grads(self, inputs, targets):
        with tf.GradientTape() as tape:
            loss = self.loss(inputs, targets)
        return tape.gradient(loss, self.variables)
        
    def predict(self, encoder_outputs, bias):
        pass
        
    def _encode(self, inputs, attention_bias):
        embedded_inputs = self.embedding_layer(inputs)
        embedded_inputs += model_utils.get_position_encoding(hparams['max_length'], hparams['num_units'])
        inputs_padding = model_utils.get_padding(inputs)

        if self.is_train:
            encoder_inputs = self.encoder_embedding_dropout(embedded_inputs)
        return self.encoder_stack(encoder_inputs, attention_bias, inputs_padding)
    
    def _decode(self, encoder_outputs, targets, attention_bias):
        decoder_inputs = self.embedding_layer(targets)
        decoder_inputs = tf.pad(decoder_inputs, [[0, 0], [1, 0], [0, 0]])[:, :-1, :]
        # add positional encoding
        length = decoder_inputs.shape[1]
        decoder_inputs += model_utils.get_position_encoding(length, self.hparams['num_units'])
        
        if self.is_train:
            decoder_inputs = self.decoder_embedding_dropout(decoder_inputs)

        decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias(length)
        outputs = self.decoder_stack(decoder_inputs, encoder_outputs, decoder_self_attention_bias, attention_bias)
        logits = self.embedding_layer.linear(outputs)
        return logits
        

In [35]:
class EncoderStack(tf.keras.Model):
    
    def __init__(self, hparams, is_train):
        super(EncoderStack, self).__init__()
        self.my_layers = []
        
        for i in range(hparams['num_layers']):
            self_attention_layer = SelfAttention(hparams['num_units'], hparams['num_heads'], hparams['dropout_rate'], is_train)
            ffn_layer = FeedForwardNetwork(hparams['num_units'], hparams['num_filter_units'], hparams['dropout_rate'], is_train)
            self.my_layers.append([
                LayerWrapper(self_attention_layer, hparams['num_units'], hparams['dropout_rate'], is_train),
                LayerWrapper(ffn_layer, hparams['num_units'], hparams['dropout_rate'], is_train),
            ])
            
        self.output_norm = LayerNormalization(hparams['num_units'])
            
    
    def call(self, encoder_inputs, attention_bias, inputs_padding):
        for n, layer in enumerate(self.my_layers):
            self_attention_layer = layer[0]
            ffn_layer = layer[1]
            
            encoder_inputs = self_attention_layer(encoder_inputs, attention_bias)
            encoder_inputs = ffn_layer(encoder_inputs)
            
        return self.output_norm(encoder_inputs)

In [36]:
class DecoderStack(tf.keras.Model):
    
    def __init__(self, hparams, is_train):
        super(DecoderStack, self).__init__()
        self.my_layers = []
        
        for i in range(hparams['num_layers']):
            self_attention_layer = SelfAttention(hparams['num_units'], hparams['num_heads'], hparams['dropout_rate'], is_train)
            enc_dec_attention_layer = MultiheadAttention(hparams['num_units'], hparams['num_heads'], hparams['dropout_rate'], is_train)
            ffn_layer = FeedForwardNetwork(hparams['num_units'], hparams['num_filter_units'], hparams['dropout_rate'], is_train)
            self.my_layers.append([
                LayerWrapper(self_attention_layer, hparams['num_units'], hparams['dropout_rate'], is_train),
                LayerWrapper(enc_dec_attention_layer, hparams['num_units'], hparams['dropout_rate'], is_train),
                LayerWrapper(ffn_layer, hparams['num_units'], hparams['dropout_rate'], is_train),
            ])
            
        self.output_norm = LayerNormalization(hparams['num_units'])
            
    
    def call(self, decoder_inputs, encoder_outputs, decoder_self_attention_bias, attention_bias):
        for n, layer in enumerate(self.my_layers):
            self_attention_layer = layer[0]
            enc_dec_attention_layer = layer[1]
            ffn_layer = layer[2]
            
            decoder_inputs = self_attention_layer(decoder_inputs, decoder_self_attention_bias)
            decoder_inputs = enc_dec_attention_layer(decoder_inputs, encoder_outputs, attention_bias)
            decoder_inputs = ffn_layer(decoder_inputs)
            
        return self.output_norm(decoder_inputs)

In [11]:
batch = ds.feed_dict(None, hparams['batch_size'], True)

In [12]:
one_batch = batch[0]

In [21]:
one_batch[2].shape

(64, 50)

In [52]:
model = Transformer(hparams, True)

In [55]:
yay = model.grads(one_batch[0], one_batch[2])
yay

[<tensorflow.python.framework.ops.IndexedSlices at 0x7ff5a7cf82b0>,
 <tf.Tensor: id=58795, shape=(256, 256), dtype=float32, numpy=
 array([[ 5.4954155e-04,  2.8448703e-04,  2.1747116e-04, ...,
          1.2337319e-04,  1.2365717e-04,  6.4111350e-04],
        [ 2.6563415e-04,  1.3207187e-04,  1.3624194e-04, ...,
          1.0584666e-04, -1.8340190e-05,  5.5580295e-04],
        [ 2.3870365e-04,  1.3556868e-04, -9.9358822e-06, ...,
          2.0677716e-04, -1.1583711e-04,  3.5644736e-04],
        ...,
        [-4.7855478e-04, -1.7758833e-04, -1.1864794e-04, ...,
         -2.8152665e-04, -1.1523189e-05, -8.2780374e-04],
        [-3.6420461e-04,  7.9749661e-05, -1.9492365e-04, ...,
         -2.1017529e-04,  3.3695742e-05, -5.3479307e-04],
        [-6.0539955e-04, -4.5458215e-05, -3.5313083e-04, ...,
         -2.4158103e-04,  4.7240304e-05, -8.6413173e-04]], dtype=float32)>,
 <tf.Tensor: id=58769, shape=(256, 256), dtype=float32, numpy=
 array([[ 2.48188328e-04, -4.79181239e-04,  5.84773006e

In [None]:
yay.numpy().shape