In [None]:
using Pkg
Pkg.activate("../", io=devnull)

In [None]:
using VeryDiff
using LinearAlgebra
using Plots
using DataStructures
using Gurobi, JuMP

include("util.jl")

VeryDiff.NEW_HEURISTIC = false

In [None]:
hexagon = Zonotope([0.25 0.5 0.25; -0.5 0.0 0.5], [0.0, 0.0], nothing)
octagon = Zonotope([-0.5 0.25 0.5 0.25; 0.25 -0.5 0.25 0.5], [0.0, 0.0], nothing)
decagon = Zonotope([4.0, 8 / 3] .* [0.125 -0.125 0.0625 0.125 0.0625; 0.125 0.125 -0.25 0.0 0.25], [0.0, 0.0], nothing)
dodecagon = Zonotope((5 / 4) * [0.0 0.5 -0.5 0.25 0.5 0.25; 0.5 0.25 0.25 -0.5 0.0 0.5], [0.0, 0.0], nothing)


plot(dodecagon, alpha=0.5, framestyle=:origin, aspect_ratio=:equal)
plot!(decagon, alpha=0.5, framestyle=:origin, aspect_ratio=:equal)
plot!(octagon, alpha=0.5, framestyle=:origin, aspect_ratio=:equal)
plot!(hexagon, alpha=0.5, framestyle=:origin, aspect_ratio=:equal)

In [None]:
linewidth = 1.5

relu = x -> max.(x, zero(x))

Z = Zonotope([-1.0 0.0; 1.0 -1.0], [0.0, 0.0], nothing)
bounds = zono_bounds(Z)
lower = bounds[:, 1]
upper = bounds[:, 2]

x = range.(lower, upper, length=1000)

Z₀ = ReLU()(Z, PropState(true), 1, 1)
Z₁ = ReLU()(Z, PropState(true, SplitNode[SplitNode(1, 1, 1, 1)]), 1, 1)
Z₂ = ReLU()(Z, PropState(true, SplitNode[SplitNode(1, 1, 1, -1)]), 1, 1)

println("Z₀: $Z₀")
println("Z₁: $Z₁")
println("Z₂: $Z₂")

plot(Z, label="Z", framestyle=:origin, alpha=0.5, title="Neuron Splitting")
plot!(Z₀, label="Z₀", framestyle=:origin, alpha=0.5)
plot!(Z₁, label="Z₁", framestyle=:origin, alpha=0.5, lc=:green, lw=linewidth)
plot!(Z₂, label="Z₂", framestyle=:origin, alpha=0.5, lw=2*linewidth)
plot!(x[1], relu, label="ReLU(Z)", framestyle=:origin, alpha=0.5, lc=:blue, lw=linewidth)

In [None]:
function is_epsilon_counterexample(ϵ)
    return (N₁, N₂, cex) -> begin
        return VeryDiff.get_sample_distance(N₁, N₂, cex) > ϵ
    end
end

In [None]:
function plot_diff_zono(∂Z; title="Output")
    plot(∂Z.Z₁, label="Z₁", framestyle=:origin, alpha=0.5, title=title)
    plot!(∂Z.Z₂, label="Z₂", alpha=0.5)
    plot!(∂Z.∂Z, label="∂Z", alpha=0.5)
end

In [None]:
function split_neuron(node :: SplitNode, prev_split :: Tuple{BitMatrix, Vector{SplitNode}})
    mask₁, split₁ = prev_split
    mask₂, split₂ = deepcopy(prev_split)

    push!(split₁, (SplitNode(node.network, node.layer, node.neuron, -1, node.g, node.c)))
    push!(split₂, (SplitNode(node.network, node.layer, node.neuron, 1, node.g, node.c)))

    return (mask₁, split₁), (mask₂, split₂)
end

In [None]:
function propagate(Zin, prop_state)
    return N(Zin, prop_state), prop_state
end

