In [116]:
import tensorflow as tf
from tensorflow.keras import layers,datasets,models,Sequential,optimizers
import datetime
(x_train,y_train),(x_test,y_test) = datasets.fashion_mnist.load_data()

In [117]:
def process(x,y):
    x = tf.cast(x,dtype=tf.float32) / 255.
    y = tf.cast(y,dtype=tf.int32)
    return x,y
batch_sz = 128
train_db = tf.data.Dataset.from_tensor_slices((x_train,y_train))
train_db = train_db.map(process).shuffle(10000).batch(batch_sz)
test_db = tf.data.Dataset.from_tensor_slices((x_test,y_test))
test_db = test_db.map(process).batch(batch_sz)


In [118]:
train_db_iter = iter(train_db)
sample_iter = next(train_db_iter)
sample_iter[0].shape, sample_iter[1].shape

(TensorShape([128, 28, 28]), TensorShape([128]))

In [119]:
model = Sequential(
    [
        layers.Dense(258, activation='relu'),
        layers.Dense(128, activation='relu'),
        layers.Dense(64, activation='relu'),
        layers.Dense(32, activation='relu'),
        layers.Dense(10),
    ]
)
model.build(input_shape=(None, 784))
model.summary()

Model: "sequential_10"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense_50 (Dense)            (None, 258)               202530    
                                                                 
 dense_51 (Dense)            (None, 128)               33152     
                                                                 
 dense_52 (Dense)            (None, 64)                8256      
                                                                 
 dense_53 (Dense)            (None, 32)                2080      
                                                                 
 dense_54 (Dense)            (None, 10)                330       
                                                                 
Total params: 246,348
Trainable params: 246,348
Non-trainable params: 0
_________________________________________________________________


In [120]:
# 自动优化
optimizer = optimizers.Adam(learning_rate=1e-3)

log_dir = "C:/Users/lenovo/Desktop/dengruizhe/tensorflow/tensorflow/projects/logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
writer = tf.summary.create_file_writer(log_dir)
def forward():
    """
    前向传播
    :return:
    """
    for step,(x,y) in enumerate(train_db):
        x = tf.reshape(x,[-1,784])
        with tf.GradientTape() as tape:
            # 网络的输出
            logits = model(x)
            y_onehot = tf.one_hot(y,depth=10)

            # 我认为这里也是比较关键，选用什么损失函数来进行反向传播
            loss = tf.reduce_mean(tf.losses.MSE(y_onehot,logits))
            loss2 = tf.reduce_mean(tf.losses.categorical_crossentropy(y_onehot,logits,from_logits=True))
        #     计算梯度
        gradients = tape.gradient(loss2,model.trainable_variables)
        #  更新参数
        optimizer.apply_gradients(zip(gradients,model.trainable_variables))


        if step % 100 == 0:
            print(step,float(loss),float(loss2))
            with writer.as_default():
                tf.summary.scalar('train_loss', loss, step=step + epoch * len(train_db))
for epoch in range(5):
    forward()
    corrects,total = 0,0
    for step,(x,y) in enumerate(test_db):
        x = tf.reshape(x, (-1,784))
        logits = model(x)

        prob = tf.nn.softmax(logits,axis=-1)
        pred = tf.cast(tf.argmax(prob,axis=-1),dtype=tf.int32)
        correct = tf.reduce_sum(tf.cast(tf.equal(pred,y),dtype=tf.int32))
        corrects += int(correct)
        total += x.shape[0]
    print("Accuracy:",corrects / total * 100,"%")
    with writer.as_default():
        tf.summary.scalar('test_accuracy', corrects / total * 100, step=epoch)

0 0.158280611038208 2.3191261291503906
100 14.673531532287598 0.5984320044517517
200 20.077957153320312 0.541550874710083
300 15.91042423248291 0.5609288215637207
400 19.596179962158203 0.3715458810329437
Accuracy: 84.58 %
0 18.518386840820312 0.4450332224369049
100 20.434173583984375 0.3024790287017822
200 24.790096282958984 0.4904400706291199
300 22.672534942626953 0.2866445779800415
400 24.996231079101562 0.40415680408477783
Accuracy: 86.00999999999999 %
0 22.913501739501953 0.3332185745239258
100 29.017803192138672 0.3272151052951813
200 32.18880844116211 0.37596845626831055
300 28.082218170166016 0.29535555839538574
400 29.024959564208984 0.4233640730381012
Accuracy: 87.09 %
0 31.399925231933594 0.3599300980567932
100 34.22286605834961 0.38533082604408264
200 29.671043395996094 0.43384841084480286
300 33.37353515625 0.25913184881210327
400 32.95117950439453 0.3634488880634308
Accuracy: 87.38 %
0 35.841739654541016 0.22969673573970795
100 42.03129959106445 0.29542598128318787
200 3