In [None]:
model_file_path = "/home/gsoykan/Desktop/comp541/comp541_term_project/results/imagenet-resnet-152-dag.mat"
cat_img_url = "https://nextjournal.com/data/QmXNbi2LE7u6yBdBXaQ9E2zGb48FELg3TxjrLiPKBmdvZc?filename=Qat.jpg&content-type=image/jpeg"

In [None]:
# https://nextjournal.com/mpd/image-classification-with-knet

In [None]:
import CUDA

In [None]:
using MAT, OffsetArrays, FFTViews, ArgParse, Images, Knet, ImageMagick
include("modular.resnet.jl")

In [None]:
atype = CUDA.functional() ? KnetArray{Float32} : Array{Float32}
Knet.atype() = atype

In [None]:
function get_params(params, atype)
    len = length(params["value"])
    ws, ms = [], []
    for k = 1:len
        name = params["name"][k]
        value = convert(Array{Float32}, params["value"][k])

        if endswith(name, "moments")
            push!(ms, reshape(value[:,1], (1,1,size(value,1),1)))
            push!(ms, reshape(value[:,2], (1,1,size(value,1),1)))
        elseif startswith(name, "bn")
            push!(ws, reshape(value, (1,1,length(value),1)))
        elseif startswith(name, "fc") && endswith(name, "filter")
            push!(ws, transpose(reshape(value,(size(value,3),size(value,4)))))
        elseif startswith(name, "conv") && endswith(name, "bias")
            push!(ws, reshape(value, (1,1,length(value),1)))
        else
            push!(ws, value)
        end
    end
    map(wi->convert(atype, wi), ws),
    map(mi->convert(atype, mi), ms)
end

In [None]:
# From vgg.jl
function data(img, averageImage)
    if occursin("://",img)
        @info "Downloading $img"
        img = download(img)
    end
    a0 = load(img)
    new_size = ntuple(i->div(size(a0,i)*224,minimum(size(a0))),2)
    a1 = Images.imresize(a0, new_size)
    i1 = div(size(a1,1)-224,2)
    j1 = div(size(a1,2)-224,2)
    b1 = a1[i1+1:i1+224,j1+1:j1+224]
    c1 = permutedims(channelview(b1), (3,2,1))
    d1 = convert(Array{Float32}, c1)
    e1 = reshape(d1[:,:,1:3], (224,224,3,1))
    f1 = (255 * e1 .- averageImage)
    g1 = permutedims(f1, [2,1,3,4])
end

In [None]:
# OLD IMPLEMENTATION
# Batch Normalization Layer
# works both for convolutional and fully connected layers
# mode, 0=>train, 1=>test
function batchnorm(w, x, ms; mode=1, epsilon=1e-5)
    mu, sigma = nothing, nothing
    if mode == 0
        d = ndims(x) == 4 ? (1,2,4) : (2,)
        s = prod(size(x,d...))
        mu = sum(x,d) / s
        x0 = x .- mu
        x1 = x0 .* x0
        sigma = sqrt(epsilon + (sum(x1, d)) / s)
    elseif mode == 1
        mu = popfirst!(ms)
        sigma = popfirst!(ms)
    end

    # we need getval in backpropagation
    push!(ms, AutoGrad.value(mu), AutoGrad.value(sigma))
    xhat = (x.-mu) ./ sigma
    return w[1] .* xhat .+ w[2]
end

function reslayerx0(w,x,ms; padding=0, stride=1, mode=1)
    b  = conv4(w[1],x; padding=padding, stride=stride)
    bx = batchnorm(w[2:3],b,ms; mode=mode)
end

function reslayerx1(w,x,ms; padding=0, stride=1, mode=1)
    relu.(reslayerx0(w,x,ms; padding=padding, stride=stride, mode=mode))
end

function reslayerx2(w,x,ms; pads=[0,1,0], strides=[1,1,1], mode=1)
    ba = reslayerx1(w[1:3],x,ms; padding=pads[1], stride=strides[1], mode=mode)
    bb = reslayerx1(w[4:6],ba,ms; padding=pads[2], stride=strides[2], mode=mode)
    bc = reslayerx0(w[7:9],bb,ms; padding=pads[3], stride=strides[3], mode=mode)
end

function reslayerx3(w,x,ms; pads=[0,0,1,0], strides=[2,2,1,1], mode=1) # 12
    a = reslayerx0(w[1:3],x,ms; stride=strides[1], padding=pads[1], mode=mode)
    b = reslayerx2(w[4:12],x,ms; strides=strides[2:4], pads=pads[2:4], mode=mode)
    relu.(a .+ b)
end

function reslayerx4(w,x,ms; pads=[0,1,0], strides=[1,1,1], mode=1)
    relu.(x .+ reslayerx2(w,x,ms; pads=pads, strides=strides, mode=mode))
end

function reslayerx5(w,x,ms; strides=[2,2,1,1], mode=1)
    x = reslayerx3(w[1:12],x,ms; strides=strides, mode=mode)
    for k = 13:9:length(w)
        x = reslayerx4(w[k:k+8],x,ms; mode=mode)
    end
    return x
end

