In [1]:
import types
import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras import Model, layers
from threading import Thread

In [2]:
%load_ext tensorboard
import datetime
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = "logs/fit/" + current_time
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
# train_log_dir = 'logs/gradient_tape/' + current_time + '/train'
# test_log_dir = 'logs/gradient_tape/' + current_time + '/test'
# train_summary_writer = tf.summary.create_file_writer(train_log_dir)
# test_summary_writer = tf.summary.create_file_writer(test_log_dir)

In [3]:
gpus = tf.config.list_physical_devices("GPU")
if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True)

In [4]:
# load images dataset
def load_dataset(name:str="mnist", size:int=None):
    if name == "mnist":
        (train_x, train_y), (test_x, test_y) = keras.datasets.mnist.load_data()
    elif name == "cifar10":
        (train_x, train_y), (test_x, test_y) = keras.datasets.cifar10.load_data()
    train_x, test_x = train_x/255.0, test_x/255.0

    if size:
        train_x = train_x[:size][..., tf.newaxis].astype("float32")
        test_x = test_x[:size][..., tf.newaxis].astype("float32")
        train_y, test_y = train_y[:size], test_y[:size]
    return (train_x, train_y), (test_x, test_y)

In [5]:
batch_size=32
epochs=30
(train_x, train_y), (test_x, test_y) = load_dataset("cifar10")

In [6]:
class ResBlock(layers.Layer):

    def __init__(self, filters, kernel_size, strides = (1, 1), padding: str = 'same', *args, **wargs):
        super().__init__(*args, **wargs)
        self.filters = filters
        self.kernel_size = kernel_size
        self.strides = strides
        self.padding = padding
        # convolution
        self.conv1 = layers.Conv2D(self.filters, self.kernel_size, strides=self.strides, padding=self.padding, activation="relu")
        self.bn1 = layers.BatchNormalization()
        self.conv2 = layers.Conv2D(self.filters, self.kernel_size, padding=self.padding)
        self.bn2 = layers.BatchNormalization()
        # self.conv3 = layers.Conv2D(self.filters, self.kernel_size, padding=self.padding, activation="relu")
        # self.bn3 = layers.BatchNormalization()
        # shortcut
        self.downconv = layers.Conv2D(self.filters, 1, padding=self.padding)
        self.downbn = layers.BatchNormalization()

    def build(self, input_shape):
        # resolve output shape in model summary
        input_layer = layers.Input(shape=input_shape[1:], batch_size=input_shape[0])
        self.call(input_layer)
        return super().build(input_shape)
    
    def shortcut(self, x):
        x = self.downconv(x)
        x = self.downbn(x)
        return x

    def call(self, inputs:np.ndarray, training=None, mask=None):
        x:np.ndarray = inputs
        fx:np.ndarray = x
        # f(x)
        fx = self.conv1(fx, training=training)
        fx = self.bn1(fx, training=training)
        fx = self.conv2(fx, training=training)
        fx = self.bn2(fx, training=training)
        # fx = self.conv3(fx, training=training)
        # fx = self.bn3(fx, training=training)
        # h(x) = x + f(x)
        if fx.shape[-1] != x.shape[-1]:
            # x = self.downconv(x, training=training)
            # x = self.downbn(x, training=training)
            x = self.shortcut(x)
        try:
            return fx + x
        except:
            raise RuntimeError(x.shape, fx.shape, inputs.shape)
    
    def get_weights(self):
        return [self.conv1.get_weights(), self.bn1.get_weights(),
                self.conv2.get_weights(), self.bn2.get_weights()
        ]
    
    def set_weights(self, weights:list):
        self.conv1.set_weights(weights[0])
        self.bn1.set_weights(weights[1])
        self.conv2.set_weights(weights[2])
        self.bn2.set_weights(weights[3])

In [7]:
class ResNet_34(Model):
    def __init__(self, units, blocks=None, dynamic=True, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.__blocks_num = 1
        self.__frozen_blocks_num = 0
        self.__units = units
        self.dynamic_depts = dynamic
        # An ordinary ResNet, but put blocks in a list. New blocks will be added into this list when training.
        # 常规的残差网络，但将残差块放在一个list中，训练时会将新块添加到这里
        self.conv0 = layers.Conv2D(64, 7, strides=2, padding="same", name="conv0")
        self.blocks1 = [
            ResBlock(64, 3, name="res_block_64_0"),
            ResBlock(64, 3, name="res_block_64_1"),
            ResBlock(64, 3, name="res_block_64_2")
        ]
        self.blocks2 = [
            ResBlock(128, 3, name="res_block_128_3"),
            ResBlock(128, 3, name="res_block_128_4"),
            ResBlock(128, 3, name="res_block_128_5"),
            ResBlock(128, 3, name="res_block_128_6")
        ]
        self.blocks3 = [
            ResBlock(256, 3, name="res_block_256_7"),
            ResBlock(256, 3, name="res_block_256_8"),
            ResBlock(256, 3, name="res_block_256_9"),
            ResBlock(256, 3, name="res_block_256_10"),
            ResBlock(256, 3, name="res_block_256_11"),
            ResBlock(256, 3, name="res_block_256_12")
        ]
        self.blocks4 = [
            ResBlock(512, 3, name="res_block_512_13"),
            ResBlock(512, 3, name="res_block_512_14"),
            ResBlock(512, 3, name="res_block_512_15"),
        ]
        self.blocks = blocks
        self.flatten = layers.Flatten()
        self.dense = layers.Dense(units)
    
    def build(self, input_shape):
        # resolve output shape in model summary
        input_layer = layers.Input(shape=input_shape[1:], batch_size=input_shape[0])
        self.call(input_layer)
        return super().build(input_shape)

    def call(self, x, training=None, mask=None):
        x = self.conv0(x, training=training)
        if self.dynamic_depts:
            if self.blocks is None:
                # print("called dynamic")
                for i in range(min(len(self.blocks1), self.__blocks_num)):
                    x = self.blocks1[i](x, training=training)
                for i in range(min(len(self.blocks2), self.__blocks_num)):
                    x = self.blocks2[i](x, training=training)
                for i in range(min(len(self.blocks3), self.__blocks_num)):
                    x = self.blocks3[i](x, training=training)
                for i in range(min(len(self.blocks4), self.__blocks_num)):
                    x = self.blocks4[i](x, training=training)
            else:
                for i in range(self.__blocks_num):
                    x = self.blocks[i](x, training=training)
        else:
            if self.blocks is None:
                # print("called static")
                for block in self.blocks1:
                    x = block(x, training=training)
                for block in self.blocks2:
                    x = block(x, training=training)
                for block in self.blocks3:
                    x = block(x, training=training)
                for block in self.blocks4:
                    x = block(x, training=training)
            else:
                for block in self.blocks:
                    x = block(x, training=training)
        x = self.flatten(x, training=training)
        x = self.dense(x, training=training)
        return x

    def getBlocksNum(self):
        return self.__blocks_num
    
    def freezeBlocks(self, num):
        for i in range(self.__frozen_blocks_num, min(self.__frozen_blocks_num+num, self.__blocks_num)):
            if self.blocks is None:
                if len(self.blocks1) > self.__blocks_num:
                    self.blocks1[i].trainable = False
                if len(self.blocks2) > self.__blocks_num:
                    self.blocks2[i].trainable = False
                if len(self.blocks3) > self.__blocks_num:
                    self.blocks3[i].trainable = False
                if len(self.blocks4) > self.__blocks_num:
                    self.blocks4[i].trainable = False
            else:
                self.blocks[i].trainable = False
        self.__frozen_blocks_num = min(self.__frozen_blocks_num+num, self.__blocks_num)

    def addNewBlock(self):
        print("----------")
        print("add new blocks")
        self.freezeBlocks(1)
        self.__blocks_num += 1
        print(f"this is the {self.__blocks_num} added blocks")#, block name: {self.blocks[self.__blocks_num-1].name}")
    
    def copyLastBlock(self):
        print("----------")
        print("copy last block")
        self.freezeBlocks(1)
        if self.blocks is not None:
            return
        if len(self.blocks1) > self.__blocks_num:
            newBlock = self.blocks1[self.__blocks_num]
            last_block:ResBlock = self.blocks1[self.__blocks_num-1]
            newBlock(last_block.output)
            if last_block.input_shape == newBlock.input_shape and last_block.output_shape == newBlock.output_shape:
                newBlock.set_weights(last_block.get_weights())
            else:
                print("block1 copy failed: shape different with last block")
        if len(self.blocks2) > self.__blocks_num:
            newBlock = self.blocks2[self.__blocks_num]
            last_block:ResBlock = self.blocks2[self.__blocks_num-1]
            newBlock(last_block.output)
            if last_block.input_shape == newBlock.input_shape and last_block.output_shape == newBlock.output_shape:
                newBlock.set_weights(last_block.get_weights())
            else:
                print("block2 copy failed: shape different with last block")
        if len(self.blocks3) > self.__blocks_num:
            newBlock = self.blocks3[self.__blocks_num]
            last_block:ResBlock = self.blocks3[self.__blocks_num-1]
            newBlock(last_block.output)
            if last_block.input_shape == newBlock.input_shape and last_block.output_shape == newBlock.output_shape:
                newBlock.set_weights(last_block.get_weights())
            else:
                print("block3 copy failed: shape different with last block")
        if len(self.blocks4) > self.__blocks_num:
            newBlock = self.blocks4[self.__blocks_num]
            last_block:ResBlock = self.blocks4[self.__blocks_num-1]
            newBlock(last_block.output)
            if last_block.input_shape == newBlock.input_shape and last_block.output_shape == newBlock.output_shape:
                newBlock.set_weights(last_block.get_weights())
            else:
                print("block4 copy failed: shape different with last block")
        self.__blocks_num += 1

In [8]:
class dynamicResNet:
    def __init__(self, is_dynamic=True, condition: types.FunctionType = None, max_blocks_num:int = 2, copy_last_block:bool = False,*args, **wargs) -> None:
        """
        Args:
            is_dynamic: bool, this model's depth should be dynamical increase or not
                模型深度是否动态增加
            condition: A function, which will be called in every epoch and returns a boolean value representing whether to add a new block.
                每个epoch会被调用一次，返回值为布尔类型，代表是否添加新的块
            max_blocks_num: int, total num of blocks which will be added into model in the last
                最终会被添加到模型中的总残差块数
            copy_last_block: bool, whether copy the last block's weight to new block
                新的block是否复制最后一个block的权重
        """
        super(dynamicResNet, self).__init__(*args, **wargs)
        self.dynamic = is_dynamic
        if condition is None:
            self.add_condition = self.set_epochs
            self.add_condition()
        else:
            if callable(condition):
                self.add_condition = condition
            else:
                raise ValueError("'condition' must be a function")
        self.max_blocks_num = max_blocks_num
        self.copy_last_block = copy_last_block
        # build model //创建模型
        self.model = ResNet_34(10, dynamic=self.dynamic)
        self.compiled = False

    def compile(self,
                optimizer="rmsprop",
                loss=None,
                metrics=None,
                loss_weights=None,
                weighted_metrics=None,
                run_eagerly=None,
                steps_per_execution=None,
                **kwargs
    ):
        self.complieArgs = [optimizer, loss, metrics, loss_weights, weighted_metrics, run_eagerly, steps_per_execution]
        self.complieKwargs = kwargs
        self.model.compile(*self.complieArgs, **kwargs)
        self.compiled = True

    def fit(self,
            x=None,
            y=None,
            batch_size=None,
            epochs=1,
            verbose="auto",
            callbacks=None,
            validation_split=0.0,
            validation_data=None,
            shuffle=True,
            class_weight=None,
            sample_weight=None,
            initial_epoch=0,
            steps_per_epoch=None,
            validation_steps=None,
            validation_batch_size=None,
            validation_freq=1,
            max_queue_size=10,
            workers=1,
            use_multiprocessing=False
    ):
        if not self.compiled:
            raise RuntimeError("model should be compiled before fit")
        self.epochs = epochs
        self.fitArgs = [x,y,batch_size,1,verbose,callbacks,validation_split,validation_data,shuffle,class_weight,sample_weight,initial_epoch,steps_per_epoch,validation_steps,validation_batch_size,validation_freq,max_queue_size,workers,use_multiprocessing]
        return self.call(training=True)
    
    def predict(self,
                x,
                batch_size=None,
                verbose="auto",
                steps=None,
                callbacks=None,
                max_queue_size=10,
                workers=1,
                use_multiprocessing=False
    ):
        if not self.compiled:
            raise RuntimeError("model should be compiled before predict")
        return self.model.predict( x,
                                    batch_size=batch_size,
                                    verbose=verbose,
                                    steps=steps,
                                    callbacks=callbacks,
                                    max_queue_size=max_queue_size,
                                    workers=workers,
                                    use_multiprocessing=use_multiprocessing
                                 )

    def call(self, x=None, training=False):
        if training:
            def fit_epoch():
                # 满足条件动态添加新残差块
                if self.model.getBlocksNum() < self.max_blocks_num and self.add_condition():
                    if self.copy_last_block:
                        self.model.copyLastBlock()
                    else:
                        self.model.addNewBlock()
                    self.model.compile(*self.complieArgs, **self.complieKwargs)
                self.model.fit(*self.fitArgs)
            for epoch in range(self.epochs):
                print("Epoch: ", epoch)
                p = Thread(target=fit_epoch)
                p.start()
                p.join()
        else:
            return self.model.predict(x)

    def set_epochs(self, interval_of_epochs:int = None) -> None:
        self.epoch = 0
        self.last_change_epoch = 1
        if interval_of_epochs is None:
            self.interval = 1
        else:
            self.interval = interval_of_epochs
        self.add_condition = self.__num_of_epochs

    def __num_of_epochs(self) -> bool:
        self.epoch += 1
        if self.epoch - self.last_change_epoch == self.interval:
            self.last_change_epoch = self.epoch
            return True
        return False

In [9]:
dynamic_model = dynamicResNet(is_dynamic=True, max_blocks_num=6, copy_last_block=True)
dynamic_model.set_epochs(5)
def fit_dynamic_model():
    dynamic_model.compile(optimizer="Adam", loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=["accuracy"])
    dynamic_model.fit(train_x, train_y, batch_size=batch_size, epochs=epochs, callbacks=[tensorboard_callback])
p = Thread(target=fit_dynamic_model)
p.start()
p.join()
dynamic_model.model.summary()

Epoch:  0
Cause: mangled names are not yet supported
Cause: mangled names are not yet supported
Epoch:  1
Epoch:  2
Epoch:  3
Epoch:  4
Epoch:  5
----------
copy last block
1 (None, 16, 16, 64) (None, 16, 16, 64) (None, 16, 16, 64) (None, 16, 16, 64) (None, 16, 16, 64)
block2 copy failed: shape different with last block
block3 copy failed: shape different with last block
block4 copy failed: shape different with last block
Epoch:  6
Epoch:  7
Epoch:  8
Epoch:  9
Epoch:  10
----------
copy last block
2 (None, 16, 16, 64) (None, 16, 16, 64) (None, 16, 16, 64) (None, 16, 16, 64) (None, 16, 16, 64)
Epoch:  11
Epoch:  12
Epoch:  13
Epoch:  14
Epoch:  15
----------
copy last block
Epoch:  16
Epoch:  17
Epoch:  18
Epoch:  19
Epoch:  20
----------
copy last block
Epoch:  21
Epoch:  22
Epoch:  23
Epoch:  24
Epoch:  25
----------
copy last block
Epoch:  26
Epoch:  27
Epoch:  28
Epoch:  29
Model: "res_net_34"
_________________________________________________________________
Layer (type)           

In [10]:
static_model = ResNet_34(units=10, dynamic=False)
def fit_static_model():
    static_model.compile(optimizer="Adam", loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=["accuracy"])
    static_model.fit(train_x, train_y, batch_size=batch_size, epochs=epochs, callbacks=[tensorboard_callback])
p = Thread(target=fit_static_model)
p.start()
p.join()
static_model.summary()

Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30
Model: "res_net_34_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv0 (Conv2D)               (None, 16, 16, 64)        9472      
_________________________________________________________________
res_block_64_0 (ResBlock)    (None, 16, 16, 64)        74368     
_________________________________________________________________
res_block_64_1 (ResBlock)    (None, 16, 16, 64)        74368     
_________________________________________________________________
res_block_64_2 (ResBlock)    (None, 16, 16, 64)        74368     
_________________________________