In [1]:
import torch
from transformers import LlamaTokenizer

from utils import get_yaml
from models.config import BitformerConfig
from models.model_zoo import BitformerForLM

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = LlamaTokenizer.from_pretrained('meta-llama/Llama-2-7b-chat-hf', legacy=True)

In [3]:
yargs = get_yaml('yamls/small_bitformer.yaml')
cfg = BitformerConfig(**yargs['model_config'])
cfg.bos_token_id = tokenizer.bos_token_id
cfg.eos_token_id = tokenizer.eos_token_id
cfg.pad_token_id = tokenizer.pad_token_id

In [4]:
# GPT-like
cfg.is_causal = True
cfg.output_router_logits = False # needs to be False if going to use .generate
model = BitformerForLM(config=cfg).to(device)
print(model)
input = tokenizer('Hello world', return_tensors='pt')
input = {k:v.to(device) for k,v in input.items()}
print(input)
out = model(**input,
            labels=input['input_ids'],
            output_hidden_states=True,
            output_attentions=True,
            output_router_logits=cfg.moe # set to true to get aux loss
            )
print(out.loss)
print(out.aux_loss)
print(out.logits.shape)
print(out.hidden_states[0].shape)
print(out.attentions[0].shape)
print(out.router_logits[0].shape)
gen = model.generate(input['input_ids'], max_new_tokens=20)
print(tokenizer.decode(gen[0])) # random because random weights

BitformerForLM(
  (model): BitformerModel(
    (embed_tokens): Embedding(32000, 512)
    (layers): ModuleList(
      (0-11): 12 x BitformerLayer(
        (self_attn): SelfAttention(
          (q_proj): BitLinear(
            in_features=512, out_features=512, bias=False
            (rms_norm): RMSNorm()
          )
          (k_proj): BitLinear(
            in_features=512, out_features=512, bias=False
            (rms_norm): RMSNorm()
          )
          (v_proj): BitLinear(
            in_features=512, out_features=512, bias=False
            (rms_norm): RMSNorm()
          )
          (o_proj): BitLinear(
            in_features=512, out_features=512, bias=False
            (rms_norm): RMSNorm()
          )
          (rotary_emb): RotaryEmbedding()
        )
        (MLP): TokenTopKMoeBlock(
          (router): Linear(in_features=512, out_features=4, bias=False)
          (experts): ModuleList(
            (0-3): 4 x MLP(
              (w1): BitLinear(
                in_features=

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


tensor(10.9279, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2.1160, device='cuda:0', grad_fn=<MulBackward0>)
torch.Size([1, 3, 32000])
torch.Size([1, 3, 512])
torch.Size([1, 8, 3, 3])
torch.Size([3, 4])
<s>Hello world solid KnoB Nations Objectńska attr Apache playingugg listopada Objectsuper Writ Staff neglectobjects dicembre externeчни


In [1]:
from datasets import load_dataset

data = load_dataset('allenai/dolma', split='train', streaming=True)


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


{'id': '600f7d0e70e779b5c95464411000c5998ea252ba',
 'added': '2023-04-25T05:49:46.922Z',
 'created': '2006-01-09T00:00:00.000Z',
 'source': 'gutenberg'}

In [2]:
next(iter(data)).keys()

dict_keys(['id', 'text', 'added', 'created', 'source'])

In [5]:
# BERT-like
cfg.is_causal = False
cfg.output_router_logits = False # needs to be False if going to use .generate
model = BitformerForLM(config=cfg).to(device)
print(model)
input = tokenizer('Hello world', return_tensors='pt')
input = {k:v.to(device) for k,v in input.items()}
print(input)
out = model(**input,
            labels=input['input_ids'],
            output_hidden_states=True,
            output_attentions=True,
            output_router_logits=cfg.moe # set to true to get aux loss
            )
print(out.loss)
print(out.aux_loss)
print(out.logits.shape)
print(out.hidden_states[0].shape)
print(out.attentions[0].shape)
print(out.router_logits[0].shape)

BitformerForLM(
  (model): BitformerModel(
    (embed_tokens): Embedding(32000, 512)
    (layers): ModuleList(
      (0-11): 12 x BitformerLayer(
        (self_attn): SelfAttention(
          (q_proj): BitLinear(
            in_features=512, out_features=512, bias=False
            (rms_norm): RMSNorm()
          )
          (k_proj): BitLinear(
            in_features=512, out_features=512, bias=False
            (rms_norm): RMSNorm()
          )
          (v_proj): BitLinear(
            in_features=512, out_features=512, bias=False
            (rms_norm): RMSNorm()
          )
          (o_proj): BitLinear(
            in_features=512, out_features=512, bias=False
            (rms_norm): RMSNorm()
          )
          (rotary_emb): RotaryEmbedding()
        )
        (MLP): SentenceTopKMoeBlock(
          (router): Linear(in_features=512, out_features=4, bias=False)
          (experts): ModuleList(
            (0-3): 4 x MLP(
              (w1): BitLinear(
                in_featur