In [124]:
import tensorflow as tf
import tf_keras
from tf_keras.layers import Dense
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
import io

In [125]:
def plot_to_image(figure):
  buf = io.BytesIO()
  plt.savefig(buf, format='png')
  plt.close(figure)
  buf.seek(0)
  image = tf.image.decode_png(buf.getvalue(), channels=4)
  image = tf.expand_dims(image, 0)
  return image

In [126]:
X_train = np.random.rand(100,1) * 10
y_train = 2 * X_train + 1
X_val = X_train
y_val = y_train

In [127]:
model = tf_keras.Sequential([Dense(10, activation="relu", input_shape=(1,)),
                             Dense(10, activation="relu"),
                             Dense(1, activation="linear")])

model.compile(optimizer=tf_keras.optimizers.Adam(learning_rate=1E-3),
              loss="mse")

In [128]:
logdir = "../logs/pictures/" + datetime.now().strftime("%Y%m%d-%H%M%S")
file_writer = tf.summary.create_file_writer(logdir)

def plot(X_train, y_train, X_val, y_pred):
    fig, ax = plt.subplots(figsize=(5,4))
    
    ax.plot(X_train, y_train, label="true")
    ax.plot(X_val, y_pred, label="pred")
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.legend()
    
    return fig

def log_plot(epoch, logs):
    y_pred = model.predict(X_val)
    figure = plot(X_train, y_train, X_val, y_pred)
    image = plot_to_image(figure)
    with file_writer.as_default():
        tf.summary.image("plot", image, step=epoch)
    

plot_callback = tf_keras.callbacks.LambdaCallback(on_epoch_end=log_plot)

In [129]:
model.fit(X_train, y_train, epochs=25, batch_size=32, verbose=0, callbacks=[plot_callback])
model.evaluate(X_val, y_val)



38.42515563964844