# Lib

In [170]:
import numpy as np
from tensorflow.keras.datasets import mnist
import os
from tensorflow.keras import layers, models
import cv2
import tensorflow as tf
from matplotlib import pyplot as plt
from tensorflow import keras
from keras.models import *
from keras.layers import *
from keras.utils import *
from tensorflow.keras.utils import to_categorical
import pickle
import hashlib
import random

In [2]:
def identity(input_shape):
  num_units = input_shape
  custom_weights = np.eye(num_units)
  custom_biases = np.zeros((num_units,))
  custom_layer = tf.keras.layers.Dense(
      units=num_units,
      use_bias=True,
      kernel_initializer=tf.constant_initializer(custom_weights),
      bias_initializer=tf.constant_initializer(custom_biases),
      activation=tf.keras.activations.linear
  )
  input_shape = (None, num_units)
  custom_layer.build(input_shape)
  return custom_layer

# Classes

In [88]:
class Block:

  def __init__(self, id, previousHash, layer, layer_aft):
    self.id = id
    self.previousHash = previousHash
    self.layer = layer
    self.wi, self.bi = layer.get_weights()
    self.acti = tf.keras.activations.get(layer.activation.__name__)
    self.wi_aft, self.bi_aft = layer_aft.get_weights()
    self.acti_aft = tf.keras.activations.get(layer_aft.activation.__name__)

    hash1 = self.previousHash.encode() if self.id != 0 else "msg"
    hash2 = hashlib.sha256(hashlib.sha256(self.wi.tobytes()).hexdigest().encode() + hashlib.sha256(self.bi.tobytes()).hexdigest().encode()).hexdigest().encode()
    hash3 = hashlib.sha256(hashlib.sha256(self.wi_aft.tobytes()).hexdigest().encode() + hashlib.sha256(self.bi_aft.tobytes()).hexdigest().encode()).hexdigest().encode()
    self.hash = hashlib.sha256(hash1 + hash2 + hash3).hexdigest() if self.id != 0 else hashlib.sha256(hash2 + hash3).hexdigest()


  # This method compute the output of the layer's block
  def output(self, input):
    x = input.reshape((1, input.shape[0])) if len(input.shape) != 2 else None
    input = input if len(input.shape) == 2 else x
    return self.acti(tf.matmul(input, self.wi) + self.bi)

In [89]:
class Ouroboros(Block):
  def __init__(self, id, previousHash, layer, layer_aft):
    super().__init__(id, previousHash, layer, layer_aft)
    self.origHash = None
    self.authenticity = (self.previousHash == self.origHash)
    self.flag = False

  def output(self, input):
    if self.authenticity:
      if self.flag == True: print("Query Mode: \n")
      return input
    else:
      print("Tracking Mode! \n")

In [177]:
class DeepRing():
  def __init__(self, model, input, label):
    self.model = model
    self.label = label
    self.input = input
    self.blockchain = {}
    self.blockchain = self.setup()

  def setup(self):
    sh = self.input.shape[1] if len(self.input.shape) == 2 else self.input.shape[0]
    self.blockchain[0] = Ouroboros(0, None, identity(sh), self.model.layers[0])

    for j, i in enumerate(range(1, len(self.model.layers)+1)):
      self.blockchain[i] = Block(i, self.blockchain[i-1].hash, self.model.layers[j],
                              self.model.layers[j+1] if j != len(self.model.layers)-1 else self.model.layers[0] )

    self.blockchain[0].previousHash = self.blockchain[len(self.model.layers)].hash
    self.blockchain[0].origHash = self.blockchain[len(self.model.layers)].hash
    return self.blockchain

  def inference(self):
    self.blockchain[0].flag = False
    start = self.blockchain[0].output(self.input)

    for j, i in enumerate(range(1, len(self.blockchain))):
      self.blockchain[i] = Block(i, self.blockchain[i-1].hash, self.model.layers[j],
                              self.model.layers[j+1] if j != len(self.model.layers)-1 else self.model.layers[0] )
      start = self.blockchain[i].output(start)

    self.blockchain[0].previousHash = self.blockchain[len(self.model.layers)].hash
    self.blockchain[0].authenticity = (self.blockchain[0].origHash == self.blockchain[0].previousHash)
    self.blockchain[0].flag = True

    x = tf.one_hot(tf.argmax(start, axis=1), self.input.shape[0])
    return self.blockchain[0].output(list(np.where(x == 1)[1]))


  def evaluate(self):
    self.inference()
    return self.model.evaluate(self.input, self.label) if self.blockchain[0].authenticity else None

# Test

## Prepare model and data

In [166]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [167]:
x_train = x_train.astype('float32')
x_train = np.reshape(x_train, (x_train.shape[0], 28*28))
x_test = x_test.astype('float32')
x_test = np.reshape(x_test, (x_test.shape[0], 28*28))

In [168]:
x_train = x_train / 255.0
x_test = x_test / 255.0

In [171]:
def createModel():
  num_classes = 10

  model = models.Sequential()
  model.add(layers.Dense(900, activation='relu', input_shape=(784,)))
  model.add(layers.Dense(600, activation='relu'))
  model.add(layers.Dense(300, activation='relu'))
  model.add(layers.Dense(100, activation='relu'))
  model.add(layers.Dense(num_classes, activation='softmax'))

  model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
  return model

In [172]:
def fitModel(model, x_train, y_train, x_test, y_test):
  batch_size = 64
  epochs = 10
  model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_data=(x_test, y_test))

In [173]:
model = createModel()
fitModel(model, x_train, y_train, x_test, y_test)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


## Test tempering attack on the DeepRing model

In [174]:
def simpleAttack(layer):
  Wi, _ = layer.get_weights()
  noise = np.random.normal(0, 1, size=Wi.shape)
  Wi += noise
  layer.set_weights([Wi, _])

Setup the DeepRing model

In [181]:
my_model = DeepRing(model, x_train[2], y_train[2])
y_train[2]

4

Compute inference

In [182]:
my_model.inference()

Query Mode: 



[4]

Apply tempering attack on the DeepRing model

In [183]:
simpleAttack(my_model.blockchain[2].layer)

Compute inference in the attack case

In [184]:
my_model.inference()

Tracking Mode! 