# mode, 0=>train, 1=>test
function resnet152(w,x,ms; mode=1)
    # layer 1
    conv1 = reslayerx1(w[1:3],x,ms; padding=3, stride=2, mode=mode)
    pool1 = pool(conv1; window=3, stride=2)

    # layer 2,3,4,5
    r2 = reslayerx5(w[4:33], pool1, ms; strides=[1,1,1,1], mode=mode)
    r3 = reslayerx5(w[34:108], r2, ms; mode=mode)
    r4 = reslayerx5(w[109:435], r3, ms; mode=mode)
    r5 = reslayerx5(w[436:465], r4, ms; mode=mode)

    # fully connected layer
    pool5  = pool(r5; stride=1, window=7, mode=2)
    fc1000 = w[466] * mat(pool5) .+ w[467]
end

In [None]:
o = Dict(
  :atype => KnetArray{Float32},
  :model => model_file_path,
  :image => cat_img_url,
  :top   => 10
)

In [None]:
	@info "Reading $(o[:model])"
	model = matread(abspath(o[:model]))
	avgimg = model["meta"]["normalization"]["averageImage"]
	avgimg = convert(Array{Float32}, avgimg)
	description = model["meta"]["classes"]["description"]
	w, ms = get_params(model["params"], o[:atype])

	@info "Reading $(o[:image])"
	img = data(o[:image], avgimg)
	img = convert(o[:atype], img)


In [None]:
function predict(o)
	@info "Reading $(o[:model])"
	model = matread(abspath(o[:model]))
	avgimg = model["meta"]["normalization"]["averageImage"]
	avgimg = convert(Array{Float32}, avgimg)
	description = model["meta"]["classes"]["description"]
	w, ms = get_params(model["params"], o[:atype])

	@info "Reading $(o[:image])"
	img = data(o[:image], avgimg)
	img = convert(o[:atype], img)

	@info "Classifying."
	#@time y1 = resnet152(w,img,ms)
    modular_resnet152 = generate_resnet_from_weights(w, ms)
    y1 = modular_resnet152(img)
  
  return y1, description
end

In [None]:
#model = matread(abspath(o[:model]))
#w, ms = get_params(model["params"], o[:atype])


In [None]:
Knet.atype()

In [None]:
# Batch Norm Fixing Codes

function init_model(;et=Float32)
    # Use bnparams() to initialize gammas and betas
    w = Any[
        kaiming(et, 3, 3, 3, 16),    bnparams(et, 16),
        kaiming(et, 3, 3, 16, 32),   bnparams(et, 32),
        kaiming(et, 3, 3, 32, 64),   bnparams(et, 64),
        xavier(et, 100, 8 * 8 * 64), bnparams(et, 100),
        xavier(et, 10, 100),         zeros(et, 10, 1)
    ]
    # Initialize a moments object for each batchnorm
    m = Any[bnmoments() for i = 1:4]
    w = map(Knet.array_type[], w)
    return w, m
end

function conv_layer(w, m, x; maxpool=true)
    o = conv4(w[1], x; padding=1)
    o = batchnorm(o, m, w[2])
    o = relu.(o)
    if maxpool; o=pool(o); end
    return o
end

moments = bnmoments()
params = bnparams(C)
...
### size(x) -> (H, W, C, N)
y = batchnorm(x, moments, params)

model = matread(abspath(o[:model]))
	avgimg = model["meta"]["normalization"]["averageImage"]
	avgimg = convert(Array{Float32}, avgimg)
	description = model["meta"]["classes"]["description"]
	w, ms = get_params(model["params"], o[:atype])
	@info "Reading $(o[:image])"
	img = data(o[:image], avgimg)
	img = convert(o[:atype], img);

res_conv_0 = ResLayerX0(w[1:3], ms; padding=3, stride=2)

 _wsize(y) = ((1 for _=1:ndims(y)-2)..., size(y)[end-1], 1)
_bnscale(param) = param[1:div(length(param), 2)]
_bnbias(param) = param[div(length(param), 2)+1:end]

_bnscale(w[2:3])

_bnbias(w[2:3])[begin]

o = conv4(w[1], img; padding=3, stride=2)

   res_mean = popfirst!(ms)
        res_variance = popfirst!(ms)
        batch_ms = bnmoments(meaninit=res_mean, varinit=res_variance)

bnmoments()

f_res_mean = convert(Array{Float32}, res_mean)
f_res_variance = convert(Array{Float32}, res_variance)
f_batch_ms = bnmoments(mean=res_mean, var=res_variance)

function var_function(eltype, dims...)
   return convert(eltype, f_res_variance)
end

function mean_function(eltype, dims...)
    f_res_mean = convert(eltype, f_res_mean)
end

_wsize(o)

_bnscale(vcatted_ws)

#w2 = convert(Array{Float32}, w[2])
#w3 = convert(Array{Float32}, w[3])
w2 =  w[2]
w3 =  w[3]
vec_w2 = vec(w2)
vec_w3 = vec(w3)
vcatted_ws = vcat(vec_w2, vec_w3)

batchnorm(o, f_batch_ms, vcatted_ws)

batchnorm(o, bnmoments(), vcatted_ws)

res_conv_0(img)

In [None]:
w, ms = get_params(model["params"], o[:atype]);

In [None]:
modular_resnet152 = generate_resnet_from_weights(w, ms)
y1 = modular_resnet152(img)

In [None]:
z1 = vec(Array(y1))
s1 = sortperm(z1,rev=true)
p1 = exp.(logp(z1))

In [None]:
using Printf

for ind in s1[1:o[:top]]
  print("$(description[ind]): $(@sprintf("%.2f",p1[ind]*100))%\n")
end