In [32]:
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from keras.layers import Conv2D,Input,Flatten,Layer,Dense,Reshape,Conv2DTranspose,Lambda,LayerNormalization,Embedding,MultiHeadAttention,Flatten,Dropout
from keras.models import Sequential,Model


In [33]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [39]:
x_train = tf.expand_dims(x_train,-1)
x_test = tf.expand_dims(x_test,-1)

In [34]:
from google.colab import patches
class GeneratePatches(Layer):
  def __init__(self,patch_shape):
    super().__init__()
    self.patch_shape = patch_shape
  def extract_patches(self,images,patch_size,stride=None):
    if(stride):
      raise NotImplementedError("not implemented")
    else:
      # keep stride equal to patch size
      if(len(images.shape)==3):
        images = images.expand_dims(0)

      B,H,W,C = images.shape
      h,w = patch_size
      Patches = []

      for y in range(H//h):
        y1 = y*h 
        y2 = (y+1)*h
        for x in range(W//w):
          x1 = x*w
          x2 = (x+1)*w       
          
          Patches.append(images[:,y1:y2,x1:x2,:])
      Patches = tf.stack(Patches,1)
      return Patches

  def call(self,images):

    bz,H,W,dims = images.shape
    patches = self.extract_patches(images,self.patch_shape)
    #patches = tf.reshape(patches,(bz,-1,ps[0]*ps[1]*dims))
    
    return patches

    
    

In [35]:
"""PatchGen = GeneratePatches((16,16))
op = PatchGen(x_test).numpy()
from google.colab.patches import cv2_imshow
print(x_test.shape,op.shape,op.dtype)

for i in range(5):
  img = op[i]
  cv2_imshow(x_test[i])
  for j in range(img.shape[0]):
    cv2_imshow(img[j].reshape(16,16,3))
"""

'PatchGen = GeneratePatches((16,16))\nop = PatchGen(x_test).numpy()\nfrom google.colab.patches import cv2_imshow\nprint(x_test.shape,op.shape,op.dtype)\n\nfor i in range(5):\n  img = op[i]\n  cv2_imshow(x_test[i])\n  for j in range(img.shape[0]):\n    cv2_imshow(img[j].reshape(16,16,3))\n'

In [36]:
def MultiLAyerPerceptron(inDim,Dim):

  MLP = Sequential()
  MLP.add(Dense(units=inDim, activation=tf.nn.gelu))
  MLP.add(Dropout(rate=0.1)) # dropout rate is from original paper,
  MLP.add(Dense(units=Dim, activation=tf.nn.gelu))
  MLP.add(Dropout(rate=0.1))

  return MLP

In [37]:
class ENCODER(Layer):
  def __init__(self,**HP): # HyperParameter : HP

     super().__init__()

     self.LN1 = LayerNormalization()
     self.LN2 = LayerNormalization()
     self.MLP = MultiLAyerPerceptron(HP["dim"]*3,HP["dim"])
     self.Attn = MultiHeadAttention(num_heads=HP["heads"], key_dim=HP["dim"], dropout=0.1)


  def call(self,X):

    y = self.LN1(X)
    y = self.Attn(y,y)

    y1 = y + X

    y = self.LN2(y1)
    y = self.MLP(y)

    y = y + y1

    return y

In [41]:
patch_size = (4,4)
_,H,W,C = x_test.shape
no_patches = H//patch_size[0] * W//patch_size[1]

HP = {
    
    "dim":128,
    "heads":4,
    "no_blocks":6,
    "patch_size":patch_size,
    "no_patches":no_patches,
    "classes":10,
    "color_channel":C

}

In [42]:
#input_sequence =  tf.range(start=0, limit=HP["no_patches"], delta=1)# The embeddings to be passed for all positions in the sequence : which is range(of 0 to len(Seq))
Transformer = Sequential([ENCODER(**HP) for _ in range(HP["no_blocks"])])

In [43]:

inputs = Input(shape = (HP["no_patches"],HP["patch_size"][0],HP["patch_size"][1],HP["color_channel"]),name="patches")
input_sequence  = Input(shape = (HP["no_patches"]),name="position")

patch = Conv2D(filters=HP["dim"], kernel_size=HP["patch_size"], strides=HP["patch_size"], padding='valid')(inputs)
patch = Reshape(target_shape = (HP["no_patches"],HP["dim"]))(patch)
pos_emb = Embedding(input_dim=HP["no_patches"], output_dim=HP["dim"])(input_sequence)
print(patch.shape,pos_emb.shape,inputs.shape,input_sequence.shape)

y = patch + pos_emb

for _ in range(HP["no_blocks"]):
   y = ENCODER(**HP)(y)

y = LayerNormalization(epsilon=1e-6)(y)
y = Dropout(0.25)(y)
y = y[:,0]

y = MultiLAyerPerceptron(HP["dim"]*3,HP["dim"])(y)
y = Dense(HP["classes"],activation="softmax")(y)

model =Model(inputs = [inputs,input_sequence], outputs = y)


(None, 49, 128) (None, 49, 128) (None, 49, 4, 4, 1) (None, 49)


In [44]:

def get_input_sequence(shape):
  
  input_sequence =  tf.range(start=0, limit=HP["no_patches"], delta=1)# The embeddings to be passed for all positions in the sequence : which is range(of 0 to len(Seq))
  input_sequence = tf.broadcast_to(input_sequence, [shape[0], HP["no_patches"]])

  return input_sequence


In [45]:
PatchGen = GeneratePatches(HP["patch_size"])


In [46]:
Data = {"patches": PatchGen(x_train),"position":get_input_sequence(x_train.shape)}
Data_val = {"patches": PatchGen(x_test),"position":get_input_sequence(x_test.shape)}

for v in Data.values():
  print(v.shape)

(60000, 49, 4, 4, 1)
(60000, 49)


In [47]:
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])

model.fit(Data,y_train,batch_size=1024,epochs=70)

Epoch 1/70
Epoch 2/70
Epoch 3/70
Epoch 4/70
Epoch 5/70
Epoch 6/70
Epoch 7/70
Epoch 8/70
Epoch 9/70
Epoch 10/70
Epoch 11/70
Epoch 12/70
Epoch 13/70
Epoch 14/70
Epoch 15/70
Epoch 16/70
Epoch 17/70
Epoch 18/70
Epoch 19/70
Epoch 20/70
Epoch 21/70
Epoch 22/70
Epoch 23/70
Epoch 24/70
Epoch 25/70
Epoch 26/70
Epoch 27/70
Epoch 28/70
Epoch 29/70
Epoch 30/70
Epoch 31/70
Epoch 32/70
13/59 [=====>........................] - ETA: 27s - loss: 0.0623 - sparse_categorical_accuracy: 0.9796

KeyboardInterrupt: ignored

In [48]:
model.evaluate(Data_val,y_test,batch_size=128)



[0.2373482882976532, 0.9412999749183655]