<a href="https://colab.research.google.com/github/bklooste/tensorflowcollab/blob/master/mnistwithpretrained.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import tensorflow as tf
from tensorflow.keras import datasets, layers, models, Model
from tensorflow.keras.utils import to_categorical

In [23]:
from google.colab import drive
drive.mount('/content/gdrive')

model_save_name = 'mnistconv.h5'
#torch.save(model.state_dict(), path)


!ls '/content/gdrive/My Drive/saved_models'
path = F"/content/gdrive/My Drive/saved_models/{model_save_name}" 

#model.load(path) 


Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).
mnistconv.h5  mnistconv.pt


In [0]:
def margin_loss(y_true, y_pred):
    """
    Margin loss for Eq.(4). When y_true[i, :] contains not just one `1`, this loss should work too. Not test it.
    :param y_true: [None, n_classes]
    :param y_pred: [None, num_capsule]
    :return: a scalar loss value.
    """
    # return tf.reduce_mean(tf.square(y_pred))
    L = y_true * tf.square(tf.maximum(0., 0.9 - y_pred)) + \
        0.5 * (1 - y_true) * tf.square(tf.maximum(0., y_pred - 0.1))

    return tf.reduce_mean(tf.reduce_sum(L, 1))


In [25]:
base_model = tf.keras.models.load_model(path, custom_objects={'loss': margin_loss}, compile=False)

# Check its architecture
base_model.compile( loss = margin_loss, optimizer = tf.keras.optimizers.Adam())
base_model.trainable = False

# Let's take a look at the base model architecture
base_model.summary()



Model: "sequential_8"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv1 (Conv2D)               (32, 20, 20, 256)         20992     
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (32, 10, 10, 256)         0         
_________________________________________________________________
flatten_2 (Flatten)          (32, 25600)               0         
_________________________________________________________________
dense_4 (Dense)              (32, 64)                  1638464   
_________________________________________________________________
dense_5 (Dense)              (32, 10)                  650       
Total params: 1,660,106
Trainable params: 0
Non-trainable params: 1,660,106
_________________________________________________________________


In [26]:
# cause i screwed up and had extra layers in the base
layer_name = 'conv1'
intermediate_base_model = Model(inputs=base_model.input,
                                  outputs=base_model.get_layer(layer_name).output)

model = models.Sequential()
model.add(intermediate_base_model)
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10))
model.summary()

# model = tf.keras.Sequential([
#   base_model,
#   global_average_layer,
#   prediction_layer
# ])

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
model_1 (Model)              (32, 20, 20, 256)         20992     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (32, 10, 10, 256)         0         
_________________________________________________________________
flatten_1 (Flatten)          (32, 25600)               0         
_________________________________________________________________
dense_2 (Dense)              (32, 64)                  1638464   
_________________________________________________________________
dense_3 (Dense)              (32, 10)                  650       
Total params: 1,660,106
Trainable params: 1,639,114
Non-trainable params: 20,992
_________________________________________________________________


In [27]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.reshape(-1, 28, 28, 1).astype('float32') / 255.
x_test = x_test.reshape(-1, 28, 28, 1).astype('float32') / 255.
y_train = to_categorical(y_train.astype('float32'))
y_test = to_categorical(y_test.astype('float32'))
#x_train, x_test = x_train / 255.0, x_test / 255.0

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [0]:
am_recon=0.392

model.compile(optimizer='adam',
              loss=[margin_loss, 'mse'],
              loss_weights=[1., am_recon],
              metrics=['accuracy'])

In [30]:
model.fit(x_train, y_train, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x7f99c6e31048>

In [0]:
# need to try both  pooling and capsule trained networks 

#  So what we do is take the conv network , train with 0 to 8 ,
# than add 9 see the reusult.
# add some more neurons try again 
#suspect both will have issues then try shake and bake

# use different notebooks

#then try the same for capsule networks 

In [31]:
model.evaluate(x_test,  y_test, verbose=2)

313/313 - 5s - loss: 0.0174 - accuracy: 0.9890


[0.017385095357894897, 0.9890175461769104]

In [32]:
probability_model = tf.keras.Sequential([
  model,
  tf.keras.layers.Softmax()
])
probability_model.summary()

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
sequential_1 (Sequential)    (32, 10)                  1660106   
_________________________________________________________________
softmax (Softmax)            (32, 10)                  0         
Total params: 1,660,106
Trainable params: 1,639,114
Non-trainable params: 20,992
_________________________________________________________________


In [33]:
probability_model(x_test[:32])

<tf.Tensor: shape=(32, 10), dtype=float32, numpy=
array([[2.37217522e-04, 4.54253890e-03, 1.03620710e-02, 7.16515630e-03,
        2.55332212e-03, 7.29817839e-04, 6.44535394e-05, 9.71943259e-01,
        4.09561559e-04, 1.99262775e-03],
       [2.69478100e-04, 9.20898019e-05, 9.99260366e-01, 5.11646340e-06,
        4.31508779e-05, 2.69276643e-06, 6.48281275e-05, 2.14018310e-05,
        5.59200453e-05, 1.85009383e-04],
       [4.88716643e-04, 9.91415918e-01, 8.24047835e-04, 9.99032491e-05,
        8.58804851e-04, 3.29593336e-03, 6.44594606e-04, 1.37863657e-03,
        6.83921040e-04, 3.09527590e-04],
       [9.08543468e-01, 1.44310463e-02, 2.52515939e-03, 1.38960350e-02,
        1.02825277e-02, 6.58020843e-03, 1.94084775e-02, 8.09461623e-03,
        3.98677681e-03, 1.22516463e-02],
       [1.59934023e-03, 1.98725518e-03, 9.78483469e-04, 4.16089955e-04,
        9.87470329e-01, 6.23942062e-04, 1.42019871e-03, 8.14167841e-04,
        3.97015177e-03, 7.20130163e-04],
       [6.50969218e-04, 9