In [1]:
cd ..

c:\Users\leoni\Documents\projects\visionBLT


In [2]:
import torch

from bytelatent.args import (
    TrainArgs, 
    DataloaderArgs, OptimArgs, ByteLatentTransformerArgs, DistributedArgs, TokenizerArgs, PatcherArgs,
    ProfilerArgs, CheckpointArgs, LoggingArgs
)
from bytelatent.checkpoint import SaveEvery
from bytelatent.model.local_models import LocalModelArgs, LocalEncoder, VisionModelArgs
from bytelatent.model.blt import (
    get_encoder_dim_token_emb, get_encoder_dim_patch_emb, compute_hash_embeddings, cross_attn_mask, create_patch_mask_from_ids, patch_ids_from_frames, 
)
from bytelatent.model.utils import create_vision_causal_mask

W0807 22:35:58.004000 25636 Lib\site-packages\torch\distributed\elastic\multiprocessing\redirects.py:29] NOTE: Redirects are currently not supported in Windows or MacOs.
  from .autonotebook import tqdm as notebook_tqdm


In [3]:
train_args = TrainArgs(
    name="debug",
    dump_dir="/tmp/",
    seed=777,
    steps=100_000,
    optim=OptimArgs(
        lr=4e-04,
        warmup=500,
        lr_min_ratio=0.1,
        clip=10.0
    ),
    distributed=DistributedArgs(
        fsdp_type="full_shard",
        model_dtype="bf16",
        matmul_allow_tf32=False,
        selective_activation_checkpointing=False,
        tp_size=1
    ),
    model=ByteLatentTransformerArgs(
        vision=VisionModelArgs(
            img_channels=1,
            img_height=64,
            img_width=64,
            scale_factor=2,
            inner_channels=96,
            norm_channels=4,
            num_process_layers=2,
            tile_height=4,
            tile_width=4,
            patch_size=4
        ),
        n_heads=2,
        n_heads_local_encoder=2,
        n_heads_local_decoder=2,
        dim=96,
        vocab_size=260,
        dim_token=96,
        patch_size=6,
        patching_mode="space",
        tie_local_encoder_decoder_logits=False,
        patch_in_forward=False,
        max_encoder_seq_length=3072,
        pad_to_max_length=True,
        patching_threshold=3.1439168453216553,
        encoder_hash_byte_group_size=None,
        encoder_hash_byte_group_vocab=50002,
        encoder_hash_byte_group_nb_functions=3,
        encoder_enable_byte_ngrams=False,
        cross_attn_encoder=True, # assuming cross_attention is true
        cross_attn_decoder=True, # assuming cross_attention is true
        cross_attn_window_encoder=None,
        cross_attn_window_decoder=None,
        dim_local_encoder=96,
        dim_local_decoder=96,
        cross_attn_k=2,
        cross_attn_nheads=2,
        cross_attn_all_layers_decoder=True,
        cross_attn_all_layers_encoder=True,
        cross_attn_use_flex_attention=False,
        cross_attn_init_by_pooling=True,
        log_patch_lengths=True,
        non_linearity="swiglu",
        use_rope=True,
        recompute_fc1_out=False,
        recompute_fc3_out=False,
        recompute_attn=False,
        custom_bwd=False,
        layer_ckpt="none",
        use_local_encoder_transformer=True,
        init_use_gaussian=True,
        init_use_depth="current",
        attn_impl="xformers",
        attn_bias_type="block_causal",
        alpha_depth="disabled",
        max_length=256,
        local_attention_window_len=512,
        downsampling_by_pooling="max",
    ),
    data=DataloaderArgs(
        root_dir=".",
        sources={"alpaca-cleaned": 1.0},
        dataset_files=["alpaca-cleaned/1.json"],
        batch_size=2,
        prefetch_size=64,
        seq_len=1024,
        max_encoder_seq_length=3072,
        load_async=False,
        preprocess_dir="preprocess_dir",
        patcher_args=PatcherArgs(patching_mode="space"),
        tokenizer_args=TokenizerArgs(name="blt"),
    ),
    profiling=ProfilerArgs(run=False),
    checkpoint=CheckpointArgs(
        path="train_checkpoints",
        dump=SaveEvery(
            every=500,
            keep=3
        ),
        eval=SaveEvery(
            every=1000,
            keep=-1
        )
    ),
    logging=LoggingArgs(freq=10),
    eval_on_gpus=1,
    # env=None
)

args = train_args.model

In [4]:
local_encoder_args = LocalModelArgs(
    # Updated args
    vision=args.vision,
    dim=args.dim_local_encoder,
    n_layers=args.n_layers_local_encoder,
    n_heads=args.n_heads_local_encoder,
    dim_token_emb=get_encoder_dim_token_emb(args),
    dim_patch_emb=get_encoder_dim_patch_emb(args),
    cross_attn_encoder=args.cross_attn_encoder,
    cross_attn_decoder=False,
    cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None,
    cross_attn_init_by_pooling=args.cross_attn_init_by_pooling,
    # Defaults
    head_dim=args.head_dim,
    max_seqlen=args.max_encoder_seq_length,
    dropout=args.dropout,
    vocab_size=args.vocab_size + args.pm_size,
    norm_eps=args.norm_eps,
    patch_size=args.patch_size,
    sliding_window=args.local_attention_window_len,
    use_rope=args.use_rope,
    rope_theta=args.rope_theta,
    rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
    init_base_std=args.init_base_std,
    init_std_factor=args.init_std_factor,
    n_kv_heads=args.n_kv_heads,
    attn_impl="sdpa",
    attn_bias_type="local_block_causal",
    multiple_of=args.multiple_of,
    ffn_dim_multiplier=args.ffn_dim_multiplier,
    patching_mode=args.patching_mode,
    use_local_encoder_transformer=args.use_local_encoder_transformer,
    downsampling_by_pooling=args.downsampling_by_pooling,
    encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
    cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder,
    cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder,
    cross_attn_nheads=args.cross_attn_nheads,
    eos_id=args.eos_id,
)

In [5]:
local_encoder = LocalEncoder(local_encoder_args).to("cuda")

In [6]:
local_encoder

LocalEncoder(
  (layers): ModuleList(
    (0-7): 8 x TransformerBlock(
      (attention): Attention(
        (wq): Linear(in_features=96, out_features=96, bias=False)
        (wk): Linear(in_features=96, out_features=96, bias=False)
        (wv): Linear(in_features=96, out_features=96, bias=False)
        (wo): Linear(in_features=96, out_features=96, bias=False)
      )
      (feed_forward): FeedForward(
        (w1): Linear(in_features=96, out_features=256, bias=False)
        (w3): Linear(in_features=96, out_features=256, bias=False)
        (w2): Linear(in_features=256, out_features=96, bias=False)
      )
      (attention_norm): RMSNorm((96,), eps=1e-05, elementwise_affine=True)
      (ffn_norm): RMSNorm((96,), eps=1e-05, elementwise_affine=True)
    )
  )
  (rope): RotaryEmbeddingNd()
  (patch_embedding_projection): Linear(in_features=96, out_features=192, bias=False)
  (image_encoder): Sequential(
    (0): Conv2d(1, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): 

In [10]:
batch_size, num_frames, channels, height, width = (
    4, 10, 1, 64, 64
)
frames = torch.randn(batch_size, num_frames, channels, height, width).to("cuda")
frames = frames.reshape(batch_size*num_frames, channels, height, width)
encoder_input = local_encoder.image_encoder(frames)

frame_height, frame_width = encoder_input.shape[2:]
patch_ids = patch_ids_from_frames(
    batch_size=batch_size,
    num_frames=num_frames,
    height=frame_height,
    width=frame_width,
    tile_height=args.vision.tile_height, 
    tile_width=args.vision.tile_width, 
    patch_size=args.vision.patch_size,
    device=encoder_input.device
)

frame_elements = frame_height * frame_width * num_frames
encoder_input = encoder_input.permute(0, 2, 3, 1).reshape(batch_size, -1, local_encoder.inner_channels)

patch_lengths = torch.unique(patch_ids).unsqueeze(dim=0).to("cuda")
local_encoder_embeds = None

cross_attn_mask_enc = cross_attn_mask(
    patch_ids,
    patch_lengths,
    frame_elements,
    patches_as_queries=True,
    cross_attn_k=args.cross_attn_k,
    window=args.cross_attn_window_encoder,
    block_mask=args.cross_attn_use_flex_attention,
).to("cuda")
print("cross_attn_mask_enc:", cross_attn_mask_enc.shape)

mask = create_vision_causal_mask(
    patch_ids.shape[1],
    frame_elements,
    local_encoder_args.attn_impl,
    "causal"
).to("cuda")
print("mask:", mask.shape)

cross_attn_mask_enc: torch.Size([4, 1, 384, 10240])
mask: torch.Size([10240, 10240])


In [12]:
(h_encoder, h_cross), cache_encoder = local_encoder(
    frames=encoder_input,
    mask=mask,
    cross_mask=cross_attn_mask_enc,
    patch_embeds=None,
    num_patches=patch_lengths.shape[1],
    patch_ids=patch_ids,
)

In [13]:
h_encoder.shape, h_cross.shape

(torch.Size([4, 10240, 96]), torch.Size([4, 384, 96]))