In [8]:
import numpy as np
import torch
import gurobipy as grb
import torch.nn as nn
import matplotlib.pyplot as plt

from src.neural_networks.verinet_nn import VeriNetNN
from src.algorithm.verinet import VeriNet
from src.data_loader.input_data_loader import load_neurify_mnist, load_cifar10_human_readable
from src.data_loader.nnet import NNET
from src.algorithm.verification_objectives import LocalRobustnessObjective
from src.algorithm.verinet_util import Status

### Load 50 images and create bounds

In [16]:
images = load_neurify_mnist("../../data/mnist_neurify/test_images_100/", list(range(100))).reshape(-1, 784)
print(images.shape)
eps = 5
input_bounds = np.zeros((*images.shape, 2), dtype=np.float32)
input_bounds[:, :, 0] = images - eps
input_bounds[:, :, 1] = images + eps

(100, 784)


### Load network, normalise data, and initialise solver

In [17]:
nnet = NNET("../../data/models_nnet/neurify/mnist24.nnet")
images = nnet.normalize_input(images)
input_bounds = nnet.normalize_input(input_bounds)

model = nnet.from_nnet_to_verinet_nn()
model.eval()
solver = VeriNet(model, max_procs=20)
targets = model(torch.Tensor(images)).argmax(dim=1).numpy()

### Verifiy inputs

In [None]:
counter_examples = []
counter_examples_idx = []

for i, input_bound in enumerate(input_bounds):
    objective = LocalRobustnessObjective(targets[i], input_bound, output_size=10)
    status = solver.verify(objective, timeout=3600, no_split=False, verbose=False)
    
    if status == Status.Unsafe:
        counter_examples.append(solver.counter_example)
        counter_examples_idx.append(i)
        
    print(f"Image {i:3}: {status:13}, branches explored: {solver.branches_explored:3}, max depth: {solver.max_depth:2}")

Image   0: Status.Safe  , branches explored:   1, max depth:  0
Image   1: Status.Safe  , branches explored:  21, max depth:  6
Image   2: Status.Unsafe, branches explored:   1, max depth:  0
Image   3: Status.Unsafe, branches explored:   1, max depth:  0
Image   4: Status.Safe  , branches explored:   1, max depth:  0
Image   5: Status.Safe  , branches explored:  17, max depth:  4
Image   6: Status.Safe  , branches explored:   1, max depth:  0
Image   7: Status.Safe  , branches explored:   1, max depth:  0
Image   8: Status.Safe  , branches explored: 341, max depth: 15
Image   9: Status.Safe  , branches explored:   1, max depth:  0
Image  10: Status.Safe  , branches explored:   1, max depth:  0
Image  11: Status.Safe  , branches explored:   1, max depth:  0
Image  12: Status.Unsafe, branches explored:   1, max depth:  0
Image  13: Status.Unsafe, branches explored:   1, max depth:  0
Image  14: Status.Safe  , branches explored:   1, max depth:  0
Image  15: Status.Safe  , branches explo

### Visualise counterexamples

In [7]:
for idx in range(len(counter_examples_idx)):
    plt.figure(figsize=(7,7))
    counter_example = counter_examples[idx]
    image = images[counter_examples_idx[idx]]
    
    diff = (abs(counter_example - image) * 255).astype(np.int32) * 10
    
    plt.subplot(1,3,1)
    plt.imshow(image.reshape((28,28)), cmap="gray")
    plt.axis('off')
    plt.title(f"Class={model(torch.Tensor(image)).argmax(dim=1).numpy()[0]}")
    plt.subplot(1,3,2)
    plt.imshow(diff.reshape((28, 28)), cmap="gray", vmin=0, vmax=255)
    plt.axis('off')
    plt.title("Noise x10")
    plt.subplot(1,3,3)
    plt.imshow(counter_example.reshape(28, 28), cmap="gray")
    plt.axis('off');
    plt.title(f"Predicted={model(torch.Tensor(counter_example)).argmax(dim=1).numpy()[0]}");