In [1]:
import tensorflow_datasets as tfds
import tensorflow as tf
import numpy as np
import datetime

#1. get mnist from tensorflow_datasets
mnist, info = tfds.load("mnist", split =["train","test"], as_supervised=True, with_info=True)
ds_train = mnist[0]
ds_val = mnist[1]

# print(info)
# tfds.show_examples(train, info) 

Downloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to ~/tensorflow_datasets/mnist/3.0.1...


Dl Completed...:   0%|          | 0/4 [00:00<?, ? file/s]

Dataset mnist downloaded and prepared to ~/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.


In [3]:
#2. write function to create the dataset that we want
def preprocess(data, batch_size) :
    #Image should be float
    data = data.map(lambda x, t: (tf.cast(x, tf.float32), tf.cast(t, tf.float32)))
    #Image should be flattened
    data = data.map(lambda x, t: (tf.reshape(x, (-1,)), t))
    #Image vector will here have values between -1 and 1
    data = data.map(lambda x, t: ((x/128.)-1., t))
    #We want to have two mnist images in each example
    #This leads to a single example being ((x1,y1),(x2,y2))
    zipped_ds = tf.data.Dataset.zip((data.shuffle(2000), 
                                     data.shuffle(2000)))
    
    #Map ((x1,y1),(x2,y2)) to ( x1, x2, t1 - t2, ((t1 + t2) >= 5) ) (*boolean and int)
    zipped_ds = zipped_ds.map(lambda x1, x2: (x1[0], x2[0], x1[1] - x2[1], (x1[1] + x2[1] >= 5) ))
    #Transform boolean target to int
    zipped_ds = zipped_ds.map(lambda x1, x2, t1, t2: (x1, x2, tf.cast(t1, tf.int32), tf.cast(t2, tf.int32) ))
    #Batch the dataset
    zipped_ds = zipped_ds.batch(batch_size)
    #Prefetch
    zipped_ds = zipped_ds.prefetch(tf.data.AUTOTUNE)
    return zipped_ds

train = preprocess(ds_train, batch_size=32) #ds_train.apply(preprocess)
val = preprocess(ds_val, batch_size=32) #ds_val.apply(preprocess)

# check the contents of the dataset
for img1, img2, label_1, label_2 in train:
    print(img1.shape, img2.shape, label_1.shape, label_2.shape)
    break

(32, 784) (32, 784) (32,) (32,)


In [4]:
from tensorflow.python.util import tf_inspect
class MNISTCalc(tf.keras.Model) :
  ## 1. constructor
  def __init__(self) :
    super().__init__()

    ## optimizer, loss function and metrics
    self.metrics_list = [ tf.keras.metrics.BinaryAccuracy(), tf.keras.metrics.Mean(name="loss") ]
    self.optimizer = tf.keras.optimizers.get("Adam")
    self.loss_function_0 = tf.keras.losses.BinaryCrossentropy()
    self.loss_function_1 = tf.keras.losses.MeanSquaredError()

    ## layers to encode the images (both layers used for both images)
    self.dense1 = tf.keras.layers.Dense(128, activation=tf.nn.relu)
    self.dense2 = tf.keras.layers.Dense(128, activation=tf.nn.relu)
    self.dense3 = tf.keras.layers.Dense(128, activation=tf.nn.relu)
    self.out_layer = tf.keras.layers.Dense(1,activation=tf.nn.softmax)

    self.calc_1 = tf.keras.layers.Dense(1)
    self.calc_2 = tf.keras.layers.Dense(1,activation=tf.nn.sigmoid) 

  
  ## 2.1 call method (forward computation for mnist image classification)
  # @tf.function
  def classify(self, img) :
    img = self.dense1(img)
    img = self.dense2(img)
    img = self.dense3(img)
    output = self.out_layer(img)
    return output


  ## 2.2 call method (forward computation)
  @tf.function
  def call(self, imgs, training=False) :
    im_0, im_1 = imgs
    im_0 = self.classify(im_0)
    im_1 = self.classify(im_1)

    imgs = tf.concat([im_0, im_1], axis=1)
    result_1 = self.calc_1(imgs)
    result_2 = self.calc_2(imgs)

    ## Two output activations for respective tasks
    return result_1, result_2

  ## 3. metrics property
  @ property
  def metrics(self):
    ## return a list with all metrics in the model
    return self.metrics_list

  def reset_metrics(self):
    for metric in self.metrics:
      metric.reset_state()

  ## 5. train step method
  @tf.function
  def train_step(self, data):
    img_0, img_1, label_0, label_1 = data
    with tf.GradientTape() as tape:
      output_0, output_1 = self((img_0, img_1), training=True)
      loss_0 = self.loss_function_0(label_0, output_0)
      loss_1 = self.loss_function_1(label_1, output_1)
      loss = loss_0 + loss_1
        
    gradients = tape.gradient(loss, self.trainable_variables)
    self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
    
    ## update the state of the metrics according to loss
    self.metrics[0].update_state(label_0, output_0)
    self.metrics[1].update_state(loss)
    ## return a dictionary with metric names as keys and metric results as values
    return {m.name : m.result() for m in self.metrics}

    ## 6. test_step method
    @tf.function
    def test_step(self, data):
      img_0, img_1, label_0, label_1 = data
      ## same as train step (without parameter updates)
      output_0, output_1 = self((img_0, img_1), training=False)
      loss_0 = self.loss_function_0(label_0, output_0)
      loss_1 = self.loss_function_1(label_1, output_1)
      loss = loss_0 + loss_2

      self.metrics[0].update_state(label_0, output_0)
      self.metrics[1].update_state(loss)
      return {m.name : m.result() for m in self.metrics}

