In [2]:
import torch
from config import BitformerConfig
from bitformer import BitformerForLM
from transformers import LlamaTokenizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [30]:
tokenizer = LlamaTokenizer.from_pretrained('meta-llama/Llama-2-7b-chat-hf', legacy=True)
cfg = BitformerConfig()
cfg.is_causal = True
cfg.hidden_size = 768
cfg.intermediate_size = 2048
cfg.num_hidden_layers = 12
cfg.num_local_experts = 8
cfg.num_experts_per_tok = 2
cfg.bos_token_id = tokenizer.bos_token_id
cfg.eos_token_id = tokenizer.eos_token_id
cfg.pad_token_id = tokenizer.pad_token_id
cfg.output_router_logits = False # needs to be False if going to use .generate
model = BitformerForLM(config=cfg).to(device)
model

BitformerForLM(
  (model): BitformerModel(
    (embed_tokens): Embedding(32000, 256)
    (layers): ModuleList(
      (0-5): 6 x BitformerLayer(
        (self_attn): SelfAttention(
          (q_proj): BitLinear(
            in_features=256, out_features=256, bias=False
            (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          )
          (k_proj): BitLinear(
            in_features=256, out_features=256, bias=False
            (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          )
          (v_proj): BitLinear(
            in_features=256, out_features=256, bias=False
            (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          )
          (o_proj): BitLinear(
            in_features=256, out_features=256, bias=False
            (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          )
          (rotary_emb): RotaryEmbedding()
        )
        (MLP): TokenTopKMoeBlock(
          (router): Linear(in_feat

In [8]:
input = tokenizer('Hello world', return_tensors='pt')
input = {k:v.to(device) for k,v in input.items()}
input

{'input_ids': tensor([[    1, 15043,  3186]], device='cuda:0'),
 'attention_mask': tensor([[1, 1, 1]], device='cuda:0')}

In [14]:
out = model(**input, output_router_logits=cfg.moe) # set to true to get aux loss

In [31]:
gen = model.generate(input['input_ids'], max_new_tokens=20)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


In [34]:
tokenizer.decode(gen[0])

'<s>Hello world abol{` youngisceisceoin cru arbitrץ Febaset SendbewInputInput mentparams maintenance carriage rép'