In [10]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from datetime import datetime
from maraboupy import Marabou

In [2]:
MODEL_PATH = "mnist_model"
TRAIN_EPOCHS = 3
NUM_SAMPLES_TO_TRY = 500
DELTA = 0.01
EPSILON = 0.1

In [3]:
(train_x, train_y), (test_x, test_y) = tf.keras.datasets.mnist.load_data()
train_x = train_x.reshape((train_x.shape[0], 784)) / 255.0
test_x = test_x.reshape((test_x.shape[0], 784)) / 255.0
print("train data shape", train_x.shape, train_y.shape)
print("test data shape", test_x.shape, test_y.shape)

train data shape (60000, 784) (60000,)
test data shape (10000, 784) (10000,)


In [4]:
tf_model = tf.keras.models.Sequential([
    tf.keras.layers.Input(shape=(784)),
    tf.keras.layers.Dense(16, activation=tf.keras.layers.ReLU()),
    tf.keras.layers.Dense(16, activation=tf.keras.layers.ReLU()),
    tf.keras.layers.Dense(10),
])

In [5]:
tf_model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense (Dense)                (None, 16)                12560     
_________________________________________________________________
dense_1 (Dense)              (None, 16)                272       
_________________________________________________________________
dense_2 (Dense)              (None, 10)                170       
Total params: 13,002
Trainable params: 13,002
Non-trainable params: 0
_________________________________________________________________


In [6]:
tf_model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

In [7]:
tf_model.fit(
    train_x, train_y,
    epochs=TRAIN_EPOCHS,
    validation_data=(test_x, test_y),
)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<tensorflow.python.keras.callbacks.History at 0x19fc5be9940>

In [8]:
tf.saved_model.save(tf_model, MODEL_PATH)



INFO:tensorflow:Assets written to: mnist_model\assets


INFO:tensorflow:Assets written to: mnist_model\assets


In [9]:
adv_x = np.empty((0, 784), int)
adv_y = np.empty((0), int)
num_sat_samples = 0
num_unsat_samples = 0

In [11]:
print("start time", datetime.now().strftime("%H:%M:%S"))

samples_to_try = zip(train_x[:NUM_SAMPLES_TO_TRY], train_y[:NUM_SAMPLES_TO_TRY])
for sample_index, (image, actual_label) in enumerate(samples_to_try):
    
    for target_label in range(10):
        if target_label == actual_label: continue
        
        mb_model = Marabou.read_tf(MODEL_PATH, modelType="savedModel_v2")
        mb_input_vars = mb_model.inputVars[0][0]
        mb_output_vars = mb_model.outputVars[0]
        
        for variable, value in enumerate(image):
            mb_model.setLowerBound(mb_input_vars[variable], value - DELTA)
            mb_model.setUpperBound(mb_input_vars[variable], value + DELTA)

        for label in range(10):
            if label != target_label:
                mb_model.addInequality(
                    [ mb_output_vars[label], mb_output_vars[target_label] ],
                    [ +1.0, -1.0 ], -1.0 * EPSILON,
                )
        
        mb_variables, stats = mb_model.solve("marabou.log", verbose=False)
        if not mb_variables:
            num_unsat_samples += 1
            print(f"unsat for sample {sample_index} with target label {target_label}")
            continue
        num_sat_samples += 1
        print(f"sat for sample {sample_index} with target label {target_label}")
        
        image = list()
        for i in range(0, 784):
            image.append(mb_variables[mb_input_vars[i]])
        
        adv_x = np.append(adv_x, np.array([image]), axis=0)
        adv_y = np.append(adv_y, np.array([actual_label]), axis=0)
        
        np.save("adv_x.npy", adv_x)
        np.save("adv_y.npy", adv_y)
        
        print("adv data shape", adv_x.shape, adv_y.shape)
        print(f"num sat samples: {num_sat_samples}, " \
              f"num unsat samples: {num_unsat_samples}")

print("end time", datetime.now().strftime("%H:%M:%S"))

start time 12:14:14
unsat for sample 0 with target label 0
unsat for sample 0 with target label 1
unsat for sample 0 with target label 2
unsat for sample 0 with target label 3
unsat for sample 0 with target label 4
unsat for sample 0 with target label 6
unsat for sample 0 with target label 7
unsat for sample 0 with target label 8
unsat for sample 0 with target label 9
unsat for sample 1 with target label 1
unsat for sample 1 with target label 2
unsat for sample 1 with target label 3
unsat for sample 1 with target label 4
unsat for sample 1 with target label 5
unsat for sample 1 with target label 6
unsat for sample 1 with target label 7
unsat for sample 1 with target label 8
unsat for sample 1 with target label 9
unsat for sample 2 with target label 0
unsat for sample 2 with target label 1
unsat for sample 2 with target label 2
unsat for sample 2 with target label 3
unsat for sample 2 with target label 5
unsat for sample 2 with target label 6
unsat for sample 2 with target label 7
unsat