In [1]:
import math
import random

import lightning
import lightning.pytorch
import torch
import torch.nn
import torch.utils.data
from torch.nn import functional as F
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.modeling_utils import PreTrainedModel

from quiet_star.config import Config, ModelConfig
from quiet_star.constants import END_THOUGHT_TOKEN, START_THOUGHT_TOKEN
from quiet_star.torch.utils import assert_shape, expand_dims, torch_dtype
from quiet_star.torch.pretrained import PretrainedThoughtModel

In [2]:
def prepare_test_inputs(
    model: lightning.LightningModule, config: Config, text: str, max_thought_length: int
) -> tuple[list[int], list[list[int]]]:
    x = model.tokenizer(
        text,
        padding="do_not_pad",
        truncation=True,
        max_length=config.model.max_length - config.thought_length - 2,
        return_tensors="np",
        return_attention_mask=False,
    )["input_ids"][0].tolist()
    start_thought_token = model.tokenizer(
        START_THOUGHT_TOKEN, return_tensors="np", return_attention_mask=False
    )["input_ids"][0, 0].tolist()
    thought_tokens = [[start_thought_token] for _ in range(len(x))]
    if max_thought_length > 0:
        next_tokens = [
            [
                random.randrange(0, len(model.tokenizer))
                for _ in range(max_thought_length)
            ]
            for _ in range(len(x))
        ]
        thought_tokens = [thought_tokens[i] + next_tokens[i] for i in range(len(x))]

    return x, thought_tokens

def prepare_next_thought_token_input(
    x: list[int],
    thought_tokens: list[list[int]],
    thought_length: int,
    batch_size: int,
    device: str | torch.device,
    last_thought_token_only: bool,
) -> torch.Tensor:
    if not last_thought_token_only:
        thoughts = [tokens[:thought_length] for tokens in thought_tokens]
        x = torch.tensor(x, dtype=torch.int64, device=device)
        x = torch.unsqueeze(x, dim=-1)
        thoughts = torch.tensor(thoughts, dtype=torch.int64, device=device)
        inputs = torch.concatenate([x, thoughts], dim=-1).tolist()
    else:
        thoughts = [[tokens[thought_length - 1]] for tokens in thought_tokens]
        inputs = thoughts

    return torch.tensor(
        [inputs for _ in range(batch_size)], dtype=torch.int64, device=device
    )  # add batch dimension

In [13]:
device = "cuda" if torch.cuda.is_available() else "cpu"
config = Config(
    batch_size=2,
    thought_length=3,
    model=ModelConfig(
        attn_type="torch",
        device=device,
        dropout_attn=0.0,
        dropout_embed=0.0,
        dtype="bfloat16",
        model_name="Qwen/Qwen1.5-0.5B-Chat",
        max_length=32,
    ),
)
model = PretrainedThoughtModel(config).to(config.model.device)
activation_cache = None
x, thought_tokens = prepare_test_inputs(
    model, config, "This is a test.", config.thought_length
)
t = 1
i = 0
# correct
xi = torch.tensor(
    [x[: i + 1] + thought_tokens[i][:t]],
    dtype=torch.int64,
    device=model.device,
)
# testing
ax = prepare_next_thought_token_input(
    x,
    thought_tokens,
    t,
    config.batch_size,
    model.device,
    last_thought_token_only=(t > 1),
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


number of parameters: 468.19M


In [14]:
# correct
b, l = xi.shape
# testing
ab, al, ad = ax.shape
if activation_cache is None:
    activation_cache = [{} for _ in range(len(model.layers))]
print(ab, al, ad)

2 5 2


In [15]:
# correct
causal_mask1 = torch.triu(
    torch.full((l, l), float("-inf"), dtype=model._dtype, device=model.device),
    diagonal=1,
)
causal_mask1 = causal_mask1.unsqueeze(0)
# testing
acausal_mask1 = torch.triu(
    torch.full((al, al), float("-inf"), dtype=model._dtype, device=model.device),
    diagonal=1,
)
acausal_mask1 = expand_dims(acausal_mask1, (0, 1, 3))
acausal_mask2 = torch.triu(
    torch.full(
        (t + 1, t + 1), float("-inf"), dtype=model._dtype, device=model.device
    ),
    diagonal=1,
)
acausal_mask2 = expand_dims(acausal_mask2[t - ad + 1 :, 1:], (0, 1, 2))

In [16]:
# correct
row = torch.arange(0, l, dtype=torch.int64, device=model.device)
position_ids = row.reshape(1, l).tile((b, 1))
xi2 = model.tok_emb(xi)
# testing
arows = torch.arange(0, ad, dtype=torch.int64, device=model.device).reshape(1, ad)
arow_offsets = torch.arange(
    t - ad + 1, al + t - ad + 1, dtype=torch.int64, device=model.device
).reshape(al, 1)
aposition_ids = (arows + arow_offsets).reshape((1, al, ad)).tile((ab, 1, 1))
ax2 = model.tok_emb(ax)

print("position_ids:", position_ids.shape, position_ids)
print("aposition_ids:", aposition_ids.shape, aposition_ids)
assert torch.allclose(xi2, ax2[:1, 0, :2, :])

position_ids: torch.Size([1, 2]) tensor([[0, 1]], device='cuda:0')
aposition_ids: torch.Size([2, 5, 2]) tensor([[[0, 1],
         [1, 2],
         [2, 3],
         [3, 4],
         [4, 5]],

        [[0, 1],
         [1, 2],
         [2, 3],
         [3, 4],
         [4, 5]]], device='cuda:0')


In [17]:
i = 0
layer = model.layers[0]

In [18]:
# correct
residual = xi2
xi3 = layer.input_layernorm(xi2)
# testing
aresidual = ax2
ax3 = model.bfloat_safe_apply(layer.input_layernorm, ax2)
ax3p = layer.input_layernorm(ax2)
assert torch.allclose(xi3, ax3[:1, 0, :2, :])
assert torch.allclose(xi3, ax3p[:1, 0, :2, :])

In [19]:
# correct
q = (
    layer.self_attn.q_proj(xi3)
    .reshape(b, l, model.num_heads, -1)
    .permute([0, 2, 1, 3])
)
k = (
    layer.self_attn.k_proj(xi3)
    .reshape(b, l, model.num_heads, -1)
    .permute([0, 2, 1, 3])
)
v = (
    layer.self_attn.v_proj(xi3)
    .reshape(b, l, model.num_heads, -1)
    .permute([0, 2, 1, 3])
)
# testing
aq = (
    model.bfloat_safe_apply(layer.self_attn.q_proj, ax3)
    .reshape(ab, al, ad, model.num_heads, -1)
    .permute([0, 3, 1, 2, 4])
)
ak1 = (
    model.bfloat_safe_apply(layer.self_attn.k_proj, ax3[:, :, :1, :])
    .reshape(ab, al, 1, model.num_heads, -1)
    .permute([0, 3, 1, 2, 4])
)
ak2 = (
    model.bfloat_safe_apply(layer.self_attn.k_proj, ax3[:, :, 1:, :])
    .reshape(ab, al, ad - 1, model.num_heads, -1)
    .permute([0, 3, 1, 2, 4])
)
av1 = (
    model.bfloat_safe_apply(layer.self_attn.v_proj, ax3[:, :, :1, :])
    .reshape(ab, al, 1, model.num_heads, -1)
    .permute([0, 3, 2, 1, 4])
)
av2 = (
    model.bfloat_safe_apply(layer.self_attn.v_proj, ax3[:, :, 1:, :])
    .reshape(ab, al, ad - 1, model.num_heads, -1)
    .permute([0, 3, 1, 2, 4])
)
activation_cache[i]["k1"] = ak1
activation_cache[i]["v1"] = av1
assert torch.allclose(q, aq[:1, :, 0, :2, :])
assert torch.allclose(k[:, :, :1, :], ak1[:1, :, 0, :1, :]), f"\nk: {k[:, :, :1, :]}\nak1:{ak1[:1, :, 0, :1, :]}"
assert torch.allclose(k[:, :, 1:, :], ak2[:1, :, 0, :1, :]), f"\nk: {k[:, :, 1:, :]}\nak2:{ak2[:1, :, 0, :1, :]}"
assert torch.allclose(v[:, :, :1, :], av1[:1, :, 0, :1, :])
assert torch.allclose(v[:, :, 1:, :], av2[:1, :, 0, :1, :])

In [21]:
# correct
cos, sin = layer.self_attn.rotary_emb(v, seq_len=l)
qp = model.apply_rotary_pos_emb(q, cos, sin, position_ids)
kp = model.apply_rotary_pos_emb(k, cos, sin, position_ids)
# testing
acos, asin = layer.self_attn.rotary_emb(av1, seq_len=al + t + 1)
aqp = model.apply_rotary_pos_emb(aq, acos, asin, aposition_ids)
ak1p = model.apply_rotary_pos_emb(ak1, acos, asin, aposition_ids[:, :, :1])
if ad > 1:  # only apply to k2 if it is nonempty
      ak2p = model.apply_rotary_pos_emb(ak2, acos, asin, aposition_ids[:, :, 1:])
assert torch.allclose(qp, aqp[:1, :, 0, :2, :]), f"\nq: {qp}\naq: {aqp[:1, :, 0, :2, :]}"
assert torch.allclose(kp[:, :, :1, :], ak1p[:1, :, 0, :1, :]), f"\nk: {kp[:, :, :1, :]}\nak1:{ak1p[:1, :, 0, :1, :]}"
assert torch.allclose(kp[:, :, 1:, :], ak2p[:1, :, 0, :1, :]), f"\nk: {kp[:, :, 1:, :]}\nak2:{ak2p[:1, :, 0, :1, :]}"


In [28]:
# correct
a = torch.nn.functional.softmax(
    (torch.matmul(qp, kp.permute([0, 1, 3, 2])) + causal_mask1)
    / math.sqrt(model.embed_dim / model.num_heads),
    dim=-1,
)
# testing
aa = torch.nn.functional.softmax(
    torch.concatenate(
        [
            # attend to tokens in original string
            # (B, H, L, D, E) @ (B, H, 1, E, L) => (B, H, L, D, L)
            torch.matmul(aqp, ak1p.permute([0, 1, 3, 4, 2])) + acausal_mask1,
            # attend to thought tokens generated so far
            # (B, H, L, D, E) @ (B, H, L, E, T) => (B, H, L, D, T)
            torch.matmul(aqp, ak2p.permute([0, 1, 2, 4, 3])) + acausal_mask2,
        ],
        dim=-1,
    )
    / math.sqrt(model.embed_dim / model.num_heads),
    dim=-1,
)
aa1 = aa[:, :, :, :, :al]
aa2 = aa[:, :, :, :, al:]
assert torch.allclose(a[:, :, :, :1], aa1[:1, :, 0, :, :1])
assert torch.allclose(a[:, :, :, 1:], aa2[:1, :, 0, :, :]), f"\na:{a[:, :, :, 1:]}\naa2:{aa2[:1, :, 0, :, :]}"

In [33]:
# correct
attn_out = torch.matmul(a, v)
# testing
def safe_matmul(a, b):
    ashape = list(a.shape)
    ap = a.reshape(math.prod(ashape[:-2]), ashape[-2], ashape[-1])
    bshape = list(b.shape)
    bp = b.reshape(math.prod(bshape[:-2]), bshape[-2], bshape[-1])
    result = torch.matmul(ap, bp)
    out_shape = ashape[:-2] + [ashape[-2]] + [bshape[-1]]
    return result.reshape(*out_shape)

aattn_out = (
    # contributions of tokens in original string
    # (B, H, L, D, L) @ (B, H, 1, L, E) => (B, H, L, D, E)
    safe_matmul(aa1, torch.tile(av1, (1, 1, aa1.shape[2], 1, 1)))
    # contributions of thought tokens generated so far
    # (B, H, L, D, T) @ (B, H, L, T, E) => (B, H, L, D, E)
    + safe_matmul(aa2, av2)
)
assert torch.allclose(attn_out, aattn_out[:1, :, 0, :, :]), f"\nattn_out: {attn_out.shape},{attn_out}\naattn_out: {aattn_out[:1, :, 0, :, :].shape},{aattn_out[:1, :, 0, :, :]}"

AssertionError: 
attn_out: torch.Size([1, 16, 2, 64]),tensor([[[[ 0.0223,  0.0140,  0.0221,  ...,  0.0361, -0.0466,  0.0210],
          [ 0.0123,  0.0040,  0.0118,  ...,  0.0154, -0.0209,  0.0109]],

         [[-0.0042,  0.0049,  0.0140,  ..., -0.0007,  0.0082, -0.0055],
          [-0.0033,  0.0045,  0.0096,  ..., -0.0007,  0.0047, -0.0022]],

         [[-0.0107, -0.0145, -0.0123,  ..., -0.0063, -0.0124, -0.0059],
          [-0.0064, -0.0114, -0.0081,  ..., -0.0054, -0.0089, -0.0035]],

         ...,

         [[-0.0317,  0.0034, -0.0103,  ...,  0.0058,  0.0129, -0.0070],
          [-0.0280,  0.0036, -0.0079,  ...,  0.0047,  0.0114, -0.0060]],

         [[ 0.0179,  0.0233,  0.0038,  ..., -0.0051, -0.0023,  0.0128],
          [ 0.0164,  0.0205,  0.0033,  ..., -0.0039, -0.0020,  0.0116]],

         [[-0.0223,  0.0110, -0.0292,  ..., -0.0175,  0.1069, -0.0026],
          [-0.0146,  0.0084, -0.0177,  ..., -0.0134,  0.0718, -0.0036]]]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>)
aattn_out: torch.Size([1, 16, 2, 64]),tensor([[[[ 0.0223,  0.0140,  0.0221,  ...,  0.0361, -0.0466,  0.0210],
          [ 0.0123,  0.0040,  0.0118,  ...,  0.0154, -0.0209,  0.0109]],

         [[-0.0042,  0.0049,  0.0140,  ..., -0.0007,  0.0082, -0.0055],
          [-0.0033,  0.0045,  0.0096,  ..., -0.0007,  0.0047, -0.0022]],

         [[-0.0107, -0.0145, -0.0123,  ..., -0.0063, -0.0124, -0.0059],
          [-0.0064, -0.0114, -0.0081,  ..., -0.0054, -0.0089, -0.0035]],

         ...,

         [[-0.0317,  0.0034, -0.0103,  ...,  0.0058,  0.0129, -0.0070],
          [-0.0281,  0.0036, -0.0080,  ...,  0.0047,  0.0114, -0.0059]],

         [[ 0.0179,  0.0233,  0.0038,  ..., -0.0051, -0.0023,  0.0128],
          [ 0.0164,  0.0205,  0.0033,  ..., -0.0039, -0.0020,  0.0116]],

         [[-0.0223,  0.0110, -0.0292,  ..., -0.0175,  0.1069, -0.0026],
          [-0.0146,  0.0084, -0.0177,  ..., -0.0134,  0.0718, -0.0035]]]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<SliceBackward0>)

In [None]:
# correct
attn_out2 = layer.self_attn.o_proj(
    attn_out.permute([0, 2, 1, 3]).reshape(b, l, model.embed_dim)
)
# testing
aattn_out2 = model.bfloat_safe_apply(
    layer.self_attn.o_proj,
    aattn_out.permute([0, 2, 3, 1, 4]).reshape(ab, al, ad, model.embed_dim),
)

In [None]:
# correct
x4 = residual + attn_out2
# testing
ax4 = aresidual + aattn_out2

In [None]:
# correct
x5 = layer.post_attention_layernorm(x4)
# testing
ax5 = model.bfloat_safe_apply(layer.post_attention_layernorm, ax4)

In [None]:
# correct
x6 = layer.mlp(x5)
# testing
ax6 = model.bfloat_safe_apply(layer.mlp, ax5)

In [None]:
# correct
x7 = x4 + x6
# testing
ax7 = ax4 + ax6