In [6]:
import datetime
def create_summary_writers(config_name):
    
    # Define where to save the logs
    # along with this, you may want to save a config file with the same name so you know what the hyperparameters were used
    # alternatively make a copy of the code that is used for later reference
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    train_log_path = f"logs/{config_name}/{current_time}/train"
    val_log_path = f"logs/{config_name}/{current_time}/val"

    # log writer for training metrics
    train_summary_writer = tf.summary.create_file_writer(train_log_path)
    # log writer for validation metrics
    val_summary_writer = tf.summary.create_file_writer(val_log_path)
    return train_summary_writer, val_summary_writer

train_summary_writer, val_summary_writer = create_summary_writers(config_name="RUN1")

In [11]:
import pprint
import tqdm

def training_loop(model, optimizer, n_epochs, train, test, train_summary_writer, val_summary_writer, save_path) :

  for e in range(n_epochs) :
    print(f"Epoch {e}:")

    for batch in tqdm.tqdm(train, position=0, leave=True):
      metrics = model.train_step(batch)
      
      # logging the validation metrics to the log file which is used by tensorboard
      with train_summary_writer.as_default():
        for metric in model.metrics:
          tf.summary.scalar(f"{metric.name}", metric.result(), step=e)

    print([f"{key}: {value.numpy()}" for (key, value) in metrics.items()])
    
    # 4. reset metric objects
    model.reset_metrics()

    # 5. evaluate on validation data
    for batch in test:
      metrics = model.test_step(batch)
      # 6. log validation metrics
      with val_summary_writer.as_default():
          # for scalar metrics:
          for metric in model.metrics:
                  tf.summary.scalar(f"{metric.name}", metric.result(), step=e)
        
    print([f"val_{key}: {value.numpy()}" for (key, value) in metrics.items()])
    # 7. reset metric objects
    model.reset_metrics()

  # 8. save model weights if save_path is given
  if save_path:
    model.save_weights(save_path)
  


  #####################################################################
  # if (task == "regression") :
  #   pass
  # elif (task == "classification") :
  #   pass

In [12]:
def training(optimizers, n_epochs, task = "calc") :
  calculator = MNISTCalc()
  empty_imgs = (np.zeros((1, 28**2)), np.zeros((1, 28**2)))
  calculator(empty_imgs)
  train_summary_writer, val_summary_writer = create_summary_writers(config_name="RUN1")
  save_path = "trained_model_RUN1"

  for optimizer in optimizers : 
    training_loop(calculator, optimizer, n_epochs, 
            train, val, 
            train_summary_writer, val_summary_writer, 
            save_path)
    
  if (task == "calc") : pass 
  elif (task == "comp") : pass

In [13]:
calculator = MNISTCalc()
list_optimizers = ['Adam'] #['Adam', 'SGD']

training(list_optimizers, 100, task="calc")

Epoch 0:


100%|██████████| 1875/1875 [00:35<00:00, 52.22it/s]


['binary_accuracy: 0.09003333002328873', 'loss: 15.382424354553223']


ValueError: ignored