In [45]:
import tensorflow  as tf
from tensorflow.keras import layers,Sequential,optimizers,datasets

In [41]:
class BasicBlock(tf.keras.Model):
    """
    filters:卷积层通道数量
    """
    def __init__(self,filters,strides=1):
        super(BasicBlock,self).__init__()

        self.con1 = layers.Conv2D(filters=filters,kernel_size=(3,3),strides=strides,padding="same")
        self.bn1 = layers.BatchNormalization()
        self.relu = layers.ReLU()
        self.conv2 = layers.Conv2D(filters=filters,kernel_size=(3,3),strides=1,padding="same")
        self.bn2 = layers.BatchNormalization()
        self.shortcut = Sequential()
        if strides != 1:
            self.shortcut.add(layers.Conv2D(filters=filters,kernel_size=(1,1),strides=strides))
        else:
            self.shortcut = lambda x: x
    def call(self,inputs,training=None):

        out = self.con1(inputs)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        identity = self.shortcut(inputs)
        out += identity
        out = self.relu(out)

        return out
class ResNet(layers.Layer):
    def __init__(self,layer_dims,num_classes=10):
        super(ResNet,self).__init__()

        self.stem = Sequential([
            layers.Conv2D(filters=64,kernel_size=(3,3),strides=1,padding="same"),
            layers.BatchNormalization(),
            layers.ReLU(),
            layers.MaxPool2D(pool_size=(2,2),strides=2),
        ])
        self.layer1 = self.build_(64,layer_dims[0])
        # 降维
        self.layer2 = self.build_(128,layer_dims[1],strides=2)
        self.layer3 = self.build_(256,layer_dims[2],strides=2)
        self.layer4 = self.build_(512,layer_dims[3],strides=2)

        self.avgpool = layers.GlobalAveragePooling2D()
        self.fc1 = layers.Dense(num_classes)



    def call(self,inputs,training=None):
        out = self.stem(inputs)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = self.fc1(out)
        return out


    def build_(self,filters,blocks,strides=1):
        """
        :param filters:
        :param blocks: block数量
        :param strides:
        :return:
        """
        res_blocks = Sequential()
        res_blocks.add(BasicBlock(filters=filters,strides=strides))

        # 后续的bk不进行下采样
        for _ in range(1,blocks):
            res_blocks.add(BasicBlock(filters=filters))
        return res_blocks

In [44]:
def resnet18():
    return ResNet([2,2,2,2])
def resnet34():
    return ResNet([3,4,6,3])

In [48]:
def process(x,y):
    x = tf.cast(x, tf.float32) / 255.
    y = tf.cast(y, tf.int32)
    return x,y
(train_x, train_y), (test_x, test_y) = datasets.cifar10.load_data()
train_y = tf.squeeze(train_y)
test_y = tf.squeeze(test_y)
train_db = tf.data.Dataset.from_tensor_slices((train_x, train_y))
train_db = train_db.map(process).shuffle(10000).batch(128)
test_db = tf.data.Dataset.from_tensor_slices((test_x, test_y))
test_db = test_db.map(process).batch(128)
model = resnet18()
model.build(input_shape=[None,32,32,3])

In [54]:
for epoch in range(10):
    for step, (x, y) in enumerate(train_db):
        with tf.GradientTape() as tape:
            logits = model(x)
            y_onehot = tf.one_hot(y, depth=10)
            loss = tf.reduce_mean(tf.keras.losses.categorical_crossentropy(y_onehot,logits,from_logits=True))
            grads = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))
            if step % 100 == 0:
                print("epoch:{},step:{},loss:{}".format(epoch,step,loss))

epoch:0,step:0,loss:2.2989563941955566


KeyboardInterrupt: 