In [1]:
import types
import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras import Model, layers
from threading import Thread

In [2]:
%load_ext tensorboard
import datetime
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = "logs/fit/" + current_time
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
# train_log_dir = 'logs/gradient_tape/' + current_time + '/train'
# test_log_dir = 'logs/gradient_tape/' + current_time + '/test'
# train_summary_writer = tf.summary.create_file_writer(train_log_dir)
# test_summary_writer = tf.summary.create_file_writer(test_log_dir)

In [3]:
gpus = tf.config.list_physical_devices("GPU")
if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True)

In [4]:
# load images dataset
def load_dataset(name:str="mnist", size:int=-1):
    if name == "mnist":
        (train_x, train_y), (test_x, test_y) = keras.datasets.mnist.load_data()
    elif name == "cifar10":
        (train_x, train_y), (test_x, test_y) = keras.datasets.cifar10.load_data()
    train_x, test_x = train_x/255.0, test_x/255.0

    train_x = train_x[:size][..., tf.newaxis].astype("float32")
    test_x = test_x[:size][..., tf.newaxis].astype("float32")
    train_y, test_y = train_y[:size], test_y[:size]
    return (train_x, train_y), (test_x, test_y)

In [5]:
batch_size=32
epochs=30
(train_x, train_y), (test_x, test_y) = load_dataset("cifar10", size=1000)

In [11]:
class ResBlock(layers.Layer):

    def __init__(self, *args, **wargs):
        super().__init__(*args, **wargs)
        self.conv = layers.Conv2D(64, 3, padding="same", activation="relu")
        self.bn = layers.BatchNormalization()
        self.downconv = layers.Conv2D(64, 1, padding="same")
        self.downbn = layers.BatchNormalization()

    def build(self, input_shape):
        # resolve output shape in model summary
        input_layer = layers.Input(shape=input_shape[1:], batch_size=input_shape[0])
        self.call(input_layer)
        return super().build(input_shape)

    def call(self, inputs:np.ndarray):
        x:np.ndarray = inputs
        fx:np.ndarray = x
        fx = self.conv(fx)
        fx = self.bn(fx)
        if fx.shape[-1] != x.shape[-1]:
            x = self.downconv(x)
            x = self.downbn(x)
        try:
            # print(self.name, x.shape, fx.shape, inputs.shape)
            return fx + x
        except:
            raise RuntimeError(x.shape, fx.shape, inputs.shape)

    # def get_weights(self):
    #     return [self.conv.get_weights(), self.bn.get_weights()]

    # def set_weights(self, weights:list):
    #     self.conv.set_weights(weights[0])
    #     self.bn.set_weights(weights[1])
    #     return super().set_weights(weights)

In [7]:
class MyResNet(Model):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.__blocks_num = 1
        self.__frozen_blocks_num = 0
        # An ordinary ResNet, but put blocks in a list. New blocks will be added into this list when training.
        # 常规的残差网络，但将残差块放在一个list中，训练时会将新块添加到这里
        self.blocks = [ResBlock(name="res_block0")]
        self.flatten = layers.Flatten()
        self.dense = layers.Dense(10)
    
    def build(self, input_shape):
        # resolve output shape in model summary
        input_layer = layers.Input(shape=input_shape[1:], batch_size=input_shape[0])
        self.call(input_layer)
        return super().build(input_shape)

    def call(self, x=None, training=None, mask=None):
        for i in range(self.__blocks_num):
            x = self.blocks[i](x, training=training)
        x = self.flatten(x, training=training)
        x = self.dense(x, training=training)
        return x

    def getBlocksNum(self):
        return self.__blocks_num
    
    def freezeBlocks(self, num=1):
        for i in range(self.__frozen_blocks_num, min(self.__frozen_blocks_num+num, self.__blocks_num)):
            self.blocks[i].trainable = False
        self.__frozen_blocks_num = min(self.__frozen_blocks_num+num, self.__blocks_num)
        print("freeze blocks:", num, ", total frozen blocks:", self.__frozen_blocks_num)
    
    def addNewBlock(self):
        print("----------")
        print("add new block")
        # self.freezeBlocks(1)
        newBlock = ResBlock(name="res_block"+str(self.__blocks_num))
        newBlock(self.blocks[-1].output)
        self.blocks.append(newBlock)
        self.__blocks_num += 1
    
    def copyLastBlock(self):
        print("----------")
        print("copy last block")
        # self.freezeBlocks(1)
        newBlock = ResBlock(name="res_block"+str(self.__blocks_num))
        last_block:ResBlock = self.blocks[-1]
        newBlock(last_block.output)
        if last_block.input_shape == last_block.output_shape:
            newBlock.set_weights(last_block.get_weights())
        else:
            print("copy failed: shape different with last block")
        self.blocks.append(newBlock)
        self.__blocks_num += 1

In [12]:
class dynamicResNet:
    def __init__(self, condition: types.FunctionType = None, max_blocks_num:int = 2, copy_last_block:bool = False,*args, **wargs) -> None:
        """
            condition: A function, which will be called in every epoch and returns a boolean value representing whether to add a new block.
                        一个函数，每个epoch会被调用一次，返回值为布尔类型，代表是否添加新的块
        """
        super(dynamicResNet, self).__init__(*args, **wargs)
        if condition is None:
            self.add_condition = self.set_epochs
            self.add_condition()
        else:
            if callable(condition):
                self.add_condition = condition
            else:
                raise ValueError("'condition' must be a function")
        self.max_blocks_num = max_blocks_num
        self.copy_last_block = copy_last_block
        # build model //创建模型
        self.model = MyResNet()
        self.compiled = False

    def compile(self,
                optimizer="rmsprop",
                loss=None,
                metrics=None,
                loss_weights=None,
                weighted_metrics=None,
                run_eagerly=None,
                steps_per_execution=None,
                **kwargs
    ):
        self.complieArgs = [optimizer, loss, metrics, loss_weights, weighted_metrics, run_eagerly, steps_per_execution]
        self.complieKwargs = kwargs
        self.model.compile(*self.complieArgs, **kwargs)
        self.compiled = True

    def fit(self,
            x=None,
            y=None,
            batch_size=None,
            epochs=1,
            verbose="auto",
            callbacks=None,
            validation_split=0.0,
            validation_data=None,
            shuffle=True,
            class_weight=None,
            sample_weight=None,
            initial_epoch=0,
            steps_per_epoch=None,
            validation_steps=None,
            validation_batch_size=None,
            validation_freq=1,
            max_queue_size=10,
            workers=1,
            use_multiprocessing=False
    ):
        if not self.compiled:
            raise RuntimeError("model should be compiled before fit")
        self.epochs = epochs
        self.fitArgs = [x,y,batch_size,1,verbose,callbacks,validation_split,validation_data,shuffle,class_weight,sample_weight,initial_epoch,steps_per_epoch,validation_steps,validation_batch_size,validation_freq,max_queue_size,workers,use_multiprocessing]
        return self.call(training=True)
    
    def predict(self,
                x,
                batch_size=None,
                verbose="auto",
                steps=None,
                callbacks=None,
                max_queue_size=10,
                workers=1,
                use_multiprocessing=False
    ):
        if not self.compiled:
            raise RuntimeError("model should be compiled before predict")
        return self.model.predict( x,
                                    batch_size=batch_size,
                                    verbose=verbose,
                                    steps=steps,
                                    callbacks=callbacks,
                                    max_queue_size=max_queue_size,
                                    workers=workers,
                                    use_multiprocessing=use_multiprocessing
                                 )


    def call(self, x=None, training=False):
        if training:
            if x:
                raise ValueError("Please use 'fit' when training.")
            def fit_epoch():
                # 满足条件动态添加新残差块
                if self.model.getBlocksNum() < self.max_blocks_num and self.add_condition():
                    self.model.freezeBlocks(1)
                    if self.copy_last_block:
                        self.model.copyLastBlock()
                    else:
                        self.model.addNewBlock()
                    self.model.compile(*self.complieArgs, **self.complieKwargs)
                self.model.fit(*self.fitArgs)
            for epoch in range(self.epochs):
                print(f"Epoch {epoch+1}/{self.epochs}")
                # 使用多进程的方式可以释放显存
                p = Thread(target=fit_epoch)
                p.start()
                p.join()
        else:
            return self.model.predict(x)

    def set_epochs(self, interval_of_epochs:int = None) -> None:
        self.epoch = 0
        self.last_change_epoch = 1
        if interval_of_epochs is None:
            self.interval = 1
        else:
            self.interval = interval_of_epochs
        self.add_condition = self.__num_of_epochs

    def __num_of_epochs(self) -> bool:
        self.epoch += 1
        if self.epoch - self.last_change_epoch == self.interval:
            self.last_change_epoch = self.epoch
            return True
        return False

