You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I tried to use the "colbert-ir/colbertv2.0" pretrained checkpoint for a task (it's essentially a BERT model + a linear layer, for this issue we only focus on the BERT model). Here is how I loaded the model:
using CUDA
using Flux
using OneHotArrays
using Test
using Transformers
using Transformers.TextEncoders
const PRETRAINED_BERT ="colbert-ir/colbertv2.0"
bert_config = Transformers.load_config(PRETRAINED_BERT)
bert_tokenizer = Transformers.load_tokenizer(PRETRAINED_BERT)
bert_model = Transformers.load_model(PRETRAINED_BERT)
const VOCABSIZE =size(bert_tokenizer.vocab.list)[1]
Now, we'll simply run the bert_model over a bunch of sentences.
docs = [
"hello world",
"thank you!",
"a",
"this is some longer text, so length should be longer",
]
encoded_text =encode(bert_tokenizer, docs)
ids, mask = encoded_text.token, encoded_text.attention_mask
Above, by default, ids is a OneHotArray. We convert it to an integer matrix, containing integer token IDS:
integer_ids =Matrix(onecold(ids))
As expected, the bert_model gives the same results on the integer-ids as well as the one-hot encodings:
julia>@testisequal(bert_model((token = integer_ids, attention_mask=mask)), bert_model((token = ids, attention_mask=mask)))
Test Passed
Note that we can also convert from integer_ids back to the OneHotArray using the onehotbatch function. Here's a test just for a sanity check:
julia>@testisequal(ids, onehotbatch(integer_ids, 1:VOCABSIZE)) # test passes
Test Passed
However, if we convert back from the integer ids to the one-hot encodings, and use the converted one-hot encodings in the bert_model, the model throws an error:
You should use integer_ids = reinterpret(Int32, ids) and OneHotArray{VOCABSIZE}(integer_ids). The OneHotArray used in Transformers and Flux is different and the error happened because that OneHotArray does not overload gather
I tried to use the
"colbert-ir/colbertv2.0"
pretrained checkpoint for a task (it's essentially a BERT model + a linear layer, for this issue we only focus on the BERT model). Here is how I loaded the model:Now, we'll simply run the
bert_model
over a bunch of sentences.Above, by default,
ids
is aOneHotArray
. We convert it to an integer matrix, containing integer token IDS:As expected, the
bert_model
gives the same results on the integer-ids as well as the one-hot encodings:Note that we can also convert from
integer_ids
back to theOneHotArray
using theonehotbatch
function. Here's a test just for a sanity check:However, if we convert back from the integer ids to the one-hot encodings, and use the converted one-hot encodings in the
bert_model
, the model throws an error:Am I missing something here?
The text was updated successfully, but these errors were encountered: