In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, Dense, Activation
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import Callback

%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
def generate_samples(n_samples=100, func=lambda x: np.cos(2 * np.pi * x), xlim=(-1, 1)):
    np.random.seed(42)
    x = np.random.uniform(*xlim, n_samples)
    y = func(x)
    return x[...,np.newaxis], y[...,np.newaxis]


In [None]:
from sklearn.model_selection import train_test_split

n_samples = 100

#func = lambda x: -.6*x+(1 - (-.5 < x) * (x < -.08)) * np.cos(4 * np.pi * x / (.8 * x ** 2 + .7)) * np.exp(-x ** 2 + .3*x)
func = lambda x: np.exp(-5*x**2)+ .4*np.exp(-10*(x-1.8)**2)

func = lambda x: np.cos(10*x)/(1+np.abs(x))

X0, Y0 = generate_samples(n_samples, func, xlim=(-1,1))


X, X_val, Y, Y_val =  train_test_split(X0,Y0,train_size = 0.8)

print(X.shape, Y.shape, X_val.shape, Y_val.shape)

In [None]:
plt.figure(figsize=(12,7))
x = np.linspace(-2,2,200)
plt.plot(x, func(x), color = "k", label = "func")
plt.plot(X[:,0],Y[:,0],"o", color = "C0", label = "train")
plt.plot(X_val[:,0],Y_val[:,0],".", color = "C1", label = "test")
plt.legend()
plt.show()

In [None]:
def build_model(activation = "relu", n_neurons=32, n_hidden = 2):
    np.random.seed(42)
    inp = Input((1,), name = "input")
    layer = inp 
    for n in range(n_hidden):
        layer = Dense(n_neurons,activation=activation, name= "hidden_%s"%(n+1))(layer)
        
    out = Dense(1, activation=None, name =  "output")(layer)
    return Model(inp, out)

In [None]:
from IPython.display import clear_output
from collections import defaultdict
import tensorflow.keras.backend as K

class MyCallback(Callback):
    def __init__(self, X, Y, xlim_test =(-1.5,1.5), n_interval = 5, smooth = .4, yscale = "log"):
        self._n_interval = n_interval
        self._logs = defaultdict(list)
        self._logs_smooth = defaultdict(list)
        self._weights = []
        self._X_test = np.linspace(*xlim_test,200).reshape(-1,1)
        self._X_train = X
        self._Y_train = Y
        self._axs = None
        self._yscale = yscale
        self._smooth = smooth
        super(MyCallback,self).__init__()

    def on_epoch_end(self, epoch, logs={}):

        for k,v in logs.items():
            self._logs[k].append(v)
            if epoch==0:
                self._logs_smooth[k].append(v)
            else:
                self._logs_smooth[k].append((1-self._smooth)*v+self._smooth*self._logs_smooth[k][-1])
            
            
        ws = np.concatenate([w.flatten() for w in self.model.get_weights()])
        self._weights.append(ws)

        Y_pred_test  = self.model.predict(self._X_test)
        Y_pred_train = self.model.predict(self._X_train)

        # plot every self._n_interval epoch 
        if (epoch % self._n_interval) ==0:

            _, self.axs = plt.subplots(1, 2, figsize=(12,4))
            self.axs = self.axs.flatten()    

            clear_output(wait=True)

            for i,k in enumerate(self._logs.keys()):
                self.axs[0].plot(np.arange(epoch+1),self._logs[k], color = "C%s"%i, alpha = .2)
                self.axs[0].plot(np.arange(epoch+1),self._logs_smooth[k], color = "C%s"%i,label = k)
                
            self.axs[0].legend()
            self.axs[0].set_yscale(self._yscale)
            
            self.axs[1].plot(self._X_test[:,0], func(self._X_test)[:,0],"-", color = "k", label = "true (test)")
            self.axs[1].plot(self._X_test[:,0], Y_pred_test[:,0],".", color = "C0", label = "pred (test)")
            if Y_pred_test.shape[-1] >1:
                m =  Y_pred_test[:,0]
                s =  Y_pred_test[:,1]
                self.axs[1].fill_between(self._X_test[:,0], m-1.414*s, m+1.414*s, color = "C0", alpha =.2)
                
            self.axs[1].plot(self._X_train[:,0], Y_pred_train[:,0],".", color = "C1", label = "pred (train)")

            self.axs[1].set_ylim(-1.5,1.2)                                 
            self.axs[1].legend()
            self.axs[1].legend()

            plt.show()
    



In [None]:
model = build_model(n_hidden = 2)
model.compile(loss="mse", optimizer=Adam(lr=0.005))

model.summary()


In [None]:
cb = MyCallback(X,Y, xlim_test = (-2,2), n_interval = 5)
model.fit(X, Y, batch_size=2,
          validation_data=(X_val, Y_val),
          callbacks=[cb],
          epochs=1000)


In [None]:
w,b = model.get_weights()[:2]
print(w,b)
plt.plot(X_val[:,0], model.predict(X_val)[:,0])

In [None]:
plt.plot(tuple(w.mean() for w in cb._weights))