<a href="https://colab.research.google.com/github/joaogui1/Recreational/blob/master/weight_evolution.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
!pip install tensorflow==2.1

In [7]:
 import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
from tensorflow.keras import layers, Sequential
from keras.datasets import mnist
from keras.utils import  to_categorical
import numpy as np
from scipy.stats import entropy as kl_div
from numpy.linalg import svd
print(tf.__version__)

2.1.0


In [2]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = (x_train.reshape(-1, 784) - 177.5)/177.5
x_test = (x_test.reshape(-1, 784) - 177.5)/177.5
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

Downloading data from https://s3.amazonaws.com/img-datasets/mnist.npz


In [3]:
x_train.shape

(60000, 784)

In [0]:
model = Sequential([layers.Dense(1024, activation='relu', input_dim=(28*28)),
                    layers.Dense(512, activation='relu'),
                    layers.Dense(10, activation='softmax')])

In [0]:
def plot_layers(model, epoch=0):
  for idx, layer in enumerate(model.layers):
    W, b = layer.weights[0].numpy(), layer.weights[1].numpy()
    plt.hist(W.reshape(np.prod(W.shape)))
    plt.savefig(f'W_{idx}_{epoch}.png')
    plt.close()
    plt.hist(b)
    plt.savefig(f'b_{idx}_{epoch}.png')
    plt.close()

In [0]:
class PlotCallback(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    plot_layers(self.model, epoch + 1)
  

In [0]:
class PrintSingularValues(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=[]):
    for idx, layer in enumerate(self.model.layers):
      W = layer.weights[0].numpy()
      __, s, _ = svd(W)
      s = np.sort(s)
      print(f"Largest Singular Value for {idx}th layer:", s[-1])

In [0]:
#WIP
class PrintKL(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=[]):
    if epoch > 0:
      for idx, layer in enumerate(self.model.layers):
        W_hist, b_hist = np.histogram(layer.weights[0].numpy()), np.histogram(layer.weights[1].numpy())
        kl_W = kl_div(W_hist, self.W_hist[idx])
        kl_b = kl_div(b_hist, self.b_hist[idx])
        print(f"Weights KL Divergence: {kl_W}, bias KL Divergence {kl_b}")
        self.W_hist[idx], self.b_hist[idx] = W_hist, b_hist
    else: 
      self.W_hist = []
      self.b_hist = []
      print(logs)
      for idx, layer in enumerate(self.model.layers):
        W_hist, b_hist = np.histogram(layer.weights[0].numpy()), np.histogram(layer.weights[1].numpy())
        self.W_hist.append(W_hist)
        self.b_hist.append(b_hist)

In [0]:
plot_layers(model)

In [0]:
model.compile('adam', 'categorical_crossentropy', metrics=['acc'])

In [19]:
model.fit(x_train, y_train, batch_size=128, epochs=10, callbacks=[PlotCallback(), PrintSingularValues()])

Train on 60000 samples
Epoch 1/10
Largest Singular Value for 1th layer: 9.572143
Largest Singular Value for 2th layer: 3.529023
Epoch 2/10
Largest Singular Value for 1th layer: 9.697352
Largest Singular Value for 2th layer: 3.5807707
Epoch 3/10
Largest Singular Value for 1th layer: 9.853334
Largest Singular Value for 2th layer: 3.6637583
Epoch 4/10
Largest Singular Value for 1th layer: 10.077137
Largest Singular Value for 2th layer: 3.7294214
Epoch 5/10
Largest Singular Value for 1th layer: 10.193441
Largest Singular Value for 2th layer: 3.766798
Epoch 6/10
Largest Singular Value for 1th layer: 10.374943
Largest Singular Value for 2th layer: 3.8121274
Epoch 7/10
Largest Singular Value for 1th layer: 10.5168295
Largest Singular Value for 2th layer: 3.8493502
Epoch 8/10
Largest Singular Value for 1th layer: 10.6215725
Largest Singular Value for 2th layer: 3.8751643
Epoch 9/10
Largest Singular Value for 1th layer: 10.877382
Largest Singular Value for 2th layer: 3.8904655
Epoch 10/10
Large

<tensorflow.python.keras.callbacks.History at 0x7fbd013ac320>

In [0]:
from google.colab import files
for epoch in range(11):
  for idx in range(3):
    files.download(f'W_{idx}_{epoch}.png')
    files.download(f'b_{idx}_{epoch}.png')