In [1]:
using Revise

In [13]:
using LazySets
using ModelVerification
using PyCall
using CSV
using ONNX
using Flux
using Test
# using DataFrames
using MLUtils
# import Flux: flatten

In [3]:
# using Flux: onehotbatch, onecold, flatten
# using Flux.Losses: logitcrossentropy
# using Statistics: mean
using CUDA
using MLDatasets: CIFAR10
using MLUtils: splitobs, DataLoader
using Accessors

In [4]:
function test_mlp(prop_method)
    small_nnet_file = "../test/networks/small_nnet.nnet"
    # small_nnet encodes the simple function 24*max(x + 1.5, 0) + 18.5
    small_nnet = read_nnet(small_nnet_file, last_layer_activation = ModelVerification.ReLU())
    flux_model = Flux.Chain(small_nnet)
    # println(flux_model.layers[1].weight, " ", flux_model.layers[1].bias) # max(x+1.5, 0) max(x+1.5, 0)              [0,4]
    # println(flux_model.layers[2].weight, " ", flux_model.layers[2].bias) # 4*max(x+1.5, 0)+2.5 4*max(x+1.5, 0)+2.5  [2.5, 18.5]
    # println(flux_model.layers[3].weight, " ", flux_model.layers[3].bias) # 24*max(x+1.5, 0)+18.5                    [18.5, 114.5]
    in_hyper  = Hyperrectangle(low = [-2.5], high = [2.5]) # expected out: [18.5, 114.5]
    out_violated    = Hyperrectangle(low = [19], high = [114]) # 20.0 ≤ y ≤ 90.0
    out_holds = Hyperrectangle(low = [18], high = [115.0]) # -1.0 ≤ y ≤ 50.0
    comp_violated    = Complement(Hyperrectangle(low = [10], high = [19])) # y ≤ 10.0 or 19 ≤ y
    comp_holds    = Complement(Hyperrectangle(low = [115], high = [118])) # y ≤ 10.0 or 18 ≤ y
    info = nothing
    search_method = BFS(max_iter=100, batch_size=10)
    split_method = Bisect(1)
    @test verify(search_method, split_method, prop_method, Problem(flux_model, in_hyper, out_holds)).status == :holds
    @test verify(search_method, split_method, prop_method, Problem(flux_model, in_hyper, out_violated)).status == :violated
    @test verify(search_method, split_method, prop_method, Problem(flux_model, in_hyper, comp_holds)).status == :holds
    @test verify(search_method, split_method, prop_method, Problem(flux_model, in_hyper, comp_violated)).status == :violated
end
@timed begin
    for i in 1:100
        test_mlp(Ai2z())
        # test_mlp(Crown())
        # test_mlp(Ai2s())
    end
end

(value = nothing, time = 4.835242157, bytes = 779664866, gctime = 0.216948196, gcstats = Base.GC_Diff(779664866, 50, 10, 14542456, 5619, 126, 216948196, 3, 0))

In [20]:
model = Chain([
    Conv((3, 3), 3 => 5, relu, pad=SamePad(), stride=(2, 2)), #pad=SamePad() ensures size(output,d) == size(x,d) / stride.
    BatchNorm(5),
    MeanPool((2,2)),
    SkipConnection(
        Chain([
            Conv((5, 5), 5 => 5, relu, pad=SamePad(), stride=(1, 1))
            ]),
        +
    ),
    ConvTranspose((3, 3), 5 => 2, relu, pad=SamePad(), stride=(2, 2)),#pad=SamePad() ensures size(output,d) == size(x,d) * stride.
    Flux.flatten,
    Dense(512, 100, relu),
    Dense(100, 10)
])
testmode!(model)
# image_seeds = CIFAR10(:train)[1:5].features # 32 x 32 x 3 x 5
image_seeds = [CIFAR10(:train)[i].features for i in 1:2]
# println(typeof(image_seeds[1][1,1,1,1]))
search_method = BFS(max_iter=1, batch_size=1)
split_method = Bisect(1)
prop_method = ImageStar()
output_set = BallInf(zeros(10), 1.0)
verify(search_method, split_method, prop_method, Problem(model, image_seeds, output_set))
# test_mlp(ImageStarZono())