In [2]:
using Revise

In [3]:
using Images, ImageIO
using ONNXNaiveNASflux, NaiveNASflux, .NaiveNASlib
using LinearAlgebra
using OpenCV
using Flux
using CSV
using DataFrames
using DataStructures

In [3]:

function square_img(image, bbox, img_size)
    bbox_width = round(Int, bbox["bbox_x2"] - bbox["bbox_x1"])
    bbox_height = round(Int, bbox["bbox_y2"] - bbox["bbox_y1"])
    offset_tl = Int.([bbox["bbox_x1"], bbox["bbox_y1"]] .+ 1)
    h, _ = img_size
    if bbox_width >= bbox_height
        sq_img = zeros(UInt8, bbox_width, bbox_width, 3)
        if bbox_width >= h
            crop_img = image[:, offset_tl[1]: offset_tl[1] + bbox_width - 1, :]
            offset_tl[2] = 1
            bbox_height = h
        elseif offset_tl[2] + bbox_width - 1 >= h
            offset_tl[2] = h + 1 - bbox_width
            bbox_height = bbox_width
            crop_img = image[offset_tl[2] : offset_tl[2] + bbox_width - 1, offset_tl[1] : offset_tl[1] + bbox_width - 1, :]
        else
            crop_img = image[offset_tl[2] : offset_tl[2] + bbox_width - 1, offset_tl[1]+1 : offset_tl[1] + bbox_width - 1, :]
            bbox_height = bbox_width
        end
        sq_img[1:bbox_height, 1:bbox_width, :] = crop_img
    else
        sq_img = zeros(UInt8, bbox_height, bbox_height, 3)
    end
    return sq_img, offset_tl
end

function preprocess_image(img_path, csv_path)
    image = OpenCV.imread(img_path)
    _, w, h = size(image)  # (3, 1920, 1200)
    # CWH, BGR channel for Julia OpenCV, we desire WHCN for Flux model
    image = OpenCV.cvtColor(image, OpenCV.COLOR_BGR2GRAY)
    image = OpenCV.cvtColor(image, OpenCV.COLOR_GRAY2RGB)
    image = permutedims(image, (3, 2, 1))
    
    truth_df = DataFrame(CSV.File(csv_path))
    bbox = truth_df[!, [:bbox_x1, :bbox_y1, :bbox_x2, :bbox_y2]]
    bbox = bbox[parse(Int, split(img_path[1:end-4], '_')[end]) + 1, :]
    sq_image, offset_tl = square_img(image, bbox, (h, w))
    input_size = 256
    sq_image = permutedims(sq_image, (3, 2, 1))  # no need to permute multiple times, will remove it later
    resize_img = OpenCV.resize(sq_image, OpenCV.Size{Int32}(input_size, input_size), interpolation=OpenCV.INTER_AREA)
    resize_img = resize_img / 255.0
    resize_img = permutedims(resize_img, (3, 2, 1))
    return resize_img
end

function get_parallel_chains(comp_vertices, index_more_than_one_outputs)
    function get_chain(vertex)
        m = Any[]
        curr_vertex = vertex
        while length(inputs(curr_vertex)) == 1
            # println("curr vertex ", name(curr_vertex))
            push!(m, layer(curr_vertex))
            curr_vertex = outputs(curr_vertex)[1]
        end
        return Chain(m...), curr_vertex
    end
    outs = outputs(comp_vertices[index_more_than_one_outputs])
    @assert length(outs) == 2
    chain1, vertex_more_than_one_inputs = get_chain(outs[1])
    chain2, _ = get_chain(outs[2])
    @assert occursin("Add", name(vertex_more_than_one_inputs))
    inner_iter = findfirst(v -> name(v) == name(vertex_more_than_one_inputs), comp_vertices)
    if length(chain1) == 0
        return SkipConnection(chain2, (+)), inner_iter
    elseif length(chain2) == 0
        return SkipConnection(chain1, (+)), inner_iter
    else
        return Parallel(+; α = chain1, β = chain2), inner_iter
    end
end

