In [37]:
import tensorflow as tf
from tensorflow import keras
from keras import layers, models, optimizers
import numpy as np
import datetime

In [149]:
def datagen(tasks, sampl, dimsx, dimsy):
    x = tf.random.normal([tasks, sampl, dimsx])
    
    Wo = tf.Variable(tf.random.normal([tasks, dimsx, dimsy]))
    bo = tf.Variable(tf.random.normal([tasks, 1, dimsy]))
    y = tf.nn.sigmoid(tf.add(tf.matmul(x, Wo), bo))
    
    l = y[:, -1, :]
    y = tf.concat([y[:, :-1, :], tf.zeros(bo.shape)], axis=1)
    return tf.concat([x, y], axis=2), l

def datagen_layered(tasks, sampl, dimsx, dimsy):
    x = tf.random.normal([tasks, sampl, dimsx])
    W1 = tf.Variable(tf.random.normal([tasks, dimsx, 5]))
    b1 = tf.Variable(tf.random.normal([tasks, 1, 5]))
    x1 = tf.nn.relu(tf.add(tf.matmul(x, W1), b1))

    W2 = tf.Variable(tf.random.normal([tasks, 5, 10]))
    b2 = tf.Variable(tf.random.normal([tasks, 1, 10]))
    x2 = tf.nn.relu(tf.add(tf.matmul(x1, W2), b2))

    W3 = tf.Variable(tf.random.normal([tasks, 10, 4]))
    b3 = tf.Variable(tf.random.normal([tasks, 1, 4]))
    x3 = tf.nn.relu(tf.add(tf.matmul(x2, W3), b3))
    
    Wo = tf.Variable(tf.random.normal([tasks, 4, dimsy]))
    bo = tf.Variable(tf.random.normal([tasks, 1, dimsy]))
    y = tf.nn.sigmoid(tf.add(tf.matmul(x3, Wo), bo))
    
    l = y[:, -1, :]
    y = tf.concat([y[:, :-1, :], tf.zeros(bo.shape)], axis=1)
    return tf.concat([x, y], axis=2), l

def datagen_wide(tasks, sampl, dimsx, dimsy):
    x = tf.random.normal([tasks, sampl, dimsx])
    W1 = tf.Variable(tf.random.normal([tasks, dimsx, 20]))
    b1 = tf.Variable(tf.random.normal([tasks, 1, 20]))
    x1 = tf.nn.relu(tf.add(tf.matmul(x, W1), b1))
    
    Wo = tf.Variable(tf.random.normal([tasks, 20, dimsy]))
    bo = tf.Variable(tf.random.normal([tasks, 1, dimsy]))
    y = tf.nn.sigmoid(tf.add(tf.matmul(x1, Wo), bo))
    
    l = y[:, -1, :]
    y = tf.concat([y[:, :-1, :], tf.zeros(bo.shape)], axis=1)
    return tf.concat([x, y], axis=2), l


By default, all Transformers have a key, value, and query size of 32, 8 heads, and 4 layers, and model size of NM = 256. The model size defines the dimensionality of each token, and the MLP between layers scales this size up to a hidden representation of 4 × NM where NM corresponds to the model size.

In [92]:
class BaseAttention(layers.Layer):
  def __init__(self, **kwargs):
    super().__init__()
    self.mha = layers.MultiHeadAttention(**kwargs)
    # self.layernorm = layers.LayerNormalization()
    self.add = layers.Add()

class GlobalSelfAttention(BaseAttention):
  def call(self, x):
    attn_output = self.mha(
        query=x,
        value=x,
        key=x)
    x = self.add([x, attn_output])
    # x = self.layernorm(x)
    return x

class FeedForward(layers.Layer):
  def __init__(self, d_model, dff):
    super().__init__()
    self.seq = models.Sequential([
      layers.Dense(dff, activation='relu'),
      layers.Dense(d_model)
    ])
    self.add = layers.Add()

  def call(self, x):
    x = self.add([x, self.seq(x)])
    return x

class TransformLayer(layers.Layer):
  def __init__(self,*, d_model, num_heads, dff):
    super().__init__()

    self.self_attention = GlobalSelfAttention(
        num_heads=num_heads,
        key_dim=d_model)

    self.ffn = FeedForward(d_model, dff)

  def call(self, x):
    x = self.self_attention(x)
    x = self.ffn(x)
    return x

