In [1]:
from typing import Any, Callable, Dict, Iterable, Optional, Text, Tuple, Union

from flax import linen as nn
import jax
import jax.numpy as jnp
import jax.random as jr

from maskgit.nets.maskgit_transformer import Mlp, Bias, Embed, MlmLayer

LAYERNORM_EPSILON = 1e-12  # Layer norm from BERT

InitializerType = Callable[[jnp.ndarray, Iterable[int], jnp.dtype], jnp.ndarray]

In [15]:
def truncated_normal(stddev: Union[float, jnp.ndarray], dtype=jnp.float32):

  def init(key: jnp.ndarray, shape: Iterable[int], dtype: jnp.dtype = dtype):
    return jax.random.truncated_normal(
        key=key, lower=-2, upper=2, shape=shape, dtype=dtype) * stddev

  return init

In [55]:
class CausalAttention(nn.Module):
  """Attention layer that is part of each Transformer layer."""
  hidden_size: int
  hidden_dropout_prob: float
  num_attention_heads: int
  attention_probs_dropout_prob: float
  hidden_dropout_prob: float
  initializer_fn: InitializerType

  @nn.compact
  def __call__(self, q: jnp.ndarray, kv: jnp.ndarray,
               attention_mask: jnp.ndarray,
               deterministic: bool) -> jnp.ndarray:

    attention_output = nn.attention.MultiHeadAttention(
        num_heads=self.num_attention_heads,
        qkv_features=self.hidden_size,
        out_features=None,
        dropout_rate=self.attention_probs_dropout_prob,
        deterministic=deterministic,
        kernel_init=self.initializer_fn,
        bias_init=jax.nn.initializers.zeros,
        name='multi_head_attention',
    )(q, kv, kv, mask=attention_mask)

    attention_output = nn.Dropout(rate=self.hidden_dropout_prob)(
        attention_output, deterministic=deterministic)
    attention_output = nn.LayerNorm(
        epsilon=LAYERNORM_EPSILON, name='attention_output_ln')(
            attention_output + q)

    return attention_output

class GenericTransformerLayer(nn.Module):
  """A single Transformer layer."""
  intermediate_size: int
  hidden_size: int
  hidden_dropout_prob: float
  num_attention_heads: int
  attention_probs_dropout_prob: float
  initializer_fn: InitializerType

  @nn.compact
  def __call__(self, q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray,
               deterministic: bool) -> jnp.ndarray:
      
    attention_output = CausalAttention(
        hidden_size=self.hidden_size,
        hidden_dropout_prob=self.hidden_dropout_prob,
        num_attention_heads=self.num_attention_heads,
        attention_probs_dropout_prob=self.attention_probs_dropout_prob,
        initializer_fn=self.initializer_fn)(
            q=q, kv=kv, attention_mask=mask,
            deterministic=deterministic)
      
    layer_output = Mlp(
        hidden_size=self.hidden_size,
        hidden_dropout_prob=self.hidden_dropout_prob,
        intermediate_size=self.intermediate_size,
        initializer_fn=self.initializer_fn)(
            attention_output=attention_output, deterministic=deterministic)

    return layer_output