function build_flux_model(onnx_model_path)
    comp_graph = ONNXNaiveNASflux.load(onnx_model_path)
    # find mean value
    model_vec = Any[]
    # sub_vertices = findvertices("/Sub", comp_graph)
    # if !isempty(sub_vertices)
    #     img_mean = inputs(sub_vertices[1])[2]()
    #     println(img_mean)
    #     # println(inputs(vertices(comp_graph)[5])[1]())

    #     push!(model_vec, x -> x .- img_mean)
    # end
    

    inner_iter = 0
    for (index, vertex) in enumerate(vertices(comp_graph))
        if index < 5 || index <= inner_iter
            continue
        end 
        if string(layer(vertex)) == "#213"
            push!(model_vec, NNlib.relu)
        else
            push!(model_vec, layer(vertex))
        end
        if length(outputs(vertex)) > 1
            # println("name: ", name(vertex))
            parallel_chain, inner_iter = get_parallel_chains(vertices(comp_graph), index)
            push!(model_vec, parallel_chain)
        end
    end
    model = Chain(model_vec...)
    Flux.testmode!(model)
    return (model)
end

build_flux_model (generic function with 1 method)

In [31]:
onnx_model_path = "/home/verification/ModelVerification.jl/onnx_parser/resnet_model.onnx"
comp_graph = ONNXNaiveNASflux.load(onnx_model_path, infer_shapes=false)

for (index, vertex) in enumerate(vertices(comp_graph))
    if index == 10
        #println(layer(vertex))
        #println(length(inputs(vertex)))
        #println(outputs(vertex))
        #println(NaiveNASflux.name(vertex))
    end
    println(NaiveNASflux.name(vertex))
    println(layer(vertex))
    #println(length(inputs(vertex)))
end

Set(Any["BatchNormalization", "Div", "Relu", "Constant", "Sub", "Conv", "ConvTranspose", "AveragePool", "Add"])
onnx::Sub_0
NaiveNASflux.LayerTypeWrapper{NaiveNASflux.GenericFluxConvolutional{2}}(NaiveNASflux.GenericFluxConvolutional{2}())
/Constant
nothing
/Sub
#341
/Div
#202
/conv1/Conv
Conv((7, 7), 3 => 64, pad=3, stride=2, bias=false)
/bn1/BatchNormalization
BatchNorm(64, relu)
/avgpool/AveragePool
MeanPool((3, 3), pad=1, stride=2)
/layer1/layer1.0/conv1/Conv
Conv((3, 3), 64 => 64, pad=1, bias=false)
/layer1/layer1.0/bn1/BatchNormalization
BatchNorm(64, relu)
/layer1/layer1.0/conv2/Conv
Conv((3, 3), 64 => 64, pad=1, bias=false)
/layer1/layer1.0/bn2/BatchNormalization
BatchNorm(64)
/layer1/layer1.0/Add
#341
/layer1/layer1.0/relu_1/Relu
#213
/layer1/layer1.1/conv1/Conv
Conv((3, 3), 64 => 64, pad=1, bias=false)
/layer1/layer1.1/bn1/BatchNormalization
BatchNorm(64, relu)
/layer1/layer1.1/conv2/Conv
Conv((3, 3), 64 => 64, pad=1, bias=false)
/layer1/layer1.1/bn2/BatchNormalization
BatchN

In [25]:
# +++++++++++++++++++ build flux model +++++++++++++++++++
onnx_model_path = "/home/verification/ModelVerification.jl/mlp.onnx"
#model = build_flux_model(onnx_model_path)
#println.(model)
queue = Queue{Any}()
batch_info = Dict()
global_info = Dict()
comp_graph = ONNXNaiveNASflux.load(onnx_model_path)
#= for (index, vertex) in enumerate(vertices(comp_graph))
   # println(inputs(vertex), "  ", outputs(vertex))
   # println(layer(vertex))
    if index < 2
        continue
    end 
    new_dict = Dict()
    push!(new_dict, "layer" => layer(vertex))
    push!(new_dict, "index" => index)
    push!(new_dict, "inputs" => inputs(vertex))
    push!(new_dict, "outputs" => outputs(vertex))
    push!(batch_info, vertex => new_dict)
    #= batch_info[index]["layer"] = layer(vertex)
    batch_info[index]["index"] = index
    batch_info[index]["inputs"] = inputs(vertex)
    batch_info[index]["outputs"] = outputs(vertex) =#
    enqueue!(queue, vertex)