In [None]:
function search(N₁, N₂, Zin, Zout, prop_state, (mask, split_nodes))
    prop_satisfied, cex, _, _, _ = property_check(N₁, N₂, Zin, Zout, nothing)
    
    if !prop_satisfied
        if !isnothing(cex)
            return VeryDiff.UNSAFE, cex[1]
        end

        split_nodes = prop_state.split_nodes
        # Initialize the LP solver
        model = Model(() -> Gurobi.Optimizer(VeryDiff.GRB_ENV[]))
        set_time_limit_sec(model, 10)
        
        # Add variables and input and output constraints
        var_num = size(Zout.∂Z.G, 2)
        @variable(model, -1.0 <= x[1:var_num] <= 1.0)

        # Add split constraints
        Gₛ = zeros(Float64, size(split_nodes, 1), var_num)
        cₛ = zeros(Float64, size(split_nodes))
        dₛ = zeros(Float64, size(split_nodes))
        for (i, split_node) in enumerate(split_nodes)
            Gₛ[i, 1:size(split_node.g, 1)] = split_node.g
            cₛ[i] = split_node.c
            dₛ[i] = split_node.direction
        end
        @constraint(model, dₛ .* (Gₛ * x + cₛ) .>= 0.0)

        bounds = zono_bounds(Zout.∂Z)
        # Compute all output dimensions that still need to be proven
        mask = hcat(bounds[:, 1] .< -epsilon, bounds[:, 2] .> epsilon) .&& (isempty(mask) ? true : mask)

        # For each unproven output dimension we solve a LP for corresponding lower and upper bound
        # for i in sort!((1:size(mask, 1))[mask[:, 1] .|| mask[:, 2]], by=k->sum(abs.(bounds[k, :])))
        for i in (1:size(mask, 1))[mask[:, 1] .|| mask[:, 2]]
            for (j, σ) in [(1, -1), (2, 1)][mask[i, :]]

                @objective(model, Max, σ * (Zout.∂Z.G[i, :]' * x + Zout.∂Z.c[i]))
                optimize!(model)

                if is_solved_and_feasible(model)
                    cex = Zin.Z₁.G * value.(x)[1:input_dim] + Zin.Z₁.c
                    sample_distance = VeryDiff.get_sample_distance(N₁, N₂, cex)
                    if sample_distance > epsilon
                        return VeryDiff.UNSAFE, cex
                    end
                end

                mask[i, j] = termination_status(model) != MOI.INFEASIBLE
            end
        end

        if any(mask)
            return VeryDiff.UNKNOWN, nothing
        end
    end

    return VeryDiff.SAFE, nothing
end

In [None]:
function branch_and_bound(splits, split, split_candidate)
    split₁, split₂ = split_neuron(split_candidate, split)
    push!(splits, split₁, split₂)
end

In [None]:
W1 = [1.0 1.0; 1.0 -1.0]
b1 = zeros(2)
W2 = [1.0 1.0; 1.0 -1.0]
b2 = [-0.5, 0.0]
W3 = [-1.0 0.0; 1.0 1.0]
b3 = zeros(2)

layers1 = [Dense(W1, b1), ReLU(), Dense(W2, b2), ReLU(), Dense(W3, b3)]

W1a = [1.0 1.1; 1.0 -0.9]
b1a = zeros(2)
W2a = [1.1 1.0; 0.9 -1.0]
b2a = [-0.4, 0.0]
W3a = [-1.0 0.0; 1.0 1.1]
b3a = zeros(2)

layers2 = [Dense(W1a, b1a), ReLU(), Dense(W2a, b2a), ReLU(), Dense(W3a, b3a)]

In [None]:
epsilon = 0.2
property_check = get_epsilon_property(epsilon)
is_counterexample = is_epsilon_counterexample(epsilon)

N₁ = Network(layers1)
N₂ = Network(layers2)
N = GeminiNetwork(N₁, N₂)

Z = Zonotope([1.0 0.0; 0.0 1.0], zeros(2), I(2))
∂Z = Zonotope(zeros(Float64, size(Z.G)), zeros(size(Z.c)), nothing)
Zin = DiffZonotope(Z, deepcopy(Z), ∂Z, 0, 0, 0)

input_dim = size(Z.G, 2)

splits = Deque{Tuple{BitMatrix, Vector{SplitNode}}}()
push!(splits, (hcat(falses(0), falses(0)), SplitNode[]))

neuron_splits = 0

In [None]:
split = popfirst!(splits)
split₁, split₂ = split_neuron(SplitNode(1, 1, 1, 0), split)
push!(splits, split₂, split₁)

In [None]:
mask, split_nodes = popfirst!(splits)
Zout, prop_state = propagate(Zin, PropState(true, split_nodes))
split_nodes

In [None]:
plot_diff_zono(Zout)

In [None]:
_status, cex = search(N₁, N₂, Zin, Zout, prop_state, (mask, split_nodes))

In [None]:
if _status == VeryDiff.UNKNOWN
    branch_and_bound(splits, (mask, split_nodes), prop_state.split_candidate)
    neuron_splits += 1
end

In [None]:
mask, split_nodes = popfirst!(splits)
Zout, prop_state = propagate(Zin, PropState(true, split_nodes))
split_nodes

In [None]:
plot_diff_zono(Zout)

In [None]:
_status, cex = search(N₁, N₂, Zin, Zout, prop_state, (mask, split_nodes))

In [None]:
if _status == VeryDiff.UNKNOWN
    branch_and_bound(splits, (mask, split_nodes), prop_state.split_candidate)
    neuron_splits += 1
end

In [None]:
mask, split_nodes = popfirst!(splits)
Zout,prop_state = propagate(Zin, PropState(true, split_nodes))
split_nodes

In [None]:
plot_diff_zono(Zout)

In [None]:
_status, cex = search(N₁, N₂, Zin, Zout, prop_state, (mask, split_nodes))

In [None]:
if _status == VeryDiff.UNKNOWN
    branch_and_bound(splits, (mask, split_nodes), prop_state.split_candidate)
    neuron_splits += 1
end

In [None]:
mask, split_nodes = popfirst!(splits)
Zout, prop_state = propagate(Zin, PropState(true, split_nodes))
split_nodes

In [None]:
plot_diff_zono(Zout)

In [None]:
_status, cex = search(N₁, N₂, Zin, Zout, prop_state, (mask, split_nodes))