This notebook introduces how to complete a training process from scratch via `tf.keras` in `Tensorflow 2.x`. Keras in Tensorflow, `tf.keras`, is different from the native Keras (https://github.com/keras-team/keras/). The latter one is built on the top of Tensorflow 1.x and not supporting Tensorflow 2.x. And above all, `tf.keras` APIs are optimized for using Tensorflow as the backend, including parallel training tasks, etc.

This notebook can be a template helping you run the whole process training a model from scratch.

In [0]:
!pip install --force tensorflow-gpu==2.0.0

In [0]:
import tensorflow as tf
import tensorflow.keras as keras
import os
import matplotlib.pyplot as plt
import cv2
import numpy as np
from tensorflow.python.client import device_lib

print("Tensorflow Version: {}".format(tf.__version__))
print("Tensorflow.Keras Version: {}".format(keras.__version__))
print("Devices: {}".format(device_lib.list_local_devices()))

# Using TF2.Keras

## Data Processing

In this tutorial, we are going to use a flower dataset downloaded from the Google storage (https://github.com/googlecodelabs/tensorflow-for-poets-2). The dataset is composed of 5 types of flowers. Each of them contains around 600 images. 

### Downloading an Image Dataset

In [0]:
!curl http://download.tensorflow.org/example_images/flower_photos.tgz -o flower_photos.tgz

In [0]:
!tar zxf flower_photos.tgz

### Spliting Datasets

First of all, the flower dataset is split into three parts, training, validation and test sub-datasets in the ratio of 80:10:10.

A file (`./flower_photos/file_list.csv`) is used to save the information of each image in the following formats.
```text
image_path image_category subdataset_type
```

Another file (`./flower_photos/label_list.txt`) is used to record the categories.

In [0]:
def split_datasets(path, 
                   output_file_path, 
                   output_label_path,
                   ratio={'train': 0.8, 'val': 0.1, 'test': 0.1},):
  assert os.path.exists(path) and output_file_path != None, "Lacked necessary information."
  folders = next(os.walk(path))[1]

  for folder in folders:
    if folder[0] == ".": continue

    folder_path = os.path.join(path, folder)
    filenames = next(os.walk(folder_path))[2]
    np.random.shuffle(filenames)
    total_file_nums = len(filenames)
    train_end = int(total_file_nums * ratio["train"])
    val_end = int(total_file_nums * ratio['val'])

    with open(output_file_path, "a") as fout:
      for name in filenames[:train_end]:
        file_path = os.path.join(folder_path, name)
        fout.write("{} {} {}\n".format(file_path, folder, "train"))
      for name in filenames[train_end:(train_end + val_end)]:
        file_path = os.path.join(folder_path, name)
        fout.write("{} {} {}\n".format(file_path, folder, "val"))
      for name in filenames[(train_end + val_end):]:
        file_path = os.path.join(folder_path, name)
        fout.write("{} {} {}\n".format(file_path, folder, "test"))
      
  with open(output_label_path, "w") as fout:
    for folder in folders:
      if folder[0] == ".": continue
      fout.write("{}\n".format(folder))


file_list = "./flower_photos/file_list.csv"
label_list = "./flower_photos/label_list.txt"
if not os.path.exists(file_list):
  filenames = split_datasets("./flower_photos", file_list, label_list)
  print("Generate both file and label lists.")
else:
  print("Used a previous generated file.")

In [0]:
!head -n 5 ./flower_photos/file_list.csv
!tail -n 800 ./flower_photos/file_list.csv | head -n 5
!wc -l ./flower_photos/file_list.csv
!head -n 10 ./flower_photos/label_list.txt
!wc -l ./flower_photos/label_list.txt

### Data Generator

After a file list for saving image information was generated, a data generator using the file list was necessary to be instantiated.

Because there are three types of sub-datasets, three data generators were instantiated to each of them. 

Second, in the data generator, several image augment methods are also implemented, including flipping, cropping, etc.

In short, you are going to instantiate three data generators to different sub-datasets. These data generators would return a batch of images with the operation of augments and their correct labels during a training or test process.

In [0]:
class Image_DataGenerator(keras.utils.Sequence):

  __indexes = 0
  __file_nums = 0
  __list_objs = []  # file_path, category, type
  __label_list = []

  def __init__(self, file_path=None, label_path=None, datatype='train',
               batch_size=32, output_shape=(224, 224),
               shuffle=True, aug=True):
    """(Necessary) Constructor of this generator."""
    self.__indexes = 0
    self.__file_nums = 0
    self.__list_objs = []
    self.__label_list = []

    self.batch_size = batch_size
    self.output_shape = output_shape
    self.shuffle = shuffle
    self.aug = aug
    self.file_path = file_path
    self.label_path = label_path
    self.datatype = datatype

    self.__load_file_list()
    self.on_epoch_end()  # for shuffling

  def __load_file_list(self):
    """(Optional) Load the sub-dataset file."""
    with open(self.file_path, "r") as fin:
      for line in fin:
        tmp = line.strip().split(" ")
        if tmp[2] == self.datatype:
          self.__list_objs.append([tmp[0], tmp[1]])
          self.__file_nums += 1

    with open(self.label_path, "r") as fin:
      for line in fin:
        self.__label_list.append(line.strip())
      self.__label_list = np.array(self.__label_list)
    
  def __len__(self):
    """(Necessary) Returns how many batches to the dataset."""
    return int(np.floor(self.__file_nums / self.batch_size))
  
  def __getitem__(self, index):
    """(Necessary) Generate one batch of data"""
    indexes = self.__indexes[index*self.batch_size : (index+1)*self.batch_size]
    file_obj_list = [self.__list_objs[k] for k in indexes]
    X, y = self.__data_generation(file_obj_list)
    return X, y
  
  def on_epoch_end(self):
    """(Necessary) Updates indexes after each epoch"""
    self.__indexes = np.arange(self.__file_nums)
    if self.shuffle == True:
      np.random.shuffle(self.__indexes)

  def example_take(self, count):
    """(Optional) Fetch an example dataset."""
    return self.__getitem__(count)
      
  def __data_generation(self, file_obj_list):
    """(Necessary) Generates data containing batch_size samples
    # X : (n_samples, *dim, n_channels)
    """
    # Initialization
    X = np.empty((self.batch_size, *self.output_shape, 3))
    y = np.empty((self.batch_size, len(self.__label_list)), dtype=float)

    # Generate data
    for i, obj in enumerate(file_obj_list):
      try:
        img_path, img_label = obj

        # process the image
        X[i,] = self.__preprocess_image_data(img_path)

        # preprocess the image info
        y[i] = self.__process_data_label(img_label)
      except Exception as e:
        tf.print("\nLoad image {} went error. {}\n".format(ID, e))
        continue
      
    return X, y

  def __brightness(self, image, bright=0):
    # RGB -> HSV (Hue, Saturation, Value)
    hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
    hsv = hsv.astype(np.uint16)
    hsv[:,:,2] += bright
    hsv[:,:,2] = np.minimum(hsv[:,:,2], 255)
    hsv[:,:,2] = np.maximum(hsv[:,:,2], 0)
    hsv = hsv.astype(np.uint8)
    return cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
  
  def __contrast(self, image, value=3.0):
    # LAB channel 
    lab= cv2.cvtColor(image, cv2.COLOR_BGR2LAB)

    # Splitting the LAB image to different channels
    l, a, b = cv2.split(lab)

    # Applying CLAHE to L-channel
    clahe = cv2.createCLAHE(clipLimit=value, tileGridSize=(8,8))
    cl = clahe.apply(l)

    # Merge the CLAHE enhanced L-channel with the a and b channel
    clab = cv2.merge((cl,a,b))

    return cv2.cvtColor(clab, cv2.COLOR_LAB2BGR)
  
  def __tone(self, image, color_value=0):
    # RGB -> HSV (Hue, Saturation, Value)
    hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
    hsv = hsv.astype(np.uint16)
    hsv[:,:,0] += color_value
    hsv[:,:,0] = np.minimum(hsv[:,:,0], 255)
    hsv[:,:,0] = np.maximum(hsv[:,:,0], 0)
    hsv = hsv.astype(np.uint8)
    return cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
  
  def __flip(self, img, opt="hor"):
    if opt == "hor":
      return cv2.flip(img, 1)
    elif opt == "ver":
      return cv2.flip(img, 0)
    elif opt == "both":
      return cv2.flip(img, -1)
    else:
      return img

  def __crop(self, image, scale=0.5):
    """
    description: random crop the image into small sections
    """
    (h, w) = image.shape[:2]
    new_h = int(h * 0.5)
    new_w = int(w * 0.5)
    rand_x = np.random.randint(w * (1-scale))
    rand_y = np.random.randint(h * (1-scale))
    #print(h, w, new_h, new_w, rand_x, rand_y)
    return image[rand_y : (rand_y + new_h), rand_x : (rand_x + new_w)]    

  def __shift(self, image, x, y):
    m = np.float32([[1,0,x],[0,1,y]])
    return cv2.warpAffine(image, m, (image.shape[1], image.shape[0]))

  def __preprocess_image_data(self, image_path):
    """(Optional) Preprocess the image."""
    try:
      img = cv2.imread(image_path)
      img = img[:,:,::-1]

      if self.aug:
        #brightness = np.random.randint(0,60)
        #img = self.__brightness(img, bright=brightness)

        #contrast = np.random.rand()
        #img = self.__contrast(img, value=contrast)

        #tone = np.random.randint(0,30)
        #img = self.__tone(img, tone)  

        # random crop
        img = self.__crop(img, 0.95)

        # random shift
        ver, hor = np.random.randint(-8,8), np.random.randint(-8,8)
        img = self.__shift(img, ver, hor)

        choice = np.random.choice(["hor", "ver", "both"], size=1)
        img = self.__flip(img, choice)

      img = cv2.resize(img, self.output_shape)

      img = (img - 0.0) / 255.0
      return img
    except Exception as e:
      print("Failed in processing image {}. ({})".format(image_path, e))
      return np.zeros((*self.output_shape, 3), dtype=float)

  def __process_data_label(self, img_label):
    """(Optional) Process the label in sparse into an one-hot format."""
    return (self.__label_list == img_label).astype(float)


### An example fetching datasets

The following is an example to use the data generator.

In [0]:
imggen = Image_DataGenerator(file_path=file_list, label_path=label_list, datatype='test', aug=True)

In [0]:
imgs, labels = imggen.example_take(1)

idx = 0

plt.imshow((imgs[idx] * 255.0).astype(int))
plt.show()

print("label: {}".format(labels[idx]))

### Data Generator Instances

In [0]:
train_params = {"file_path": file_list, "label_path": label_list, "datatype": 'train', "aug": True}
val_params = {"file_path": file_list, "label_path": label_list, "datatype": 'val', "aug": False}
test_params = {"file_path": file_list, "label_path": label_list, "datatype": 'test', "aug": False}

train_generator = Image_DataGenerator(**train_params)
val_generator = Image_DataGenerator(**val_params)
test_generator = Image_DataGenerator(**test_params)

## Building a Model

After preprocessing the datasets, it is time for you to build a neural network. Here, you are going to build a CNN model in `tf.keras` with a custom layer, implementing a skip connection idea.

In [0]:
class Skipconnection(keras.layers.Layer):
  """Skipconnection implements a connection across different layers.
  
  The **`Custom Layer` is not recommended for creating complex or sequential layers.**
  """

  target_shape = None

  def __init__(self, target_shape, **kwargs):
    """Necessary"""
    super(Skipconnection, self).__init__(**kwargs)
    self.target_shape = target_shape
    self.conv1 = keras.layers.Conv2D(filters=30,
                                     kernel_size=(4,4), 
                                     strides=(4,4), 
                                     padding='same')
    self.conv2 = keras.layers.Conv2D(filters=60, 
                                     kernel_size=(4,4), 
                                     strides=(4,4), 
                                     padding='same')   
    self.bn1 = keras.layers.BatchNormalization()
    self.bn2 = keras.layers.BatchNormalization()

  def get_config(self):
    """Necessary in exporting models.
    
    Must be available in Tensorflow 2.0.0.
    """
    config = {"target_shape": self.target_shape}
    base_config = super(Skipconnection, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

  def call(self, upstream, downstream, **kwargs):
    """Necessary"""

    x = self.conv1(upstream)
    x = self.bn1(x)
    x = self.conv2(x)
    x = self.bn2(x)
    sc = keras.layers.Add()([x, downstream])
    return sc

In [0]:
def build_model(inputs):
  """The function is built for image classification.
  inputs: [None, 224, 224, 3]
  """ 

  # [None, 224, 224, 16]
  x = keras.layers.Conv2D(filters=16, kernel_size=(3,3), strides=(1,1), padding='same')(inputs)
  x = keras.layers.BatchNormalization()(x)
  x = keras.layers.Activation('relu')(x)

  # [None, 224, 224, 24]
  x = keras.layers.Conv2D(filters=24, kernel_size=(5,5), strides=(1,1), padding='same')(x)
  x = keras.layers.BatchNormalization()(x)
  # for skip connection
  sc_upstream = keras.layers.Activation('relu')(x)

  # [none, 112, 112, 30]
  x = keras.layers.Conv2D(filters=30, kernel_size=(3,3), strides=(2,2), padding='same')(sc_upstream)
  x = keras.layers.BatchNormalization()(x)
  x = keras.layers.Activation('relu')(x)

  # [none, 112, 112, 36]
  x = keras.layers.Conv2D(filters=36, kernel_size=(5,5), strides=(1,1), padding='same')(x)
  x = keras.layers.BatchNormalization()(x)
  x = keras.layers.Activation('relu')(x)    
  
  # [none, 112, 112, 42]
  x = keras.layers.Conv2D(filters=42, kernel_size=(3,3), strides=(1,1), padding='same')(x)
  x = keras.layers.BatchNormalization()(x)
  x = keras.layers.Activation('relu')(x)   

  # [none, 56, 56, 42]
  x = keras.layers.Conv2D(filters=42, kernel_size=(3,3), strides=(2,2), padding='same')(x)
  x = keras.layers.BatchNormalization()(x)
  x = keras.layers.Activation('relu')(x)      

  # [none, 56, 56, 48]
  x = keras.layers.Conv2D(filters=48, kernel_size=(3,3), strides=(1,1), padding='same')(x)
  x = keras.layers.BatchNormalization()(x)
  x = keras.layers.Activation('relu')(x)     

  # [none, 56, 56, 48]
  x = keras.layers.Conv2D(filters=48, kernel_size=(5,5), strides=(1,1), padding='same')(x)
  x = keras.layers.BatchNormalization()(x)
  x = keras.layers.Activation('relu')(x)     

  # [none, 56, 56, 48]
  x = keras.layers.Conv2D(filters=48, kernel_size=(3,3), strides=(1,1), padding='same')(x)
  x = keras.layers.BatchNormalization()(x)
  x = keras.layers.Activation('relu')(x)    

  # [none, 56, 56, 48]
  x = keras.layers.Conv2D(filters=48, kernel_size=(5,5), strides=(1,1), padding='same')(x)
  x = keras.layers.BatchNormalization()(x)
  x = keras.layers.Activation('relu')(x)     

  # [none, 28, 28, 48]
  x = keras.layers.Conv2D(filters=48, kernel_size=(3,3), strides=(2,2), padding='same')(x)
  x = keras.layers.BatchNormalization()(x)
  x = keras.layers.Activation('relu')(x)      

  # [none, 28, 28, 48]
  x = keras.layers.Conv2D(filters=48, kernel_size=(5,5), strides=(1,1), padding='same')(x)
  x = keras.layers.BatchNormalization()(x)
  x = keras.layers.Activation('relu')(x)        

  # [none, 28, 28, 60]
  x = keras.layers.Conv2D(filters=60, kernel_size=(3,3), strides=(1,1), padding='same')(x)
  x = keras.layers.BatchNormalization()(x)
  x = keras.layers.Activation('relu')(x)   

  # [none, 28, 28, 60]
  x = keras.layers.Conv2D(filters=60, kernel_size=(5,5), strides=(1,1), padding='same')(x)
  x = keras.layers.BatchNormalization()(x)
  x = keras.layers.Activation('relu')(x)   

  # [none, 14, 14, 60]
  x = keras.layers.Conv2D(filters=60, kernel_size=(3,3), strides=(2,2), padding='same')(x)
  x = keras.layers.BatchNormalization()(x)
  x = keras.layers.Activation('relu')(x)   

  # [none, 14, 14, 60]
  x = keras.layers.Conv2D(filters=60, kernel_size=(5,5), strides=(1,1), padding='same')(x)
  x = keras.layers.BatchNormalization()(x)
  x = keras.layers.Activation('relu')(x)   

  # implements a skip connection
  # sc = keras.layers.Conv2D(filters=30, kernel_size=(4,4), strides=(4,4), padding='same')(sc_upstream)
  # sc = keras.layers.BatchNormalization()(sc)
  # sc = keras.layers.Conv2D(filters=60, kernel_size=(4,4), strides=(4,4), padding='same')(sc) 
  # sc = keras.layers.BatchNormalization()(sc)
  # sc = keras.layers.Add()([sc, x])

  # custom layer
  sc = Skipconnection([7,7,60])(sc_upstream, x)

  # [none, 7, 7, 60]
  x = keras.layers.Conv2D(filters=60, kernel_size=(3,3), strides=(2,2), padding='same')(sc)
  x = keras.layers.BatchNormalization()(x)
  x = keras.layers.Activation('relu')(x) 

  # [none, 7, 7, 60]
  x = keras.layers.Conv2D(filters=60, kernel_size=(5,5), strides=(1,1), padding='same')(x)
  x = keras.layers.BatchNormalization()(x)
  x = keras.layers.Activation('relu')(x)   

  # [None, 7, 7, 72]
  x = keras.layers.Conv2D(filters=72, kernel_size=(1,1), strides=(1,1), padding='same')(x)
  x = tf.keras.layers.BatchNormalization()(x)
  x = keras.layers.Activation('relu')(x) 

  # one-dimensional feature
  x = tf.keras.layers.Flatten()(x)
  x = tf.keras.layers.Dense(4096)(x)
  x = tf.keras.layers.BatchNormalization()(x)
  x = tf.keras.layers.Dense(2048)(x)
  feature_extraction = tf.keras.layers.BatchNormalization()(x)

  outputs = tf.keras.layers.Dense(5)(feature_extraction)
  prob = tf.nn.softmax(outputs)  

  return prob

After you build a neural network architecture, you need to assign both input and output specifications and then wrap them as a model.

In [0]:
inputs = tf.keras.Input(shape=(224, 224, 3))
outputs = build_model(inputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
model.summary() 

### A Model built with TF.Hub

In Tensorflow 2, you can easily build a model on the top of the pre-trained model downloaded from TFHub. It is easy to add a final layer, mostly a dense layer representing the number of categories. You can also choose to train the downloaded model or not via the parameter `trainable=[False|True]`.


In [0]:
used_hub = False

In [0]:
if used_hub:

  !pip install tensorflow_hub

  import tensorflow_hub as hub

  model = tf.keras.Sequential([
      hub.KerasLayer("https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4", 
                    output_shape=[1280], trainable=False, dtype=float), 
      tf.keras.layers.Dense(5, activation="softmax")
  ])
  model.build([None, 224, 224, 3])

### Let's try the inference.

We can do an inference, forward propagation, to the model we just wrapped as a test.

In [0]:
val_data = val_generator.example_take(1)
val_data[0].shape, val_data[1].shape

In [0]:
_untrain_result = model.predict(val_data)
_untrain_result.shape

### Loss function

The model training requires an objectiveness or loss function to do the backpropagation. In Tensorflow 2, it is easier to define the loss function with a decorator `@tf.function`. Now you can easily pass Python float objects to the function to get the loss. (convenient than Tensorflow 1)

In [0]:
@tf.function
def category_loss(y_true, y_pred):
  return tf.reduce_mean(tf.losses.categorical_crossentropy(y_true, y_pred))

In [0]:
category_loss(val_data[1], _untrain_result)

### Metrics Functrion

You need to define metrics to evaluate the training performance. Here accuracy is used as the metric.

In [0]:
@tf.function
def category_accuracy(y_true, y_pred):
  true_cls = tf.argmax(y_true, axis=1)
  pred_cls = tf.argmax(y_pred, axis=1)
  correct = tf.reduce_sum(tf.cast(tf.equal(true_cls, pred_cls), float))
  total = tf.cast(tf.shape(y_true)[0], float)
  return correct / total

In [0]:
category_accuracy(val_data[1], _untrain_result)

### Training Parameters

In [0]:
learning_rate = 1e-3
epochs = 101

In the long-time training process, the learning rate is better decreasing slightly for discovering the weights. Here a decay function for the learning rate was defined.

In [0]:
def step_decay_schedule(initial_lr=1e-3, decay_factor=0.75, step_size=15):
    def schedule(epoch):
        return initial_lr * (decay_factor ** np.floor(epoch/step_size))
    return tf.keras.callbacks.LearningRateScheduler(schedule)

lr_sched = step_decay_schedule(initial_lr=learning_rate, decay_factor=0.75, step_size=5)

The early stopping function helps developers to stop training earlier to avoid overfitting.

In [0]:
early_stopping = tf.keras.callbacks.EarlyStopping(patience=epochs // 3.0)

In [0]:
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

You can also use Tensorboard to monitor the training process. In Colab, you can load the magic function to view Tensorboard directly inside the Colab environment.

In [0]:
tfb_visual = keras.callbacks.TensorBoard(
    './log', histogram_freq=0, write_graph=True, write_grads=False, write_images=False)

In [0]:
!rm -rf ./log

In [0]:
# magic function in colab
%reload_ext tensorboard
%tensorboard --logdir "./log"

After data preprocessing, model building, and defining loss function and metrics functions, you now can wrap them and compile the model. The following two scripts are the same effects but one is using custom objects for demonstration.

In [0]:
# predefined loss function and metrics function
#model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])

# an advanced operation with custom objects
model.compile(optimizer=optimizer, loss=category_loss, metrics=[category_accuracy])

The method `fit_generator` or `fit` is to begin the training process.

In [0]:
model.fit_generator(train_generator, epochs=epochs, verbose=1, 
                    validation_data=val_generator, 
                    callbacks=[tfb_visual, lr_sched, early_stopping])

You can simply use the `evaluation()` method of the model to evaluate the model on the test dataset.

In [0]:
model.evaluate(test_generator)

## Saving and Loading a Model

The backend of Keras is Tensorflow, multiple ways to export or load the model are available on both sides. Here you would understand several ways to manipulate models in Keras and how to interact with the Tensorflow environment. 

**[Notice] In Tensorflow 2.0 released version, a function named `get_config()` must be defined for exporting or importing the model.**



### Whole Model

[**Keras to Keras**] The first way to export and load models in `tf.Keras` is inherited from the native Keras. You can export and import the model generated as a `.h5` format. Such format is mainly used in `tf.Keras` or native Keras, one way among Keras runtime environment.

In [0]:
latest_model_path = "./whole_model.h5"

In [0]:
tf.keras.models.save_model(model, latest_model_path)  

In [0]:
model_load_1 = tf.keras.models.load_model(
    latest_model_path,
    custom_objects={'Skipconnection': Skipconnection,
                    'category_loss': category_loss, 
                    'category_accuracy': category_accuracy})  

In [0]:
imgs, labels = test_generator.example_take(1)
labels[:1], model_load_1.predict(imgs[:1])

### Partial Save

#### Architecture-only saving

[**Keras to Keras**] The second way to save models is to separate the model architecture and the model weight from a whole model file.

In addition, you can save the model architecture in different two formats, one is `dict` object and the other is `json` format.

* dict-based saving

In [0]:
config = model.get_config()

In [0]:
dict_model = keras.Model.from_config(
    config, 
    custom_objects={'Skipconnection': Skipconnection,
                    'category_loss': category_loss, 
                    'category_accuracy': category_accuracy})

* json-based saving

In [0]:
config_json = model.to_json()

In [0]:
json_model = keras.models.model_from_json(
    config_json, 
    custom_objects={'Skipconnection': Skipconnection,
                    'category_loss': category_loss, 
                    'category_accuracy': category_accuracy})

In [0]:
json_model.predict(imgs[:1])

#### Weights-only saving

In model weights, you can also choose to save weights in Keras format `.h5` or in Tensorflow format (as the SavedModel format).

In [0]:
weights = model.get_weights()

You can combine two parts, the model configure (get_config()/from_config()) and the model status (get_weights()/set_weights()), to recreate the model.
However, you have to call compile() again before using the model for training.

In [0]:
config = model.get_config()
weights = model.get_weights()

model_load_2 = keras.Model.from_config(
    config, 
    custom_objects={'Skipconnection': Skipconnection,
                    'category_loss': category_loss, 
                    'category_accuracy': category_accuracy})
model_load_2.set_weights(weights)

In [0]:
model_load_2.predict(imgs[:1])

#### Both architecture and weights

After you save both parts of the model architecture and weights, you can load both of them and recover the model back.

* Save to the disk (local files).

In [0]:
json_config = model.to_json()
model_arch_path = os.path.join(".", "arch.json")
with open(model_arch_path, "w") as fout:
  fout.write(json_config)

# save in h5 format
model_weights_path = os.path.join(".", "weights.h5")
model.save_weights(model_weights_path)

# save in SavedModel model
model_weights_path_tf = os.path.join(".", "model", "weights_tf")
model.save_weights(model_weights_path_tf, save_format="tf")

* Load both architecture and weights from the disk.



In [0]:
json_config_local = ""
with open(model_arch_path, "r") as fin:
  json_config_local = fin.read()
  
# load a h5 format model  
model_load_3 = keras.models.model_from_json(
    json_config_local, 
    custom_objects={'Skipconnection': Skipconnection,
                    'category_loss': category_loss, 
                    'category_accuracy': category_accuracy})
model_load_3.load_weights(model_weights_path)

# load a tf format model
model_load_4 = keras.models.model_from_json(
    json_config_local, 
    custom_objects={'Skipconnection': Skipconnection,
                    'category_loss': category_loss, 
                    'category_accuracy': category_accuracy})
model_load_4.load_weights(model_weights_path_tf)

In [0]:
imgs, labels = test_generator.example_take(1)
model_load_3.predict(imgs[:1])

In [0]:
model_load_4.predict(imgs[:1])

### SavedModel

[Keras to Tensorflow] The third way is slightly different from the previous two. You can export the model trained in Keras runtime to the Tensorflow. For now, only one type of saving models in Tensorflow was available for Keras, which is `SavedModel` format.

In [0]:
SavedModel_path = os.path.join(".", "saved_models")

In [0]:
keras.experimental.export_saved_model(
    model, 
    SavedModel_path, 
    custom_objects={'Skipconnection': Skipconnection,
                    'category_loss': category_loss, 
                    'category_accuracy': category_accuracy})

In [0]:
saved_model = keras.experimental.load_from_saved_model(
    SavedModel_path, 
    custom_objects={'Skipconnection': Skipconnection,
                    'category_loss': category_loss, 
                    'category_accuracy': category_accuracy})

In [0]:
imgs, labels = test_generator.example_take(1)
saved_model.predict(imgs[:1])

In [0]:
!tar -cf savedmodels.tar saved_models