In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.modeling_utils import PreTrainedModel

from quiet_star.config import Config, ModelConfig
from quiet_star.torch.openelm import OpenELMThoughtModel
from quiet_star.torch.pretrained import PretrainedThoughtModel
from quiet_star.torch.qwen import QwenThoughtModel
from quiet_star.torch.utils import torch_dtype


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
config = Config(
    batch_size=2,
    lookahead_tokens=3,
    thought_length=3,
    model=ModelConfig(
        attn_type="torch",
        dtype="float32",
        device=device,
        dropout_attn=0.0,
        dropout_embed=0.0,
        model_name="apple/OpenELM-270M-Instruct",
        tokenizer_name="meta-llama/Llama-2-7b-hf",
        max_length=32,
    ),
)

In [3]:
model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
    config.model.model_name,
    torch_dtype=torch_dtype(config.model.dtype),
    trust_remote_code=True,
).to(config.model.device)

tokenizer = AutoTokenizer.from_pretrained(
    config.model.tokenizer_name,
    trust_remote_code=True,
)

text = "This is a longer test sentence."
x = tokenizer(
    text,
    padding="do_not_pad",
    truncation=True,
    max_length=config.model.max_length - config.thought_length - 2,
    return_tensors="np",
    return_attention_mask=False,
)["input_ids"][0].tolist()
input_ids = torch.tensor(
    [x for _ in range(config.batch_size)],
    dtype=torch.int64,
    device=config.model.device,
)


In [4]:
model.modules

<bound method Module.modules of OpenELMForCausalLM(
  (transformer): OpenELMModel(
    (token_embeddings): Embedding(32000, 1280)
    (layers): ModuleList(
      (0): OpenELMDecoderLayer(
        (attn): OpenELMMultiHeadCausalAttention(
          query_heads=12, key_heads=3, value_heads=3
          (qkv_proj): Linear(in_features=1280, out_features=1152, bias=False)
          (pos_embedding): OpenELMRotaryEmbedding(	model_dim=64, max_seq_length=4096, freq_constant=10000)
          (q_norm): OpenELMRMSNorm(num_features=64, eps=1e-06)
          (k_norm): OpenELMRMSNorm(num_features=64, eps=1e-06)
          (out_proj): Linear(in_features=768, out_features=1280, bias=False)
        )
        (ffn): OpenELMFeedForwardNetwork(
          (ffn_with_glu) : True
          (proj_1): Linear(in_features=1280, out_features=1536, bias=False)
          (proj_2): Linear(in_features=768, out_features=1280, bias=False)
          (act): SiLU()
        )
        (ffn_norm): OpenELMRMSNorm(num_features=1280,

In [6]:
len(model.transformer.layers)

16

In [5]:
input_ids.shape, input_ids

(torch.Size([2, 8]),
 tensor([[    1,   910,   338,   263,  5520,  1243, 10541, 29889],
         [    1,   910,   338,   263,  5520,  1243, 10541, 29889]],
        device='cuda:0'))

In [6]:
outputs = model(input_ids, output_hidden_states=True)
outputs[0].shape, outputs[0]

(torch.Size([2, 8, 32000]),
 tensor([[[ -9.9827,  -9.4185,  -1.3738,  ...,  -5.5373,  -6.8383,  -4.9822],
          [-11.7295, -20.6063,  -5.0599,  ...,  -8.0301,  -9.4964,  -8.1932],
          [-11.1467, -12.0488,  -4.3880,  ...,  -6.7239,  -8.3881,  -7.5842],
          ...,
          [-10.9954, -18.0083,  -1.4600,  ...,  -9.1341, -11.1416,  -9.1180],
          [-11.8424, -17.7331,  -4.0620,  ..., -11.7066, -11.5651,  -7.8645],
          [-11.9498, -12.1675,  -7.3315,  ...,  -6.3045,  -8.0009,  -4.4673]],
 
         [[ -9.9827,  -9.4185,  -1.3738,  ...,  -5.5373,  -6.8383,  -4.9822],
          [-11.7295, -20.6063,  -5.0599,  ...,  -8.0301,  -9.4964,  -8.1932],
          [-11.1467, -12.0488,  -4.3880,  ...,  -6.7239,  -8.3881,  -7.5842],
          ...,
          [-10.9954, -18.0083,  -1.4600,  ...,  -9.1341, -11.1416,  -9.1180],
          [-11.8424, -17.7331,  -4.0620,  ..., -11.7066, -11.5651,  -7.8645],
          [-11.9498, -12.1675,  -7.3315,  ...,  -6.3045,  -8.0009,  -4.4673]]],
 

In [7]:
model.transformer.token_embeddings(torch.tensor([[1]], dtype=torch.int64, device=model.device))

tensor([[[1.3351e-03, 6.4850e-04, 1.4067e-05,  ..., 1.8616e-03,
          8.2016e-04, 7.1335e-04]]], device='cuda:0',
       grad_fn=<EmbeddingBackward0>)

In [8]:
outputs[2][0][0].shape, outputs[2][0][0]

(torch.Size([8, 1280]),
 tensor([[ 1.3351e-03,  6.4850e-04,  1.4067e-05,  ...,  1.8616e-03,
           8.2016e-04,  7.1335e-04],
         [ 1.3306e-02,  3.4943e-03,  4.4434e-02,  ...,  3.1250e-02,
           4.6631e-02,  1.3306e-02],
         [-5.0781e-02,  2.6245e-02,  5.4932e-03,  ..., -4.9805e-02,
          -1.2268e-02, -2.9663e-02],
         ...,
         [-2.1729e-02, -2.2705e-02,  2.3804e-02,  ..., -2.7710e-02,
           3.2959e-02, -1.4771e-02],
         [ 1.6357e-02, -3.6377e-02,  6.2561e-03,  ...,  2.4902e-02,
           1.0071e-02, -8.5449e-02],
         [-1.4725e-03,  1.6235e-02,  2.7954e-02,  ..., -1.3550e-02,
           3.4485e-03, -2.5513e-02]], device='cuda:0',
        grad_fn=<SelectBackward0>))

In [9]:
tmodel = OpenELMThoughtModel(config).to(config.model.device)



number of parameters: 278.12M


In [10]:
b, l = input_ids.shape

causal_mask1 = torch.triu(
    torch.full((l, l), float("-inf"), dtype=tmodel._dtype, device=tmodel.device),
    diagonal=1,
)
causal_mask1 = causal_mask1.unsqueeze(0).unsqueeze(1)
causal_mask1 = causal_mask1.repeat(b, 1, 1, 1)

row = torch.arange(0, l, dtype=torch.int64, device=tmodel.device)
position_ids = row.reshape(1, l).tile((b, 1))

x = tmodel.tok_emb(input_ids)
x.shape, x

(torch.Size([2, 8, 1280]),
 tensor([[[ 1.3351e-03,  6.4850e-04,  1.4067e-05,  ...,  1.8616e-03,
            8.2016e-04,  7.1335e-04],
          [ 1.3306e-02,  3.4943e-03,  4.4434e-02,  ...,  3.1250e-02,
            4.6631e-02,  1.3306e-02],
          [-5.0781e-02,  2.6245e-02,  5.4932e-03,  ..., -4.9805e-02,
           -1.2268e-02, -2.9663e-02],
          ...,
          [-2.1729e-02, -2.2705e-02,  2.3804e-02,  ..., -2.7710e-02,
            3.2959e-02, -1.4771e-02],
          [ 1.6357e-02, -3.6377e-02,  6.2561e-03,  ...,  2.4902e-02,
            1.0071e-02, -8.5449e-02],
          [-1.4725e-03,  1.6235e-02,  2.7954e-02,  ..., -1.3550e-02,
            3.4485e-03, -2.5513e-02]],
 
         [[ 1.3351e-03,  6.4850e-04,  1.4067e-05,  ...,  1.8616e-03,
            8.2016e-04,  7.1335e-04],
          [ 1.3306e-02,  3.4943e-03,  4.4434e-02,  ...,  3.1250e-02,
            4.6631e-02,  1.3306e-02],
          [-5.0781e-02,  2.6245e-02,  5.4932e-03,  ..., -4.9805e-02,
           -1.2268e-02, -2.966

In [11]:
torch.allclose(x[0], outputs[2][0][0])

True

In [12]:
outputs[2][1][0].shape, outputs[2][1][0]

(torch.Size([8, 1280]),
 tensor([[ 0.0274,  0.0157,  0.0946,  ..., -0.0227,  0.0035, -0.0185],
         [-0.0532, -0.1701,  0.0188,  ...,  0.1279,  0.1166,  0.0664],
         [-0.4543, -0.0032, -0.1225,  ...,  0.1885, -0.0846, -0.0783],
         ...,
         [-0.1314, -0.1227, -0.0984,  ..., -0.0571,  0.1851,  0.0105],
         [-0.2809,  0.0842,  0.7183,  ...,  0.0822,  0.1731,  0.1994],
         [ 0.0046, -0.0328, -0.0318,  ...,  0.0520,  0.0103,  0.0787]],
        device='cuda:0', grad_fn=<SelectBackward0>))

In [13]:
i, layer = 0, tmodel.layers[0]

residual = x
x2 = layer.attn_norm(x)

attn_out = layer.attn(x2, attention_mask=causal_mask1)[0]

x2 = residual + attn_out
x2 = x2 + layer.ffn(layer.ffn_norm(x2))

x2[0].shape, x2[0]

(torch.Size([8, 1280]),
 tensor([[ 0.0274,  0.0157,  0.0946,  ..., -0.0227,  0.0035, -0.0185],
         [-0.0532, -0.1701,  0.0188,  ...,  0.1279,  0.1166,  0.0664],
         [-0.4543, -0.0032, -0.1225,  ...,  0.1885, -0.0846, -0.0783],
         ...,
         [-0.1314, -0.1227, -0.0984,  ..., -0.0571,  0.1851,  0.0105],
         [-0.2809,  0.0842,  0.7183,  ...,  0.0822,  0.1731,  0.1994],
         [ 0.0046, -0.0328, -0.0318,  ...,  0.0520,  0.0103,  0.0787]],
        device='cuda:0', grad_fn=<SelectBackward0>))

In [21]:
import math

i, layer = 0, tmodel.layers[0]

residual = x
x2 = layer.attn_norm(x)

qkv = (
    layer.attn.qkv_proj(x2)
    .reshape(b, l, 2 * tmodel.num_kv_heads[i] + tmodel.num_query_heads[i], -1)
    .swapaxes(1, 2)
)
q, k, v = qkv.split([tmodel.num_query_heads[i], tmodel.num_kv_heads[i], tmodel.num_kv_heads[i]], dim=1)

q = layer.attn.q_norm(q)
k = layer.attn.k_norm(k)

# apply rotary embedding
q, k = layer.attn.pos_embedding(q, k)
# cos, sin = layer.self_attn.rotary_emb(v, seq_len=l)
# q = self.apply_rotary_pos_emb(q, cos, sin, position_ids)
# k = self.apply_rotary_pos_emb(k, cos, sin, position_ids)

print(layer.attn.num_groups, tmodel.num_gqa_groups)
k = k.repeat_interleave(layer.attn.num_groups, dim=1)
v = v.repeat_interleave(layer.attn.num_groups, dim=1)

a = torch.nn.functional.softmax(
    (torch.matmul(q, k.transpose(-2, -1)) + causal_mask1)
    / math.sqrt(q.size(-1)),
    dim=-1,
)

# attn_out is (B, H, L, E)
attn_out = torch.matmul(a, v)

attn_out = layer.attn.out_proj(
    attn_out.permute([0, 2, 1, 3]).reshape(b, l, tmodel.num_query_heads[i] * tmodel.head_dim)
)
x2 = residual + attn_out
x2 = x2 + layer.ffn(layer.ffn_norm(x2))

x2[0].shape, x2[0]

4 4


(torch.Size([8, 1280]),
 tensor([[ 0.0274,  0.0157,  0.0946,  ..., -0.0227,  0.0035, -0.0185],
         [-0.0532, -0.1701,  0.0188,  ...,  0.1279,  0.1166,  0.0664],
         [-0.4543, -0.0032, -0.1225,  ...,  0.1885, -0.0846, -0.0783],
         ...,
         [-0.1314, -0.1227, -0.0984,  ..., -0.0571,  0.1851,  0.0105],
         [-0.2809,  0.0842,  0.7183,  ...,  0.0822,  0.1731,  0.1994],
         [ 0.0046, -0.0328, -0.0318,  ...,  0.0520,  0.0103,  0.0787]],
        device='cuda:0', grad_fn=<SelectBackward0>))

In [18]:
import inspect
print(inspect.getsource(layer.attn.__class__))

class OpenELMMultiHeadCausalAttention(nn.Module):
    def __init__(self, config: OpenELMConfig, layer_idx: int) -> None:
        super().__init__()
        self.layer_idx = layer_idx
        head_dim = config.head_dim
        q_heads = config.num_query_heads[layer_idx]
        k_heads = config.num_kv_heads[layer_idx]
        v_heads = config.num_kv_heads[layer_idx]

        self.qkv_proj = nn.Linear(
            in_features=config.model_dim,
            out_features=(q_heads + k_heads + v_heads) * head_dim,
            bias=False,
        )

        self.pos_embedding = OpenELMRotaryEmbedding(
            model_dim=config.head_dim,
            max_seq_length=config.rope_max_length,
            freq_constant=config.rope_freq_constant,
        )

        if config.normalize_qk_projections:
            self.q_norm = OpenELMRMSNorm(
                num_features=config.head_dim,
            )
            self.k_norm = OpenELMRMSNorm(
                num_features=config.head_dim,
           

In [19]:
layer.attn.__class__.__module__

'transformers_modules.apple.OpenELM-270M-Instruct.eb111ff2e6724348e5b905984063d4064d4bc579.modeling_openelm'