-
-
Notifications
You must be signed in to change notification settings - Fork 361
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bert token type embedding #213
Comments
lucidrains
added a commit
that referenced
this issue
Dec 2, 2023
@eyalmazuz hey Eyal! thanks for bringing this up do you want to see if the following works for you in the latest version? import torch
from x_transformers import TransformerWrapper, Decoder
model = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
embed_num_tokens = dict(type = 5),
attn_layers = Decoder(
dim = 512,
depth = 12,
heads = 8
)
)
x = torch.randint(0, 256, (1, 1024))
types = torch.randint(0, 5, (1, 1024))
logits = model(x, embed_ids = dict(type = types))
logits.shape # (1, 1024, 20000) |
lucidrains
added a commit
that referenced
this issue
Dec 2, 2023
@lucidrains Hey Phil! Thanks again for the feature |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I was looking at the example of the Bert you give in the code
but unlike the original paper
I didn't see in the source code a way to add token type embedding
a solution is to manage myself by doing something like
but it seems weird to handle the word embedding matrix myself when I can just used the regular
TransformerWrapper
is there a way to add token type embedding into the model so I could just say something like:
The text was updated successfully, but these errors were encountered: