In [1]:
import tensorflow as tf

In [2]:
# 数据集处理

In [3]:
def preprocess(x,y):
    x = tf.cast(x,tf.float32)/255.
    y = tf.cast(y,dtype=tf.int32)
    return x,y

In [4]:
# 加载数据集

In [5]:
(x,y),(x_test,y_test) = tf.keras.datasets.cifar10.load_data()

In [6]:
y = tf.squeeze(y)
y_test = tf.squeeze(y_test)
y = tf.one_hot(y,depth=10)
y_test = tf.one_hot(y_test,depth=10)

In [7]:
# 数据集库搭建

In [8]:
banchsize = 128
train_db = tf.data.Dataset.from_tensor_slices((x,y))
train_db = train_db.map(preprocess).shuffle(10000).batch(banchsize)
test_db = tf.data.Dataset.from_tensor_slices((x_test,y_test))
test_db = test_db.map(preprocess).batch(banchsize)

In [9]:
# 抽样检测

In [10]:
sample = next(iter(train_db))
print(sample[0].shape,sample[1].shape)

(128, 32, 32, 3) (128, 10)


In [11]:
# 设置自定义层

In [12]:
class MyDense(tf.keras.layers.Layer):
    def __init__(self,intp_dim,outp_dim):
        super(MyDense,self).__init__()
        self.kernel = self.add_weight('w',[intp_dim,outp_dim])
    
    def call(self,inputs,training = None):
        x = inputs@self.kernel
        return x

In [16]:
class MyNetWork(tf.keras.models.Model):
    def __init__(self):
        super(MyNetWork,self).__init__()
        self.fc1 = MyDense(32*32*3,256)
        self.fc2 = MyDense(256,128)
        self.fc3 = MyDense(128,64)
        self.fc4 = MyDense(64,32)
        self.fc5 = MyDense(32,10)
    def call(self,inputs):
        x = tf.reshape(inputs,[-1,32*32*3])
        
        x = self.fc1(x)
        x = tf.nn.relu(x)
        
        x = self.fc2(x)
        x = tf.nn.relu(x)
        
        x = self.fc3(x)
        x = tf.nn.relu(x)
        
        x = self.fc4(x)
        x = tf.nn.relu(x)
        
        x = self.fc5(x)
        
        return x
        
        

In [17]:
netWork = MyNetWork()

In [18]:
netWork.build(input_shape=[None,32,32,3])

In [19]:
netWork.summary()

Model: "my_net_work_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 my_dense_5 (MyDense)        multiple                  786432    
                                                                 
 my_dense_6 (MyDense)        multiple                  32768     
                                                                 
 my_dense_7 (MyDense)        multiple                  8192      
                                                                 
 my_dense_8 (MyDense)        multiple                  2048      
                                                                 
 my_dense_9 (MyDense)        multiple                  320       
                                                                 
Total params: 829,760
Trainable params: 829,760
Non-trainable params: 0
_________________________________________________________________


In [20]:
netWork.compile(optimizer=tf.keras.optimizers.Adam(1e-3),
               loss= tf.losses.CategoricalCrossentropy(from_logits=True),
               metrics=['accuracy'])

In [21]:
netWork.fit(train_db,epochs=5,validation_data=test_db,validation_freq=1)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x158c32d0f40>

In [22]:
# 保存模型

In [23]:
netWork.evaluate(test_db)



[1.5431029796600342, 0.4530999958515167]

In [24]:
netWork.save_weights('ckpt/weights/weight_demo01.ckpt')

In [30]:
del netWork

In [32]:
netWork = MyNetWork()
netWork.compile(optimizer=tf.keras.optimizers.Adam(1e-3),
               loss= tf.losses.CategoricalCrossentropy(from_logits=True),
               metrics=['accuracy'])

In [34]:
netWork.load_weights('ckpt/weights/weight_demo01.ckpt')

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x158c89088b0>

In [35]:
netWork.evaluate(test_db)



[1.5431029796600342, 0.4530999958515167]