In [1]:
import types
import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras import Model, layers, Sequential
from threading import Thread
import math

In [2]:
%load_ext tensorboard
import datetime
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = "logs/fit/" + current_time
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
# train_log_dir = 'logs/gradient_tape/' + current_time + '/train'
# test_log_dir = 'logs/gradient_tape/' + current_time + '/test'
# train_summary_writer = tf.summary.create_file_writer(train_log_dir)
# test_summary_writer = tf.summary.create_file_writer(test_log_dir)

In [3]:
# load images dataset
def load_dataset(name:str="mnist", size:int=None):
    if name == "mnist":
        (train_x, train_y), (test_x, test_y) = keras.datasets.mnist.load_data()
    elif name == "cifar10":
        (train_x, train_y), (test_x, test_y) = keras.datasets.cifar10.load_data()
    train_x, test_x = train_x/255.0, test_x/255.0

    if size:
        train_x = train_x[:size][..., tf.newaxis].astype("float32")
        test_x = test_x[:size][..., tf.newaxis].astype("float32")
        train_y, test_y = train_y[:size], test_y[:size]
    return (train_x, train_y), (test_x, test_y)

In [4]:
batch_size=32
epochs=30
(train_x, train_y), (test_x, test_y) = load_dataset("cifar10",size=5000)

In [5]:
class ResBlock(layers.Layer):

    def __init__(self, *args, **wargs):
        super().__init__(*args, **wargs)
        self.conv = layers.Conv2D(64, 3, padding="same", activation="relu")
        self.bn = layers.BatchNormalization()
        self.downconv = layers.Conv2D(64, 1, padding="same")
        self.downbn = layers.BatchNormalization()

    def build(self, input_shape):
        # resolve output shape in model summary
        input_layer = layers.Input(shape=input_shape[1:], batch_size=input_shape[0])
        self.call(input_layer)
        return super().build(input_shape)

    def call(self, inputs:np.ndarray):
        x:np.ndarray = inputs
        fx:np.ndarray = x
        try:
            fx = self.conv(fx)
        except Exception as e:
            tf.print(e)
            raise RuntimeError("conv error in ",self.name,x.shape, fx.shape, inputs.shape)
        fx = self.bn(fx)
        if fx.shape[-1] != x.shape[-1]:
            x = self.downconv(x)
            x = self.downbn(x)
        try:
            return fx + x
        except:
            raise RuntimeError(x.shape, fx.shape, inputs.shape)

