In [9]:
'''
Function Description:手写数字识别
'''
import tensorflow as tf 
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics

def preprocess(x, y):
    x = tf.cast(x, dtype=tf.float32)/255.
    y = tf.cast(y, dtype=tf.int32)
    return x, y 

(x ,y), (x_test, y_test) = datasets.fashion_mnist.load_data()
print(x.shape, y.shape)

batch_size = 128
db = tf.data.Dataset.from_tensor_slices((x, y))
db = db.map(preprocess).shuffle(10000).batch(batch_size)

db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
db_test = db_test.map(preprocess).batch(batch_size)

db_iter = iter(db)
sample = next(db_iter)
print('batch:', sample[0].shape, sample[1].shape)

(60000, 28, 28) (60000,)
batch: (128, 28, 28) (128,)


In [10]:
 modle = Sequential([
     layers.Dense(256, activation = tf.nn.relu),
     layers.Dense(128, activation = tf.nn.relu),
     layers.Dense(64, activation = tf.nn.relu),
     layers.Dense(32, activation = tf.nn.relu),
     layers.Dense(10, activation = tf.nn.relu)
 ])
 modle.build(input_shape=[None,28*28])
 modle.summary()

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_5 (Dense)              (None, 256)               200960    
_________________________________________________________________
dense_6 (Dense)              (None, 128)               32896     
_________________________________________________________________
dense_7 (Dense)              (None, 64)                8256      
_________________________________________________________________
dense_8 (Dense)              (None, 32)                2080      
_________________________________________________________________
dense_9 (Dense)              (None, 10)                330       
Total params: 244,522
Trainable params: 244,522
Non-trainable params: 0
_________________________________________________________________


In [11]:
optimizer = optimizers.Adam(lr = 1e-3)



In [12]:
for epoch in range(30):
    for step, (x, y) in enumerate(db):
        x = tf.reshape(x, [-1, 28*28])
        with tf.GradientTape() as tape:
            logits = modle(x)
            y_onehot = tf.one_hot(y, depth=10)

            loss = tf.reduce_mean(tf.losses.MSE(y_onehot, logits))
            loss2 = tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True)
            loss2 = tf.reduce_mean(loss2)

        grads = tape.gradient(loss2, modle.trainable_variables)
        optimizer.apply_gradients(zip(grads, modle.trainable_variables))

        if step % 100 == 0:
            print(epoch, step, 'loss:', float(loss2), float(loss))
    total_correct = 0
    total_num = 0
    for x, y in db_test:
        x = tf.reshape(x, [-1, 28*28])
        logits = modle(x)
        porb = tf.nn.softmax(logits, axis=1)
        pred = tf.argmax(porb, axis=1)
        pred = tf.cast(pred, dtype=tf.int32)
        correct = tf.equal(pred, y)
        correct = tf.reduce_sum(tf.cast(correct, dtype=tf.int32))

        total_correct += int(correct)
        total_num += x.shape[0]

    acc = total_correct / total_num
    print(epoch, 'test acc:', acc)

0 0 loss: 2.3076236248016357 0.10553194582462311
0 100 loss: 0.9101220369338989 6.532830238342285
0 200 loss: 0.9955911636352539 7.688383102416992
0 300 loss: 0.7440605759620667 6.276527404785156
0 400 loss: 0.7618370056152344 8.465690612792969
0 test acc: 0.7437
1 0 loss: 0.679572343826294 8.535686492919922
1 100 loss: 0.6789568662643433 11.624359130859375
1 200 loss: 0.674553394317627 7.913036346435547
1 300 loss: 0.7738601565361023 8.893863677978516
1 400 loss: 0.6247442960739136 9.057350158691406
1 test acc: 0.7469
2 0 loss: 0.6061739921569824 11.332883834838867
2 100 loss: 0.706638514995575 14.089025497436523
2 200 loss: 0.8583351969718933 8.625970840454102
2 300 loss: 0.5021404027938843 13.022235870361328
2 400 loss: 0.5824685096740723 12.232477188110352
2 test acc: 0.7494
3 0 loss: 0.7296172380447388 9.329782485961914
3 100 loss: 0.7809584736824036 12.589255332946777
3 200 loss: 0.7950228452682495 13.361106872558594
3 300 loss: 0.6903210282325745 13.059016227722168
3 400 loss: 0