In [1]:
cd ..

c:\Users\leoni\Documents\projects\visionBLT


In [2]:
import torch

from bytelatent.args import (
    TrainArgs, 
    DataloaderArgs, OptimArgs, ByteLatentTransformerArgs, DistributedArgs, TokenizerArgs, PatcherArgs,
    ProfilerArgs, CheckpointArgs, LoggingArgs
)
from bytelatent.checkpoint import SaveEvery
from bytelatent.model.local_models import LocalModelArgs, LocalEncoder, LocalDecoder, VisionModelArgs
from bytelatent.model.blt import (
    GlobalTransformer, 
    get_global_dim_patch_emb, compute_hash_embeddings, cross_attn_mask, create_patch_mask_from_ids, patch_ids_from_frames, 
    create_local_encoder, create_local_decoder, 
    get_encoder_dim_token_emb, get_encoder_dim_patch_emb,
    get_decoder_dim_token_emb, decoder_patch_ids_from_lengths, 
)
from bytelatent.model.utils import create_vision_causal_mask

W0809 10:11:43.272000 16600 Lib\site-packages\torch\distributed\elastic\multiprocessing\redirects.py:29] NOTE: Redirects are currently not supported in Windows or MacOs.
  from .autonotebook import tqdm as notebook_tqdm


# Define args

In [3]:
train_args = TrainArgs(
    name="debug",
    dump_dir="/tmp/",
    seed=777,
    steps=100_000,
    optim=OptimArgs(
        lr=4e-04,
        warmup=500,
        lr_min_ratio=0.1,
        clip=10.0
    ),
    distributed=DistributedArgs(
        fsdp_type="full_shard",
        model_dtype="bf16",
        matmul_allow_tf32=False,
        selective_activation_checkpointing=False,
        tp_size=1
    ),
    model=ByteLatentTransformerArgs(
        vision=VisionModelArgs(
            img_channels=1,
            img_height=64,
            img_width=64,
            scale_factor=2,
            norm_channels=4,
            num_process_layers=2,
            tile_height=4,
            tile_width=4,
        ),
        n_heads=2,
        n_heads_local_encoder=2,
        n_heads_local_decoder=2,
        dim=96,
        cross_attn_k=2,
        dim_global=96*2,
        vocab_size=260,
        dim_token=96,
        patch_size=5,  # number of consecutive frames sharing the same tile
        patching_mode="space",
        tie_local_encoder_decoder_logits=False,
        patch_in_forward=False,
        max_encoder_seq_length=3072,
        pad_to_max_length=True,
        patching_threshold=3.1439168453216553,
        encoder_hash_byte_group_size=None,
        encoder_hash_byte_group_vocab=50002,
        encoder_hash_byte_group_nb_functions=3,
        encoder_enable_byte_ngrams=False,
        cross_attn_encoder=True, # assuming cross_attention is true
        cross_attn_decoder=True, # assuming cross_attention is true
        cross_attn_window_encoder=None,
        cross_attn_window_decoder=None,
        dim_local_encoder=96,
        dim_local_decoder=96,
        cross_attn_nheads=2,
        cross_attn_all_layers_decoder=True,
        cross_attn_all_layers_encoder=True,
        cross_attn_use_flex_attention=False,
        cross_attn_init_by_pooling=True,
        log_patch_lengths=True,
        non_linearity="swiglu",
        use_rope=True,
        recompute_fc1_out=False,
        recompute_fc3_out=False,
        recompute_attn=False,
        custom_bwd=False,
        layer_ckpt="none",
        use_local_encoder_transformer=True,
        init_use_gaussian=True,
        init_use_depth="current",
        attn_impl="sdpa",
        attn_bias_type="block_causal",
        alpha_depth="disabled",
        max_length=256,
        local_attention_window_len=512,
        downsampling_by_pooling="max",
    ),
    data=DataloaderArgs(
        root_dir=".",
        sources={"alpaca-cleaned": 1.0},
        dataset_files=["alpaca-cleaned/1.json"],
        batch_size=2,
        prefetch_size=64,
        seq_len=1024,
        max_encoder_seq_length=3072,
        load_async=False,
        preprocess_dir="preprocess_dir",
        patcher_args=PatcherArgs(patching_mode="space"),
        tokenizer_args=TokenizerArgs(name="blt"),
    ),
    profiling=ProfilerArgs(run=False),
    checkpoint=CheckpointArgs(
        path="train_checkpoints",
        dump=SaveEvery(
            every=500,
            keep=3
        ),
        eval=SaveEvery(
            every=1000,
            keep=-1
        )
    ),
    logging=LoggingArgs(freq=10),
    eval_on_gpus=1,
    # env=None
)

args = train_args.model

# Init models

In [4]:
local_encoder = create_local_encoder(args=args).to(device="cuda")

In [5]:
global_transformer = GlobalTransformer(args.model_copy(
    deep=True,
    update=dict(
        dim=args.dim_global,
        n_layers=args.n_layers_global,
        n_heads=args.n_heads_global,
        n_kv_heads=args.n_kv_heads_global,
        local_attention_window_len=None,
        dim_token_emb=get_global_dim_patch_emb(args),
        dim_patch_emb=None,
        cross_attn_encoder=False,
        cross_attn_decoder=False,
    ),
)).to("cuda")

In [6]:
local_decoder = create_local_decoder(args=args).to(device="cuda")

# Run the model

## Preprocess

In [7]:
batch_size, num_frames, channels, height, width = (4, 10, 1, 64, 64)
frames = torch.randn(batch_size, num_frames, channels, height, width).to("cuda")
frames = frames.reshape(batch_size*num_frames, channels, height, width)
encoder_input = local_encoder.image_encoder(frames)

frame_height, frame_width = encoder_input.shape[2:]
patch_ids = patch_ids_from_frames(
    batch_size=batch_size,
    num_frames=num_frames,
    height=frame_height,
    width=frame_width,
    tile_height=args.vision.tile_height, 
    tile_width=args.vision.tile_width, 
    patch_size=args.patch_size,
    device=encoder_input.device,
    skip_from_start=1
)

encoder_input = encoder_input.permute(0, 2, 3, 1).reshape(batch_size, -1, local_encoder.dim)
print("encoder_input:", encoder_input.shape)

N = encoder_input.shape[1]

patch_lengths = torch.unique(patch_ids[0], return_counts=True)[1].unsqueeze(dim=0).to("cuda")
print("patch_lengths:", patch_lengths.shape)

encoder_input: torch.Size([4, 10240, 96])
patch_lengths: torch.Size([1, 192])


## Encoder

In [8]:
frame_elements = frame_height * frame_width
print("frame_elements:", frame_elements)

cross_attn_mask_enc = cross_attn_mask(
    patch_ids,
    patch_lengths,
    N,
    patches_as_queries=True,
    cross_attn_k=args.cross_attn_k,
    window=args.cross_attn_window_encoder,
    block_mask=args.cross_attn_use_flex_attention,
).to("cuda")
print("cross_attn_mask_enc:", cross_attn_mask_enc.shape)

causal_mask_enc = create_vision_causal_mask(
    patch_ids.shape[1],
    frame_elements,
    args.attn_impl,
    "causal"
).to("cuda")
print("causal_mask_enc:", causal_mask_enc.shape)

(h_encoder, h_cross), cache_encoder = local_encoder(
    frames=encoder_input,
    mask=causal_mask_enc,
    cross_mask=cross_attn_mask_enc,
    patch_embeds=None,
    num_patches=patch_lengths.shape[1],
    patch_ids=patch_ids,
)
print("h_encoder:", h_encoder.shape)
print("h_cross:", h_cross.shape)

