Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type unstable functions #8

Closed
janfrancu opened this issue Jun 25, 2020 · 2 comments
Closed

Type unstable functions #8

janfrancu opened this issue Jun 25, 2020 · 2 comments
Assignees
Labels
bug Something isn't working

Comments

@janfrancu
Copy link

janfrancu commented Jun 25, 2020

Forward step of Transformer is type unstable. Running the example from the docs

using Transfomers

m = Transformer(512, 8, 64, 2048) #define a Transformer block with 8 head and 64 neuron for each head
x = randn(512, 30, 3) #fake data of length 30

y = m(x)

and checking for @code_warntype produces:

julia> @code_warntype m(x,nothing)
Variables
  t::Transformer
  x::Array{Float64,3}
  mask::Core.Compiler.Const(nothing, false)
  a::Any
  insize::Any
  res_a::Any
  pwffn::AbstractArray{T,2} where T
  res_pwffn::Any

Body::Any
1 ─       Core.NewvarNode(:(insize))
│         Core.NewvarNode(:(pwffn))
│         Core.NewvarNode(:(res_pwffn))
│   %4  = (:mask,)::Core.Compiler.Const((:mask,), false)
│   %5  = Core.apply_type(Core.NamedTuple, %4)::Core.Compiler.Const(NamedTuple{(:mask,),T} where T<:Tuple, false)
│   %6  = Core.tuple(mask)::Core.Compiler.Const((nothing,), false)
│   %7  = (%5)(%6)::Core.Compiler.Const((mask = nothing,), false)
│   %8  = Base.getproperty(t, :mh)::Transformers.Basic.MultiheadAttention
│   %9  = Core.kwfunc(%8)::Core.Compiler.Const(Core.var"#Any##kw"(), false)
│   %10 = Base.getproperty(t, :mh)::Transformers.Basic.MultiheadAttention
│         (a = (%9)(%7, %10, x, x, x))
│   %12 = Base.getproperty(t, :drop)::Flux.Dropout
│         (a = (%12)(a))
│   %14 = Base.broadcasted(Transformers.Basic.:+, x, a)::Any
│         (res_a = Base.materialize(%14))
│   %16 = ($(Expr(:static_parameter, 2)) == 3)::Core.Compiler.Const(true, false)
└──       goto #3 if not %16
2 ─       (insize = Transformers.Basic.size(res_a))
│   %19 = res_a::Any
│   %20 = Base.getindex(insize, 1)::Any
└──       (res_a = Transformers.Basic.reshape(%19, %20, Transformers.Basic.:(:)))
3 ┄ %22 = Base.getproperty(t, :mhn)::Flux.LayerNorm
│         (res_a = (%22)(res_a))
│   %24 = Base.getproperty(t, :pw)::Transformers.Basic.PwFFN
│         (pwffn = (%24)(res_a))
│   %26 = Base.getproperty(t, :drop)::Flux.Dropout
│         (pwffn = (%26)(pwffn))
│   %28 = Base.broadcasted(Transformers.Basic.:+, res_a, pwffn)::Any
│         (res_pwffn = Base.materialize(%28))
│   %30 = Base.getproperty(t, :pwn)::Flux.LayerNorm
│         (res_pwffn = (%30)(res_pwffn))
│   %32 = ($(Expr(:static_parameter, 2)) == 3)::Core.Compiler.Const(true, false)
└──       goto #5 if not %32
4 ─ %34 = Core.tuple(res_pwffn, Transformers.Basic.:(:))::Core.Compiler.PartialStruct(Tuple{Any,Colon}, Any[Any, Core.Compiler.Const(Colon(), false)])
│   %35 = Base.tail::Core.Compiler.Const(Base.tail, false)
│   %36 = (%35)(insize)::Union{Tuple, NamedTuple}
└──       (res_pwffn = Core._apply_iterate(Base.iterate, Transformers.Basic.reshape, %34, %36))
5 ┄       return res_pwffn

The source of the unstabillity is probably the multihead attention, but I have not been able to distill it any further.
I am using latest tagged version 0.1.3 of Transformers on Julia 1.4.1.

@chengchingwen chengchingwen self-assigned this Jul 3, 2020
@chengchingwen chengchingwen added invalid This doesn't seem right bug Something isn't working and removed invalid This doesn't seem right labels Jul 3, 2020
@chengchingwen
Copy link
Owner

chengchingwen commented Jul 3, 2020

I can reproduce this result on Julia 1.4.2 with the master branch. It does look like there are some problems with type inference for multihead attention. I will take some time to fix this.

Thanks for reporting it!

@chengchingwen
Copy link
Owner

Should be fixed in the new release (v0.1.7)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants