In [1]:
import tensorflow as tf
from tensorflow import keras

class MyMetric(keras.metrics.Metric):

    def __init__(self, name="r2_score", **kwargs):
        super().__init__(name=name, **kwargs)
        self.sse_sum = self.add_weight(name="sse_sum", initializer="zeros")
        self.tss_sum = self.add_weight(name="tss_sum", initializer="zeros")
        self.total_samples = self.add_weight(
            name="total_samples", initializer="zeros", dtype="int32")

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.cast(y_true, y_pred.dtype)
        residuals = y_true - y_pred
        sse = tf.reduce_sum(tf.square(residuals))
        self.sse_sum.assign_add(sse)

        mean_y_true = tf.reduce_mean(y_true)
        total_deviation = y_true - mean_y_true
        tss = tf.reduce_sum(tf.square(total_deviation))
        self.tss_sum.assign_add(tss)

        num_samples = tf.shape(y_pred)[0]
        self.total_samples.assign_add(num_samples)

    def result(self):
        r2_score = 1 - (self.sse_sum / (self.tss_sum + tf.keras.backend.epsilon()))
        return r2_score

    def reset_state(self):
        self.sse_sum.assign(0.)
        self.tss_sum.assign(0.)
        self.total_samples.assign(0)

In [2]:
from tensorflow.keras.datasets import imdb
import numpy as np
from tensorflow.keras import layers
(train_data, train_labels), _ = imdb.load_data(num_words=10000)

def vectorize_sequences(sequences, dimension=10000):
    results = np.zeros((len(sequences), dimension))
    for i, sequence in enumerate(sequences):
        results[i, sequence] = 1.
    return results
train_data = vectorize_sequences(train_data)

model = keras.Sequential([
    layers.Dense(16, activation="relu"),
    layers.Dense(16, activation="relu"),
    layers.Dense(1, activation="sigmoid")
])
model.compile(optimizer="rmsprop",
              loss="binary_crossentropy",
              metrics=["accuracy", MyMetric()])
history_original = model.fit(train_data, train_labels, epochs=20, batch_size=128, validation_split=0.4)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
