In [1]:
# Francisco Dominguez Mateos
# 09/05/2021
# Inspired by: https://www.tensorflow.org/tutorials/text/transformer
# Playing with attention

In [189]:
import jax
from jax import random
import jax.numpy as np
from jax.experimental.stax import Dense, serial, Relu, BatchNorm, Dropout
from jax.nn.initializers import uniform
from jax.ops import index_update
# Current convention is to import original numpy as "onp"
import numpy as onp

In [3]:
# Generate key which is used to generate random numbers
rng = random.PRNGKey(1)

In [76]:
def attention(k,q,v):
    """
    """
    kT=np.swapaxes(k,-1,-2)
    #print("q =",q.shape)
    #print("kT=",kT.shape)
    qk=np.matmul(q,kT)
    #print("qk=",qk.shape)
    dk=k.shape[-1]
    scalled_attention_logits=qk/np.sqrt(dk)
    attention_weights=jax.nn.softmax(scalled_attention_logits)
    #print("attention_weights=",attention_weights.shape)
    output=np.matmul(attention_weights,v)
    return output, attention_weights
def attention_scaled_dot_product(k,q,v,mask=None):
    """
    """
    kT=np.swapaxes(k,-1,-2)
    #print("q =",q.shape)
    #print("kT=",kT.shape)
    qk=np.matmul(q,kT)
    #print("qk=",qk.shape)
    dk=k.shape[-1]
    scalled_attention_logits=qk/np.sqrt(dk)
    #TODO: apply the mask
    attention_weights=jax.nn.softmax(scalled_attention_logits)
    #print("attention_weights=",attention_weights.shape)
    output=np.matmul(attention_weights,v)
    return output, attention_weights

In [77]:
np.set_printoptions(formatter={'float': lambda x: "{0:0.2f}".format(x)})

In [78]:
temp_k = np.array([[10, 0, 0],
                      [0, 10, 0],
                      [0, 0, 10],
                      [0, 0, 10]])  # (4, 3)

temp_v = np.array([[1, 0],
                      [10, 0],
                      [100, 5],
                      [1000, 6]])  # (4, 2)

# This `query` aligns with the second `key`,
# so the second `value` is returned.
temp_q = np.array([[0, 10, 0]])  # (1, 3)
o,a=attention_scaled_dot_product(temp_k,temp_q,temp_v)
print("a=",a)
print("o=",o)

a= [[0.00 1.00 0.00 0.00]]
o= [[10.00 0.00]]


In [79]:
temp_q = np.array([[0, 0, 10]])  # (1, 3)
o,a=attention_scaled_dot_product(temp_k,temp_q,temp_v)
print("a=",a)
print("o=",o)

a= [[0.00 0.00 0.50 0.50]]
o= [[550.00 5.50]]


In [80]:
temp_q = np.array([[10, 10, 0]])  # (1, 3)
o,a=attention_scaled_dot_product(temp_k,temp_q,temp_v)
print("a=",a)
print("o=",o)

a= [[0.50 0.50 0.00 0.00]]
o= [[5.50 0.00]]


In [81]:
temp_q = np.array([[0, 0, 10],
                      [0, 10, 0],
                      [10, 10, 0]])  # (3, 3)
o,a=attention_scaled_dot_product(temp_k,temp_q,temp_v)
print("a=")
print(a)
print("o=")
print(o)

a=
[[0.00 0.00 0.50 0.50]
 [0.00 1.00 0.00 0.00]
 [0.50 0.50 0.00 0.00]]
o=
[[550.00 5.50]
 [10.00 0.00]
 [5.50 0.00]]