frame_elements: 1024
cross_attn_mask_enc: torch.Size([4, 1, 384, 10240])
causal_mask_enc: torch.Size([10240, 10240])
h_encoder: torch.Size([4, 10240, 96])
h_cross: torch.Size([4, 384, 96])


## Global

In [9]:
# Reshape h_cross
h = h_cross.view(batch_size, patch_lengths.shape[1], -1)
print(f"Global transformer input shape: {h.shape}.")

tiles_y = frame_height // args.vision.tile_height
tiles_x = frame_width  // args.vision.tile_width
latent_frame_elements = tiles_y * tiles_x
print("latent_frame_elements:", latent_frame_elements)

causal_mask_global = create_vision_causal_mask(
    h.shape[1],
    latent_frame_elements,
    args.attn_impl,
    "causal"
).to("cuda")
print("causal_mask_global:", causal_mask_global.shape)

h, _ = global_transformer(
    embeds=h,
    mask=causal_mask_global
)
print(f"Global transformer output shape: {h.shape}.")

Global transformer input shape: torch.Size([4, 192, 192]).
latent_frame_elements: 64
causal_mask_global: torch.Size([192, 192])
Global transformer output shape: torch.Size([4, 192, 192]).


## Decoder

In [10]:
# Unpatching
dec_embeds = h_encoder
print(f"Decoder embeddings `dec_embeds` shape: {dec_embeds.shape}.")

# Generate decoder patch IDs
decoder_patch_ids = patch_ids_from_frames(
    batch_size=batch_size,
    num_frames=num_frames,
    height=frame_height,
    width=frame_width,
    tile_height=args.vision.tile_height, 
    tile_width=args.vision.tile_width, 
    patch_size=args.patch_size,
    device=encoder_input.device,
    skip_from_end=1
)
decoder_patch_lengths = torch.unique(decoder_patch_ids[0], return_counts=True)[1].unsqueeze(dim=0).to("cuda")
print(f"Decoder patch IDs shape: {decoder_patch_ids.shape}.")

# Cross-attention decoder
cross_attn_mask_dec = cross_attn_mask(
    decoder_patch_ids,
    decoder_patch_lengths,
    N,
    patches_as_queries=False,
    cross_attn_k=args.cross_attn_k,
    window=args.cross_attn_window_decoder,
    block_mask=args.cross_attn_use_flex_attention,
).to("cuda")
print("cross_attn_mask_dec:", cross_attn_mask_dec.shape)

causal_mask_dec = create_vision_causal_mask(
    decoder_patch_ids.shape[1],
    frame_elements,
    args.attn_impl,
    "causal"
).to("cuda")
print("causal_mask_dec:", causal_mask_dec.shape)

# Local decoder
decoder_output, _ = local_decoder(
    embeds=dec_embeds,
    patch_embeds=h,
    mask=causal_mask_dec,
    cross_mask=cross_attn_mask_dec,
)
print(f"Decoder output shape: {decoder_output.shape}")

Decoder embeddings `dec_embeds` shape: torch.Size([4, 10240, 96]).
Decoder patch IDs shape: torch.Size([4, 10240]).
cross_attn_mask_dec: torch.Size([4, 1, 10240, 384])
causal_mask_dec: torch.Size([10240, 10240])
Decoder output shape: torch.Size([4, 10240, 96])


## Postprocess

In [11]:
latent_frames = decoder_output.reshape(batch_size*num_frames, frame_height, frame_width, local_decoder.dim).permute(0, 3, 1, 2)
print("latent_frames:", latent_frames.shape)

frames_features = local_decoder.image_decoder(latent_frames)
frames_features = frames_features.permute(0, 2, 3, 1).reshape(batch_size, -1, local_decoder.dim)
print("frames_features:", frames_features.shape)

logits = local_decoder.head(frames_features).float()
print("logits:", logits.shape)

latent_frames: torch.Size([40, 96, 32, 32])
frames_features: torch.Size([4, 40960, 96])
logits: torch.Size([4, 40960, 256])
