In [None]:
import tensorflow as tf
import tensorboard
import numpy as np
import matplotlib.pyplot as plt

def generate_train_data(f = lambda x: x, start = -1.0, end = 1.0, step = 0.0001):
  x = np.arange(start, end, step)
  np.random.shuffle(x)
  y = f(x)
  return x, y

def split_data(x, y, test_size=0.2):
  assert len(x) == len(y)
  test_len = int(len(x)*test_size)
  train_len = len(x) - test_len
  x_train = x[:train_len]
  x_test = x[train_len:]
  y_train = y[:train_len]
  y_test = y[train_len:]
  return x_train, x_test, y_train, y_test

def plot_graph(functions, start=-1.0, end=1.0, step=0.1, title=''):
    x = np.arange(start, end, step)
    for f in functions:
        y = f(x)
        plt.plot(x, y, label=f.__name__)
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title(title)
    plt.grid(True)
    plt.legend()  
    plt.show()

def training_function(x):
  #  return x * 2 + 0.3
  # return np.sin(x)
  # return np.cos(x)
  # return np.exp(x)
  return np.cos(x) * np.sin(x)

x, y = generate_train_data(training_function)
x_train, x_test, y_train, y_test = split_data(x, y)

model = tf.keras.Sequential([
  tf.keras.layers.Dense(32, activation='relu', input_shape=(1,)),
  tf.keras.layers.Dense(32, activation='relu'),
  tf.keras.layers.Dense(1, activation='linear')
])

optimizer = tf.keras.optimizers.Adam()

metrics = ['mean_absolute_error', tf.keras.metrics.RootMeanSquaredError(), 'mean_absolute_percentage_error']

model.compile(loss='mean_squared_error', optimizer=optimizer, metrics=metrics)

def trained_function(x):
   return model.predict(np.reshape(x, (-1, 1)))

plot_graph([training_function, trained_function], -1.0, 1.0)

history = model.fit(x_train, y_train, batch_size=128, epochs=10, validation_split=0.2, callbacks=[tf.keras.callbacks.LambdaCallback(on_epoch_end=lambda batch, logs: plot_graph([training_function, trained_function], -1.0, 1.0))])

plot_graph([training_function, trained_function], -1.0, 1.0)

print(model.evaluate(x_test, y_test))

prediction = model.predict(np.array([[0.5]]))
print(prediction)

weights = model.get_weights()
print(weights)