In [1]:
cd ..

c:\Users\leoni\Documents\projects\visionBLT


In [2]:
import torch

from bytelatent.args import (
    TrainArgs, 
    DataloaderArgs, OptimArgs, ByteLatentTransformerArgs, DistributedArgs, TokenizerArgs, PatcherArgs,
    ProfilerArgs, CheckpointArgs, LoggingArgs
)
from bytelatent.checkpoint import SaveEvery
from bytelatent.vision_model.local_models import LocalModelArgs, LocalEncoder
from bytelatent.model.blt import (
    get_encoder_dim_token_emb, get_encoder_dim_patch_emb, compute_hash_embeddings, cross_attn_mask, create_patch_mask_from_ids
)

W0805 21:38:27.163000 5140 Lib\site-packages\torch\distributed\elastic\multiprocessing\redirects.py:29] NOTE: Redirects are currently not supported in Windows or MacOs.
  from .autonotebook import tqdm as notebook_tqdm


In [3]:
train_args = TrainArgs(
    name="debug",
    dump_dir="/tmp/",
    seed=777,
    steps=100_000,
    optim=OptimArgs(
        lr=4e-04,
        warmup=500,
        lr_min_ratio=0.1,
        clip=10.0
    ),
    distributed=DistributedArgs(
        fsdp_type="full_shard",
        model_dtype="bf16",
        matmul_allow_tf32=False,
        selective_activation_checkpointing=False,
        tp_size=1
    ),
    model=ByteLatentTransformerArgs(
        n_heads=8,
        dim=512,
        vocab_size=260,
        dim_token=256,
        patch_size=6,
        patching_mode="space",
        tie_local_encoder_decoder_logits=False,
        patch_in_forward=False,
        max_encoder_seq_length=3072,
        pad_to_max_length=True,
        patching_threshold=3.1439168453216553,
        encoder_hash_byte_group_size=None,
        encoder_hash_byte_group_vocab=50002,
        encoder_hash_byte_group_nb_functions=3,
        encoder_enable_byte_ngrams=False,
        cross_attn_encoder=True, # assuming cross_attention is true
        cross_attn_decoder=True, # assuming cross_attention is true
        cross_attn_window_encoder=None,
        cross_attn_window_decoder=None,
        dim_local_encoder=256,
        dim_local_decoder=256,
        cross_attn_k=2,
        cross_attn_nheads=4,
        cross_attn_all_layers_decoder=True,
        cross_attn_all_layers_encoder=True,
        cross_attn_use_flex_attention=False,
        cross_attn_init_by_pooling=True,
        log_patch_lengths=True,
        non_linearity="swiglu",
        use_rope=True,
        recompute_fc1_out=False,
        recompute_fc3_out=False,
        recompute_attn=False,
        custom_bwd=False,
        layer_ckpt="none",
        use_local_encoder_transformer=True,
        init_use_gaussian=True,
        init_use_depth="current",
        attn_impl="xformers",
        attn_bias_type="block_causal",
        alpha_depth="disabled",
        max_length=256,
        local_attention_window_len=512,
        downsampling_by_pooling="max",
    ),
    data=DataloaderArgs(
        root_dir=".",
        sources={"alpaca-cleaned": 1.0},
        dataset_files=["alpaca-cleaned/1.json"],
        batch_size=2,
        prefetch_size=64,
        seq_len=1024,
        max_encoder_seq_length=3072,
        load_async=False,
        preprocess_dir="preprocess_dir",
        patcher_args=PatcherArgs(patching_mode="space"),
        tokenizer_args=TokenizerArgs(name="blt"),
    ),
    profiling=ProfilerArgs(run=False),
    checkpoint=CheckpointArgs(
        path="train_checkpoints",
        dump=SaveEvery(
            every=500,
            keep=3
        ),
        eval=SaveEvery(
            every=1000,
            keep=-1
        )
    ),
    logging=LoggingArgs(freq=10),
    eval_on_gpus=1,
    # env=None
)

args = train_args.model