class Transformer(models.Model):
  def __init__(self, *, num_layers, d_model, num_heads, dff):
    super().__init__()

    self.layerstack = [TransformLayer(d_model=d_model, num_heads=num_heads, dff=dff) for _ in range(num_layers)]

  def call(self, x):
    for layer in self.layerstack:
      x = layer(x)
    x = tf.nn.sigmoid(x)
    return x[:, -1, -1]


In [93]:
sampl = 100
dx = 15
dy = 1

mdim = dx + dy
head = 2
mlay = 2
assert mdim % head == 0



In [42]:

%load_ext tensorboard


The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [153]:

gpicl_standard = Transformer(num_layers=mlay, d_model=mdim, num_heads=head, dff=mdim*4)
gpicl_layered = Transformer(num_layers=mlay, d_model=mdim, num_heads=head, dff=mdim*4)
gpicl_wide = Transformer(num_layers=mlay, d_model=mdim, num_heads=head, dff=mdim*4)

class DisplayProgress(keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        print('Epoch {}/{}'.format(epoch + 1, eps), end='\r')

def loss_fn(y_true, y_pred):
    return tf.reduce_mean(tf.square(y_true - y_pred))


opt = optimizers.SGD(learning_rate=0.01)
gpicl_standard.compile(optimizer=opt, loss=loss_fn, metrics=['mae'])
gpicl_layered.compile(optimizer=opt, loss=loss_fn, metrics=['mae'])
gpicl_wide.compile(optimizer=opt, loss=loss_fn, metrics=['mae'])

eps = 200

train_s_standard, train_l_standard = datagen(400, sampl, dx, dy)
valid_s_standard, valid_l_standard = datagen(100, sampl, dx, dy)
train_s_layered, train_l_layered = datagen_layered(400, sampl, dx, dy)
valid_s_layered, valid_l_layered = datagen_layered(100, sampl, dx, dy)
train_s_wide, train_l_wide = datagen_wide(400, sampl, dx, dy)
valid_s_wide, valid_l_wide = datagen_wide(100, sampl, dx, dy)

with tf.device('/cpu:0'):
    log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    tensorboard_callback = keras.callbacks.TensorBoard(log_dir=log_dir)
    gpicl_standard.fit(train_s_standard, train_l_standard, epochs=eps, batch_size=100, validation_data=(valid_s, valid_l), callbacks=[tensorboard_callback, DisplayProgress()], verbose=0)
    log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    tensorboard_callback = keras.callbacks.TensorBoard(log_dir=log_dir)
    gpicl_layered.fit(train_s_layered, train_l_layered, epochs=eps, batch_size=100, validation_data=(valid_s, valid_l), callbacks=[tensorboard_callback, DisplayProgress()], verbose=0)
    log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    tensorboard_callback = keras.callbacks.TensorBoard(log_dir=log_dir)
    gpicl_wide.fit(train_s_wide, train_l_wide, epochs=eps, batch_size=100, validation_data=(valid_s, valid_l), callbacks=[tensorboard_callback, DisplayProgress()], verbose=0)


2023-03-30 11:32:28.046443: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.
2023-03-30 11:32:28.291454: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.


Epoch 200/200

2023-03-30 11:33:01.990966: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.
2023-03-30 11:33:02.230537: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.


Epoch 200/200

2023-03-30 11:33:35.674864: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.
2023-03-30 11:33:35.927144: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.


Epoch 200/200

In [133]:
print(tf.math.reduce_mean(train_l).numpy())
print(tf.math.reduce_std(train_l).numpy())

tf.summary.histogram(
    'labels', train_l
)

gpicl.summary()

0.4940679
0.42865488
Model: "transformer_43"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 transform_layer_208 (Transf  multiple                 4288      
 ormLayer)                                                       
                                                                 
 transform_layer_209 (Transf  multiple                 4288      
 ormLayer)                                                       
                                                                 
Total params: 8,576
Trainable params: 8,576
Non-trainable params: 0
_________________________________________________________________


In [134]:
print(tf.math.reduce_mean(tf.math.abs(train_l[:1000] - valid_l)))
for i in range(20, 50):
    print(valid_l[i].numpy(), gpicl(valid_s[i:i+1]).numpy())

tf.Tensor(0.47471383, shape=(), dtype=float32)


In [154]:
%tensorboard --logdir logs/fit