In [None]:
using Revise 

In [None]:
using ModelVerification

In [None]:
using Flux
using LazySets

In [None]:
small_nnet_file = "test/networks/small_nnet.nnet"
# small_nnet encodes the simple function 24*max(x + 1.5, 0) + 18.5
small_nnet = ModelVerification.read_nnet(small_nnet_file)
model = Flux.Chain(small_nnet)
input = Hyperrectangle(low = [-2.5], high = [2.5])
output = Hyperrectangle(low = [18.4], high = [114.5])
search_method = BFS(max_iter=20, batch_size=4)
split_method = Bisect(1)
prop_method = Ai2z()
problem = Problem(model, input, output)
verify(search_method, split_method, prop_method, problem)

In [None]:
using NeuralVerification, LazySets, PyCall, CSV
include("vnnlib_parser.jl")

function onnx_to_nnet(onnx_file)
    pushfirst!(PyVector(pyimport("sys")."path"), @__DIR__)
    nnet = pyimport("NNet")
    use_gz = split(onnx_file, ".")[end] == "gz"
    if use_gz
        onnx_file = onnx_file[1:end-3]
    end
    nnet_file = onnx_file[1:end-4] * "nnet"
    isfile(nnet_file) && return
    nnet.onnx2nnet(onnx_file, nnetFile=nnet_file)
end

function verify_an_instance(onnx_file, spec_file)
    use_gz = split(onnx_file, ".")[end] == "gz"
    nnet_file = use_gz ? onnx_file[1:end-7] * "nnet" : onnx_file[1:end-4] * "nnet"
    net = read_nnet(nnet_file)
    n_in = size(net.layers[1].weights)[2]
    n_out = length(net.layers[end].bias)

    specs = read_vnnlib_simple(spec_file, n_in, n_out)
    for spec in specs
        X_range, Y_cons = spec
        lb = [bd[1] for bd in X_range]
        ub = [bd[2] for bd in X_range]
        X = Hyperrectangle(low = lb, high = ub)
        res = nothing
        A = []
        b = []
        for Y_con in Y_cons
            A = hcat(Y_con[1]...)'
            b = Y_con[2]
            if length(b) > 1
                Y_adv = HPolytope(A, b)
                Y = Complement(Y_adv)
                solver = MIPVerify()
                prob = Problem(net, X, Y)
                res = solve(solver, prob)
            else
                Y = HPolytope(-A, -b)
                solver = ReluVal(max_iter=1e5)
                prob = Problem(net, X, Y)
                res = solve(solver, prob)
            end
            
            res.status == :violated && (return "violated")
            res.status == :unknown && (return "unknown")
        end
    end
    return "holds"
end

function main(args)
    file = CSV.File(args)
    dirpath = args[1:end-20]      
    outpath = dirpath * "out.txt"
    result_file = open("/home/verification/ModelVerification.jl/output.txt", "w")
    all_time = 0
    ave_time = 0
    instance_num = -1
    min_time = Inf
    max_time = 0
    hold_number = 0
    for row in file
        instance_num += 1
        onnx_file = dirpath * row[1] 
        vnnlib_file = dirpath * row[2]
        onnx_to_nnet(onnx_file)
        result = @timed verify_an_instance(onnx_file, vnnlib_file)
        print(instance_num)
        print("\n")
        print(result)
        print("\n") 
        all_time += result.time
        if result.time > max_time
            max_time = result.time
        elseif result.time < min_time
            min_time = result.time
        end

        if(result.value === "holds")
            hold_number += 1
        end
        #= open(outpath, "w") do io
            write(io, "\n")
            write(io, string(instance_num))
            write(io, "\n")
            write(io, result.value)
            write(io, "\n")
        end =#
        text = "This is instance $instance_num.\n"
        write(result_file, text)
        write(result_file, result.value)
        write(result_file, "\n")
    end
    ave_time = all_time / (instance_num + 1)
    print(ave_time)
    print("\n")
    print(max_time)
    print("\n")
    print(min_time)
    print("\n")
    print(hold_number)
    print("\n")
    text = "average time is $ave_time.\n"
    write(result_file, text)
    write(result_file, "\n")
    text = "max time is $max_time.\n"
    write(result_file, text)
    write(result_file, "\n")
    text = "min time is $min_time.\n"
    write(result_file, text)
    write(result_file, "\n")
    text = "holds number is $hold_number.\n"
    write(result_file, text)
    write(result_file, "\n")
    #close(file)
    close(result_file)
end

main("/home/verification/vnncomp2021/benchmarks/acasxu/acasxu_instances.csv")