In [14]:
class BothResNet(Model):
    def __init__(self, use_cache=True, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.__blocks_num = tf.Variable(1,dtype=tf.int8, trainable=False)
        self.__frozen_blocks_num = tf.Variable(0,dtype=tf.int8, trainable=False)
        self.use_cache = tf.constant(True) if use_cache else tf.constant(False)
        # An ordinary ResNet, but put blocks in a list. New blocks will be added into this list when training.
        # 常规的残差网络，但将残差块放在一个list中，训练时会将新块添加到这里
        self.blocks = [ResBlock(name="res_block0")]
        self.flatten = layers.Flatten()
        self.dense = layers.Dense(10)
    
    def build(self, input_shape):
        # resolve output shape in model summary
        input_layer = layers.Input(shape=input_shape[1:], batch_size=input_shape[0])
        self.call(input_layer)
        return super().build(input_shape)
    
    def __use_last(self, x:tf.Tensor, training=None) -> tf.Tensor:
        return self.blocks[-1](x, training=training)
    
    def __use_all(self, x:tf.Tensor, training=None) -> tf.Tensor:
        raise RuntimeError("run the wrong function")
        for block in self.blocks:
            x = block(x, training=training)
        return x

    def call(self, x=None, training=None, mask=None):
        # 只用最后一个残差块进行训练
        x = tf.cond(self.use_cache, lambda: self.__use_last(x, training), lambda: self.__use_all(x, training))
        x = self.flatten(x, training=training)
        x = self.dense(x, training=training)
        return x

    @tf.autograph.experimental.do_not_convert
    def getBlocksNum(self) -> tf.int8:
        return self.__blocks_num
    
    @tf.autograph.experimental.do_not_convert
    def getLastFrozenBlock(self) -> ResBlock:
        index = self.__frozen_blocks_num.numpy()-1
        block = self.blocks[index]
        return block
    
    @tf.autograph.experimental.do_not_convert
    def freezeBlock(self):
        if self.__frozen_blocks_num < self.__blocks_num:
            index = self.__frozen_blocks_num.numpy()
            block = self.blocks[index]
            block.trainable = False
            self.__frozen_blocks_num.assign_add(1)
            print("freeze block:", self.blocks[index].name, ", total frozen blocks:", index+1)
    
    @tf.autograph.experimental.do_not_convert
    def addNewBlock(self):
        print("----------")
        print("add new block")
        i = self.__blocks_num.numpy()
        newBlock = ResBlock(name="res_block"+str(i))
        newBlock(self.blocks[-1].output)
        self.blocks.append(newBlock)
        self.__blocks_num.assign_add(1)
        print(f"this is the {i+1} added blocks, block name: {self.blocks[i].name}")

    @tf.autograph.experimental.do_not_convert
    def copyLastBlock(self):
        print("----------")
        print("copy last block")
        newBlock = ResBlock(name="res_block"+str(self.__blocks_num.numpy()))
        last_block:ResBlock = self.blocks[-1]
        newBlock(last_block.output)
        if last_block.input_shape == last_block.output_shape:
            newBlock.set_weights(last_block.get_weights())
        else:
            print("copy failed: shape different with last block")
        self.blocks.append(newBlock)
        self.__blocks_num.assign_add(1)

In [15]:
class MyResNet(Model):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.__blocks_num = tf.Variable(1,dtype=tf.int8, trainable=False)
        self.__frozen_blocks_num = tf.Variable(0,dtype=tf.int8, trainable=False)
        # An ordinary ResNet, but put blocks in a list. New blocks will be added into this list when training.
        # 常规的残差网络，但将残差块放在一个list中，训练时会将新块添加到这里
        self.blocks = [ResBlock(name="res_block0")]
        self.flatten = layers.Flatten()
        self.dense = layers.Dense(10)
    
    def build(self, input_shape):
        # resolve output shape in model summary
        input_layer = layers.Input(shape=input_shape[1:], batch_size=input_shape[0])
        self.call(input_layer)
        return super().build(input_shape)
    
    def call(self, x=None, training=None, mask=None):
        for block in self.blocks:
            x = block(x, training=training)
        x = self.flatten(x, training=training)
        x = self.dense(x, training=training)
        return x

    @tf.autograph.experimental.do_not_convert
    def getBlocksNum(self) -> tf.int8:
        return self.__blocks_num
    
    @tf.autograph.experimental.do_not_convert
    def getLastFrozenBlock(self) -> ResBlock:
        index = self.__frozen_blocks_num.numpy()-1
        block = self.blocks[index]
        return block
    
    @tf.autograph.experimental.do_not_convert
    def freezeBlock(self):
        if self.__frozen_blocks_num < self.__blocks_num:
            index = self.__frozen_blocks_num.numpy()
            block = self.blocks[index]
            block.trainable = False
            self.__frozen_blocks_num.assign_add(1)
            print("freeze block:", self.blocks[index].name, ", total frozen blocks:", index+1)
    
    @tf.autograph.experimental.do_not_convert
    def addNewBlock(self):
        print("----------")
        print("add new block")
        i = self.__blocks_num.numpy()
        newBlock = ResBlock(name="res_block"+str(i))
        newBlock(self.blocks[-1].output)
        self.blocks.append(newBlock)
        self.__blocks_num.assign_add(1)
        print(f"this is the {i+1} added blocks, block name: {self.blocks[i].name}")

    @tf.autograph.experimental.do_not_convert
    def copyLastBlock(self):
        print("----------")
        print("copy last block")
        newBlock = ResBlock(name="res_block"+str(self.__blocks_num.numpy()))
        last_block:ResBlock = self.blocks[-1]
        newBlock(last_block.output)
        if last_block.input_shape == last_block.output_shape:
            newBlock.set_weights(last_block.get_weights())
        else:
            print("copy failed: shape different with last block")
        self.blocks.append(newBlock)
        self.__blocks_num.assign_add(1)

In [16]:
class CachedResNet(Model):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.__blocks_num = tf.Variable(1,dtype=tf.int8, trainable=False)
        self.__frozen_blocks_num = tf.Variable(0,dtype=tf.int8, trainable=False)
        # An ordinary ResNet, but put blocks in a list. New blocks will be added into this list when training.
        # 常规的残差网络，但将残差块放在一个list中，训练时会将新块添加到这里
        self.blocks = [ResBlock(name="res_block0")]
        self.flatten = layers.Flatten()
        self.dense = layers.Dense(10)
    
    def build(self, input_shape):
        # resolve output shape in model summary
        input_layer = layers.Input(shape=input_shape[1:], batch_size=input_shape[0])
        self.call(input_layer)
        return super().build(input_shape)

    def call(self, x=None, training=None, mask=None):
        # 只用最后一个残差块进行训练
        x = self.blocks[-1](x, training=training)
        x = self.flatten(x, training=training)
        x = self.dense(x, training=training)
        return x

    @tf.autograph.experimental.do_not_convert
    def getBlocksNum(self) -> tf.int8:
        return self.__blocks_num
    
    @tf.autograph.experimental.do_not_convert
    def getLastFrozenBlock(self) -> ResBlock:
        index = self.__frozen_blocks_num.numpy()-1
        block = self.blocks[index]
        return block
    
    @tf.autograph.experimental.do_not_convert
    def freezeBlock(self):
        if self.__frozen_blocks_num < self.__blocks_num:
            index = self.__frozen_blocks_num.numpy()
            block = self.blocks[index]
            block.trainable = False
            self.__frozen_blocks_num.assign_add(1)
            print("freeze block:", self.blocks[index].name, ", total frozen blocks:", index+1)
    
    @tf.autograph.experimental.do_not_convert
    def addNewBlock(self):
        print("----------")
        print("add new block")
        i = self.__blocks_num.numpy()
        newBlock = ResBlock(name="res_block"+str(i))
        newBlock(self.blocks[-1].output)
        self.blocks.append(newBlock)
        self.__blocks_num.assign_add(1)
        print(f"this is the {i+1} added blocks, block name: {self.blocks[i].name}")

    @tf.autograph.experimental.do_not_convert
    def copyLastBlock(self):
        print("----------")
        print("copy last block")
        newBlock = ResBlock(name="res_block"+str(self.__blocks_num.numpy()))
        last_block:ResBlock = self.blocks[-1]
        newBlock(last_block.output)
        if last_block.input_shape == last_block.output_shape:
            newBlock.set_weights(last_block.get_weights())
        else:
            print("copy failed: shape different with last block")
        self.blocks.append(newBlock)
        self.__blocks_num.assign_add(1)

In [22]:
class dynamicResNet:
    def __init__(self, condition: types.FunctionType = None, max_blocks_num:int = 2, cache:bool=True, copy_last_block:bool = False,*args, **wargs) -> None:
        """
            condition: A function, which will be called in every epoch and returns a boolean value representing whether to add a new block.
                        一个函数，每个epoch会被调用一次，返回值为布尔类型，代表是否添加新的块
        """
        super(dynamicResNet, self).__init__(*args, **wargs)
        if condition is None:
            self.add_condition = self.set_epochs
            self.add_condition()
        else:
            if callable(condition):
                self.add_condition = condition
            else:
                raise ValueError("'condition' must be a function")
        self.max_blocks_num = max_blocks_num
        self.use_cache = cache
        self.copy_last_block = copy_last_block
        # build model //创建模型
        self.model = CachedResNet() if cache else MyResNet()
        self.compiled = False

    def compile(self,
                optimizer="rmsprop",
                loss=None,
                metrics=None,
                loss_weights=None,
                weighted_metrics=None,
                run_eagerly=None,
                steps_per_execution=None,
                **kwargs
    ):
        self.complieArgs = [optimizer, loss, metrics, loss_weights, weighted_metrics, run_eagerly, steps_per_execution]
        self.complieKwargs = kwargs
        self.model.compile(*self.complieArgs, **kwargs)
        self.compiled = True

    def fit(self,
            x=None,
            y=None,
            batch_size=None,
            epochs=1,
            verbose="auto",
            callbacks=None,
            validation_split=0.0,
            validation_data=None,
            shuffle=True,
            class_weight=None,
            sample_weight=None,
            initial_epoch=0,
            steps_per_epoch=None,
            validation_steps=None,
            validation_batch_size=None,
            validation_freq=1,
            max_queue_size=10,
            workers=1,
            use_multiprocessing=False
    ):
        if not self.compiled:
            raise RuntimeError("model should be compiled before fit")
        self.epochs = epochs
        self.fitArgs = [x,y,batch_size,1,verbose,callbacks,validation_split,validation_data,shuffle,class_weight,sample_weight,initial_epoch,steps_per_epoch,validation_steps,validation_batch_size,validation_freq,max_queue_size,workers,use_multiprocessing]
        return self.call(training=True)
    
    def predict(self,
                x,
                batch_size=None,
                verbose="auto",
                steps=None,
                callbacks=None,
                max_queue_size=10,
                workers=1,
                use_multiprocessing=False
    ):
        if not self.compiled:
            raise RuntimeError("model should be compiled before predict")
        return self.model.predict( x,
                                    batch_size=batch_size,
                                    verbose=verbose,
                                    steps=steps,
                                    callbacks=callbacks,
                                    max_queue_size=max_queue_size,
                                    workers=workers,
                                    use_multiprocessing=use_multiprocessing
                                 )

    def call(self, x=None, training=False):
        if training:
            if x:
                raise ValueError("Please use 'fit' when training.")
            def fit_epoch():
                # 满足条件动态添加新残差块
                if self.model.getBlocksNum() < self.max_blocks_num and self.add_condition():
                    self.model.freezeBlock()
                    if self.copy_last_block:
                        self.model.copyLastBlock()
                    else:
                        self.model.addNewBlock()
                    if self.use_cache:
                        print("caching")
                        block = self.model.getLastFrozenBlock()
                        cache_model = keras.Model(block.input, block.output)
                        self.fitArgs[0] = cache_model.predict(self.fitArgs[0], batch_size=self.fitArgs[2])
                        print("cached")
                    print("compiling")
                    self.model.compile(*self.complieArgs, **self.complieKwargs)
                    print("compiled")
                tf.print(self.fitArgs[0].shape)
                self.model.fit(*self.fitArgs)
            for epoch in range(self.epochs):
                print(f"Epoch {epoch+1}/{self.epochs}")
                # 使用多线程的方式可以释放显存
                p = Thread(target=fit_epoch)
                p.start()
                p.join()
#                 fit_epoch()
        else:
            return self.model.predict(x)

    def set_epochs(self, interval_of_epochs:int = None) -> None:
        self.epoch = 0
        self.last_change_epoch = 1
        if interval_of_epochs is None:
            self.interval = 1
        else:
            self.interval = interval_of_epochs
        self.add_condition = self.__num_of_epochs

    def __num_of_epochs(self) -> bool:
        self.epoch += 1
        if self.epoch - self.last_change_epoch == self.interval:
            self.last_change_epoch = self.epoch
            return True
        return False

In [9]:
dynamic_model = dynamicResNet(max_blocks_num=5, cache=False, copy_last_block=False)
dynamic_model.set_epochs(5)
def fit_dinamic_model():
    dynamic_model.compile(optimizer="Adam", loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=["accuracy"])
    dynamic_model.fit(train_x, train_y, batch_size=batch_size, epochs=epochs, callbacks=[tensorboard_callback])
p = Thread(target=fit_dinamic_model)
p.start()
p.join()
dynamic_model.model.summary()

Epoch 1/30
Cause: mangled names are not yet supported
Cause: mangled names are not yet supported
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
freeze blocks: 1 , total frozen blocks: 1
----------
add new block
this is the 2 added blocks, block name: res_block1
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
freeze blocks: 1 , total frozen blocks: 2
----------
add new block
this is the 3 added blocks, block name: res_block2
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
freeze blocks: 1 , total frozen blocks: 3
----------
add new block
this is the 4 added blocks, block name: res_block3
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
freeze blocks: 1 , total frozen blocks: 4
----------
add new block
this is the 5 added blocks, block name: res_block4
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30
Model: "my_res_net"
_________________________________________________________________
Layer

In [23]:
dynamic_model_cache = dynamicResNet(max_blocks_num=5, cache=True, copy_last_block=False)
dynamic_model_cache.set_epochs(1)
def fit_dinamic_model_cache():
    dynamic_model_cache.compile(optimizer="Adam", loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=["accuracy"])
    dynamic_model_cache.fit(train_x, train_y, batch_size=batch_size, epochs=3, callbacks=[tensorboard_callback])
p = Thread(target=fit_dinamic_model_cache)
p.start()
p.join()
dynamic_model_cache.model.summary()

Epoch 1/3
(5000, 32, 32, 3, 1)
Epoch 2/3
freeze block: res_block0 , total frozen blocks: 1
----------
add new block
this is the 2 added blocks, block name: res_block1
caching
cached
compiling
compiled
(5000, 32, 32, 3, 64)
Epoch 3/3
freeze block: res_block1 , total frozen blocks: 2
----------
add new block
this is the 3 added blocks, block name: res_block2
caching
Note that input tensors are instantiated via `tensor = tf.keras.Input(shape)`.
The tensor that caused the issue was: res_block0/add:0
Model: "cached_res_net_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
res_block0 (ResBlock)        (None, 32, 32, 3, 64)     1280      
_________________________________________________________________
res_block1 (ResBlock)        (None, 32, 32, 3, 64)     37184     
_________________________________________________________________
res_block2 (ResBlock)        (None, 32, 32, 3, 64)     37184     
___________

Exception in thread Thread-19:
Traceback (most recent call last):
  File "C:\ProgramData\Anaconda3\lib\threading.py", line 916, in _bootstrap_inner
    self.run()
  File "C:\ProgramData\Anaconda3\lib\threading.py", line 864, in run
    self._target(*self._args, **self._kwargs)
  File "<ipython-input-22-bf8b7a2b7f0a>", line 102, in fit_epoch
    cache_model = keras.Model(block.input, block.output)
  File "C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\training\tracking\base.py", line 522, in _method_wrapper
    result = method(self, *args, **kwargs)
  File "C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\keras\engine\functional.py", line 115, in __init__
    self._init_graph_network(inputs, outputs)
  File "C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\training\tracking\base.py", line 522, in _method_wrapper
    result = method(self, *args, **kwargs)
  File "C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\keras\engine\functional.py"

In [None]:
dynamic_model_copy = dynamicResNet(max_blocks_num=5, copy_last_block=True)
dynamic_model_copy.set_epochs(5)

def fit_dinamic_model_copy():
    dynamic_model_copy.compile(optimizer="Adam", loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=["accuracy"])
    dynamic_model_copy.fit(train_x, train_y, batch_size=batch_size, epochs=epochs, callbacks=[tensorboard_callback])

p = Thread(target=fit_dinamic_model_copy)
p.start()
p.join()
dynamic_model_copy.model.summary()

In [None]:
dynamic_model_cache_copy = dynamicResNet(max_blocks_num=5, copy_last_block=False)
dynamic_model_cache_copy.set_epochs(5)
def fit_dynamic_model_cache_copy():
    dynamic_model_cache_copy.compile(optimizer="Adam", loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=["accuracy"])
    dynamic_model_cache_copy.fit(train_x, train_y, batch_size=batch_size, epochs=epochs, callbacks=[tensorboard_callback])
p = Thread(target=fit_dynamic_model_cache_copy)
p.start()
p.join()
dynamic_model_cache.model.summary()

In [None]:
static_model = Sequential([ResBlock(), ResBlock(), ResBlock(), ResBlock(), ResBlock(), layers.Flatten(), layers.Dense(10)])
def fit_static_model():
    static_model.compile(optimizer="Adam", loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=["accuracy"])
    static_model.fit(train_x, train_y, batch_size=batch_size, epochs=epochs, callbacks=[tensorboard_callback])
p = Thread(target=fit_static_model)
p.start()
p.join()
static_model.summary()

In [None]:
%tensorboard --logdir logs/fit

In [15]:
test_model = MyResNet(use_cache=False)
test_model.compile(optimizer="Adam", loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=["accuracy"])
test_model.fit(train_x, train_y, batch_size=batch_size, epochs=2, callbacks=[tensorboard_callback])

Epoch 1/2
Epoch 2/2


<keras.callbacks.History at 0x23517da3e48>

In [16]:
test2 = Model(test_model.layers[0].input, test_model.layers[0].output)

In [17]:
ty=test2.predict(test_x,batch_size=batch_size)

In [11]:
str(dynamic_model_cache.model.getBlocksNum().numpy())

'1'