In [9]:
dynamic_model = dynamicResNet(max_blocks_num=5, copy_last_block=False)
dynamic_model.set_epochs(5)
def fit_dinamic_model():
    dynamic_model.compile(optimizer="Adam", loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=["accuracy"])
    dynamic_model.fit(train_x, train_y, batch_size=batch_size, epochs=epochs, callbacks=[tensorboard_callback])
p = Thread(target=fit_dinamic_model)
p.start()
p.join()
dynamic_model.model.summary()

Epoch 1/30
Cause: mangled names are not yet supported
Cause: mangled names are not yet supported
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
freeze blocks: 1 , total frozen blocks: 1
add new block
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
freeze blocks: 1 , total frozen blocks: 2
add new block
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
freeze blocks: 1 , total frozen blocks: 3
add new block
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
freeze blocks: 1 , total frozen blocks: 4
add new block
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30
Model: "my_res_net"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 res_block0 (ResBlock)       (None, 32, 32, 3, 64)     1280      
                                                                 
 res_block1 (ResBlock)       (None, 32, 32, 3, 64)     

In [13]:
dynamic_model_copy = dynamicResNet(max_blocks_num=5, copy_last_block=True)
dynamic_model_copy.set_epochs(5)

def fit_dinamic_model():
    dynamic_model_copy.compile(optimizer="Adam", loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=["accuracy"])
    dynamic_model_copy.fit(train_x, train_y, batch_size=batch_size, epochs=epochs, callbacks=[tensorboard_callback])

p = Thread(target=fit_dinamic_model)
p.start()
p.join()
dynamic_model_copy.model.summary()

Epoch 1/4
Epoch 2/4
freeze blocks: 1 , total frozen blocks: 1
----------
copy last block
copy failed: shape different with last block
Epoch 3/4
freeze blocks: 1 , total frozen blocks: 2
----------
copy last block
Epoch 4/4
freeze blocks: 1 , total frozen blocks: 3
----------
copy last block
Model: "my_res_net_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 res_block0 (ResBlock)       (None, 32, 32, 3, 64)     1280      
                                                                 
 res_block1 (ResBlock)       (None, 32, 32, 3, 64)     37184     
                                                                 
 res_block2 (ResBlock)       (None, 32, 32, 3, 64)     37184     
                                                                 
 res_block3 (ResBlock)       (None, 32, 32, 3, 64)     37184     
                                                                 
 flatten_2 (Flatten)      

In [10]:
from keras import Sequential
static_model = Sequential([ResBlock(), ResBlock(), ResBlock(), ResBlock(), ResBlock(), layers.Flatten(), layers.Dense(10)])
def fit_static_model():
    static_model.compile(optimizer="Adam", loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=["accuracy"])
    static_model.fit(train_x, train_y, batch_size=batch_size, epochs=epochs, callbacks=[tensorboard_callback])
p = Thread(target=fit_static_model)
p.start()
p.join()
static_model.summary()

Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30


In [None]:
%tensorboard --logdir logs/fit