In [4]:
local_encoder_args = LocalModelArgs(
    # Updated args
    dim=args.dim_local_encoder,
    n_layers=args.n_layers_local_encoder,
    n_heads=args.n_heads_local_encoder,
    dim_token_emb=get_encoder_dim_token_emb(args),
    dim_patch_emb=get_encoder_dim_patch_emb(args),
    cross_attn_encoder=args.cross_attn_encoder,
    cross_attn_decoder=False,
    cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None,
    cross_attn_init_by_pooling=args.cross_attn_init_by_pooling,
    # Defaults
    head_dim=args.head_dim,
    max_seqlen=args.max_encoder_seq_length,
    dropout=args.dropout,
    vocab_size=args.vocab_size + args.pm_size,
    norm_eps=args.norm_eps,
    patch_size=args.patch_size,
    sliding_window=args.local_attention_window_len,
    use_rope=args.use_rope,
    rope_theta=args.rope_theta,
    rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
    init_base_std=args.init_base_std,
    init_std_factor=args.init_std_factor,
    n_kv_heads=args.n_kv_heads,
    attn_impl=args.attn_impl,
    attn_bias_type="local_block_causal",
    multiple_of=args.multiple_of,
    ffn_dim_multiplier=args.ffn_dim_multiplier,
    patching_mode=args.patching_mode,
    use_local_encoder_transformer=args.use_local_encoder_transformer,
    downsampling_by_pooling=args.downsampling_by_pooling,
    encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
    cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder,
    cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder,
    cross_attn_nheads=args.cross_attn_nheads,
    eos_id=args.eos_id,
)

In [5]:
local_encoder = LocalEncoder(local_encoder_args).to("cuda")

In [6]:
local_encoder

