In [2]:
import keras
import tensorflow as tf
from datasets import load_dataset
import numpy as np
from keras.layers import Dense,Dropout,LayerNormalization,Embedding,TextVectorization

In [123]:
class MHA(keras.layers.Layer):
    def __init__(self,d_model,d_head):
        super(MHA,self).__init__()
        self.d_model=d_model
        self.d_head=d_head
        self.num_heads=int(self.d_model/self.d_head)
        assert d_model==self.num_heads*self.d_head
        self.q=Dense(self.d_model)
        self.k=Dense(self.d_model)
        self.v=Dense(self.d_model)
        self.h=Dense(self.d_model)
    def scaled_dot_product(self,q,k,v,mask=None):
        a_score=tf.matmul(q,k,transpose_b=True)/np.sqrt(self.d_head)
        if mask is not None:
            a_score=tf.where(mask==0,-np.inf,a_score)
        a_weights=tf.nn.softmax(a_score)
        a_values=tf.matmul(a_weights,v)
        return a_values,a_weights
    def split_heads(self,x):
        x_reshaped=tf.reshape(x,[x.shape[0],-1,self.num_heads,self.d_head])
        x_trans=tf.transpose(x_reshaped,[0,2,1,3])
        return x_trans
    def merge_heads(self,x):
        x_trans=tf.transpose(x,[0,2,1,3])
        x_merge=tf.reshape(x_trans,[x.shape[0],-1,self.d_model])
        return x_merge
    def call(self,input1,input2,input3,mask=None):
        q=self.q(input1)
        k=self.k(input2)
        v=self.v(input3)
        qs,ks,vs=self.split_heads(q),self.split_heads(k),self.split_heads(v)
        att_values,att_weights=self.scaled_dot_product(qs,ks,vs,mask)
        merge_att_values=self.merge_heads(att_values)
        return self.h(merge_att_values),att_weights

In [124]:
class Encoder(keras.layers.Layer):
    def __init__(self,d_model,d_head,seq_len,vocab_size,n_encoders,d_rate=0.1):
        super(Encoder,self).__init__()
        self.d_model=d_model
        self.seq_len=seq_len
        self.n_encoders=n_encoders
        self.att=MHA(d_model,d_head)
        self.drop=Dropout(d_rate)
        self.norm=LayerNormalization()
        self.token_emd=Embedding(vocab_size,d_model)
        self.pos_emd=Embedding(seq_len,d_model)
    def feed_forward(self,emd_inputs):
        feed=keras.Sequential([Dense(2*d_model,activation='relu'),Dense(d_model)])
        return feed(emd_inputs)
    def pos_indices(self,batch_size):
        pos=np.arange(self.seq_len)
        pos_batch=np.resize(pos,batch_size*self.seq_len)
        pos_reshape=np.reshape(pos_batch,(batch_size,self.seq_len))
        return pos_reshape
    def encoder_output(self,inputs,training=True,mask=None):
        a_val,a_wei=self.att(inputs,inputs,inputs,mask=mask)
        res_val=self.drop(a_val,training=training)
        norm_val=self.norm(res_val+inputs)
        feed_val=self.feed_forward(norm_val)
        res_val=self.drop(feed_val,training=training)
        norm_val=self.norm(norm_val+res_val)
        return norm_val,a_wei
    def call(self,inputs,training=True):
        mask=tf.cast(tf.math.not_equal(inputs,0),tf.float32)
        mask=mask[:,tf.newaxis,tf.newaxis,:]
        tok_emd=self.token_emd(inputs)
        batch_size=inputs.shape[0]
        pos_ind=self.pos_indices(batch_size)
        pos_emd=self.pos_emd(pos_ind)
        inp_emd=self.drop(tok_emd+pos_emd,training=training)
        for i in range(self.n_encoders):
            inp_emb,weights=self.encoder_output(inp_emd,training=training,mask=mask)
        return inp_emb,weights

In [125]:
class Decoder(keras.layers.Layer):
    def __init__(self, d_model, d_head, seq_len, vocab_size, n_decoders, d_rate=0.1):
        super(Decoder, self).__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.n_decoders = n_decoders
        self.att1 = MHA(d_model, d_head)
        self.att2 = MHA(d_model, d_head)
        self.drop = Dropout(d_rate)
        self.norm = LayerNormalization()
        self.token_emd = Embedding(vocab_size, d_model)
        self.pos_emd = Embedding(seq_len, d_model)

    def feed_forward(self, emd_inputs):
        feed = keras.Sequential([Dense(2 * self.d_model, activation='relu'), Dense(self.d_model)])
        return feed(emd_inputs)

    def pos_indices(self, batch_size):
        pos = np.arange(self.seq_len)
        pos_batch = np.resize(pos, batch_size * self.seq_len)
        pos_reshape = np.reshape(pos_batch, (batch_size, self.seq_len))
        return pos_reshape

    def decoder_output(self, encoder_output, target, training=True, mask1=None, mask2=None):
        a_val, a_wei = self.att1(target, target, target, mask=mask1)
        res_val = self.drop(a_val, training=training)
        norm_val = self.norm(res_val + target)
        a_val, a_wei = self.att2(norm_val, encoder_output, encoder_output, mask=mask2)
        res_val = self.drop(a_val, training=training)
        norm_val = self.norm(norm_val + res_val)
        feed_val = self.feed_forward(norm_val)
        res_val = self.drop(feed_val, training=training)
        norm_val = self.norm(norm_val + res_val)
        return norm_val, a_wei

    def call(self, enc_inp, target, training=True):
        mask1 = tf.cast(tf.math.not_equal(target, 0), tf.float32)
        mask2 = mask1[:, tf.newaxis, tf.newaxis, :]
        l_tri = tf.linalg.band_part(tf.ones((self.seq_len, self.seq_len)), -1, 0)
        mask1 = tf.minimum(mask2, l_tri)
        tok_emd = self.token_emd(target)
        batch_size = target.shape[0]
        pos_ind = self.pos_indices(batch_size)
        pos_emd = self.pos_emd(pos_ind)
        inp_emd = self.drop(tok_emd + pos_emd, training=training)
        for i in range(self.n_decoders):
            inp_emb, weights = self.decoder_output(enc_inp, inp_emd, training=training, mask1=mask1, mask2=mask2)
        return inp_emb, weights