In [172]:
def MultiHeadLayer(embed_dim, num_heads=8):
    assert embed_dim % num_heads==0
    depth=embed_dim//num_heads
    init_wk ,wk =Dense(embed_dim)
    init_wq ,wq =Dense(embed_dim)
    init_wv ,wv =Dense(embed_dim)
    init_lin,lin=Dense(embed_dim)
    
    def init_fun(rng,input_shape):
        rng_k,rng_q,rng_v,rng_lin=random.split(rng,4)
        shape, param_k   =init_wk (rng_k  ,input_shape)
        shape, param_q   =init_wq (rng_q  ,input_shape)
        shape, param_v   =init_wv (rng_v  ,input_shape)
        shape, param_lin =init_lin(rng_lin,input_shape)
        return shape,(param_k,param_q,param_v,param_lin)

    def split_heads(x,batch_size):
        x=np.reshape(x,(batch_size,-1,num_heads,depth))
        x=np.transpose(x,(0,2,1,3))
        return x
    
    def apply_fun(params,x,mask=None):
        batch_size=x[0].shape[0]
        k,q,v=x
        param_k,param_q,param_v,param_lin=params
        k=wk(param_k,k)
        q=wq(param_q,q)
        v=wv(param_v,v)
        #print("k=",k.shape)
        #print("q=",q.shape)
        #print("v=",v.shape)
        k=split_heads(k,batch_size)
        q=split_heads(q,batch_size)
        v=split_heads(v,batch_size)
        #print("k=",k.shape)
        #print("q=",q.shape)
        #print("v=",v.shape)
        a,aw=attention_scaled_dot_product(k,q,v)
        sa=np.transpose(a,(0,2,1,3))
        ca=np.reshape(sa,(batch_size,-1,embed_dim))
        output=lin(param_lin,ca)
        return output,aw
    
    return init_fun,apply_fun
        

In [173]:
in_shape = (1,60,512)
init_mha,mha=MultiHeadLayer(in_shape[2])
output_shape, params = init_mha(rng, in_shape)
fake_data=random.uniform(rng,in_shape)
print("output_shape=",output_shape)
o,aw=mha(params,(fake_data,fake_data,fake_data))
print("o =",o.shape)
print("aw=",aw.shape)

output_shape= (1, 60, 512)
o = (1, 60, 512)
aw= (1, 8, 60, 60)


In [174]:
#Normalizing ..... to implement layer normalization
temp_a = np.array([[[0, 0, 10],
                   [0, 10, 0],
                   [0, 10, 0],
                   [10, 10, 0]],
                  [[0, 0, 5],
                   [0, 5, 0],
                   [0, 6, 0],
                   [5, 5, 0]]])
print("temp_a=",temp_a.shape)
print(jax.nn.normalize(temp_a,axis=0))
print(jax.nn.normalize(temp_a,axis=1))
print(jax.nn.normalize(temp_a,axis=2))
print(jax.nn.normalize(temp_a,axis=(1,2)))

temp_a= (2, 4, 3)
[[[0.00 0.00 1.00]
  [0.00 1.00 0.00]
  [0.00 1.00 0.00]
  [1.00 1.00 0.00]]

 [[0.00 0.00 -1.00]
  [0.00 -1.00 0.00]
  [0.00 -1.00 0.00]
  [-1.00 -1.00 0.00]]]
[[[-0.58 -1.73 1.73]
  [-0.58 0.58 -0.58]
  [-0.58 0.58 -0.58]
  [1.73 0.58 -0.58]]

 [[-0.58 -1.71 1.73]
  [-0.58 0.43 -0.58]
  [-0.58 0.85 -0.58]
  [1.73 0.43 -0.58]]]
[[[-0.71 -0.71 1.41]
  [-0.71 1.41 -0.71]
  [-0.71 1.41 -0.71]
  [0.71 0.71 -1.41]]

 [[-0.71 -0.71 1.41]
  [-0.71 1.41 -0.71]
  [-0.71 1.41 -0.71]
  [0.71 0.71 -1.41]]]
[[[-0.85 -0.85 1.18]
  [-0.85 1.18 -0.85]
  [-0.85 1.18 -0.85]
  [1.18 1.18 -0.85]]

 [[-0.84 -0.84 1.10]
  [-0.84 1.10 -0.84]
  [-0.84 1.49 -0.84]
  [1.10 1.10 -0.84]]]


In [175]:
def PointWiseFeedForwardNetwork(embeb_dim,dff):
    return serial(Dense(dff),Relu,Dense(embeb_dim))

In [176]:
in_shape = (64,50,512)
init_dff,dff=PointWiseFeedForwardNetwork(512,2048)
output_shape, params = init_dff(rng, in_shape)
fake_data=random.uniform(rng,in_shape)
print("output_shape=",output_shape)
o=dff(params,fake_data)
print("o =",o.shape)