end =#
#= layer_1 = outputs(dequeue!(queue))
println(outputs(dequeue!(queue))) =#

#= Vertex = []
Isvisit = Dict()
for (index, vertex) in enumerate(vertices(comp_graph))
    # println(inputs(vertex), "  ", outputs(vertex))
    # println(layer(vertex))
    push!(Vertex, vertex)
    push!(Isvisit, vertex => false)
end =#
Vertex = []
for (index, vertex) in enumerate(vertices(comp_graph))
    if index < 2
        push!(Vertex, nothing)
        continue
    end 
    if index == 4
        #println(NaiveNASflux.name(vertex)) 
        #println(inputs(vertex))
    end
    push!(Vertex, Vertex)
end


for (index, vertex) in enumerate(vertices(comp_graph))
    if index == 1 # the vertex which index == 1 has no useful information, so it's output node will be the start node of the model
        push!(global_info, "start_node" => []) #creat a array for storing the start_node because mayer there are more than 1 start_node
        for node in outputs(vertex)
            push!(global_info["start_node"], NaiveNASflux.name(node))
        end
        continue
    end 

    new_dict = Dict() # store the information of this vertex 
    push!(new_dict, "vertex" => vertex)
    push!(new_dict, "layer" => layer(vertex))
    push!(new_dict, "index" => index)
    push!(new_dict, "inputs" => inputs(vertex))
    push!(new_dict, "outputs" => outputs(vertex))
    push!(batch_info, NaiveNASflux.name(vertex) => new_dict) #new_dict belongs to batch_info
    
    if length(outputs(vertex)) == 0  #the final node has no output nodes
        global_info["final_node"] =  NaiveNASflux.name(vertex)
    end
    #println(layer(vertex))
    #println(NaiveNASflux.name(vertex))
    if index == 7
        println(string(NaiveNASflux.name(vertex))[1:7] == "Flatten")
    end
end


#output = batch_info["add_0"]["outputs"]
#println(typeof(output[1]))
#println(batch_info[NaiveNASflux.name(output[1])])
start_node = global_info["start_node"]
Isvisit = Dict()
for node in start_node
    enqueue!(queue, node)
    batch_info[node]["inputs"] = nothing
end

for node in batch_info["conv_0"]["outputs"]
    #println(NaiveNASflux.name(node))
    #println(batch_info[NaiveNASflux.name(node)])
end

#println(string(batch_info["add_0"]["layer"]) == "#341")
println(length(batch_info["add_0"]["inputs"]))
push!(batch_info["add_0"], "layer" => +)
println(batch_info["add_0"]["layer"])
println(typeof(batch_info["add_0"]["layer"]))
#println(batch_info["add_0"]["name"])
#= while isempty(queue)
    node = dequeue!(queue)
    if haskey(node)
        continue
    end
    push!(Isvisit, node => true)
    for layer in outputs(node)
        if haskey(node)
            continue
        end
        enqueue
    end
end     =# 
#= for (index, vertex) in enumerate(vertices(comp_graph))
    if index < 2
         continue
    end 
    if index == 3
        input = batch_info[vertex]["inputs"]
        println(input)
        println(haskey(batch_info, input))
    end
end  =#

Set(Any["BatchNormalization", "Flatten", "Relu", "Conv", "Gemm", "AveragePool", "Add"])
true
2
+


typeof(+)


