-
Notifications
You must be signed in to change notification settings - Fork 71
/
position_embed.jl
76 lines (64 loc) · 2.34 KB
/
position_embed.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""
PositionEmbedding(size::Int, max_len::Int = 1024; trainable::Bool = false)
The position embedding layer. `size` is the number of neuron. `max_len` is the maximum acceptable length of input.
If is not `trainable`, `max_len` will dynamically adjust to the longest input length. If `trainable`, use a random init
embedding value, otherwise use a sin/cos position encoding.
"""
mutable struct PositionEmbedding{F, W <: AbstractArray{F}} <: AbstractBroadcastEmbed{F}
trainable::Bool
embedding::W
end
@functor PositionEmbedding
Flux.trainable(pe::PositionEmbedding) = pe.trainable ? (embedding = pe.embedding,) : (;)
get_value(e::PositionEmbedding, name::Symbol, xs::NamedTuple) = e(first(xs))
function PE(size, pos, i::Int)
if rem(i, 2) == 0
sin(pos/1e4^(i/size))
else
cos(pos/1e4^((i-1)/size))
end
end
function PositionEmbedding(size::Int, max_len::Int = 1024; trainable::Bool = false)
if trainable
embedding = randn(Float32, size, max_len)
else
embedding = Matrix{Float32}(undef, size, max_len)
for l = 1:max_len
map!(i->PE(size, l, i), selectdim(embedding, 2, l), 1:size)
end
end
PositionEmbedding(trainable, embedding)
end
function resize_pe!(pe::PositionEmbedding, len::Int)
max_len = size(pe.embedding, 2)
if len > max_len
if pe.trainable
error("position embedding length exceeded")
else
over = similar(pe.embedding, size(pe.embedding, 1), len)
copyto!(over, 1, pe.embedding, 1, length(pe.embedding))
for l = size(pe.embedding, 2)+1:len
map!(i->PE(size(pe.embedding, 1), l, i), selectdim(over, 2, l), 1:size(pe.embedding, 1))
end
pe.embedding = over
end
end
return nothing
end
(pe::PositionEmbedding)(x::AbstractArray{Int}) = pe(size(x, 1))
(pe::PositionEmbedding)(x::OneHotArray) = pe(size(x, 2))
(pe::PositionEmbedding{F})(x::AbstractArray{F}) where F = pe(size(x, 2))
function (pe::PositionEmbedding)(len::Int)
Flux.Zygote.ignore() do
resize_pe!(pe, len)
end
pe.embedding[:, Base.OneTo(len)]
end
function Base.show(io::IO, pe::PositionEmbedding)
s, max_len = size(pe.embedding)
if pe.trainable
print(io, "PositionEmbedding($(s), max_len=$(max_len))")
else
print(io, "PositionEmbedding($(s))")
end
end