In [0]:
from google.colab import drive
from pathlib import Path
drive.mount('/content/drive')
DATA_DIR = Path("/content/drive/My Drive/SH/report/data")

KeyboardInterrupt: ignored

In [0]:
%tensorflow_version 2.x
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras

tf.keras.backend.clear_session()
tf.__version__

In [0]:
# Radial basis activation 
@tf.function
def rbf(x):
    return tf.exp(-tf.pow(x, 2))

In [0]:
model = keras.models.Sequential([
    keras.layers.Dense(1,
                       use_bias=False,
                       input_shape=(2,),
                       activation=rbf)
])

In [0]:
y = np.array([
    1,
    rbf(2.),
    rbf(2.)
])
x = np.array([
    (2, 2),
    (2, 0),
    (0, 2)
])
x, y

In [0]:
def get_weights_history(initial_weights, learning_rate=10, momentum=.9, epochs: int=10000):
    model.compile(loss="mse", optimizer=tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=momentum), metrics=["accuracy"])
    model.weights[0].assign(tf.constant(np.array(initial_weights).reshape((2, 1)).astype(np.float32)))

    class WeightsHistory(keras.callbacks.Callback):
        def on_train_begin(self, logs={}):
            self.weights = []

        def on_batch_end(self, batch, logs={}):
            self.weights.append(self.model.weights[0].numpy().squeeze())

    weights_history = WeightsHistory()

    history = model.fit(x, y,
                        batch_size=len(y),
                        epochs=10000,
                        verbose=False,
                        callbacks=[weights_history])
    weights = np.array(weights_history.weights)
    return weights

In [0]:
weights = [get_weights_history([1.8, 1.5],  learning_rate=10,  momentum=.9),
           get_weights_history([1.5, 1.8],  learning_rate=10,  momentum=.9),
           get_weights_history([2., 2.],    learning_rate=100, momentum=.7),
           get_weights_history([1.8, -1.5], learning_rate=1,   momentum=.9),
           get_weights_history([1.5, -1.8], learning_rate=1,   momentum=.9),
           get_weights_history([2., -2.],   learning_rate=100, momentum=.7),
           get_weights_history([-1.8, -1.5],  learning_rate=10,  momentum=.9),
           get_weights_history([-1.5, -1.8],  learning_rate=10,  momentum=.9),
           get_weights_history([-2., -2.],    learning_rate=100, momentum=.7),
           get_weights_history([-1.8, 1.5], learning_rate=1,   momentum=.9),
           get_weights_history([-1.5, 1.8], learning_rate=1,   momentum=.9),
           get_weights_history([-2., 2.],   learning_rate=100, momentum=.7)]

In [0]:
plt.axis("equal")
for weight in weights:
    plt.plot(*weight.T)

In [0]:
# Save data
def save_weights_history(weights, filename: str, num_samples=500):
    with open(filename, "w") as f:
        for w in weights[::weights.shape[0] // num_samples]:
            f.write("\t".join(map(str, w)) + "\n")
for i, weight in enumerate(weights, 1):
    save_weights_history(weight, DATA_DIR / f"stripe_problem_weights_history_{i}.dat")

In [0]:
np.array(weights)[:, -1,:]