In [None]:
function propagate(prop_method::ForwardProp, onnx_model_path, Flux_model, input_shape, batch_bound, batch_out_spec, aux_batch_info)
    # input: batch x ... x ...

    # dfs start from model.input_nodes
    @assert !isnothing(onnx_model_path) 

    #= if !isnothing(Flux_model) && !isnothing(input_shape)
        save(onnx_model_path, Flux_model, input_shape)
    end =#

    comp_graph = ONNXNaiveNASflux.load(onnx_model_path)
    batch_info = Dict()
    global_info = Dict()
    for (index, vertex) in enumerate(ONNXNaiveNASflux.vertices(comp_graph))
        if index == 1 # the vertex which index == 1 has no useful information, so it's output node will be the start node of the model
            push!(global_info, "start_node" => []) #creat a array for storing the start_node because mayer there are more than 1 start_node
            for node in outputs(vertex)
                push!(global_info["start_node"], NaiveNASflux.name(node))
            end
            continue
        end 
    
        new_dict = Dict() # store the information of this vertex 
        push!(new_dict, "vertex" => vertex)
        push!(new_dict, "layer" => NaiveNASflux.layer(vertex))
        push!(new_dict, "index" => index)
        push!(new_dict, "inputs" => inputs(vertex))# note: inputs(vertex)) is not a string, use NaiveNASflux.name convert them to string 
        push!(new_dict, "outputs" => outputs(vertex))# note: outputs(vertex)) is not a string, use NaiveNASflux.name convert them to string 
        
        if length(string(NaiveNASflux.name(vertex))) >= 7 && string(NaiveNASflux.name(vertex))[1:7] == "Flatten" 
            push!(new_dict, "layer" => Flux.flatten)
        end

        push!(batch_info, NaiveNASflux.name(vertex) => new_dict) #new_dict belongs to batch_info
        if length(outputs(vertex)) == 0  #the final node has no output nodes
            global_info["final_node"] = NaiveNASflux.name(vertex)
        end
    end

    queue = Queue{Any}()
    start_node = global_info["start_node"]
    Isvisit = Dict()
    for node in start_node
        enqueue!(queue, node)
        batch_info[node]["inputs"] = nothing #start_nodes have no input nodes
    end

    while !isempty(queue)
        node = dequeue!(queue)
        if haskey(Isvisit, node) #means that this node has been visited
            continue
        end
        push!(Isvisit, node => true)
        for node in batch_info[node]["outputs"]
            if haskey(Isvisit, node) #means that this node has been visited
                continue
            end
            enqueue!(queue, NaiveNASflux.name(node))
        end

        if isnothing(batch_info[node]["inputs"])
            current_batch_bound = batch_bound
            batch_bound, aux_batch_info = forward_layer(prop_method, batch_info[node]["layer"], current_batch_bound, aux_batch_info)
        elseif length(batch_info[node]["inputs"]) == 2
            if string(batch_info[node]["layer"]) == "#341" # #342 means this node is an "add" layer
                input_node1 = NaiveNASflux.name(batch_info[node]["inputs"][1])
                input_node2 = NaiveNASflux.name(batch_info[node]["inputs"][2])
                current_batch_bound1 = batch_info[input_node1]["output_bound"]
                current_batch_bound2 = batch_info[input_node2]["output_bound"]
                aux_batch_info1 = batch_info[input_node1]["aux_batch_info"]
                aux_batch_info2 = batch_info[input_node2]["aux_batch_info"]
                batch_bound, aux_batch_info = forward_skip_batch(prop_method, +, current_batch_bound1, current_batch_bound2, aux_batch_info1, aux_batch_info2)
            end
        else #length(batch_info[node][inputs] == 1
            input_node = NaiveNASflux.name(batch_info[node]["inputs"][1])
            current_batch_bound = batch_info[input_node]["output_bound"]
            aux_batch_info = batch_info[input_node]["aux_batch_info"]
            batch_bound, aux_batch_info = forward_layer(prop_method, batch_info[node]["layer"], current_batch_bound, aux_batch_info)
        end
        push!(batch_info[node], "output_bound" => batch_bound)
        push!(batch_info[node], "aux_batch_info" => aux_batch_info)
    end     

    final_node = global_info["final_node"]
    batch_bound = batch_info[final_node]["output_bound"]
    aux_batch_info = batch_info[final_node]["aux_batch_info"]

    return batch_bound, aux_batch_info #need to change
end    