In [1]:
# Target: finish contract's conjugate function

using Zygote
using Zygote: @adjoint

# contract function copy from Day2
function contract(a::AbstractArray{Ta, Na}, b::AbstractArray{Tb, Nb}, dima::Tuple, dimb::Tuple) where {Ta, Tb, Na, Nb}
    if [size(a)[x] for x in dima] != [size(b)[x] for x in dimb] error("size is wrong") end
    if length(dima) == 0
        a_size = [size(a)...]
        b_size = [size(b)...]
        ra = reshape(a, prod(a_size), 1)
        rb = reshape(b, 1, prod(b_size))
        return reshape(ra*rb, Tuple([a_size; b_size]))
    end

    a_left = filter(x -> !(x in dima), [1:Na...])
    a_right = [dima...]
    a_perm = [a_left; a_right]
    a_reshape_size = [size(a)[x] for x in a_left]
    a_left_len = prod(a_reshape_size)
    a_right_len = prod([size(a)[x] for x in a_right])

    b_left = [dimb...]
    b_right = filter(x -> !(x in dimb), [1:Nb...])
    b_perm = [b_left; b_right]
    b_left_len = prod([size(b)[x] for x in b_left])
    b_reshape_size = [size(b)[x] for x in b_right]
    b_right_len = prod(b_reshape_size)

    am = permutedims(a, a_perm)
    bm = permutedims(b, b_perm)

    # NOTES: I have spent alot of time on these codes
    ra = reshape(am, (a_left_len, a_right_len))
    rb = reshape(bm, (b_left_len, b_right_len))
    
    return Number.(reshape(ra*rb, Tuple([a_reshape_size; b_reshape_size])))
end

contract (generic function with 1 method)

In [11]:
using Zygote: pullback, @adjoint
# test data
𝑨 = rand(2,3,4,5)
𝑩 = rand(4,6,5,3)
dim₁ = (3,4)
dim₂ = (1,3)

@adjoint contract(a, b, dima, dimb) = contract(a, b, dima, dimb), Δ -> begin
    dima_leftover = filter(x -> !(x in dima), [1:ndims(a)...])
    dimb_leftover = filter(x -> !(x in dimb), [1:ndims(b)...])
    
    dim_left_Δ = [1:length(dima_leftover)...]
    dim_right_Δ = [1:length(dimb_leftover)...].+length(dima_leftover)
    
    # notice, this reshape may not be right
    a_conjugate = reshape(contract(Δ, b, Tuple(dim_right_Δ), Tuple(dimb_leftover)), size(a))
    b_conjugate = reshape(contract(Δ, a, Tuple(dim_left_Δ), Tuple(dima_leftover)), size(b))
    
    return (a_conjugate, b_conjugate, nothing, nothing)
end

#gradient(contract, 𝑨, 𝑩, dim₁, dim₂)
res, back = Zygote.pullback(contract, 𝑨, 𝑩, dim₁, dim₂)
back_res = back(res)

for 𝑨:[0.046808644092217744 0.7698536519539523 0.2783029865399267; 0.4336756471574894 0.48048722742388983 0.8010646955510041]

[0.7395004886457495 0.9004262857022896 0.7452909847316285; 0.669535335261148 0.40893103584605983 0.05947134373048879]

[0.11280999475009379 0.36817654789631726 0.01306666371423204; 0.3174322134731711 0.717478139102171 0.35903640118304403]

[0.7570234466827717 0.06746939829089182 0.28366634042502215; 0.9237329147320681 0.472391059401982 0.6784915629596431]

[0.9933130191774957 0.046542589937060175 0.8560031232107954; 0.8287108870276663 0.22984629056985195 0.01866594926480647]

[0.35986230100415817 0.9379546006072916 0.7569863957534533; 0.4199435152040587 0.7749708976863017 0.8884063352460712]

[0.6458175569066824 0.9956610037216298 0.5623821011787193; 0.20381744649692002 0.49469063341721164 0.4423353877312497]

[0.8060240923046509 0.8418863718478333 0.4083623473967861; 0.17381241597567 0.3666167304531027 0.4688604488995851]

[0.07964745901036929 0.89545981827874

[49.737992562610316 58.227683414490905 44.89360926079648; 55.43428895038263 49.45900983909117 53.534743843125725]

[52.827991259586284 61.073834825653684 47.37901577564899; 59.64400699836834 52.08465886130097 55.35936930830325]

[56.372631783615034 66.03256564411573 51.62149266986516; 66.10911112227703 57.44194687046923 62.5650573234705]

[39.748398494771315 47.12151220653099 37.58561857130065; 45.97167888310427 40.1132543329331 44.18248093631701]

[59.43066794626967 69.99470532270324 54.47410499522308; 67.82260741830532 58.999327524050244 65.5477076937107]

[52.63741792547315 60.27058939316751 46.250536809259884; 59.59184684265684 51.66484269239796 55.47349647761155]

[45.83383955711055 53.00457350201797 40.90442059496638; 53.38168271516665 46.80916518321077 51.221361108806825]

[52.22720010573304 61.48734087314647 48.01233824284062; 58.804750348203996 53.29612005854509 57.01740216644224]

[52.311102297512704 60.1371854042916 46.34049367386441; 57.42756062144855 51.45105185196372 54.2

##### PROBLEM
1. Zygote的gradient，里面的函数输出必须是一个scalar，因为gradient的定义是用scalar进行梯度的求解
2. 针对1问题，可以用res, pbfunc = Zygote.forward(...)其中，res为调用函数后的输出结果，而pbfunc为pullback函数
3. 针对2，但是在Julia 1.3，貌似这个函数并没有