In [1]:
using Revise

In [2]:
using ModelVerification

In [3]:
using LazySets
using PyCall
using CSV
using ONNX
using Flux
using Test
using NNlib
using ONNXNaiveNASflux
using NaiveNASflux
using Zygote
# using DataFrames
# import Flux: flatten

In [4]:
# using Flux: onehotbatch, onecold, flatten
# using Flux.Losses: logitcrossentropy
# using Statistics: mean
using CUDA
using MLDatasets: CIFAR10, MNIST
using MLUtils: splitobs, DataLoader
using Accessors
using Profile
using LinearAlgebra
using Einsum

In [None]:
struct AlphaLayer
    node
    alpha
    lower
    unstable_mask
    lower_mask 
    upper_slope
    lower_bias
    upper_bias
end
Flux.@functor AlphaLayer (alpha,)

function (f::AlphaLayer)(x)
    Last_A = x[1]
    Last_bias = x[2]
    lower_slope = clamp.(f.alpha, 0, 1) .* f.unstable_mask .+ f.lower_mask 
    if f.lower 
        New_A = bound_oneside(Last_A, lower_slope, f.upper_slope)
    else
        New_A = bound_oneside(Last_A, f.upper_slope, lower_slope)
    end

    if isnothing(Last_bias)
        return [New_A, nothing]
    end
    New_bias = multiply_bias(Last_bias, f.upper_slope, f.upper_bias, f.lower_bias)

    return [New_A, New_bias]
end

#Upper bound slope and intercept according to CROWN relaxation.
function relu_upper_bound(lower, upper)
    lower_r = clamp.(lower, -Inf, 0)
    upper_r = clamp.(upper, 0, Inf)
    upper_r .= max.(upper_r, lower_r .+ 1e-8)
    upper_slope = upper_r ./ (upper_r .- lower_r) #the slope of the relu upper bound
    upper_bias = - lower_r .* upper_slope #the bias of the relu upper bound
    return upper_slope, upper_bias
end

function clamp_mutiply_A(last_A, slope_pos, slope_neg) 
    A_pos = clamp.(last_A, 0, Inf)
    A_neg = clamp.(last_A, -Inf, 0)
    slope_pos = repeat(reshape(slope_pos,(1, size(slope_pos)...)), size(A_pos)[1], 1, 1) #add spec dim for slope_pos
    slope_neg = repeat(reshape(slope_neg,(1, size(slope_neg)...)), size(A_neg)[1], 1, 1) #add spec dim for slope_pos
    New_A = slope_pos .* A_pos .+ slope_neg .* A_neg 
    return New_A
end 


function clamp_mutiply_bias(last_A, bias_pos, bias_neg) 
    A_pos = clamp.(last_A, 0, Inf)
    A_neg = clamp.(last_A, -Inf, 0) 
    if bias_pos !== nothing #new_bias_pos = torch.einsum('s...b,s...b->sb', A_pos, bias_pos)
        new_bias_pos = zeros((size(A_pos)[1], size(A_pos)[end]))#spec_dim x batch dim
        @einsum new_bias_pos[s,b] = A_pos[s,r,b] * bias_pos[r,b]
    end

    if bias_neg !== nothing #new_bias_neg = torch.einsum('...sb,...sb->sb', A_neg, bias_neg)
        new_bias_neg = zeros((size(A_neg)[1], size(A_neg)[end]))#spec_dim x batch dim
        @einsum new_bias_neg[s,b] = A_neg[s,r,b] * bias_neg[r,b]
    end
    New_bias = new_bias_pos .+ new_bias_neg
    return New_bias
end 

#using last_A for getting New_A
function multiply_by_A_signs(last_A, slope_pos, slope_neg)
    if ndims(slope_pos) == 1
        # Special case for LSTM, the bias term is 1-dimension. 
        New_A = clamp.(last_A, 0, Inf) .* slope_pos .+ clamp.(last_A, -Inf, 0) .* slope_neg
    else
        New_A = clamp_mutiply_A(last_A, slope_pos, slope_neg)
        return New_A
    end
end

function multiply_bias(last_A, upper_slope, bias_pos, bias_neg)
    if ndims(upper_slope) == 1
        # Special case for LSTM, the bias term is 1-dimension. 
        New_bias = clamp.(last_A, 0, Inf) .* bias_pos .+ clamp.(last_A, -Inf, 0) .* bias_neg
    else
        New_bias = clamp_mutiply_bias(last_A, bias_pos, bias_neg)
        return New_bias
    end
end

#bound oneside of the relu, like upper or lower
function bound_oneside(last_A, slope_pos, slope_neg)
    if isnothing(last_A)
        return nothing, nothing
    end
    New_A = multiply_by_A_signs(last_A, slope_pos, slope_neg)
    return New_A
end


alpha_lower = [20]
alpha_upper = [1]
unstable_mask = [1]
lower_mask = [1]
upper_slope = [2]
upper_bias = [0]
lower_bias = [0]
lower = upper = true
if lower == true
    Alpha_Lower_Layer = AlphaLayer("relu_1", alpha_lower, true, unstable_mask, lower_mask, upper_slope, upper_bias, lower_bias)
end
if upper ==true
    Alpha_Upper_Layer = AlphaLayer("relu_1", alpha_upper, false, unstable_mask, lower_mask, upper_slope, lower_bias, upper_bias)
end
a = []
push!(a, Alpha_Lower_Layer) 
push!(a, Alpha_Upper_Layer)
println(Flux.params(Alpha_Lower_Layer))
println(Flux.params(a)) 
a = Chain(a)
println(a([2, 0]))

In [None]:
for activation_node in model_info.activation_nodes
    batch_info[activation_node][:split_active] = []
end
primals, duals, mini_inp = None, None, None
upper_bound = zeros(size(lower_bound)) .+ Inf

In [67]:
function test_mlp(prop_method)
    small_nnet_file = "/home/verification/ModelVerification.jl/test/networks/small_nnet.nnet"
    # small_nnet encodes the simple function 24*max(x + 1.5, 0) + 18.5
    small_nnet = read_nnet(small_nnet_file, last_layer_activation = ModelVerification.ReLU())
    flux_model = Flux.Chain(small_nnet)
    #ONNXNaiveNASflux.save("/home/verification/ModelVerification.jl/small_nnet.onnx", flux_model, (1,1))
    #println(flux_model)
    #println(flux_model.layers[1].weight, " ", flux_model.layers[1].bias) # max(x+1.5, 0) max(x+1.5, 0)              [0,4]
    #println(flux_model.layers[2].weight, " ", flux_model.layers[2].bias) # 4*max(x+1.5, 0)+2.5 4*max(x+1.5, 0)+2.5  [2.5, 18.5]
    #println(flux_model.layers[3].weight, " ", flux_model.layers[3].bias) # 24*max(x+1.5, 0)+18.5                    [18.5, 114.5]
    in_hyper  = Hyperrectangle(low = [-2.5], high = [2.5]) # expected out: [18.5, 114.5]
    out_violated    = Hyperrectangle(low = [19], high = [114]) # 20.0 ≤ y ≤ 90.0
    out_holds = Hyperrectangle(low = [18], high = [115.0]) # -1.0 ≤ y ≤ 50.0
    comp_violated    = Complement(Hyperrectangle(low = [10], high = [19])) # y ≤ 10.0 or 19 ≤ y
    comp_holds    = Complement(Hyperrectangle(low = [115], high = [118])) # y ≤ 10.0 or 18 ≤ y
    info = nothing
    search_method = BFS(max_iter=100, batch_size=1)
    split_method = Bisect(1)
    @test verify(search_method, split_method, prop_method, Problem("/home/verification/ModelVerification.jl/small_nnet.onnx", in_hyper, out_holds)).status == :holds
    @test verify(search_method, split_method, prop_method, Problem("/home/verification/ModelVerification.jl/small_nnet.onnx", in_hyper, out_violated)).status == :violated
    @test verify(search_method, split_method, prop_method, Problem("/home/verification/ModelVerification.jl/small_nnet.onnx", in_hyper, comp_holds)).status == :holds
    @test verify(search_method, split_method, prop_method, Problem("/home/verification/ModelVerification.jl/small_nnet.onnx", in_hyper, comp_violated)).status == :violated
    #= @test verify(search_method, split_method, prop_method, Problem(flux_model, in_hyper, out_holds)).status == :holds
    @test verify(search_method, split_method, prop_method, Problem(flux_model, in_hyper, out_violated)).status == :violated
    @test verify(search_method, split_method, prop_method, Problem(flux_model, in_hyper, comp_holds)).status == :holds
    @test verify(search_method, split_method, prop_method, Problem(flux_model, in_hyper, comp_violated)).status == :violated =#
end
@timed begin
    #for i in 1:1
        test_mlp(AlphaCrown(Crown(true, true), true, false, Flux.Optimiser(Flux.ADAM(0.1)), 10))
        #test_mlp(Ai2z())
        #test_mlp(Crown(true, true))
        #test_mlp(StarSet(Crown(true, true)))
    #end
end

In [45]:
last_A = [1;;;]
println(size(last_A))
x = [18.5;;]
println(size(x))
bias = [0;;]
println(size(bias))
out = NNlib.batched_mul(last_A, x) .+ bias
println(size(out))
println(out)

(1, 1, 1)
(1, 1)
(1, 1)


(1, 1, 1)
[18.5;;;]


In [25]:
last_A = [9.6 9.6;;;]
bias = [1.5; 1.5;;]
println(bias)
New_bias = NNlib.batched_mul(last_A, bias)
println(size(New_bias))
println(New_bias)


[1.5; 1.5;;]
(1, 1, 1)
[28.799999999999997;;;]


In [None]:
model = Chain([
    Conv((3, 3), 3 => 8, relu, pad=SamePad(), stride=(2, 2)), #pad=SamePad() ensures size(output,d) == size(x,d) / stride.
    BatchNorm(8),
    MeanPool((2,2)),
    SkipConnection(
        Chain([
            Conv((5, 5), 8 => 8, relu, pad=SamePad(), stride=(1, 1))
            ]),
        +
    ),
    #ConvTranspose((3, 3), 8 => 4, relu, pad=SamePad(), stride=(2, 2)),#pad=SamePad() ensures size(output,d) == size(x,d) * stride.
    Flux.flatten,
    Dense(512, 100, relu),
    Dense(100, 10)
])
testmode!(model)
# image_seeds = CIFAR10(:train)[1:5].features # 32 x 32 x 3 x 5
image_seeds = [CIFAR10(:train)[i].features for i in 1:2]
input_set = ImageConvexHull(image_seeds)
# println(typeof(image_seeds[1][1,1,1,1]))
search_method = BFS(max_iter=1, batch_size=1)
split_method = Bisect(1)
output_set = BallInf(zeros(10), 1.0)
onnx_model_path = "/home/verification/ModelVerification.jl/mlp.onnx"
flux_model = model
image_shape = (32, 32, 3, 5)
println(image_seeds)

In [23]:
prop_method = ImageStar()
@timed verify(search_method, split_method, prop_method, Problem(flux_model, input_set, output_set))

In [None]:
prop_method = ImageStarZono()
@timed verify(search_method, split_method, prop_method, Problem(onnx_model_path, image_seeds, output_set))

In [None]:
model = Chain([
    Flux.flatten,
    Dense(784, 200, relu),
    Dense(200, 10)
])
image_seeds = [MNIST(:train)[i].features for i in 1:1]
search_method = BFS(max_iter=1, batch_size=1)
split_method = Bisect(1)
output_set = BallInf(zeros(10), 1.0)
onnx_model_path = "/home/verification/ModelVerification.jl/debug.onnx"
Flux_model = model
image_shape = (28, 28, 1, 1)

In [None]:
model = Chain([
    Conv((3, 3), 3 => 8, relu, pad=SamePad(), stride=(2, 2)), #pad=SamePad() ensures size(output,d) == size(x,d) / stride.
    BatchNorm(8),
    MeanPool((2,2)),
    SkipConnection(
        Chain([
            Conv((5, 5), 8 => 8, relu, pad=SamePad(), stride=(1, 1))
            ]),
        +
    ),
    Conv((3, 3), 8 => 8, relu, pad=SamePad(), stride=(2, 2)),
    #ConvTranspose((3, 3), 8 => 4, relu, pad=SamePad(), stride=(2, 2)),#pad=SamePad() ensures size(output,d) == size(x,d) * stride.
    Flux.flatten,
    Dense(128, 100, relu),
    Dense(100, 10)
])
testmode!(model)
# image_seeds = CIFAR10(:train)[1:5].features # 32 x 32 x 3 x 5
image_seeds = [CIFAR10(:train)[i].features for i in 1:2]
# println(typeof(image_seeds[1][1,1,1,1]))
search_method = BFS(max_iter=1, batch_size=1)
split_method = Bisect(1)
output_set = BallInf(zeros(10), 1.0)
onnx_model_path = "/home/verification/ModelVerification.jl/mlp.onnx"
Flux_model = model
image_shape = (32, 32, 3, 2)

In [None]:
prop_method = ImageStar()
@timed verify(search_method, split_method, prop_method, Problem(onnx_model_path, Flux_model, image_shape, image_seeds, output_set))

In [None]:
prop_method = ImageStarZono()
@timed verify(search_method, split_method, prop_method, Problem(onnx_model_path, Flux_model, image_shape, image_seeds, output_set))

In [None]:
model = Chain([
    Conv((3, 3), 3 => 128, relu, pad=SamePad(), stride=(2, 2)), #pad=SamePad() ensures size(output,d) == size(x,d) / stride.
    BatchNorm(128),
    MeanPool((2,2)),
    SkipConnection(
        Chain([
            Conv((5, 5), 128 => 128, relu, pad=SamePad(), stride=(1, 1))
            ]),
        +
    ),
    ConvTranspose((3, 3), 128 => 128, relu, pad=SamePad(), stride=(2, 2)),#pad=SamePad() ensures size(output,d) == size(x,d) * stride.
    Flux.flatten,
    Dense(32768, 100, relu),
    Dense(100, 10)
])
testmode!(model)
# image_seeds = CIFAR10(:train)[1:5].features # 32 x 32 x 3 x 5
image_seeds = [CIFAR10(:train)[i].features for i in 1:2]
# println(typeof(image_seeds[1][1,1,1,1]))
search_method = BFS(max_iter=1, batch_size=1)
split_method = Bisect(1)
output_set = BallInf(zeros(10), 1.0)

In [None]:
prop_method = ImageStarZono()
@timed verify(search_method, split_method, prop_method, Problem(model, image_seeds, output_set))

In [None]:
Profile.clear()
@profile verify(search_method, split_method, prop_method, Problem(model, image_seeds, output_set))

In [None]:
open("./prof.txt", "w") do s
    Profile.print(IOContext(s, :displaysize => (24, 500)))
end