output_shape= (64, 50, 512)
o = (64, 50, 512)


In [177]:
def EncoderLayer(embed_dim,num_heads,dff,rate=0.1):
    init_mha,mha=MultiHeadLayer(embed_dim,num_heads)
    init_ffn,ffn=PointWiseFeedForwardNetwork(embed_dim,dff)
    init_ln1,ln1=BatchNorm()
    init_ln2,ln2=BatchNorm()
    init_do1,do1=Dropout(rate)
    init_do2,do2=Dropout(rate)
    def init_fun(rnd,input_shape):
        rng_mha,rng_ffn,rng_ln1,rng_ln2,rng_do1,rng_do2=random.split(rng,6)
        shape, param_mha   =init_mha (rng_mha  ,input_shape)
        shape, param_ffn   =init_ffn (rng_ffn  ,input_shape)
        shape, param_ln1   =init_ln1 (rng_ln1  ,input_shape)
        shape, param_ln2   =init_ln2 (rng_ln2  ,input_shape)
        shape, param_do1   =init_do1 (rng_do1  ,input_shape)
        shape, param_do2   =init_do2 (rng_do2  ,input_shape)
        return in_shape,(param_mha,param_ffn,param_ln1,param_ln2,param_do1,param_do2)
    def apply_fun(params,x,rng,training=True,mask=None):
        param_mha,param_ffn,param_ln1,param_ln2,param_do1,param_do2=params
        attn_output, _=mha(param_mha,(x,x,x),mask)
        attn_output=do1(param_do1,attn_output,rng=rng,mode=training)
        #print("x=",x.shape)
        #print("attn_output",attn_output.shape)
        out1=ln1(param_ln1,x+attn_output)
        
        ffn_output=ffn(param_ffn,out1)
        ffn_output=do2(param_do2,ffn_output,rng=rng,mode=training)
        out2=ln2(param_ln2,out1+ffn_output)
        return out2
    return init_fun,apply_fun

In [183]:
in_shape = (64,43,512)
init_enl,enl=EncoderLayer(512,8,2048)
output_shape, params = init_enl(rng, in_shape)
fake_data=random.uniform(rng,in_shape)
print("output_shape=",output_shape)
oen=enl(params,fake_data,rng)
print("oen=",oen.shape)

output_shape= (64, 43, 512)
oen= (64, 43, 512)


In [185]:
def DecoderLayer(embed_dim,num_heads,dff,rate=0.1):
    init_ma1,ma1=MultiHeadLayer(embed_dim,num_heads)
    init_ma2,ma2=MultiHeadLayer(embed_dim,num_heads)
    init_ffn,ffn=PointWiseFeedForwardNetwork(embed_dim,dff)
    init_ln1,ln1=BatchNorm()
    init_ln2,ln2=BatchNorm()
    init_ln3,ln3=BatchNorm()
    init_do1,do1=Dropout(rate)
    init_do2,do2=Dropout(rate)
    init_do3,do3=Dropout(rate)
    def init_fun(rnd,input_shape):
        rng_ma1,rng_ma2,rng_ffn,rng_ln1,rng_ln2,rng_ln3,rng_do1,rng_do2,rng_do3=random.split(rng,9)
        shape, param_ma1   =init_ma1 (rng_ma1  ,input_shape)
        shape, param_ma2   =init_ma2 (rng_ma2  ,input_shape)
        shape, param_ffn   =init_ffn (rng_ffn  ,input_shape)
        shape, param_ln1   =init_ln1 (rng_ln1  ,input_shape)
        shape, param_ln2   =init_ln2 (rng_ln2  ,input_shape)
        shape, param_ln3   =init_ln3 (rng_ln3  ,input_shape)
        shape, param_do1   =init_do1 (rng_do1  ,input_shape)
        shape, param_do2   =init_do2 (rng_do2  ,input_shape)
        shape, param_do3   =init_do3 (rng_do3  ,input_shape)
        return in_shape,(param_ma1,param_ma2,param_ffn,param_ln1,param_ln2,param_ln3,param_do1,param_do2,param_do3)
    def apply_fun(params,x,enc_output,rng,training=True,mask_look_ahead=None,mask_padding=None):
        param_ma1,param_ma2,param_ffn,param_ln1,param_ln2,param_ln3,param_do1,param_do2,param_do3=params
        att1, att1_w=ma1(param_ma1,(x,x,x),mask_look_ahead)
        att1        =do1(param_do1,att1,rng=rng,mode=training)
        out1        =ln1(param_ln1,x+att1)
        
        att2, att2_w=ma2(param_ma2,(enc_output,enc_output,out1),mask_padding)
        att2        =do2(param_do2,att2,rng=rng,mode=training)
        out2        =ln2(param_ln2,out1+att2)
        
        ffn_output=ffn(param_ffn,out2)
        ffn_output=do3(param_do3,ffn_output,rng=rng,mode=training)
        out3=ln3(param_ln3,out2+ffn_output)
        return out2, att1_w, att2_w
    return init_fun,apply_fun