class HollowTransformer(nn.Module):
  """Hollow transformer modified from BERT."""
  vocab_size: int
  hidden_size: int = 768
  num_hidden_layers: int = 12
  num_attention_heads: int = 12
  intermediate_size: int = 3072
  hidden_dropout_prob: float = 0.1
  attention_probs_dropout_prob: float = 0.1
  max_position_embeddings: int = 256
  initializer_range: float = 0.02
  num_layers_per_mixed: int = 4 

  @nn.compact
  def __call__(self,
               input_ids: jnp.ndarray,
               t: float,
               deterministic: bool = True) -> Dict[Text, jnp.ndarray]:

    B, L = input_ids.shape
    # Causal attention doesn't like zero padding
    pad = jnp.zeros((B, 1))
    input_ids = jnp.concatenate([pad, input_ids, pad], axis=1)

    input_ids = input_ids.astype('int32')
    x = Embed(
        embedding_size=self.hidden_size,
        hidden_dropout_prob=self.hidden_dropout_prob,
        vocab_size=self.vocab_size,
        max_position_embeddings=self.max_position_embeddings + 2, # Including the padded values
        initializer_fn=truncated_normal(self.initializer_range))(
            input_ids=input_ids, deterministic=deterministic)
    
    # B, L, K = x.shape
    H = self.num_attention_heads
      
    forward_mask = jnp.tile(jnp.tril(jnp.ones((L, L)))[None, None], (B, H, 1, 1))
    backward_mask = jnp.tile(jnp.triu(jnp.ones((L, L)))[None, None], (B, H, 1, 1))
    mixing_mask = jnp.concatenate([forward_mask, backward_mask], axis=-1)   

    # Causal attention doesn't like zero padding
    # pad = jnp.ones((B, 1, K))
    xf = x[:,:-2] #jnp.concatenate([pad, x[:,:-1]], axis=1)
    xb = x[:,2:] #jnp.concatenate([x[:,1:], pad], axis=1)
    xm = None
      
    for i in range(self.num_hidden_layers):
      f_layer = GenericTransformerLayer(
          intermediate_size=self.intermediate_size,
          hidden_size=self.hidden_size,
          hidden_dropout_prob=self.hidden_dropout_prob,
          num_attention_heads=self.num_attention_heads,
          attention_probs_dropout_prob=self.attention_probs_dropout_prob,
          initializer_fn=truncated_normal(self.initializer_range))
      b_layer = GenericTransformerLayer(
          intermediate_size=self.intermediate_size,
          hidden_size=self.hidden_size,
          hidden_dropout_prob=self.hidden_dropout_prob,
          num_attention_heads=self.num_attention_heads,
          attention_probs_dropout_prob=self.attention_probs_dropout_prob,
          initializer_fn=truncated_normal(self.initializer_range))
      xf = f_layer(q=xf, kv=xf, mask=forward_mask, deterministic=deterministic)
      xb = b_layer(q=xb, kv=xb, mask=backward_mask, deterministic=deterministic)

      if (i + 1) % self.num_layers_per_mixed == 0:
        if xm is None:
          xm = jnp.concatenate([xf, xb], axis=2)
        xfb = jnp.concatenate([xf, xb], axis=1)
        m_layer = GenericTransformerLayer(
          intermediate_size=self.intermediate_size,
          hidden_size=self.hidden_size * 2, # since we're combining the streams
          hidden_dropout_prob=self.hidden_dropout_prob,
          num_attention_heads=self.num_attention_heads,
          attention_probs_dropout_prob=self.attention_probs_dropout_prob,
          initializer_fn=truncated_normal(self.initializer_range))
        xm = m_layer(q=xm, kv=xfb, mask=mixing_mask, deterministic=deterministic)

    layer_output = xm
      
    word_embedding_matrix = self.variables['params']['Embed_0'][
        'word_embeddings']['embedding']
    logits = MlmLayer(
        hidden_size=self.hidden_size,
        initializer_fn=truncated_normal(self.initializer_range))(
            last_layer=layer_output, embeddings=word_embedding_matrix)

    return logits

In [56]:
k = 8
vocab_size = 1024
seq_len = 10
model = HollowTransformer(
  vocab_size = vocab_size + 1,
  hidden_size = k,
  num_hidden_layers = 6,
  num_attention_heads = 2,
  intermediate_size = 10,
  hidden_dropout_prob = 0.0,
  attention_probs_dropout_prob = 0.0,
  max_position_embeddings = seq_len + 1,
  initializer_range = 1.,#0.02
  num_layers_per_mixed = 3,
)

In [57]:
key = jr.PRNGKey(0)
jax.config.update("jax_debug_nans", False)
jax.config.update("jax_debug_infs", False)

In [72]:
key, _ = jr.split(key)
input_sequence = jr.categorical(logits=jnp.ones((vocab_size,)), shape=(1, seq_len), key=key)
params = model.init(key, input_sequence, 0, True)

# def loss_fn(params):
#     eps = 1e-6
#     out = model.apply(params, input_sequence[0], 0, True)
#     out = out[0]
#     p = jax.nn.softmax(out, axis=-1)
#     x0_one_hot = jax.nn.one_hot(input_sequence, vocab_size + 1)
#     logits = jnp.log(p + eps)
#     return -jnp.mean(x0_one_hot * logits)
    
# grad_fn = jax.grad(loss_fn)
# grads = grad_fn(params)

l = 5

out1 = model.apply(params, input_sequence, 0, True)
input_sequence = input_sequence.at[0, l].set((input_sequence[0, l] - 1) % vocab_size)
out2 = model.apply(params, input_sequence, 0, True)
print(jnp.sum(out1[0] - out2[0], axis=1))

[  0.18707103  -1.7576226    1.7236228    1.4114002    0.58498496
   0.          -3.6051326   -0.6251073  -27.829098   -23.07043   ]
