-
Notifications
You must be signed in to change notification settings - Fork 71
/
bert_textencoder.jl
119 lines (105 loc) · 5.76 KB
/
bert_textencoder.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
using ..WordPieceModel
using ..WordPieceModel: DAT
using FuncPipelines
using TextEncodeBase
using TextEncodeBase: nested2batch, nestedcall
using TextEncodeBase: BaseTokenization, WrappedTokenization, MatchTokenization, Splittable,
ParentStages, TokenStages, SentenceStage, WordStage, Batch, Sentence, getvalue, getmeta
using TextEncodeBase: SequenceTemplate, ConstTerm, InputTerm, RepeatedTerm
# bert tokenizer
struct BertCasedPreTokenization <: BaseTokenization end
struct BertUnCasedPreTokenization <: BaseTokenization end
TextEncodeBase.splitting(::BertCasedPreTokenization, s::SentenceStage) = bert_cased_tokenizer(getvalue(s))
TextEncodeBase.splitting(::BertUnCasedPreTokenization, s::SentenceStage) = bert_uncased_tokenizer(getvalue(s))
const BertTokenization = Union{BertCasedPreTokenization, BertUnCasedPreTokenization}
Base.show(io::IO, ::BertCasedPreTokenization) = print(io, nameof(bert_cased_tokenizer))
Base.show(io::IO, ::BertUnCasedPreTokenization) = print(io, nameof(bert_uncased_tokenizer))
# encoder constructor
function BertTextEncoder(tkr::AbstractTokenizer, vocab::AbstractVocabulary{String}, process,
startsym::String, endsym::String, padsym::String, trunc::Union{Nothing, Int})
return TransformerTextEncoder(tkr, vocab, process, startsym, endsym, padsym, trunc)
end
BertTextEncoder(::typeof(bert_cased_tokenizer), args...; kws...) =
BertTextEncoder(BertCasedPreTokenization(), args...; kws...)
BertTextEncoder(::typeof(bert_uncased_tokenizer), args...; kws...) =
BertTextEncoder(BertUnCasedPreTokenization(), args...; kws...)
BertTextEncoder(bt::BertTokenization, wordpiece::WordPiece, args...; kws...) =
BertTextEncoder(WordPieceTokenization(bt, wordpiece), args...; kws...)
function BertTextEncoder(t::WordPieceTokenization, args...; match_tokens = nothing, kws...)
if isnothing(match_tokens)
return BertTextEncoder(TextTokenizer(t), Vocab(t.wordpiece), args...; kws...)
else
match_tokens = match_tokens isa AbstractVector ? match_tokens : [match_tokens]
return BertTextEncoder(TextTokenizer(MatchTokenization(t, match_tokens)), Vocab(t.wordpiece), args...; kws...)
end
end
function BertTextEncoder(t::AbstractTokenization, vocab::AbstractVocabulary, args...; match_tokens = nothing, kws...)
if isnothing(match_tokens)
return BertTextEncoder(TextTokenizer(t), vocab, args...; kws...)
else
match_tokens = match_tokens isa AbstractVector ? match_tokens : [match_tokens]
return BertTextEncoder(TextTokenizer(MatchTokenization(t, match_tokens)), vocab, args...; kws...)
end
end
function _wp_vocab(wp::WordPiece)
vocab = Vector{String}(undef, length(wp.trie))
for (str, id) in wp.trie
vocab[wp.index[id]] = str
end
return vocab
end
TextEncodeBase.Vocab(wp::WordPiece) = Vocab(_wp_vocab(wp), DAT.decode(wp.trie, wp.unki))
function BertTextEncoder(tkr::AbstractTokenizer, vocab::AbstractVocabulary, process;
startsym = "[CLS]", endsym = "[SEP]", padsym = "[PAD]", trunc = nothing)
check_vocab(vocab, startsym) || @warn "startsym $startsym not in vocabulary, this might cause problem."
check_vocab(vocab, endsym) || @warn "endsym $endsym not in vocabulary, this might cause problem."
check_vocab(vocab, padsym) || @warn "padsym $padsym not in vocabulary, this might cause problem."
return BertTextEncoder(tkr, vocab, process, startsym, endsym, padsym, trunc)
end
function BertTextEncoder(tkr::AbstractTokenizer, vocab::AbstractVocabulary;
fixedsize = false, trunc_end = :tail, pad_end = :tail, process = nothing,
kws...)
enc = BertTextEncoder(tkr, vocab, TextEncodeBase.process(AbstractTextEncoder); kws...)
# default processing pipelines for bert encoder
return BertTextEncoder(enc) do e
bert_default_preprocess(; trunc = e.trunc, startsym = e.startsym, endsym = e.endsym, padsym = e.padsym,
fixedsize, trunc_end, pad_end, process)
end
end
BertTextEncoder(builder, e::TrfTextEncoder) = TrfTextEncoder(builder, e)
# preprocess
function bert_default_preprocess(; startsym = "[CLS]", endsym = "[SEP]", padsym = "[PAD]",
fixedsize = false, trunc = nothing, trunc_end = :tail, pad_end = :tail,
process = nothing)
truncf = get_trunc_pad_func(fixedsize, trunc, trunc_end, pad_end)
maskf = get_mask_func(trunc, pad_end)
if isnothing(process)
process =
# group input for SequenceTemplate
Pipeline{:token}(grouping_sentence, :token) |>
# add start & end symbol, compute segment and merge sentences
Pipeline{:token_segment}(
SequenceTemplate(
ConstTerm(startsym, 1), InputTerm{String}(1), ConstTerm(endsym, 1),
RepeatedTerm(InputTerm{String}(2), ConstTerm(endsym, 2); dynamic_type_id = true)
), :token
) |>
Pipeline{:token}(nestedcall(first), :token_segment) |>
Pipeline{:segment}(nestedcall(last), :token_segment)
end
# get token and convert to string
return Pipeline{:token}(nestedcall(string_getvalue), 1) |>
process |>
Pipeline{:attention_mask}(maskf, :token) |>
# truncate input that exceed length limit and pad them to have equal length
Pipeline{:token}(truncf(padsym), :token) |>
# convert to dense array
Pipeline{:token}(nested2batch, :token) |>
# truncate & pad segment
Pipeline{:segment}(truncf(1), :segment) |>
Pipeline{:segment}(nested2batch, :segment) |>
# sequence mask
Pipeline{:sequence_mask}(identity, :attention_mask) |>
# return token and mask
PipeGet{(:token, :segment, :attention_mask, :sequence_mask)}()
end