In [186]:
in_shape = (64,43,512)
init_dcl,dcl=DecoderLayer(512,8,2048)
output_shape, params = init_dcl(rng, in_shape)
fake_data=random.uniform(rng,in_shape)
print("output_shape=",output_shape)
o,_,_=dcl(params,fake_data,oen,rng)
print("o =",o.shape)

output_shape= (64, 43, 512)
o = (64, 43, 512)


In [192]:
# I don't know if this works fine
# from: https://github.com/google/jax/pull/2157/files
def Embedding(vocab_size,
              embedding_size,
              padding_idx=None,
              embedding_init=uniform()):
  """Layer construction function for an embedding layer."""

  def init_fun(rng, input_shape):
    embedding_shape = (vocab_size, embedding_size)
    embedding_table = embedding_init(rng, embedding_shape)
    if padding_idx is not None:
      embedding_table = index_update(embedding_table, padding_idx, 0.)
    output_shape = input_shape + (embedding_size,)
    return output_shape, (embedding_table,)

  def apply_fun(params, inputs, **kwargs):
    embedding_table = params[0]
    return embedding_table[inputs]

  return init_fun, apply_fun



In [204]:
def Encoder(num_layers,embed_dim,num_heads,dff,input_vocab_size,maximum_position_encoding,rate=0.1):
    init_emb,emb=Embedding(input_vocab_size,embed_dim)
    #init_pos,pos=PositionalEncoding(maximum_position_encoding,embed_dim)
    enc_layers=[EncoderLayer(embed_dim,num_heads,dff,rate) for _ in range(num_layers)]
    init_dou,dou=Dropout(rate)
    def init_fun(rng,input_shape):
        rng_emb,rng_pos,rng_dou=random.split(rng,3)
        shape, param_emb=init_emb(rng_emb,input_shape)
        #shape, param_pos=init_pos(rng_pos,input_shape)
        param_pos=None
        shape, param_dou=init_dou(rng_dou,input_shape)
        params_enc=[]
        for init_enc,_ in enc_layers:
            rng_enc,rng=random.split(rng)
            shape, param_enc=init_enc(rng_enc,input_shape)
            params_enc.append(param_enc)
        return input_shape,(param_emb,param_pos,params_enc,param_dou)
    def apply_fun(params,x,rng,training,mask=None):
        seq_len=x.shape[1]
        param_emb,param_pos,params_enc,param_dou=params
        x =emb(param_emb,x)
        x*=np.sqrt(embed_dim)
        #x+=pos(param_pos,x)#?????????????????????????????????
        x=dou(param_dou,x,rng,mode=training)
        for i,_,enc in enumerate(enc_layers):
            x=enc(params_enc[i],x,training,mask)
        return x
    return init_fun,apply_fun

In [206]:
in_shape = (64,62)
init_enc,enc=Encoder(2,512,8,2048,8500,10000)
output_shape, params = init_enc(rng, in_shape)
fake_data=random.uniform(rng,in_shape,dtype=np.int64)
print("output_shape=",output_shape)
o,_,_=enc(params,fake_data,oen,rng)
print("o =",o.shape)

ValueError: dtype argument to `uniform` must be a float dtype, got <class 'jax._src.numpy.lax_numpy.int64'>