LocalEncoder(
  (layers): ModuleList(
    (0-7): 8 x TransformerBlock(
      (attention): Attention(
        (wq): Linear(in_features=256, out_features=256, bias=False)
        (wk): Linear(in_features=256, out_features=256, bias=False)
        (wv): Linear(in_features=256, out_features=256, bias=False)
        (wo): Linear(in_features=256, out_features=256, bias=False)
      )
      (feed_forward): FeedForward(
        (w1): Linear(in_features=256, out_features=768, bias=False)
        (w3): Linear(in_features=256, out_features=768, bias=False)
        (w2): Linear(in_features=768, out_features=256, bias=False)
      )
      (attention_norm): RMSNorm((256,), eps=1e-05, elementwise_affine=True)
      (ffn_norm): RMSNorm((256,), eps=1e-05, elementwise_affine=True)
    )
  )
  (rope): RotaryEmbedding()
  (patch_embedding_projection): Linear(in_features=256, out_features=512, bias=False)
  (tok_embeddings): Embedding(260, 256)
  (cross_attn_layers): ModuleList(
    (0-7): 8 x CrossAttenti

In [7]:
tokens = torch.tensor([[1, 39, 39, 39, 39, 36]]).to("cuda")
bs, N = tokens.shape  # Batch size and sequence length

local_encoder_tokens = tokens
patch_ids = torch.tensor([[0, 1, 1, 2, 2, 3]]).to("cuda")
patch_lengths = torch.tensor([[1, 2, 2, 1]]).to("cuda")
local_encoder_embeds = None

cross_attn_mask_enc = cross_attn_mask(
    patch_ids,
    patch_lengths,
    N,
    patches_as_queries=True,
    cross_attn_k=args.cross_attn_k,
    window=args.cross_attn_window_encoder,
    block_mask=args.cross_attn_use_flex_attention,
)

(h_encoder, h_cross), cache_encoder = local_encoder(
    tokens=local_encoder_tokens,
    embeds=local_encoder_embeds,
    patch_embeds=None,
    cross_mask=cross_attn_mask_enc,
    num_patches=patch_lengths.shape[1],
    patch_ids=patch_ids,
)

In [8]:
patch_ids, patch_lengths.shape[1]

(tensor([[0, 1, 1, 2, 2, 3]], device='cuda:0'), 4)

In [9]:
patch_lengths

tensor([[1, 2, 2, 1]], device='cuda:0')

In [28]:
bs, seq_len = patch_ids.shape
num_patches = patch_lengths.shape[1]

q_ids = patch_ids.unsqueeze(-1).expand(bs, seq_len, num_patches)
kv_ids = (
    torch.arange(num_patches, device=patch_ids.device)
    .unsqueeze(0)
    .unsqueeze(0)
    .expand(bs, seq_len, num_patches)
)
# q_ids, kv_ids

In [11]:
q_ids==kv_ids

tensor([[[ True, False, False, False],
         [False,  True, False, False],
         [False,  True, False, False],
         [False, False,  True, False],
         [False, False,  True, False],
         [False, False, False,  True]]], device='cuda:0')

In [12]:
cross_mask = create_patch_mask_from_ids(
    patch_ids,
    patch_lengths.shape[1],
    window=None,
    patches_as_queries=True,
).repeat_interleave(args.cross_attn_k, dim=1 if True else -1)
cross_mask

tensor([[[ True, False, False, False, False, False],
         [ True, False, False, False, False, False],
         [False,  True,  True, False, False, False],
         [False,  True,  True, False, False, False],
         [False, False, False,  True,  True, False],
         [False, False, False,  True,  True, False],
         [False, False, False, False, False,  True],
         [False, False, False, False, False,  True]]], device='cuda:0')

In [13]:
h_encoder.shape, h_cross.shape

(torch.Size([1, 6, 256]), torch.Size([1, 8, 256]))

In [14]:
cross_attn_mask_enc

tensor([[[[0., -inf, -inf, -inf, -inf, -inf],
          [0., -inf, -inf, -inf, -inf, -inf],
          [-inf, 0., 0., -inf, -inf, -inf],
          [-inf, 0., 0., -inf, -inf, -inf],
          [-inf, -inf, -inf, 0., 0., -inf],
          [-inf, -inf, -inf, 0., 0., -inf],
          [-inf, -inf, -inf, -inf, -inf, 0.],
          [-inf, -inf, -inf, -inf, -inf, 0.]]]], device='cuda:0')

In [15]:
for i, line in enumerate(cross_attn_mask_enc[0, 0]):
    print(i, [float(x) for x in line])

0 [0.0, -inf, -inf, -inf, -inf, -inf]
1 [0.0, -inf, -inf, -inf, -inf, -inf]
2 [-inf, 0.0, 0.0, -inf, -inf, -inf]
3 [-inf, 0.0, 0.0, -inf, -inf, -inf]
4 [-inf, -inf, -inf, 0.0, 0.0, -inf]
5 [-inf, -inf, -inf, 0.0, 0.0, -inf]
6 [-inf, -inf, -inf, -inf, -inf, 0.0]
7 [-inf, -inf, -inf, -inf, -inf, 0.0]


# Experiment with visionBLT attention mask

In [16]:
import torch

frames_batch = torch.tensor([
    [
        [
            [111, 112, 113, 114],
            [121, 122, 123, 124],
            [131, 132, 133, 134],
            [141, 142, 143, 144],
        ],
        [
            [211, 212, 213, 214],
            [221, 222, 223, 224],
            [231, 232, 233, 234],
            [241, 242, 243, 244],
        ],
        [
            [311, 312, 313, 314],
            [321, 322, 323, 324],
            [331, 332, 333, 334],
            [341, 342, 343, 344],
        ],
        [
            [411, 412, 413, 414],
            [421, 422, 423, 424],
            [431, 432, 433, 434],
            [441, 442, 443, 444],
        ],
        [
            [511, 512, 513, 514],
            [521, 522, 523, 524],
            [531, 532, 533, 534],
            [541, 542, 543, 544],
        ]
    ]
]).to("cuda")

# Parameters
batch_size, num_frames, height, width = frames_batch.shape
frame_height  = 4
frame_width   = 4
tile_height   = 2
tile_width    = 2
patch_size    = 2   # number of consecutive frames sharing the same patch IDs
device        = "cuda"

# During the explanation the 4x4 frame with tile 2x2 will and patch size 2 be used for illustration
# Compute how many tiles (patches) fit along each spatial axis
tiles_y = height // tile_height
tiles_x = width  // tile_width
patches_per_frame = tiles_y * tiles_x

# 1) Create a grid of patch indices
tile_grid = torch.arange(patches_per_frame, device=device).view(tiles_y, tiles_x)
# e.g. tile_grid = tensor([[0, 1],
#                          [2, 3]])

# 2) Expand each grid cell into a tile of size (tile_height x tile_width)
#    Result: a (height x width) map where each entry is its patch index
spatial_patch_map = (
    tile_grid
    .repeat_interleave(tile_height, dim=0)
    .repeat_interleave(tile_width,  dim=1)
)
# spatial_patch_map.shape == (4, 4)
# spatial_patch_map =
# tensor([[0, 0, 1, 1],
#         [0, 0, 1, 1],
#         [2, 2, 3, 3],
#         [2, 2, 3, 3]])

# 3) Flatten the spatial map to shape (height*width,)
flat_patch_ids = spatial_patch_map.view(-1)

# 4) Compute a "group index" for each frame: 
frame_indices = torch.arange(num_frames, device=device)
group_index  = frame_indices // patch_size  # shape: (T,)
# [0//2, 1//2, 2//2, 3//2] = [0, 0, 1, 1]

# 5) For each frame, add an offset of (group_index x patches_per_frame)
#    to the base flat_patch_ids. Broadcasting yields shape (T, H*W)
ids_per_frame = flat_patch_ids[None, :] + group_index[:, None] * patches_per_frame

# 6) Reshape to a single vector per batch and repeat for all batches
#    Final shape: (batch_size, T * height * width)
patch_ids = (
    ids_per_frame
    .reshape(-1)             # (T*H*W,)
    .unsqueeze(0)            # (1, T*H*W)
    .expand(batch_size, -1)  # (B, T*H*W)
    .to(device)
)

tokens = frames_batch.flatten(start_dim=1)  # Remain batch separation

In [17]:
patches_per_frame

4

In [18]:
tokens

tensor([[111, 112, 113, 114, 121, 122, 123, 124, 131, 132, 133, 134, 141, 142,
         143, 144, 211, 212, 213, 214, 221, 222, 223, 224, 231, 232, 233, 234,
         241, 242, 243, 244, 311, 312, 313, 314, 321, 322, 323, 324, 331, 332,
         333, 334, 341, 342, 343, 344, 411, 412, 413, 414, 421, 422, 423, 424,
         431, 432, 433, 434, 441, 442, 443, 444, 511, 512, 513, 514, 521, 522,
         523, 524, 531, 532, 533, 534, 541, 542, 543, 544]], device='cuda:0')

In [19]:
patch_ids

tensor([[ 0,  0,  1,  1,  0,  0,  1,  1,  2,  2,  3,  3,  2,  2,  3,  3,  0,  0,
          1,  1,  0,  0,  1,  1,  2,  2,  3,  3,  2,  2,  3,  3,  4,  4,  5,  5,
          4,  4,  5,  5,  6,  6,  7,  7,  6,  6,  7,  7,  4,  4,  5,  5,  4,  4,
          5,  5,  6,  6,  7,  7,  6,  6,  7,  7,  8,  8,  9,  9,  8,  8,  9,  9,
         10, 10, 11, 11, 10, 10, 11, 11]], device='cuda:0')

In [20]:
len(torch.unique(patch_ids))

12

In [21]:
patch_ids

tensor([[ 0,  0,  1,  1,  0,  0,  1,  1,  2,  2,  3,  3,  2,  2,  3,  3,  0,  0,
          1,  1,  0,  0,  1,  1,  2,  2,  3,  3,  2,  2,  3,  3,  4,  4,  5,  5,
          4,  4,  5,  5,  6,  6,  7,  7,  6,  6,  7,  7,  4,  4,  5,  5,  4,  4,
          5,  5,  6,  6,  7,  7,  6,  6,  7,  7,  8,  8,  9,  9,  8,  8,  9,  9,
         10, 10, 11, 11, 10, 10, 11, 11]], device='cuda:0')

In [22]:
# ENCODER CROSS-ATTENTION
bs, seq_len = patch_ids.shape
num_patches = len(torch.unique(patch_ids))

q_ids = patch_ids.unsqueeze(-1).expand(bs, seq_len, num_patches)
kv_ids = (
    torch.arange(num_patches, device=patch_ids.device)
    .unsqueeze(0)
    .unsqueeze(0)
    .expand(bs, seq_len, num_patches)
)
# q_ids, kv_ids

counter = 0
for line in (q_ids==kv_ids)[0]:
    print(counter, end=": ")
    for el in line:
        print(int(el) == 1, end=" ")
    print()
    counter += 1

0: True False False False False False False False False False False False 
1: True False False False False False False False False False False False 
2: False True False False False False False False False False False False 
3: False True False False False False False False False False False False 
4: True False False False False False False False False False False False 
5: True False False False False False False False False False False False 
6: False True False False False False False False False False False False 
7: False True False False False False False False False False False False 
8: False False True False False False False False False False False False 
9: False False True False False False False False False False False False 
10: False False False True False False False False False False False False 
11: False False False True False False False False False False False False 
12: False False True False False False False False False False False False 
13: False False True F

In [23]:
# DECODER CROSS-ATTENTION
bs, seq_len = patch_ids.shape
num_patches = len(torch.unique(patch_ids))

kv_ids = patch_ids.unsqueeze(1).expand(bs, num_patches, seq_len)
q_ids = (
    torch.arange(num_patches, device=patch_ids.device)
    .unsqueeze(0)
    .unsqueeze(-1)
    .expand(bs, num_patches, seq_len)
)
# q_ids, kv_ids

counter = 0
for line in (q_ids==kv_ids)[0]:
    print(counter, end=": ")
    for el in line:
        print(int(el) == 1, end=" ")
    print()
    counter += 1

0: True True False False True True False False False False False False False False False False True True False False True True False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False 
1: False False True True False False True True False False False False False False False False False False True True False False True True False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False 
2: False False False False False False False Fal

In [None]:
# Causal-at-frame-level mask === Encoder - Decoder Mask
# within each frame: full attention
# between frames: only to current or any past frame (no future)

bs, seq_len = patch_ids.shape
frame_elements = frame_height * frame_width
device       = "cuda"
dtype        = torch.bfloat16

block_full = torch.ones((frame_elements, frame_elements), device=device, dtype=dtype)
block_causal = torch.tril(torch.ones((num_frames, num_frames), device=device, dtype=dtype))
mask_causal_frames = torch.kron(block_causal, block_full)
# → mask_causal_frames[i,j] == 1 iff
#      frame_j ≤ frame_i   (where frame_* = index // frame_elements)

# add optimization to have obsolete or decay in time the amount of connections (it should be automatically handled if have proper summary mechanism)

counter = 0
for line in mask_causal_frames:
    print(counter, end=": ")
    for el in line:
        print(int(el) == 1, end=" ")
    print()
    counter += 1

0: True True True True True True True True True True True True True True True True False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False 
1: True True True True True True True True True True True True True True True True False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False False 
2: True True True True True True True True True True True True T

In [25]:
(32 * 10)**2

102400

In [26]:
32 * 32 * 10

10240

In [27]:
# Causal-at-frame-level mask === Latent model Mask
# within each frame: full attention
# between frames: only to current or any past frame (no future)

bs, seq_len = patch_ids.shape
device       = "cuda"
dtype        = torch.bfloat16

block_full = torch.ones((patches_per_frame, patches_per_frame), device=device, dtype=dtype)
block_causal = torch.tril(torch.ones((num_frames, num_frames), device=device, dtype=dtype))
mask_causal_frames = torch.kron(block_causal, block_full)
# → mask_causal_frames[i,j] == 1 iff
#      frame_j ≤ frame_i   (where frame_* = index // patches_per_frame)

counter = 0
for line in mask_causal_frames:
    print(counter, end=": ")
    for el in line:
        print(int(el) == 1, end=" ")
    print()
    counter += 1

0: True True True True False False False False False False False False False False False False False False False False 
1: True True True True False False False False False False False False False False False False False False False False 
2: True True True True False False False False False False False False False False False False False False False False 
3: True True True True False False False False False False False False False False False False False False False False 
4: True True True True True True True True False False False False False False False False False False False False 
5: True True True True True True True True False False False False False False False False False False False False 
6: True True True True True True True True False False False False False False False False False False False False 
7: True True True True True True True True False False False False False False False False False False False False 
8: True True True True True True True True True True Tru