Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bert token type embedding #213

Closed
eyalmazuz opened this issue Dec 2, 2023 · 2 comments
Closed

Bert token type embedding #213

eyalmazuz opened this issue Dec 2, 2023 · 2 comments

Comments

@eyalmazuz
Copy link

I was looking at the example of the Bert you give in the code
but unlike the original paper
I didn't see in the source code a way to add token type embedding

a solution is to manage myself by doing something like

import torch
from torch import nn
from x_transformers import ContinuousTransformerWrapper, Encoder

class Bert(nn.Module):
    def __init__(self, num_tokens, num_types, dim): 
        model = ContinuousTransformerWrapper(
            dim_in = dim,
            dim_out = dim,
            max_seq_len = 1024
            attn_layers = Encoder(
                dim = dim,
                depth = 12,
                heads = 8
            )
        )
        self.type_emb = nn.Embedding(num_types, dim)
        self.token_emb = nn.Embedding(num_tokens, dim)
        
    def forward(self, tokens, types, mask):
        emb = self.token_emb(tokens)
        type_emb = self.type_emb(types)
        
        out = self.model(emb + type_emb, mask)
        
        return out

but it seems weird to handle the word embedding matrix myself when I can just used the regular TransformerWrapper
is there a way to add token type embedding into the model so I could just say something like:

import torch
from x_transformers import TransformerWrapper, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    max_token_type = 3,
    attn_layers = Encoder(
        dim = 512,
        depth = 12,
        heads = 8
    )
).cuda()

x = torch.randint(0, 256, (1, 1024)).cuda()
mask = torch.ones_like(x).bool()

model(x, token_type = token_type, mask = mask) # (1, 1024, 20000)
@lucidrains
Copy link
Owner

lucidrains commented Dec 2, 2023

@eyalmazuz hey Eyal! thanks for bringing this up

do you want to see if the following works for you in the latest version?

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    embed_num_tokens = dict(type = 5),
    attn_layers = Decoder(
        dim = 512,
        depth = 12,
        heads = 8
    )
)

x = torch.randint(0, 256, (1, 1024))
types = torch.randint(0, 5, (1, 1024))

logits = model(x, embed_ids = dict(type = types))
logits.shape # (1, 1024, 20000)

@eyalmazuz
Copy link
Author

@lucidrains Hey Phil!
Thanks for the quick response and fix!
I think that solution is great and could work for me

Thanks again for the feature
I'll close the issue now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants