Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature request: add embedding layer #66

Closed
ngphuoc opened this issue Sep 20, 2017 · 11 comments
Closed

Feature request: add embedding layer #66

ngphuoc opened this issue Sep 20, 2017 · 11 comments

Comments

@ngphuoc
Copy link

ngphuoc commented Sep 20, 2017

It would be helpful to have an embedding layer for a large-vocabulary input sequence (http://pytorch.org/docs/master/nn.html#sparse-layers):

Something like:

m = Chain(
  Embed(V, E),
  LSTM(E, 256),
  Dense(256, V),
   softmax)
@MikeInnes
Copy link
Member

This is trivial with normal slicing:

julia> W = randn(3, 10)
3×10 Array{Float64,2}:
  1.52921     0.34536   -1.54119     -0.303022    0.123047    0.00684176
 -0.138502    0.4285    -0.841863      0.0720752  -2.19672    -0.0968025 
 -0.00694323  0.916636   0.834962     -0.239865    0.0237941   0.0236618 

julia> W[:, 5]
3-element Array{Float64,1}:
 -0.972017
  0.293653
 -1.20708 

julia> W[:, [5, 6]]
3×2 Array{Float64,2}:
 -0.972017   2.58503
  0.293653  -2.29474
 -1.20708    1.82632

julia> W[:, [5 6; 7 8]]
3×2×2 Array{Float64,3}:
[:, :, 1] =
 -0.972017  -0.914234
  0.293653   0.333886
 -1.20708   -0.386469

[:, :, 2] =
  2.58503  -0.303022 
 -2.29474   0.0720752
  1.82632  -0.239865

There may be some benefit to having a convenience wrapper though. Should be easy to put together if you want to set up a PR.

@oxinabox
Copy link
Member

Better still is to use the OneHot magic, I think.
This is like expressly what it is for?
Since onehot encoded value takes up no more space than an Int, and onehot multiplication is slicing under the hood.
http://fluxml.github.io/Flux.jl/latest/data/onehot.html

@ngphuoc
Copy link
Author

ngphuoc commented Oct 10, 2017

Thanks. I've tried onehot and got the follow error:

julia> x = onehot(1, 1:10)
10-element Flux.OneHotVector:
  true
 false
 false
 false
 false
 false
 false
 false
 false
 false

julia> m = Chain(Dense(10, 5), Dense(5, 2))
Chain(Dense(10, 5), Dense(5, 2))

julia> m(x)
ERROR: MethodError: *(::TrackedArray{…,Array{Float64,2}}, ::Flux.OneHotVector) is ambiguous. Candidates:
  *(A::AbstractArray{T,2} where T, b::Flux.OneHotVector) in Flux at /home/phuoc/.julia/v0.6/Flux/src/onehot.jl:10
  *(a::Flux.Tracker.TrackedArray{T,2,A} where A where T, b::AbstractArray{T,1} where T) in Flux.Tracker at /home/phuoc/.julia/v0.6/Flux/src/tracker/lib.jl:67
Possible fix, define
  *(::Flux.Tracker.TrackedArray{T,2,A} where A where T, ::Flux.OneHotVector)
Stacktrace:
 [1] (::Flux.Dense{Base.#identity,TrackedArray{…,Array{Float64,2}},TrackedArray{…,Array{Float64,1}}})(::Flux.OneHotVector) at /home/phuoc/.julia/v0.6/Flux/src/layers/basic.jl:61
 [2] mapfoldl_impl(::Base.#identity, ::Flux.##45#46, ::Flux.OneHotVector, ::Array{Any,1}, ::Int64) at ./reduce.jl:43
 [3] (::Flux.Chain)(::Flux.OneHotVector) at /home/phuoc/.julia/v0.6/Flux/src/layers/basic.jl:28

@ngphuoc
Copy link
Author

ngphuoc commented Oct 10, 2017

I also tried to add the following model but failed to train since there was no parameters returned by params. Did I miss something?:

julia> using Flux
julia> using Flux: onehotbatch, unstack, truncate!, throttle, logloss, initn
julia> using Flux.Tracker: param, back!, data, grad
julia> struct LangModel{E,R,F}
         emb::E
         rnn::R
         fc::F
       end

julia> LangModel(v::Int, e::Int, h::Int) = LangModel(param(initn(e, v)),
                                      LSTM(e, h),
                                      Dense(h, v))
LangModel

julia> Flux.children(m::LangModel) = (m.emb, m.rnn, m.fc,)

julia> (m::LangModel)(x) = softmax(m.fc(m.rnn(m.emb[:,x])))
julia> m = LangModel(2,3,4)
LangModel{TrackedArray{,Array{Float64,2}},Flux.Recur{Flux.LSTMCell{Flux.Dense{NNlib.#σ,TrackedArray{…,Array{Float64,2}},TrackedArray{…,Array{Float64,1}}},Flux.Dense{Base.#tanh,TrackedArray{…,Array{Float64,2}},TrackedArray{…,Array{Float64,1}}},TrackedArray{…,Array{Float64,1}}}},Flux.Dense{Base.#identity,TrackedArray{…,Array{Float64,2}},TrackedArray{…,Array{Float64,1}}}}(param([0.00883202 -0.00318895; 0.0140951 0.0126121; -0.0082283 -0.00736651]), Recur(LSTMCell(3, 4)), Dense(4, 2))

julia> params(m)
0-element Array{Any,1}

@MikeInnes
Copy link
Member

The ambiguity issue is one I can fix.

If you're on the latest Flux you need to change the Flux.children line to Flux.treelike(LangModel).

@ngphuoc
Copy link
Author

ngphuoc commented Oct 10, 2017

Thank you. Flux.treelike(LangModel) works perfectly.

@MikeInnes
Copy link
Member

Closing this for now as I think matmul is fine for embeddings, unless anyone has a specific proposal. I've noted the ambiguity issue though so I'll fix that ASAP.

@datnamer
Copy link

How about a layer that would provide a wrapper to import weights for pretrained embeddings? https://discuss.pytorch.org/t/can-we-use-pre-trained-word-embeddings-for-weight-initialization-in-nn-embedding/1222

And then a way to freeze them during training.

@oxinabox
Copy link
Member

oxinabox commented Jan 16, 2019

@datnamer it is gloriously trivial

Load the pretrained weights into a matrix W from Embeddings.jl.
(Code to do the import lives there)

Then use the onehot product discussed above.
e.g for word with one hot vector ei.

If you want to allow the embedding to fine tune use Param(W)*ei,
if you want to freeze it, use W*ei.

@datnamer
Copy link

@oxinabox thanks that looks great. Would it be similiairly easy with graph embedding? I'm just starting to play around but there aren't a lot of tutorials for that.

@oxinabox
Copy link
Member

I'm not sure. I don't know that pretrained graph embeddings are a thing.

@Drvi Drvi mentioned this issue Jan 28, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants