In [None]:
import tensorflow as tf
import numpy as np
import tensorflow.keras as keras
from tensorflow.keras import layers

In [None]:
from tensorflow.keras.mixed_precision.experimental import Policy

In [None]:
## Mixed precision setting
## 이곳을 실행하지 않으면 Tensor core를 사용하지 않음

policy = Policy('mixed_float16')
tf.keras.mixed_precision.experimental.set_policy(policy)
print('Compute dtype: %s' % policy.compute_dtype)
print('Variable dtype: %s' % policy.variable_dtype)

In [None]:
inputs = keras.Input(shape=(784,), name='digits')
if tf.config.list_physical_devices('GPU'):
  print('The model will run with 4096 units on a GPU')
  num_units = 4096
else:
  # Use fewer units on CPUs so the model finishes in a reasonable amount of time
  print('The model will run with 64 units on a CPU')
  num_units = 64
dense1 = layers.Dense(num_units, activation='relu', name='dense_1')
x = dense1(inputs)
dense2 = layers.Dense(num_units, activation='relu', name='dense_2')
x = dense2(x)

In [None]:
print('x.dtype: %s' % x.dtype.name)
# 'kernel' is dense1's variable
print('dense1.kernel.dtype: %s' % dense1.kernel.dtype.name)

In [None]:
# INCORRECT: softmax and model output will be float16, when it should be float32
outputs = layers.Dense(10, activation='softmax', name='predictions')(x)
print('Outputs dtype: %s' % outputs.dtype.name)

In [None]:
# CORRECT: softmax and model output are float32
x = layers.Dense(10, name='dense_logits')(x)
outputs = layers.Activation('softmax', dtype='float32', name='predictions')(x)
print('Outputs dtype: %s' % outputs.dtype.name)

In [None]:
# The linear activation is an identity function. So this simply casts 'outputs'
# to float32. In this particular case, 'outputs' is already float32 so this is a
# no-op.
outputs = layers.Activation('linear', dtype='float32')(outputs)

In [None]:
model = keras.Model(inputs=inputs, outputs=outputs)
model.compile(loss='sparse_categorical_crossentropy',
              optimizer=keras.optimizers.RMSprop(),
              metrics=['accuracy'])

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255

In [None]:
history = model.fit(x_train, y_train,
                    batch_size=8192,
                    epochs=5,
                    validation_split=0.2)
test_scores = model.evaluate(x_test, y_test, verbose=2)
print('Test loss:', test_scores[0])
print('Test accuracy:', test_scores[1])


## Mixex precision setting
### Test gpu : RTX 2080ti

Epoch 1/5  
6/6 [==============================] - 0s 64ms/step - loss: 5.1665 - accuracy: 0.4094 - val_loss: 0.7263 - val_accuracy: 0.8307  
Epoch 2/5  
6/6 [==============================] - 0s 37ms/step - loss: 0.6760 - accuracy: 0.7938 - val_loss: 0.4567 - val_accuracy: 0.8476  
Epoch 3/5  
6/6 [==============================] - 0s 38ms/step - loss: 0.3648 - accuracy: 0.8856 - val_loss: 0.3276 - val_accuracy: 0.9015  
Epoch 4/5  
6/6 [==============================] - 0s 37ms/step - loss: 0.3125 - accuracy: 0.8989 - val_loss: 0.5661 - val_accuracy: 0.8050  
Epoch 5/5  
6/6 [==============================] - 0s 37ms/step - loss: 0.3052 - accuracy: 0.9035 - val_loss: 0.1625 - val_accuracy: 0.9500  
313/313 - 1s - loss: 0.1676 - accuracy: 0.9485  
Test loss: 0.16764894127845764  
Test accuracy: 0.9484999775886536  

## No mixex precision setting
### Test gpu : RTX 2080ti
Epoch 1/5  
6/6 [==============================] - 1s 106ms/step - loss: 4.2295 - accuracy: 0.4273 - val_loss: 0.7785 - val_accuracy: 0.7990  
Epoch 2/5  
6/6 [==============================] - 1s 86ms/step - loss: 0.6895 - accuracy: 0.7926 - val_loss: 0.3278 - val_accuracy: 0.9089  
Epoch 3/5  
6/6 [==============================] - 1s 86ms/step - loss: 0.3661 - accuracy: 0.8807 - val_loss: 0.2844 - val_accuracy: 0.9078  
Epoch 4/5  
6/6 [==============================] - 1s 86ms/step - loss: 0.3186 - accuracy: 0.9001 - val_loss: 0.1877 - val_accuracy: 0.9470  
Epoch 5/5  
6/6 [==============================] - 1s 87ms/step - loss: 0.1981 - accuracy: 0.9417 - val_loss: 0.5932 - val_accuracy: 0.8552  
313/313 - 1s - loss: 0.6030 - accuracy: 0.8537  
Test loss: 0.6030229926109314  
Test accuracy: 0.8536999821662903  