In [None]:
import time
import tensorflow_datasets as tfds
import tensorflow as tf
from tensorflow import keras
from prw import PRW

In [None]:
class MyModel(keras.Model):
    def __init__(self):
        super().__init__()
        self.seq = keras.Sequential([
                          keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
                          keras.layers.LeakyReLU(alpha=0.3),
                          keras.layers.Conv2D(64, kernel_size=(3, 3), strides=2, activation="relu"),
                          keras.layers.LeakyReLU(alpha=0.3),
                          keras.layers.Conv2D(128, kernel_size=(3, 3), strides=2, activation="relu"),
                          keras.layers.LeakyReLU(alpha=0.3),
                          keras.layers.Conv2D(1, kernel_size=(1, 1), activation="relu"),
                          keras.layers.Flatten(),
                          keras.layers.Dense(10, activation="softmax")
                         ])
        
    @tf.function
    def call(self, x):
        return self.seq(x)
        
    @tf.function
    def train_step(self, data):
        if len(data) == 3:
            return self.train_on_embedding(data)
        return self.train_on_normal(data)
    
    @tf.function
    def train_on_normal(self, data):
        x, y = data
        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)            
            loss = self.loss(y, y_pred)
        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        return {"loss": loss}
        
    @tf.function
    def train_on_embedding(self, data):
        train, null_e, true_e = data
        x_train, y_train = train
        x_null_e, y_null_e = null_e
        x_true_e, y_true_e = true_e
        with tf.GradientTape(persistent=True) as train_tape: 
            train_pred = self(x_train, training=True)            
            train_loss = self.loss(y_train, train_pred)
            null_e_pred = self(x_null_e, training=True)
            null_e_loss = 0.25 * self.loss(y_null_e, null_e_pred)
            true_e_pred = self(x_true_e, training=True)
            true_e_loss = 0.25 * self.loss(y_true_e, true_e_pred)

        train_grads = train_tape.gradient(train_loss, self.trainable_variables)
        null_e_grads = train_tape.gradient(null_e_loss, self.trainable_variables)
        true_e_grads = train_tape.gradient(true_e_loss, self.trainable_variables)
        
        gradients = []
        for train_grad, null_e_grad, true_e_grad in zip(train_grads, null_e_grads, true_e_grads):
            gradients.append(train_grad + null_e_grad + true_e_grad)
        
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        
        return {
            "train_loss": train_loss, 
            "null_embedding_loss": null_e_loss, 
            "true_embedding_loss": true_e_loss,
            "total_loss": train_loss + null_e_loss + true_e_loss
        }

In [None]:
uniqid = "edfdabd4-7c42-4559-9335-36f282c2899f"
v = uniqid + "_" + str(int(time.time()))

def cast_to_float(example, label):
    return tf.cast(example, tf.float64) / 255.0, label

prw = PRW()
sig = prw.create_signature(v)
y_w = prw.transform(sig, 28, 28, 10, 6, 1)


ds_train = tfds.load("mnist", split="train", as_supervised=True)
ds_test = tfds.load("mnist", split="test", as_supervised=True)

ds_train = ds_train.map(cast_to_float)
ds_test = ds_test.map(cast_to_float)

ds_train_null = ds_train.map(prw.apply_null_embedding)
ds_train_embed = ds_train.map(prw.apply_true_embedding)

# plot an example null embeded image
for x_batch, y_batch in ds_train_null.take(1):
    example = x_batch.numpy().reshape((28, 28)).clip(0, 1)
    plt.imshow(example)
    plt.show()
    
# plot an example true embedded image
for x_batch, y_batch in ds_train_embed.take(1):
    example = x_batch.numpy().reshape((28, 28)).clip(0, 1)
    plt.imshow(example)
    plt.show()

ds_full_train = tf.data.Dataset.zip((ds_train, ds_train_null, ds_train_embed))

optim = tf.keras.optimizers.SGD(learning_rate=0.001, momentum=0.9)

ds_full_train = ds_full_train.shuffle(2048).batch(128).prefetch(-1)
ds_test = ds_test.batch(128)
model = MyModel()
model.build(input_shape=(None, 28, 28, 1))
model.compile(optimizer=optim, loss=keras.losses.SparseCategoricalCrossentropy(), metrics=["sparse_categorical_accuracy"])
model.fit(epochs=10, x=ds_full_train, validation_data=ds_test)

In [None]:
succ = prw.verify(model, 28, 28, 10, 6, sig, v, ds_train, 0.8, True)

if succ:
    print("Verification was successful, it's our model.")
else:
    print("Verification was not successful, maybe it's not our model.")

In [None]:
ds_train = tfds.load("mnist", split="train", as_supervised=True)
ds_test = tfds.load("mnist", split="test", as_supervised=True)
ds_train = ds_train.map(cast_to_float)
ds_test = ds_test.map(cast_to_float)

ds_train = ds_train.shuffle(2048).batch(128).prefetch(-1)
ds_test = ds_test.batch(128)

zero_epoch_acc = model.evaluate(ds_test, return_dict=True)['sparse_categorical_accuracy']

model.fit(epochs=10, x=ds_test, validation_data=ds_test)