In [None]:
cd ..

In [1]:
import os
import logging

import torch

# from bytelatent.distributed import DistributedArgs, setup_torch_distributed
from bytelatent.generate import load_consolidated_model_and_tokenizer
from bytelatent.generate_blt import generate_nocache
from bytelatent.model.blt import (
    ByteLatentTransformer, 
    patch_ids_from_lengths,
    get_blt_input,
    compute_hash_embeddings,
    decoder_patch_ids_from_lengths,
    cross_attn_mask
)
from bytelatent.model.utils import downsample
from bytelatent.distributed import (
    DistributedArgs,
    dist_max,
    dist_min,
    dist_sum,
    get_device_mesh,
    setup_torch_distributed,
)
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer

logger = logging.getLogger()

W0614 00:20:49.846000 21176 Lib\site-packages\torch\distributed\elastic\multiprocessing\redirects.py:29] NOTE: Redirects are currently not supported in Windows or MacOs.
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# distributed_args = DistributedArgs()
# distributed_args.configure_world()
# if not torch.distributed.is_initialized():
#     setup_torch_distributed(distributed_args)

In [3]:
def get_max_length(input_tokens: list[list[int]] | None) -> int:
    # reduce max length prompt over all processes to have an equal number of call on each process with fsdp
    if input_tokens is None:
        max_length = 0
    else:
        max_length = max([len(t) for t in input_tokens])
    if torch.distributed.is_initialized():
        max_length = int(dist_max(max_length))
    return max_length


def get_min_length(input_tokens: list[list[int]] | None) -> int:
    # reduce min length prompt over all processes to have an equal number of call on each process with fsdp
    if input_tokens is None:
        # TODO: Double check this change from int(1e9) is correct
        min_length = 0
    else:
        min_length = min([len(t) for t in input_tokens])
    if torch.distributed.is_initialized():
        min_length = int(dist_min(min_length))
    return min_length


def get_generation_range(
    prompt_tokens: list[list[int]] | None, max_gen_len: int
) -> tuple[int, int]:
    batch_min_prompt_length = get_min_length(prompt_tokens)
    batch_max_prompt_length = get_max_length(prompt_tokens)
    return batch_min_prompt_length, batch_max_prompt_length + max_gen_len


def sample_top_k(probs, k):
    topk_value, _ = torch.topk(probs, k)  # batch_sz x topk
    min_value_top_k = topk_value[:, [-1]]
    probs[probs < min_value_top_k] = 0.0
    probs.div_(probs.sum(dim=-1, keepdim=True))
    next_token = torch.multinomial(probs, num_samples=1)
    return next_token


def sample_top_p(probs, p):
    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
    probs_sum = torch.cumsum(probs_sort, dim=-1)
    mask = probs_sum - probs_sort > p
    probs_sort[mask] = 0.0
    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
    next_token = torch.multinomial(probs_sort, num_samples=1)
    next_token = torch.gather(probs_idx, -1, next_token)
    return next_token

In [None]:
# python -m bytelatent.train config=bytelatent/configs/debug.yaml
# python -m bytelatent.checkpoint consolidate train_checkpoints/0000018000

In [4]:
model_name = "blt_1b"

print(f"Loading BLT model: {model_name}")
model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(f"train_checkpoints/0000018000/consolidated")
# model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(f"hf-weights/{model_name}")
assert isinstance(model, ByteLatentTransformer)
assert isinstance(tokenizer, BltTokenizer)
patcher_args = train_cfg.data.patcher_args.model_copy(deep=True)
patcher_args.realtime_patching = True

print("Loading entropy model and patcher")
patcher_args.entropy_model_checkpoint_dir = "hf-weights/entropy_model"
patcher = patcher_args.build()

Loading BLT model: blt_1b




Loading entropy model and patcher


In [25]:
prompt = "### Instruction:\nGive thre"
prompt = "### Instruction:\nCreate a sentence using the following words: \"apple, banana, pencil.\""

In [26]:
patcher.patch(torch.tensor([tokenizer.encode("hello world l ksjhlkjshdlkjhsdlkjsdh")]))

(tensor([[ 1,  6,  6,  2, 23]]), None)

In [27]:
# prompts = [prompt]

# max_prompt_len: int = 256
# max_gen_len: int = 5
# use_sampling: bool = False
# temp: float = 1.0
# top_k: int = 0
# top_p: float = 0.0
# remove_prompts: bool = True

# model.eval()

# prompt_tokens = [tokenizer.encode(t, add_eos=False) for t in prompts]
# n_truncated_prompts = sum([max_prompt_len < len(t) for t in prompt_tokens])
# total_truncated_prompts = dist_sum(n_truncated_prompts)

# # Truncation
# prompt_tokens = [
#     t if len(t) < max_prompt_len else t[len(t) - max_prompt_len :]
#     for t in prompt_tokens
# ]

# if total_truncated_prompts > 0:
#     logger.info(
#         f"There are {total_truncated_prompts} prompts that are truncated on the left, "
#         f"length greater than max_prompt_len = {max_prompt_len}, "
#         f"maximum prompt length = {get_max_length(prompt_tokens)} across all gpus."
#     )

# start_pos, end_pos = get_generation_range(prompt_tokens, max_gen_len)
# batch_size = len(prompt_tokens)
# tokens = torch.full((batch_size, end_pos), tokenizer.pad_id).cuda().long()

# # Copy inputs to tensor for generated tokens
# for i, row_tokens in enumerate(prompt_tokens):
#     tokens[i, : len(row_tokens)] = torch.tensor(row_tokens).long()
# input_text_mask = tokens != tokenizer.pad_id

# for i, curr_pos in enumerate(range(start_pos, end_pos)):
#     current_tokens = tokens[:, :curr_pos]
#     patch_lengths, _ = patcher.patch(current_tokens, include_next_token=True)
#     logits = model(current_tokens, patch_lengths=patch_lengths)[:, -1]

#     if use_sampling:
#         probs = torch.softmax(logits / temp, dim=-1)
#         if top_p > 0.0:
#             next_token = sample_top_p(probs, top_p)
#         elif top_k > 0:
#             next_token = sample_top_k(probs, top_k)
#         else:
#             next_token = torch.multinomial(probs, num_samples=1)
#     else:
#         next_token = torch.argmax(logits, dim=-1)

#     next_token = torch.where(
#         input_text_mask[:, curr_pos], tokens[:, curr_pos], next_token
#     )
#     tokens[:, curr_pos] = next_token

# if remove_prompts:
#     generated_tokens = [
#         t[len(prompt_tokens[i]) : len(prompt_tokens[i]) + max_gen_len].tolist()
#         for i, t in enumerate(tokens)
#     ]
# else:
#     generated_tokens = [
#         t[: len(prompt_tokens[i]) + max_gen_len].tolist()
#         for i, t in enumerate(tokens)
#     ]

In [None]:
torch.cuda.empty_cache()
prompts = [prompt]

max_prompt_len: int = 256
max_gen_len: int = 10
use_sampling: bool = False
temp: float = 1.0
top_k: int = 0
top_p: float = 0.0
remove_prompts: bool = True

model.eval()

prompt_tokens = [tokenizer.encode(t, add_eos=False) for t in prompts]
n_truncated_prompts = sum([max_prompt_len < len(t) for t in prompt_tokens])
total_truncated_prompts = dist_sum(n_truncated_prompts)

# Truncation
prompt_tokens = [
    t if len(t) < max_prompt_len else t[len(t) - max_prompt_len :]
    for t in prompt_tokens
]

if total_truncated_prompts > 0:
    logger.info(
        f"There are {total_truncated_prompts} prompts that are truncated on the left, "
        f"length greater than max_prompt_len = {max_prompt_len}, "
        f"maximum prompt length = {get_max_length(prompt_tokens)} across all gpus."
    )

start_pos, end_pos = get_generation_range(prompt_tokens, max_gen_len)
batch_size = len(prompt_tokens)
tokens = torch.full((batch_size, end_pos), tokenizer.pad_id).cuda().long()

# Copy inputs to tensor for generated tokens
for i, row_tokens in enumerate(prompt_tokens):
    tokens[i, : len(row_tokens)] = torch.tensor(row_tokens).long()
input_text_mask = tokens != tokenizer.pad_id

for i, curr_pos in enumerate(range(start_pos, end_pos)):
    current_tokens = tokens[:, :curr_pos]
    patch_lengths, _ = patcher.patch(current_tokens, include_next_token=False)
    # logits = model(current_tokens, patch_lengths=patch_lengths)[:, -1]
    
    ################################
    #### START MODEL PROCESSING ####
    ################################
    ngram_ids = None
    bs, N = current_tokens.shape  # Batch size and sequence length

    print(f"batch size: {bs}, sequence length: {N}")

    # Get megabyte inputs
    nb_boe = int(0 if model.patching_mode != "" else model.patch_size - 1)
    print(f"nb_boe={nb_boe}")
    local_encoder_tokens, _, local_decoder_tokens = get_blt_input(
        tokens=current_tokens,
        enforce_patch_size_multiple=False,
        nb_boe=nb_boe,
        patch_size=model.patch_size,
        boe_id=model.boe_id,
    )
    print(f"tokens: {current_tokens}")
    print(f"local_encoder_tokens: {local_encoder_tokens}")
    print(f"local_decoder_tokens: {local_decoder_tokens}")

    # Patching
    if nb_boe > 0:
        patch_lengths[:, 0] += nb_boe

    assert torch.min(patch_lengths) >= 0

    # Generate patch IDs from patch_lengths
    patch_ids = patch_ids_from_lengths(
        patch_lengths, local_encoder_tokens.shape[-1]
    )
    print(f"Patch IDs: {patch_ids}, Patch lengths: {patch_lengths}")
    assert torch.max(patch_ids) + 1 <= torch.max(
        (patch_lengths != 0).sum(dim=-1)
    ), f"{torch.max(patch_ids) + 1} > {torch.max((patch_lengths != 0).sum(dim=-1))}"

    cross_attn_mask_enc = None
    # Cross-attention encoder
    if model.cross_attn_encoder:
        cross_attn_mask_enc = cross_attn_mask(
            patch_ids,
            patch_lengths,
            N,
            patches_as_queries=True,
            cross_attn_k=model.cross_attn_k,
            window=model.cross_attn_window_encoder,
            block_mask=model.cross_attn_use_flex_attention,
        )
        print(f"Cross attention mask encoder shape: {cross_attn_mask_enc.shape}")
        print(f"Cross attention mask encoder: {cross_attn_mask_enc}")
        # print(f"Cross attention mask encoder: {cross_attn_mask_enc.to_dense()}")

    # Hashing and embedding
    print(
        f"encoder_hash_tok_embedding={model.encoder_hash_tok_embedding}",
        f"encoder_hash_byte_group_nb_functions={model.encoder_hash_byte_group_nb_functions}",
        f"encoder_hash_byte_group_size={model.encoder_hash_byte_group_size}",
        f"encoder_hash_byte_group_vocab={model.encoder_hash_byte_group_vocab}",
    )
    local_encoder_embeds = compute_hash_embeddings(
        local_encoder_tokens=local_encoder_tokens,
        local_encoder=model.local_encoder,
        encoder_hash_tok_embedding=model.encoder_hash_tok_embedding,
        encoder_hash_byte_group_nb_functions=model.encoder_hash_byte_group_nb_functions,
        encoder_hash_byte_group_size=model.encoder_hash_byte_group_size,
        encoder_hash_byte_group_vocab=model.encoder_hash_byte_group_vocab,
    )
    if local_encoder_embeds:
        print(f"local_encoder_embeds.shape={local_encoder_embeds.shape}")
        print(f"local_encoder_embeds={local_encoder_embeds}")

    # N-gram table embeddings
    if model.encoder_ngram_embedding is not None:
        assert ngram_ids is not None, "ngram_ids must be provided"
        if local_encoder_embeds is None:
            local_encoder_embeds = model.local_encoder.tok_embeddings(
                local_encoder_tokens
            )
        assert len(ngram_ids) == len(
            model.encoder_ngram_embedding
        ), f"ngram_ids.shape[0]={ngram_ids.shape[0]} versus len(encoder_ngram_embedding)={len(model.encoder_ngram_embedding)}, ngram_ids.shape={ngram_ids.shape}"
        for i in range(ngram_ids.shape[0]):
            ngram_embedding = model.encoder_ngram_embedding[i]
            ngram_embeds = ngram_embedding(ngram_ids[i])
            assert (
                local_encoder_embeds.shape == ngram_embeds.shape
            ), f"Shape mismatch: {local_encoder_embeds.shape} vs {ngram_embeds.shape}, ngram_ids.shape={ngram_ids.shape}"
            local_encoder_embeds = local_encoder_embeds + ngram_embeds

    # Local encoder
    (h_encoder, h_cross), cache_encoder = model.local_encoder(
        tokens=local_encoder_tokens,
        embeds=local_encoder_embeds,
        patch_embeds=None,
        cross_mask=cross_attn_mask_enc,
        num_patches=patch_lengths.shape[1],
        patch_ids=patch_ids,
    )
    print(f"Encoder output shape: {h_encoder.shape}, Cross output shape: {h_cross.shape}")
    print(f"Encoder `h_encoder` output: {h_encoder[0]}")

    # Downsampling
    if not model.cross_attn_encoder:
        assert (
            patch_ids.shape[1] == h_encoder.shape[1]
        ), f"{patch_ids.shape[1]} != {h_encoder.shape[1]}"
        h = downsample(
            h_encoder,
            patch_lengths.shape[1],
            patch_lengths,
            patch_ids,
            downsampling_by_pooling=model.downsampling_by_pooling,
            patch_size=model.patch_size,
        )
    else:
        # Reshape h_cross
        h = h_cross.view(bs, patch_lengths.shape[1], -1)
    print(f"Global transformer input shape: {h.shape}, Global transformer input: {h}")

    # Global transformer
    global_tokens = current_tokens.new(h.shape[0], h.shape[1]).fill_(model.boe_id)
    rows, cols = torch.where(local_encoder_tokens == model.eos_id)
    eos_patch_ids = patch_ids[rows, cols]
    global_tokens[rows, eos_patch_ids] = model.eos_id

    h, _ = model.global_transformer(
        embeds=h,
        tokens=global_tokens,
    )
    print(f"Global transformer output shape: {h.shape}, Global transformer output: {h}")

    # Unpatching
    dec_embeds = h_encoder[:, nb_boe : nb_boe + N, :]
    print(f"Decoder embeddings `dec_embeds` shape: {dec_embeds.shape}, Decoder embeddings: {dec_embeds[0]}")

    # Generate decoder patch IDs
    decoder_patch_ids = decoder_patch_ids_from_lengths(
        patch_lengths, nb_boe, local_decoder_tokens.shape[-1]
    )
    print(f"Decoder patch IDs shape: {decoder_patch_ids.shape}, Decoder patch IDs: {decoder_patch_ids}")
    assert (
        torch.max(decoder_patch_ids) + 1 <= h.shape[1]
    ), f"{torch.max(decoder_patch_ids) + 1} > {h.shape[1]}"
    assert (
        decoder_patch_ids.shape[1] == dec_embeds.shape[1]
    ), f"{decoder_patch_ids.shape[1]} != {dec_embeds.shape[1]}"

    # Cross-attention decoder
    if not model.cross_attn_decoder:
        h = torch.gather(
            h, 1, decoder_patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1])
        )
        cross_attn_mask_dec = None
        assert local_decoder_tokens.shape == h.shape[:-1]
    else:
        cross_attn_mask_dec = cross_attn_mask(
            decoder_patch_ids,
            patch_lengths,
            N,
            patches_as_queries=False,
            cross_attn_k=model.cross_attn_k,
            window=model.cross_attn_window_decoder,
            block_mask=model.cross_attn_use_flex_attention,
        )
        print(f"Cross attention mask decoder shape: {cross_attn_mask_dec.shape}")
        print(f"Cross attention mask decoder: {cross_attn_mask_dec}")
        # print(f"Cross attention mask decoder: {cross_attn_mask_dec.to_dense()}")

    # Local decoder
    logits, _ = model.local_decoder(
        embeds=dec_embeds,
        patch_embeds=h,
        tokens=local_decoder_tokens,
        cross_mask=cross_attn_mask_dec,
    )
    logits = logits[:, -1]
    print(f"Decoder logits shape: {logits.shape}, Decoder logits: {logits}")

    ##############################
    #### END MODEL PROCESSING ####
    ##############################

    if use_sampling:
        probs = torch.softmax(logits / temp, dim=-1)
        if top_p > 0.0:
            next_token = sample_top_p(probs, top_p)
        elif top_k > 0:
            next_token = sample_top_k(probs, top_k)
        else:
            next_token = torch.multinomial(probs, num_samples=1)
    else:
        next_token = torch.argmax(logits, dim=-1)

    next_token = torch.where(
        input_text_mask[:, curr_pos], tokens[:, curr_pos], next_token
    )
    tokens[:, curr_pos] = next_token

if remove_prompts:
    generated_tokens = [
        t[len(prompt_tokens[i]) : len(prompt_tokens[i]) + max_gen_len].tolist()
        for i, t in enumerate(tokens)
    ]
else:
    generated_tokens = [
        t[: len(prompt_tokens[i]) + max_gen_len].tolist()
        for i, t in enumerate(tokens)
    ]

batch size: 1, sequence length: 87
nb_boe=0
tokens: tensor([[  1,  39,  39,  39,  36,  77, 114, 119, 120, 118, 121, 103, 120, 109,
         115, 114,  62,  14,  71, 118, 105, 101, 120, 105,  36, 101,  36, 119,
         105, 114, 120, 105, 114, 103, 105,  36, 121, 119, 109, 114, 107,  36,
         120, 108, 105,  36, 106, 115, 112, 112, 115, 123, 109, 114, 107,  36,
         123, 115, 118, 104, 119,  62,  36,  38, 101, 116, 116, 112, 105,  48,
          36, 102, 101, 114, 101, 114, 101,  48,  36, 116, 105, 114, 103, 109,
         112,  50,  38]], device='cuda:0')
local_encoder_tokens: tensor([[  1,  39,  39,  39,  36,  77, 114, 119, 120, 118, 121, 103, 120, 109,
         115, 114,  62,  14,  71, 118, 105, 101, 120, 105,  36, 101,  36, 119,
         105, 114, 120, 105, 114, 103, 105,  36, 121, 119, 109, 114, 107,  36,
         120, 108, 105,  36, 106, 115, 112, 112, 115, 123, 109, 114, 107,  36,
         123, 115, 118, 104, 119,  62,  36,  38, 101, 116, 116, 112, 105,  48,
          36, 

In [31]:
text_outputs = [tokenizer.decode(t) for t in generated_tokens]
for p, t in zip(prompts, text_outputs):
    print(f'Prompt: "{p}" Completion: "{t}"')

    print()

Prompt: "### Instruction:
Create a sentence using the following words: "apple, banana, pencil."" Completion: "

### Response:
Apple, banana, pencil, pencil, pencil, pencil, pencil, pencil, and pencil.### Inst"



In [47]:
for i, line in enumerate(cross_attn_mask_enc[0, 0]):
    print(i, [float(x) for x in line])

0 [0.0, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf]
1 [0.0, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf]
2 [-inf, 0.0, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf]
3 [-inf, 0.0, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -i

In [48]:
for i, line in enumerate(cross_attn_mask_dec[0, 0]):
    print(i, [float(x) for x in line])

0 [0.0, 0.0, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf]
1 [-inf, -inf, 0.0, 0.0, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf]
2 [-inf, -inf, 0.0, 0.0, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf]
3 [-inf, -inf, 0.0, 0.0, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf]
4 [-inf, -inf, -inf, -inf, 0.0, 0.0,

In [None]:
# prompts = [prompt]
# outputs = generate_nocache(
#     prompts, model=model, tokenizer=tokenizer, patcher=patcher, max_gen_len=16
# )
# text_outputs = [tokenizer.decode(t) for t in outputs]
# for p, t in zip(prompts, text_outputs):
#     print(f'Prompt: "{p}" Completion: "{t}"')
#     print()

batch size: 1, sequence length: 45
Patch IDs: tensor([[ 0,  1,  2,  2,  2,  3,  3,  3,  4,  4,  4,  5,  6,  6,  6,  6,  6,  7,
          8,  9,  9, 10, 10, 10, 10, 10, 10, 11, 12, 12, 12, 12, 13, 13, 13, 13,
         14, 15, 16, 16, 16, 17, 18, 18, 19]], device='cuda:0'), Patch lengths: tensor([[1, 1, 3, 3, 3, 1, 5, 1, 1, 2, 6, 1, 4, 4, 1, 1, 3, 1, 2, 1, 1]],
       device='cuda:0')
Cross attention mask encoder shape: (1, 1, 42, 45)
Cross attention mask encoder: BlockMask(shape=(1, 1, 42, 45), sparsity=-766.88%, 
(0, 0)
██
)
local_encoder_embeds.shape=torch.Size([1, 45, 1024])
local_encoder_embeds=tensor([[[ 0.1084,  0.0052, -0.0938,  ..., -0.0598, -0.1279,  0.0181],
         [-0.0135,  0.0476,  0.0266,  ..., -0.0229, -0.0381, -0.0337],
         [ 0.0349,  0.0091, -0.0310,  ...,  0.0245, -0.0337,  0.0134],
         ...,
         [-0.0157,  0.0247, -0.0125,  ..., -0.0310,  0.0352, -0.0135],
         [-0.0713,  0.0004,  0.0305,  ...,  0.0398,  0.0195,  0.0193],
         [ 0.0371, -0.0303

In [None]:
# ---------------------------------------------------------------------------
# NotImplementedError                       Traceback (most recent call last)
# Cell In[5], line 2
#       1 prompts = [prompt]
# ----> 2 outputs = generate_nocache(
#       3     prompts, model=model, tokenizer=tokenizer, patcher=patcher
#       4 )
#       5 text_outputs = [tokenizer.decode(t) for t in outputs]
#       6 for p, t in zip(prompts, text_outputs):

# File c:\Users\leoni\Documents\projects\blt\bytelatent\generate_blt.py:131, in generate_nocache(prompts, model, tokenizer, patcher, max_prompt_len, max_gen_len, use_sampling, temp, top_k, top_p, remove_prompts)
#     129 for i, curr_pos in enumerate(range(start_pos, end_pos)):
#     130     current_tokens = tokens[:, :curr_pos]
# --> 131     patch_lengths, _ = patcher.patch(current_tokens, include_next_token=True)
#     132     logits = model(current_tokens, patch_lengths=patch_lengths)[:, -1]
#     134     if use_sampling:

# File c:\Users\leoni\Documents\projects\blt\bytelatent\data\patcher.py:565, in Patcher.patch(self, tokens, include_next_token, preds, entropies, threshold)
#     563 else:
#     564     start_entropies = time.time()
# --> 565     scores, _ = calculate_entropies(
#     566         tokens,
#     567         self.entropy_model,
#     568         self.patching_batch_size,
#     569         self.device,
#     570     )
#     571 if self.log_time:
#     572     self.log["calculate_entropies"] += time.time() - s

# File c:\Users\leoni\Documents\projects\blt\bytelatent\data\patcher.py:96, in calculate_entropies(tokens, entropy_model, patching_batch_size, device, enable_grad)
#      94     split = split.to(device)
#      95 # assert torch.all(split >= 0) and torch.all(split < 260)
# ---> 96 pred = entropy_model(split)
#      97 pred = pred.reshape(-1, pred.shape[-1])[
#      98     : split.numel() - pad_size, :
#      99 ]  # [batch_size * seq_len, vocab]
#     100 preds.append(pred)

# File c:\Users\leoni\Documents\projects\blt\venv\Lib\site-packages\torch\nn\modules\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
#    1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
#    1735 else:
# -> 1736     return self._call_impl(*args, **kwargs)

# File c:\Users\leoni\Documents\projects\blt\venv\Lib\site-packages\torch\nn\modules\module.py:1747, in Module._call_impl(self, *args, **kwargs)
#    1742 # If we don't have any hooks, we want to skip the rest of the logic in
#    1743 # this function, and just call forward.
#    1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
#    1745         or _global_backward_pre_hooks or _global_backward_hooks
#    1746         or _global_forward_hooks or _global_forward_pre_hooks):
# -> 1747     return forward_call(*args, **kwargs)
#    1749 result = None
#    1750 called_always_called_hooks = set()

# File c:\Users\leoni\Documents\projects\blt\bytelatent\transformer.py:131, in LMTransformer.forward(self, token_values, target, tok_idx, mask, attn_impl)
#     117 h = self.tok_embeddings(token_values)
#     119 mask = (
#     120     mask
#     121     if mask is not None
#    (...)    129     )
#     130 )
# --> 131 h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
#     133 logits = self.output(self.norm(h))
#     134 if target is not None:

# File c:\Users\leoni\Documents\projects\blt\bytelatent\base_transformer.py:617, in BaseTransformer.forward(self, h, tok_idx, mask, attn_impl)
#     614 freq_cis = self.rope_embeddings(seqlen=self.max_seqlen, tok_idx=tok_idx)
#     616 for i, layer in enumerate(self.layers):
# --> 617     h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
#     618 return h

# File c:\Users\leoni\Documents\projects\blt\venv\Lib\site-packages\torch\nn\modules\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
#    1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
#    1735 else:
# -> 1736     return self._call_impl(*args, **kwargs)

# File c:\Users\leoni\Documents\projects\blt\venv\Lib\site-packages\torch\nn\modules\module.py:1747, in Module._call_impl(self, *args, **kwargs)
#    1742 # If we don't have any hooks, we want to skip the rest of the logic in
#    1743 # this function, and just call forward.
#    1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
#    1745         or _global_backward_pre_hooks or _global_backward_hooks
#    1746         or _global_forward_hooks or _global_forward_pre_hooks):
# -> 1747     return forward_call(*args, **kwargs)
#    1749 result = None
#    1750 called_always_called_hooks = set()

# File c:\Users\leoni\Documents\projects\blt\bytelatent\base_transformer.py:556, in TransformerBlock.forward(self, x, freq_cis, tok_idx, mask, attn_impl)
#     548 def forward(
#     549     self,
#     550     x: torch.Tensor,
#    (...)    554     attn_impl: str = "sdpa",
#     555 ) -> torch.Tensor:
# --> 556     attn_out = self.attention(
#     557         self.attention_norm(x),
#     558         freq_cis,
#     559         tok_idx=tok_idx,
#     560         mask=mask,
#     561         attn_impl=attn_impl,
#     562     )
#     563     h = x + attn_out
#     564     h_norm = self.ffn_norm(h)

# File c:\Users\leoni\Documents\projects\blt\venv\Lib\site-packages\torch\nn\modules\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
#    1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
#    1735 else:
# -> 1736     return self._call_impl(*args, **kwargs)

# File c:\Users\leoni\Documents\projects\blt\venv\Lib\site-packages\torch\nn\modules\module.py:1747, in Module._call_impl(self, *args, **kwargs)
#    1742 # If we don't have any hooks, we want to skip the rest of the logic in
#    1743 # this function, and just call forward.
#    1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
#    1745         or _global_backward_pre_hooks or _global_backward_hooks
#    1746         or _global_forward_hooks or _global_forward_pre_hooks):
# -> 1747     return forward_call(*args, **kwargs)
#    1749 result = None
#    1750 called_always_called_hooks = set()

# File c:\Users\leoni\Documents\projects\blt\bytelatent\base_transformer.py:401, in Attention.forward(self, x, freq_cis, tok_idx, mask, attn_impl)
#     399 query_shape = xq.shape
#     400 xq, xk, xv = _reshape_for_attn_bias(mask, xq, xk, xv)
# --> 401 output = fmha.memory_efficient_attention(xq, xk, xv, attn_bias=mask)
#     402 output = output.view(query_shape)
#     403 # This uses B S H D instead of B H S D of pytorch

# File c:\Users\leoni\Documents\projects\blt\venv\Lib\site-packages\xformers\ops\fmha\__init__.py:306, in memory_efficient_attention(query, key, value, attn_bias, p, scale, op, output_dtype)
#     194 def memory_efficient_attention(
#     195     query: torch.Tensor,
#     196     key: torch.Tensor,
#    (...)    203     output_dtype: Optional[torch.dtype] = None,
#     204 ) -> torch.Tensor:
#     205     """Implements the memory-efficient attention mechanism following
#     206     `"Self-Attention Does Not Need O(n^2) Memory" <http://arxiv.org/abs/2112.05682>`_.
#     207 
#    (...)    304     :return: multi-head attention Tensor with shape ``[B, Mq, H, Kv]``
#     305     """
# --> 306     return _memory_efficient_attention(
#     307         Inputs(
#     308             query=query,
#     309             key=key,
#     310             value=value,
#     311             p=p,
#     312             attn_bias=attn_bias,
#     313             scale=scale,
#     314             output_dtype=output_dtype,
#     315         ),
#     316         op=op,
#     317     )

# File c:\Users\leoni\Documents\projects\blt\venv\Lib\site-packages\xformers\ops\fmha\__init__.py:467, in _memory_efficient_attention(inp, op)
#     462 def _memory_efficient_attention(
#     463     inp: Inputs, op: Optional[AttentionOp] = None
#     464 ) -> torch.Tensor:
#     465     # fast-path that doesn't require computing the logsumexp for backward computation
#     466     if all(x.requires_grad is False for x in [inp.query, inp.key, inp.value]):
# --> 467         return _memory_efficient_attention_forward(
#     468             inp, op=op[0] if op is not None else None
#     469         )
#     471     output_shape = inp.normalize_bmhk()
#     473     op_fw = _serialize_op(op[0] if op is not None else None)

# File c:\Users\leoni\Documents\projects\blt\venv\Lib\site-packages\xformers\ops\fmha\__init__.py:486, in _memory_efficient_attention_forward(inp, op)
#     484 output_shape = inp.normalize_bmhk()
#     485 if op is None:
# --> 486     op = _dispatch_fw(inp, False)
#     487 else:
#     488     _ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp)

# File c:\Users\leoni\Documents\projects\blt\venv\Lib\site-packages\xformers\ops\fmha\dispatch.py:135, in _dispatch_fw(inp, needs_gradient)
#     126 def _dispatch_fw(inp: Inputs, needs_gradient: bool) -> Type[AttentionFwOpBase]:
#     127     """Computes the best operator for forward
#     128 
#     129     Raises:
#    (...)    133         AttentionOp: The best operator for the configuration
#     134     """
# --> 135     return _run_priority_list(
#     136         "memory_efficient_attention_forward",
#     137         _dispatch_fw_priority_list(inp, needs_gradient),
#     138         inp,
#     139     )

# File c:\Users\leoni\Documents\projects\blt\venv\Lib\site-packages\xformers\ops\fmha\dispatch.py:76, in _run_priority_list(name, priority_list, inp, extra_op_reasons)
#      74     for op, not_supported in extra_op_reasons:
#      75         msg += "\n" + _format_not_supported_reasons(op, not_supported)
# ---> 76 raise NotImplementedError(msg)

# NotImplementedError: No operator found for `memory_efficient_attention_forward` with inputs:
#      query       : shape=(1, 8192, 12, 64) (torch.bfloat16)
#      key         : shape=(1, 8192, 12, 64) (torch.bfloat16)
#      value       : shape=(1, 8192, 12, 64) (torch.bfloat16)
#      attn_bias   : <class 'xformers.ops.fmha.attn_bias.BlockDiagonalCausalLocalAttentionMask'>
#      p           : 0.0
# `fa2F@v2.5.7-pt` is not supported because:
#     xFormers wasn't build with CUDA support
# `cutlassF-pt` is not supported because:
#     xFormers wasn't build with CUDA support