In [126]:
class Transformer(keras.models.Model):
    def __init__(self,num_blocks,d_model,d_head,seq_len_inp,seq_len_tar,vocab_size_inp,vocab_size_tar,dropout_rate=0.1):
        super(Transformer,self).__init__()
        self.encoder=Encoder(d_model,d_head,seq_len_inp,vocab_size_inp,num_blocks,d_rate=dropout_rate)
        self.decoder=Decoder(d_model,d_head,seq_len_tar,vocab_size_tar,num_blocks,d_rate=dropout_rate)
        self.classifier=Dense(vocab_size_tar)
    def call(self,inps,training=True):
        (inp_seqs,tar_seqs)=inps
        inp_att_val,inp_att_weights=self.encoder(inp_seqs,training=training)
        tar_att_val,tar_att_weights=self.decoder(inp_att_val,tar_seqs,training=training)
        out=self.classifier(tar_att_val)
        return out,inp_att_weights,tar_att_weights

In [7]:
ds=load_dataset('tatoeba',lang1='en',lang2='te')

In [8]:
train_ds=ds['train'][:]
source=[]
tar=[]
for t in train_ds['translation']:
    source.append(t['en'])
    tar.append(t['te'])
source[:3],tar[:3]

(["I don't speak Japanese.",
  'What will you have?',
  'Tell me about your daily life.'],
 ['నేను జపనీస్ మాట్లాడను',
  'నువ్వు ఏమి తీసుకుంటావ్?',
  'నీ రోజువారీ జీవితం గురించి చెప్పు'])

In [9]:
pre_tar=['<sos> '+t+' <eos>'for t in tar]

In [10]:
vec_en=TextVectorization()
vec_tel=TextVectorization()
vec_tel.adapt(pre_tar)
vec_en.adapt(source)
source_inp=vec_en(source)
tar_inp=vec_tel(pre_tar)





In [11]:
s_vocab=vec_en.vocabulary_size()
t_vocab=vec_tel.vocabulary_size()
print(s_vocab)
print(t_vocab)

530
634


In [12]:
print(len(' '.join(source).split(' ')))
print(len(' '.join(tar).split(' ')))

1402
1086


In [13]:
len(source)

262

In [58]:
z=tf.zeros((tar_inp.shape[0],4),dtype=tf.int64)
tar_inps=tf.concat([tar_inp,z],axis=1)

In [59]:
data=tf.data.Dataset.from_tensor_slices((source_inp,tar_inps)).shuffle(len(source)).batch(8)

In [60]:
data_pre=data.map(lambda c,t:((c,t[:,:-1]),t[:,1:]),tf.data.AUTOTUNE)

In [61]:
for c,t in data_pre.take(1):
    print(c[0][:2],c[1][:2])
    print(t[:2])

tf.Tensor(
[[  6 100   2  78 278   0   0   0   0   0   0   0   0   0   0   0   0]
 [ 76   7  20   0   0   0   0   0   0   0   0   0   0   0   0   0   0]], shape=(2, 17), dtype=int64) tf.Tensor(
[[  2  12 306 201   3   0   0   0   0   0   0   0   0   0   0   0   0]
 [  2  32 560   3   0   0   0   0   0   0   0   0   0   0   0   0   0]], shape=(2, 17), dtype=int64)
tf.Tensor(
[[ 12 306 201   3   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [ 32 560   3   0   0   0   0   0   0   0   0   0   0   0   0   0   0]], shape=(2, 17), dtype=int64)


In [127]:
num_blocks,d_model,d_head,seq_len_inp,seq_len_out,vocab_size_inp,vocab_size_tar=4,128,32,17,17,s_vocab,t_vocab

In [128]:
trf=Transformer(num_blocks,d_model,d_head,seq_len_inp,seq_len_out,vocab_size_inp,vocab_size_tar)

In [129]:
o,w1,w2=trf(c)
print(o.shape)
print(w1.shape)
print(w2.shape)

(8, 17, 634)
(8, 4, 17, 17)
(8, 4, 17, 17)


In [130]:
trf.compile('adam',loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True))

In [143]:
loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True)
with tf.GradientTape(persistent=True) as tape:
    logi,w1,w2=trf(c)
    l=loss(t,o)  
tape.gradient(l,trf.trainable_weights)

[None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None]