In [1]:
# 结合 Focal Loss实现自定义损失函数(model.fit)

In [2]:
from __future__ import absolute_import, division, print_function, unicode_literals


import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model
import numpy as np

print(tf.__version__)
print(np.__version__)

mnist = np.load("mnist.npz")
x_train, y_train, x_test, y_test = mnist['x_train'], mnist['y_train'], mnist['x_test'], mnist['y_test']

x_train, x_test = x_train / 255.0, x_test /255.0
y_train = np.int32(y_train)
y_test = np.int32(y_test)

x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]
y_train = tf.one_hot(y_train, depth=10)
y_test = tf.one_hot(y_test, depth=10)
train_ds = tf.data.Dataset.from_tensor_slices((x_train,y_train)).shuffle(10000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).shuffle(100).batch(32)

2.0.0-rc0
1.18.4


In [3]:
x_test.shape

(10000, 28, 28, 1)

In [4]:
def MyModel():
    inputs = tf.keras.Input(shape=(28,28,1), name='digits')
    x = tf.keras.layers.Conv2D(32, 3, activation='relu')(inputs)
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(128, activation='relu')(x)
    outputs = tf.keras.layers.Dense(10, activation='softmax', name='predictions')(x)
    model = tf.keras.Model(inputs=inputs,outputs=outputs)
    return model

In [5]:
def FocalLoss(gamma=2.0, alpha=0.25):
    def focal_loss_fixed(y_true, y_pred):
        y_pred = tf.nn.softmax(y_pred, axis=-1)
        epsilon = tf.keras.backend.epsilon()
        y_pred = tf.clip_by_value(y_pred, epsilon, 1.0)
        
        y_true = tf.cast(y_true, tf.float32)
        
        loss = - y_true * tf.math.pow(1 - y_pred, gamma) * tf.math.log(y_pred)
        
        loss = tf.math.reduce_sum(loss,axis=1)
        return loss
    return focal_loss_fixed

In [6]:
model = MyModel()
model.compile(optimizer = tf.keras.optimizers.Adam(0.001),
             loss = FocalLoss(gamma=2.0,alpha=0.25),
             metrics = [tf.keras.metrics.CategoricalAccuracy()])

In [7]:
model.fit(train_ds, epochs=5, validation_data=test_ds)

Epoch 1/5
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x2b52c1a9b08>