-
Notifications
You must be signed in to change notification settings - Fork 29
Closed
Description
I have a question about a surprising behaviour I see when wrapping the symbolic gradient expressions provided by @tullio verbose=1 in their own functions. I was expecting this to give the same gradients as Tullios auto generated functions, but that is not the case in this example. Is this a bug?
using Tullio, Zygote
function convlocal(x, W)
@tullio c[s, t, c2, b] := x[s+i-1, t+j-1, c1, b] * W[s+i-1, t+j-1, i, j, c1, c2]
return c
end
function ∇x!(𝛥x, W, 𝛥ℛ)
@tullio grad=false 𝛥x[(s + i) - 1, (t + j) - 1, c1, b] = 𝛥x[(s + i) - 1, (t + j) - 1, c1, b] + 𝛥ℛ[s, t, c2, b] * conj(W[(s + i) - 1, (t + j) - 1, i, j, c1, c2])
return 𝛥x
end
function ∇W!(𝛥W, x, 𝛥ℛ)
@tullio grad=false 𝛥W[(s + i) - 1, (t + j) - 1, i, j, c1, c2] = 𝛥W[(s + i) - 1, (t + j) - 1, i, j, c1, c2] + 𝛥ℛ[s, t, c2, b] * conj(x[(s + i) - 1, (t + j) - 1, c1, b])
return 𝛥W
end
kernel_width, kernel_height, ch_in, ch_out = 3, 3, 1, 2
img_width, img_height, batchsize = 10, 10, 30
x = rand(Float32, img_width, img_height, ch_in, batchsize)
W = rand(Float32, img_width, img_height, kernel_width, kernel_height, ch_in, ch_out)
Δy = rand(Float32, img_width-2, img_height-2, ch_out, batchsize)
# method 1: grads computed using ∇x and ∇W
Δx1, ΔW1 = zeros(Float32, size(x)), zeros(Float32, size(W))
∇x!(Δx1, W, Δy)
∇W!(ΔW1, x, Δy)
# method 2: grads computed via Tullios auto generated functions
G = Zygote._pullback(convlocal, x, W)
Δx2 = G[2](Δy)[2]
ΔW2 = G[2](Δy)[3]
@show Δx1 ≈ Δx2
@show ΔW1 ≈ ΔW2I was expecting Δx1 to match Δx2, but this is not the case (except for when the last dimension of W is singleton...).
Δx1 ≈ Δx2 = false
ΔW1 ≈ ΔW2 = true
Metadata
Metadata
Assignees
Labels
No labels