In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

from preprocessing.data.imagenet_labels import imagenet_labels

In [2]:
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession

config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)

In [3]:
def decode_fn(record_bytes):
    example = tf.io.parse_single_example(
        # Data
        record_bytes,

        # Schema
        {"image": tf.io.FixedLenFeature([], dtype=tf.string),
         "label": tf.io.FixedLenFeature([], dtype=tf.int64)}
            )
    image = tf.io.parse_tensor(example["image"],
                               out_type = tf.uint8)
    label = example["label"]
    return image, label

In [4]:
path = './preprocessing/data/TFRs/train/'
files = [path+n for n in os.listdir(path)]
ds = tf.data.TFRecordDataset(files).map(decode_fn)

In [5]:
_ds = ds.batch(5)
for ex in _ds.take(1):
    for i in range(len(ex[0])):
        print(imagenet_labels[ex[1][i].numpy()])
#        plt.imshow(ex[0][i])
#        plt.show()
#print(ex)

beaver
pitcher, ewer
jackfruit, jak, jack
jay
dock, dockage, docking facility


In [6]:
x = ex[0]
print(x.shape)
x = tf.cast(x, dtype = tf.float32) * (1./255)

(5, 128, 128, 3)


In [7]:
conv1 = tf.keras.layers.Conv2D(16, (64,64),strides = (2,2), padding='same')(x)
print(conv1.shape)
conv2 = tf.keras.layers.Conv2D(32, (32,32),strides = (2,2), padding='same')(conv1)
print(conv2.shape)
conv3 = tf.keras.layers.Conv2D(64, (16,16),strides = (2,2), padding='same')(conv2)
print(conv3.shape)
conv4 = tf.keras.layers.Conv2D(128, (4,4),strides = (2,2), padding='same')(conv3)
print(conv4.shape)

reshape = tf.keras.layers.Reshape((8 * 8, 128))(conv4)
print(reshape.shape)
transposed = tf.keras.layers.Permute((2,1))(reshape)
print(transposed.shape)


(5, 64, 64, 16)
(5, 32, 32, 32)
(5, 16, 16, 64)
(5, 8, 8, 128)
(5, 64, 128)
(5, 128, 64)


In [8]:
class_token = tf.Variable(
    tf.random.normal(shape=tf.TensorShape([1, 1, transposed.get_shape()[2]])),
    trainable=True,
    name="class_toke1n",
    dtype=tf.float32,
    )
print(class_token.shape)

(1, 1, 64)


In [10]:
tiled_class_token = tf.tile(class_token, [transposed.get_shape()[0],1,1])
print(tiled_class_token.shape)

(5, 1, 64)


In [11]:
x_class = tf.concat([tiled_class_token,transposed], axis=1)
print(x_class.shape)

(5, 129, 64)


In [12]:
E_pos = tf.Variable(
    tf.random.normal(shape=tf.TensorShape([1, x_class.get_shape()[1], x_class.get_shape()[2]])),
    trainable=True,
    name="E_pos",
    dtype=tf.float32,
    )
tiled_E_pos = tf.tile(E_pos,[transposed.get_shape()[0],1,1])
z_0 = x_class + tiled_E_pos

print(z_0.shape)

(5, 129, 64)


In [13]:
LN = tf.keras.layers.LayerNormalization(axis = -1)(z_0)
print(LN.shape)

(5, 129, 64)


In [21]:
h = 8

U_qkv = tf.Variable(
    tf.random.normal(
        shape=tf.TensorShape([1, LN.get_shape()[2], 3 * LN.get_shape()[2]//h])
    ),
    trainable=True,
    name="U_qkv",
    dtype=tf.float32,
    )
tiled_U_qkv = tf.tile(U_qkv, [LN.get_shape()[0],1,1])
print(tiled_U_qkv.shape)
U_query = tiled_U_qkv[:,:,0 * LN.get_shape()[2]//h :1 * LN.get_shape()[2]//h]
U_key   = tiled_U_qkv[:,:,1 * LN.get_shape()[2]//h :2 * LN.get_shape()[2]//h]
U_value = tiled_U_qkv[:,:,2 * LN.get_shape()[2]//h :3 * LN.get_shape()[2]//h]
print(U_query.shape)
print(U_key.shape)
print(U_value.shape)

(5, 64, 24)
(5, 64, 8)
(5, 64, 8)
(5, 64, 8)


In [23]:
Query = tf.matmul(LN,U_query)
Key   = tf.matmul(LN,U_key)
Value = tf.matmul(LN,U_value)
print(Query.shape)
print(Key.shape)
print(Value.shape)


SA = tf.keras.layers.Attention(use_scale=True)([Query,Value,Key])
print(SA.shape)

(5, 129, 8)
(5, 129, 8)
(5, 129, 8)
(5, 129, 8)


In [26]:
U_msa  = tf.Variable(
    tf.random.normal(
        shape=tf.TensorShape([1, LN.get_shape()[2]//h, LN.get_shape()[2]])
    ),
    trainable=True,
    name="U_msa",
    dtype=tf.float32,
    )
tiled_U_msa = tf.tile(U_msa, [LN.get_shape()[0],1,1])
print(tiled_U_msa.shape)

MSA = tf.matmul(SA,tiled_U_msa) + LN
print(MSA.shape)

(5, 8, 64)
(5, 129, 64)


In [44]:
LN_2 = tf.keras.layers.LayerNormalization(axis = -1)(MSA)
print(LN_2.shape)

(5, 129, 64)


In [52]:
flatten_LN_2 = tf.keras.layers.Flatten()(LN_2)
F = tf.keras.layers.Dense(LN.get_shape()[1]*LN.get_shape()[2],
                          activation=tf.keras.activations.gelu)(flatten_LN_2)

FF = tf.keras.layers.Dense(LN.get_shape()[1]*LN.get_shape()[2])(F)
FF = tf.reshape(FF,[-1,LN_2.get_shape()[1],LN_2.get_shape()[2]])+ LN_2
print(FF.shape)

(5, 129, 64)
