In [34]:
import tensorflow as tf
from tensorflow.keras import layers,datasets,models,Sequential,optimizers,metrics
import datetime
(x_train,y_train),(x_test,y_test) = datasets.fashion_mnist.load_data()

In [35]:
def process(x,y):
    x = tf.cast(x,dtype=tf.float32) / 255.
    y = tf.cast(y,dtype=tf.int32)
    return x,y
batch_sz = 128
train_db = tf.data.Dataset.from_tensor_slices((x_train,y_train))
train_db = train_db.map(process).shuffle(10000).batch(batch_sz)
test_db = tf.data.Dataset.from_tensor_slices((x_test,y_test))
test_db = test_db.map(process).batch(batch_sz)


In [36]:
train_db_iter = iter(train_db)
sample_iter = next(train_db_iter)
sample_iter[0].shape, sample_iter[1].shape

(TensorShape([128, 28, 28]), TensorShape([128]))

In [37]:
model = Sequential(
    [
        layers.Dense(258, activation='relu'),
        layers.Dense(128, activation='relu'),
        layers.Dense(64, activation='relu'),
        layers.Dense(32, activation='relu'),
        layers.Dense(10),
    ]
)
model.build(input_shape=(None, 784))
model.summary()

Model: "sequential_6"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_30 (Dense)             (None, 258)               202530    
_________________________________________________________________
dense_31 (Dense)             (None, 128)               33152     
_________________________________________________________________
dense_32 (Dense)             (None, 64)                8256      
_________________________________________________________________
dense_33 (Dense)             (None, 32)                2080      
_________________________________________________________________
dense_34 (Dense)             (None, 10)                330       
Total params: 246,348
Trainable params: 246,348
Non-trainable params: 0
_________________________________________________________________


In [38]:
# 自动优化
optimizer = optimizers.Adam(learning_rate=1e-3)

acc_metrics = metrics.Accuracy()
loss_metrics = metrics.Mean()

log_dir = "C:/Users/Administrator/Desktop/projects/tensorflow/logs/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
writer = tf.summary.create_file_writer(log_dir)
def forward():
    """
    前向传播
    :return:
    """
    for step,(x,y) in enumerate(train_db):
        x = tf.reshape(x,[-1,784])
        with tf.GradientTape() as tape:
            # 网络的输出
            logits = model(x)
            y_onehot = tf.one_hot(y,depth=10)

            # 我认为这里也是比较关键，选用什么损失函数来进行反向传播
            loss = tf.reduce_mean(tf.keras.losses.MSE(y_onehot,logits))
            loss2 = tf.reduce_mean(tf.keras.losses.categorical_crossentropy(y_onehot,logits,from_logits=True))
            loss_metrics.update_state(loss)
        #     计算梯度
        gradients = tape.gradient(loss2,model.trainable_variables)
        #  更新参数
        optimizer.apply_gradients(zip(gradients,model.trainable_variables))


        if step % 100 == 0:
            print(step,float(loss),float(loss2))
            loss_metrics.reset_states()
            with writer.as_default():
                tf.summary.scalar('train_loss', loss, step=step + epoch * len(train_db))
for epoch in range(3):
    forward()
    corrects,total = 0,0
    acc_metrics.reset_states()
    for step,(x,y) in enumerate(test_db):
        x = tf.reshape(x, (-1,784))
        logits = model(x)

        prob = tf.nn.softmax(logits,axis=-1)
        pred = tf.cast(tf.argmax(prob,axis=-1),dtype=tf.int32)
        correct = tf.reduce_sum(tf.cast(tf.equal(pred,y),dtype=tf.int32))
        corrects += int(correct)
        total += x.shape[0]
        acc_metrics.update_state(y,pred)
    print("Accuracy:",corrects / total * 100,"%",acc_metrics.result().numpy())
    with writer.as_default():
        tf.summary.scalar('test_accuracy', corrects / total * 100, step=epoch)

0 0.152599036693573 2.3029565811157227
100 17.141860961914062 0.5382728576660156
200 22.620223999023438 0.33125609159469604
300 20.928707122802734 0.39459073543548584
400 19.317501068115234 0.48414361476898193
Accuracy: 84.39999999999999 % 0.844
0 19.764753341674805 0.38713309168815613
100 21.10894012451172 0.3928174674510956
200 20.91429901123047 0.2857171297073364
300 23.24082374572754 0.5189142823219299
400 20.17760467529297 0.3921857476234436
Accuracy: 85.26 % 0.8526
0 21.058652877807617 0.3926309645175934
100 22.836620330810547 0.3252035975456238
200 23.857646942138672 0.3607388436794281
300 25.43111801147461 0.29877251386642456
400 31.291725158691406 0.28947895765304565
Accuracy: 85.81 % 0.8581
