In [None]:
import tensorflow as tf
print('tensorflow version is: ',tf.__version__)

import numpy as np
print('numpy version is: ',np.__version__)

# 1. Prepare a dataset

加载数据集并划分为训练集以及测试集


In [None]:
(x_train,y_train),(x_test,y_test) = tf.keras.datasets.mnist.load_data()

In [None]:
print(x_train.shape)
print(y_train.shape)
print(x_test.shape)
print(y_test.shape)

In [None]:
print(y_train[:10])
#print(x_train[1,:,:])

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
plt.figure(1, figsize=(15,10))
plt.subplot(131)
plt.imshow(x_train[0,:].reshape(28,28))
plt.subplot(132)
plt.imshow(x_train[1,:].reshape(28,28))
plt.subplot(133)
plt.imshow(x_train[2,:].reshape(28,28))

# 2. Normalization

Q2. 只需要对x_train做归一化处理吗？这里做x_train的归一化会反映到dataset中去吗？如何确认？

In [None]:
x_train = x_train[:].reshape(60000,784).astype('float32') / 255

dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
dataset = dataset.shuffle(buffer_size = 1024).batch(64)

# 3. Model Creation

In [None]:
# Instantiation a simple classification model
model = tf.keras.Sequential([
    tf.keras.layers.Dense(784, activation = tf.nn.relu),
    tf.keras.layers.Dense(256, activation = tf.nn.relu),
    tf.keras.layers.Dense(10)
    ]
)

# 4. Define loss, metric, optimizer



## 4.1 What is logits?

[Wikipedia]In statistics, the logit (/ˈloʊdʒɪt/ LOH-jit) function or the log-odds is the logarithm of the odds p/(1-p) where p is probability. It is a type of function that creates a map of probability values from [0,1] to (-∞,+∞). It is the inverse of the sigmoidal "logistic" function or logistic transform used in mathematics, especially in statistics. In deep learning, the term logits layer is popularly used for the last neuron layer of neural networks used for classification tasks, which produce raw prediction values as real numbers ranging from (-∞,+∞).

[Stackexchange]
In Math, Logit is a function that maps probabilities ([0, 1]) to R ((-inf, inf)).
Probability of 0.5 corresponds to a logit of 0. Negative logit correspond to probabilities less than 0.5, positive to > 0.5.

In ML, it can be the vector of raw (non-normalized) predictions that a classification model generates, which is ordinarily then passed to a normalization function. If the model is solving a multi-class classification problem, logits typically become an input to the softmax function. The softmax function then generates a vector of (normalized) probabilities with one value for each possible class.

In binary classification problem, logits also sometimes refer to the element-wise inverse of the sigmoid function.

简而言之，在分类神经网络模型中，logits就是全连接层的（未经sigmoid或softmax处理）直接输出。在上面的模型中最后一级是没有激活函数的，也就是说模型的最后一层输出的是logits，因此以下loss函数中要指定from_logits = True.

In [None]:
# Instantiate a logistic loss function that expect integer targets
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True)

In [None]:
# Instantiate an accuracy metric
accuracy = tf.keras.metrics.SparseCategoricalAccuracy()

In [None]:
# Instantiate an optimizer
optimizer = tf.keras.optimizers.Adam()

# 5. Training the model

Q: In the following iteration loop, what does 'one step' refers to? Is it related to the above 'dataset.shuffle()', esp, the parameters buffer_size, and batch?

In [None]:
# Iterate over the batches of the dataset
for step,(x,y) in enumerate(dataset):
    # Open a GradientTape
    with tf.GradientTape() as tape:
        # Forward pass
        logits = model(x)
        # loss for this batch
        loss_value = loss(y,logits)
        
    # Get gradients of weights w.r.t the loss
    gradients = tape.gradient(loss_value, model.trainable_weights)
    
    # Update the weights of our linear layer
    optimizer.apply_gradients(zip(gradients,model.trainable_weights))
    
    # Update the running accuracy
    accuracy.update_state(y, logits)
    
    # Logging
    if (step % 100)==0:
        print('step = ',step)
        print('Loss from the last step: ',loss_value)
        print('Total running accuracy so far: ', float(accuracy.result()))
    

# 6. Prediction/Evaluation with test set

In [None]:
x_test = x_test[:].reshape(10000,784).astype('float32') / 255
logits_test = model(x_test)

In [None]:
print(logits_test[0,:])

In [None]:
print(tf.argmax(logits_test[:10], axis=1))

In [None]:
print(y_test[:10])

In [None]:
plt.figure(2, figsize=(15,10))
plt.subplot(131)
plt.imshow(x_test[0,:].reshape(28,28))
plt.subplot(132)
plt.imshow(x_test[1,:].reshape(28,28))
plt.subplot(133)
plt.imshow(x_test[2,:].reshape(28,28))