In [1]:
import torch

In [2]:
# n_layers = 8
# batch_size = 4
# n_heads = 8
# sequence_len = 1024
# head_dim = 64
# # KV Cache
# kv_shape = (n_layers, 2, batch_size, n_heads, sequence_len, head_dim)
# kv_cache = torch.empty(kv_shape, dtype=torch.bfloat16)
# pos = 0

# # Insert Key and Value tensors
# key = torch.randn((batch_size, n_heads, sequence_len, head_dim))
# value = torch.randn((batch_size, n_heads, sequence_len, head_dim))
# layer_idx = 0  # Example layer index

# # Insert new keys and values into the cache and return full cache so far
# B, H, T_add, D = key.size()
# t0, t1 = pos, pos + T_add

# # Dynamically grow the cache if needed
# if t1 > sequence_len:
#     t_needed = t1 + 1024  # Grow by 1024 tokens
#     t_needed = t_needed +(t_needed + 1023) & ~1023  # Align to 1024
#     append_shape = (n_layers, 2, batch_size, n_heads, t_needed - sequence_len, head_dim)
#     append_cache = torch.empty(append_shape, dtype=torch.bfloat16)
#     kv_cache = torch.cat([kv_cache, append_cache], dim=4)
#     kv_shape = kv_cache.shape

# kv_cache[layer_idx, 0, :, :, t0:t1, :] = key
# kv_cache[layer_idx, 1, :, :, t0:t1, :] = value

# key_view = kv_cache[layer_idx, 0, :, :, :t1, :]
# value_view = kv_cache[layer_idx, 1, :, :, :t1, :]

# pos = t1
# # Prefill

# # Other kv cache
# other_kv_cache = torch.randn((n_layers, 2, 1, n_heads, sequence_len - 512, head_dim), dtype=torch.bfloat16)
# other_pos = pos-512

# other_n_layers, other_kv, other_batch_size, other_n_heads, other_sequence_len, other_head_dim = other_kv_cache.shape

# assert other_n_layers == n_layers, "Number of layers must match"
# assert other_n_heads == n_heads, "Number of heads must match"
# assert other_head_dim == head_dim, "Head dimension must match"
# # Batch size can be expanded
# assert other_batch_size == 1 or other_batch_size == batch_size, "Other batch size must be 1 or equal to current batch size"
# assert sequence_len >= other_sequence_len, "Other sequence length must be less than or equal to current sequence length"
# # Copy the data over
# kv_cache = torch.empty(kv_shape, dtype=torch.bfloat16)
# kv_cache[:, :, :, :, :other_pos, :] = other_kv_cache

# pos = other_pos

In [None]:
class KVCache:
    def __init__(self, n_layers, n_heads, head_dim, batch_size, initial_sequence_len=1024, growth_size=1024):
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.head_dim = head_dim
        self.batch_size = batch_size
        self.growth_size = growth_size
        self.sequence_len = initial_sequence_len
        self.pos = 0
    
    def get_pos(self):
        return self.pos

    def reset(self):
        self.pos = 0

    def insert_kv(self, layer_idx, key, value):
        B, H, T_add, D = key.size()
        t0, t1 = self.pos, self.pos + T_add

        if t1 > self.sequence_len:
            t_needed = t1 + self.growth_size
            t_needed = t_needed + (t_needed + 1023) & ~1023  # Align to 1024
            append_shape = (self.n_layers, 2, self.batch_size, self.n_heads, t_needed - self.sequence_len, self.head_dim)
            append_cache = torch.empty(append_shape, dtype=torch.bfloat16)
            self.kv_cache = torch.cat([self.kv_cache, append_cache], dim=4)
            self.sequence_len = self.kv_cache.shape[4]

        self.kv_cache[layer_idx, 0, :, :, t0:t1, :] = key
        self.kv_cache[layer_idx, 1, :, :, t0:t1, :] = value

        key_view = self.kv_cache[layer_idx, 0, :, :, :t1, :]
        value_view = self.kv_cache[layer_idx, 1, :, :, :t1, :]

        self.pos = t1

        return key_view, value_view

    def prefill(self, other_kv_cache, other_pos):
        other_n_layers, other_kv, other_batch_size, other_n_heads, other_sequence_len, other_head_dim = other_kv_cache.shape

        assert other_n_layers == self.n_layers, "Number of layers must match"
        assert other_n_heads == self.n_heads, "Number of heads must match"
        assert other_head_dim == self.head_dim, "Head dimension must match"
        assert other_batch_size == 1 or other_batch_size == self.batch_size, "Other batch size must be 1 or equal to current batch size"
        assert self.sequence_len >= other_sequence_len, "Other sequence length must be less than or equal to current sequence length"

        self.kv_cache[:, :, :, :, :other_pos, :] = other_kv_cache
        self.pos = other_pos

In [4]:
import torch.nn.functional as F

def sample_next_token(logits, rng, temperature=1.0, top_k=None):
    assert temperature >= 0.0, "Temperature must be non-negative"
    if temperature == 0.0:
        return torch.argmax(logits, dim=-1, keepdim=True)
    if top_k is not None:
        top_k = min(top_k, logits.size(-1))
        values, ids = torch.topk(logits, top_k)
        values = values / temperature
        probs = F.softmax(values, dim=-1)
        choice = torch.multinomial(probs, num_samples=1, generator=rng)
        return ids.gather(1, choice)
    else:
        logits = logits / temperature
        probs = F.softmax(logits, dim=-1)
        return torch.multinomial(probs, num_samples=1, generator=rng)

In [None]:
# Inference
query = "print('Hello, World!')"


from minichat.gpt import GPTConfig, GPT
from minichat.tokenizer import RustBPETokenizer, get_tokenizer

tokenizer = get_tokenizer()

config = GPTConfig(
    vocab_size=tokenizer.get_vocab_size(),
    sequence_len=1024,
    n_layers=8,
    n_heads=8,
    emb_dim=512,
)
# Recreate the model to apply the fix
model = GPT(config)
model = model.to("cuda")

In [6]:
python_start = tokenizer.encode_special("<|python_start|>")
python_end = tokenizer.encode_special("<|python_end|>")
output_start = tokenizer.encode_special("<|output_start|>")
output_end = tokenizer.encode_special("<|output_end|>")
bos = tokenizer.get_bos_token_id()

tokens = [bos] + [python_start] + tokenizer.encode(query) + [python_end] + [output_start]

In [7]:
m = model.config

In [8]:
kv_model_kwargs = {
    "n_heads": m.n_heads,
    "head_dim": m.emb_dim // m.n_heads,
    "n_layers": m.n_layers,
}

In [9]:
kv_cache_prefill = KVCache(
    n_layers=m.n_layers,
    n_heads=m.n_heads,
    head_dim=m.emb_dim // m.n_heads,
    batch_size=1,
    initial_sequence_len = len(tokens)
)

In [10]:
ids = torch.tensor([tokens], dtype=torch.long, device=model.get_device())

In [11]:
ids.shape

torch.Size([1, 11])

In [13]:
model.forward(idx=ids)

RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::BFloat16 != float