In [211]:
from typing import Any, Callable, Dict, Iterable, Optional, Text, Tuple, Union

from flax import linen as nn
import jax
import jax.numpy as jnp
import jax.random as jr

from maskgit.nets.maskgit_transformer import Mlp, Bias, Embed#, MlmLayer

LAYERNORM_EPSILON = 1e-12  # Layer norm from BERT

InitializerType = Callable[[jnp.ndarray, Iterable[int], jnp.dtype], jnp.ndarray]

In [6]:
# import flax.linen as nn
# import jax.numpy as jnp

# # Example sequence of length 4
# sequence = jnp.array([[1, 2, 3, 4]])

# # Generate causal mask
# causal_mask = nn.make_causal_mask(sequence)

# print(causal_mask)

# mask = jnp.tril(jnp.ones((4, 4)), k=-1)
# print(mask)
jnp.tril(jnp.ones((3, 3)), k=-1)

Array([[0., 0., 0.],
       [1., 0., 0.],
       [1., 1., 0.]], dtype=float32)

In [4]:
def truncated_normal(stddev: Union[float, jnp.ndarray], dtype=jnp.float32):

  def init(key: jnp.ndarray, shape: Iterable[int], dtype: jnp.dtype = dtype):
    return jax.random.truncated_normal(
        key=key, lower=-2, upper=2, shape=shape, dtype=dtype) * stddev

  return init

In [339]:
class CausalAttention(nn.Module):
  """Attention layer that is part of each Transformer layer."""
  hidden_size: int
  hidden_dropout_prob: float
  num_attention_heads: int
  attention_probs_dropout_prob: float
  hidden_dropout_prob: float
  initializer_fn: InitializerType

  @nn.compact
  def __call__(self, q: jnp.ndarray, kv: jnp.ndarray,
               attention_mask: jnp.ndarray,
               deterministic: bool) -> jnp.ndarray:

    attention_output = nn.attention.MultiHeadAttention(
        num_heads=self.num_attention_heads,
        qkv_features=self.hidden_size,
        out_features=None,
        dropout_rate=self.attention_probs_dropout_prob,
        deterministic=deterministic,
        kernel_init=self.initializer_fn,
        bias_init=jax.nn.initializers.zeros,
        name='multi_head_attention',
    )(q, kv, kv, mask=attention_mask)

    attention_output = nn.Dropout(rate=self.hidden_dropout_prob)(
        attention_output, deterministic=deterministic)
    attention_output = nn.LayerNorm(
        epsilon=LAYERNORM_EPSILON, name='attention_output_ln')(
            attention_output + q)

    return attention_output

class GenericTransformerLayer(nn.Module):
  """A single Transformer layer."""
  intermediate_size: int
  hidden_size: int
  hidden_dropout_prob: float
  num_attention_heads: int
  attention_probs_dropout_prob: float
  initializer_fn: InitializerType

  @nn.compact
  def __call__(self, q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray,
               deterministic: bool) -> jnp.ndarray:
      
    attention_output = CausalAttention(
        hidden_size=self.hidden_size,
        hidden_dropout_prob=self.hidden_dropout_prob,
        num_attention_heads=self.num_attention_heads,
        attention_probs_dropout_prob=self.attention_probs_dropout_prob,
        initializer_fn=self.initializer_fn)(
            q=q, kv=kv, attention_mask=mask,
            deterministic=deterministic)
      
    layer_output = Mlp(
        hidden_size=self.hidden_size,
        hidden_dropout_prob=self.hidden_dropout_prob,
        intermediate_size=self.intermediate_size,
        initializer_fn=self.initializer_fn)(
            attention_output=attention_output, deterministic=deterministic)

    return layer_output

class HollowTransformer(nn.Module):
  """Hollow transformer modified from BERT."""
  vocab_size: int
  hidden_size: int = 768
  num_hidden_layers: int = 12
  num_attention_heads: int = 12
  intermediate_size: int = 3072
  hidden_dropout_prob: float = 0.1
  attention_probs_dropout_prob: float = 0.1
  max_position_embeddings: int = 256
  initializer_range: float = 0.02
  num_layers_per_mixed: int = 4 

  @nn.compact
  def __call__(self,
               input_ids: jnp.ndarray,
               deterministic: bool = True) -> Dict[Text, jnp.ndarray]:

    input_ids = input_ids.astype('int32')
    x = Embed(
        embedding_size=self.hidden_size,
        hidden_dropout_prob=self.hidden_dropout_prob,
        vocab_size=self.vocab_size,
        max_position_embeddings=self.max_position_embeddings,
        initializer_fn=truncated_normal(self.initializer_range))(
            input_ids=input_ids, deterministic=deterministic)
    
    B, L, K = x.shape
      
    forward_mask = jnp.tile(jnp.tril(jnp.ones((L, L)))[None], (B, 1))
    backward_mask = jnp.tile(jnp.triu(jnp.ones((L, L)))[None], (B, 1))
    mixing_mask = jnp.concatenate([forward_mask, backward_mask], axis=2)      

    # Causal attention doesn't like zero padding
    pad = jnp.ones((B, 1, K))
    xf = jnp.concatenate([pad, x[:,:-1]], axis=1)
    xb = jnp.concatenate([x[:,1:], pad], axis=1)
    xm = None
      
    for i in range(self.num_hidden_layers):
      f_layer = GenericTransformerLayer(
          intermediate_size=self.intermediate_size,
          hidden_size=self.hidden_size,
          hidden_dropout_prob=self.hidden_dropout_prob,
          num_attention_heads=self.num_attention_heads,
          attention_probs_dropout_prob=self.attention_probs_dropout_prob,
          initializer_fn=truncated_normal(self.initializer_range))
      b_layer = GenericTransformerLayer(
          intermediate_size=self.intermediate_size,
          hidden_size=self.hidden_size,
          hidden_dropout_prob=self.hidden_dropout_prob,
          num_attention_heads=self.num_attention_heads,
          attention_probs_dropout_prob=self.attention_probs_dropout_prob,
          initializer_fn=truncated_normal(self.initializer_range))
      xf = f_layer(q=xf, kv=xf, mask=forward_mask, deterministic=deterministic)
      xb = b_layer(q=xb, kv=xb, mask=backward_mask, deterministic=deterministic)

      if (i + 1) % self.num_layers_per_mixed == 0:
        if xm is None:
          xm = jnp.concatenate([xf, xb], axis=2)
        xfb = jnp.concatenate([xf, xb], axis=1)
        m_layer = GenericTransformerLayer(
          intermediate_size=self.intermediate_size,
          hidden_size=self.hidden_size * 2, # since we're combining the streams
          hidden_dropout_prob=self.hidden_dropout_prob,
          num_attention_heads=self.num_attention_heads,
          attention_probs_dropout_prob=self.attention_probs_dropout_prob,
          initializer_fn=truncated_normal(self.initializer_range))
        xm = m_layer(q=xm, kv=xfb, mask=mixing_mask, deterministic=deterministic)

    layer_output = xm
      
    word_embedding_matrix = self.variables['params']['Embed_0'][
        'word_embeddings']['embedding']
    logits = MlmLayer(
        hidden_size=self.hidden_size,
        initializer_fn=truncated_normal(self.initializer_range))(
            last_layer=layer_output, embeddings=word_embedding_matrix)

    return logits

In [340]:
k = 8
model = HollowTransformer(
  vocab_size = 4,
  hidden_size = k,
  num_hidden_layers = 3,
  num_attention_heads = 2,
  intermediate_size = 10,
  hidden_dropout_prob = 0.0,
  attention_probs_dropout_prob = 0.0,
  max_position_embeddings = 10,
  initializer_range = 1.,#0.02
  num_layers_per_mixed = 3,
)

In [341]:
key = jr.PRNGKey(0)

In [342]:
key, _ = jr.split(key)
input_sequence = jr.normal(shape=(1, 10), key=key)
params = model.init(key, input_sequence, True)
model.apply(params, input_sequence, True)

Array([[[ 0.68474483, -3.8607855 ,  4.3750286 , -0.04215145],
        [ 0.2565856 , -3.9484396 ,  4.4713316 ,  0.37558365],
        [-0.09743834, -3.1910896 , -0.19396186, -1.9386406 ],
        [-0.39917445, -4.6603518 ,  2.432427  ,  0.95605516],
        [ 0.42935085, -4.712538  ,  3.9293957 ,  1.0363998 ],
        [ 1.1949406 , -3.6340714 ,  3.7268753 ,  1.0710907 ],
        [ 0.47236824, -4.36226   ,  4.489441  ,  0.17659187],
        [ 0.38882256, -4.580062  ,  4.3808365 , -0.0818696 ],
        [ 0.00630283, -4.956742  ,  3.8312683 , -0.01543999],
        [-0.7918043 , -4.4137115 ,  4.0304804 ,  0.31813812]]],      dtype=float32)

In [343]:
# To test the "hollowness" we need to replace the gradient with finite differences
# since the embedding step is not differentiable

# l = 7
# grad_fn = jax.grad(lambda x: jnp.sum(model.apply(params, x, True)[0, l, 0]))
# grad_fn(input_sequence)