In [1]:
cd ~/llama-models/

/home/jovyan/llama-models


In [2]:
from models.llama4 import *
import torch
import codecs
import io
import json
import os
import sys
import time
from pathlib import Path
from typing import Callable, Generator, List, Optional

import torch
import torch.nn.functional as F

In [3]:
# !pip install fairscale

In [4]:
from models.checkpoint import maybe_reshard_state_dict
from models.datatypes import GenerationResult, QuantizationMode
from models.llama4.args import ModelArgs
from models.llama4.chat_format import ChatFormat, RawContent, RawMessage
from models.llama4.datatypes import LLMInput, MaskedEmbedding, TransformerInput
# from models.llama4.model import Transformer
from models.llama4.tokenizer import Tokenizer
import numpy as np
import re

In [5]:
ckpt_dir = "/data/llama/checkpoints/Llama-4-Scout-17B-16E-Instruct"

In [6]:
local_rank = 'cuda:0'
torch.cuda.set_device(local_rank)

ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))

max_seq_len = 128
max_batch_size = 1
world_size = 1

In [7]:
with open(Path(ckpt_dir) / "params.json", "r") as f:
    params = json.loads(f.read())

model_args: ModelArgs = ModelArgs(
    **params,
    max_seq_len=max_seq_len,
    max_batch_size=max_batch_size,
)
tokenizer = Tokenizer.get_instance()

model_args.vocab_size = tokenizer.n_words

In [8]:
ckpt_paths = np.array(sorted(ckpt_paths))
map_location = 'cpu'
mmap = True
state_dicts = [torch.load(str(p), map_location=map_location, mmap=mmap) for p in ckpt_paths]

In [9]:
from typing import Dict, Any

def convert_moe_weights(state_dict: Dict[str, Any], num_experts: int) -> Dict[str, Any]:
    routed_keys = _MOE_WEIGHT_ROW_KEY | _MOE_WEIGHT_COLUMN_KEY
    routed_regex = re.compile("|".join(routed_keys))
    keys = list(state_dict.keys())
    for key in keys:
        if routed_regex.search(key):
            state_dict[key] = state_dict.pop(key).unflatten(0, (num_experts, -1)).squeeze(dim=0)
    return state_dict

In [10]:
_WEIGHT_ROW_KEY = {
    "feed_forward.w2",
    "feed_forward.mlp.fc2",
    "attention.wo",
    "feed_forward.mlp.fc2_weight",
    "feed_forward.w_out_shared_DF.weight",
    "attn.wo.weight",
    "mlp.c_proj.weight",
}
_MOE_WEIGHT_ROW_KEY = {"feed_forward.experts.(moe_w_in_eD_F|moe_w_swiglu_eD_F)"}

_WEIGHT_COLUMN_KEY = {
    "output",
    "feed_forward.(w1|w3)",
    "feed_forward.mlp.(fc1|fc3)",
    "feed_forward.mlp.fc1_weight",
    "attention.(wk|wq|wv|wqkv).weight",
    "feed_forward.(w_in_shared_FD|w_swiglu_FD)",
    "attn.(wk|wq|wv).weight",
    "attn.(wk|wq|wv).bias",
    "mlp.c_fc.weight",
    "mlp.c_fc.bias",
    "conv1._linear.weight",
    "tok_embeddings.weight",
    "vision_projection.weight",
}
_MOE_WEIGHT_COLUMN_KEY = {"feed_forward.experts.moe_w_out_eF_D"}


In [11]:
from typing import Any, Dict, List

import torch
from torch import nn
from torch.nn import functional as F


class FeedForward(nn.Module):
    def __init__(
        self,
        dim: int,
        hidden_dim: int,
        do_reduce: bool = True,
    ):
        super().__init__()
        self.do_reduce = do_reduce

        # Replace ColumnParallelLinear/RowParallelLinear with nn.Linear
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)

        self._register_load_state_dict_pre_hook(self.load_hook)

    def load_hook(
        self,
        state_dict: Dict[str, Any],
        prefix: str,
        local_metadata: Dict[str, Any],
        strict: bool,
        missing_keys: List[str],
        unexpected_keys: List[str],
        error_msgs: List[str],
    ) -> None:
        # If checkpoint has combined fc1 weights, split into w1 and w3
        if prefix + "mlp.fc1_weight" in state_dict:
            w1_w3 = state_dict.pop(prefix + "mlp.fc1_weight")
            w1, w3 = w1_w3.chunk(2, dim=0)
            state_dict[prefix + "w1.weight"] = w1
            state_dict[prefix + "w3.weight"] = w3
            state_dict[prefix + "w2.weight"] = state_dict.pop(prefix + "mlp.fc2_weight")

    def forward(self, x):
        # x: [batch, ..., dim]
        x1 = F.linear(x, self.w1.weight)
        x3 = F.linear(x, self.w3.weight)
        x = F.silu(x1) * x3
        out = F.linear(x, self.w2.weight)
        # On single GPU, no need to reduce across model parallel region
        return out


In [12]:
from typing import Any, Dict, List

import torch
from torch import Tensor, nn
from torch.nn import functional as F

from models.llama4.args import MoEArgs
# from models.llama4.ffn import FeedForward


def divide_exact(numerator: int, denominator: int) -> int:
    assert numerator % denominator == 0, f"{numerator} is not divisible by {denominator}"
    return numerator // denominator


class Experts(nn.Module):
    def __init__(
        self,
        num_local_experts: int,
        dim: int,
        hidden_dim: int,
    ) -> None:
        super().__init__()

        dtype = torch.get_default_dtype()
        self.num_local_experts = num_local_experts
        self.dim = dim

        # Since we're on a single GPU, divide_factor = 1
        divide_factor = 1

        # w1: [e, D, hidden_dim]
        self.w1: nn.Parameter = nn.Parameter(
            torch.empty(
                num_local_experts,
                dim,
                divide_exact(hidden_dim, divide_factor),
                dtype=dtype,
            )
        )

        # w2: [e, hidden_dim, D]
        self.w2: nn.Parameter = nn.Parameter(
            torch.empty(
                num_local_experts,
                divide_exact(hidden_dim, divide_factor),
                dim,
                dtype=dtype,
            )
        )

        # w3: [e, D, hidden_dim]
        self.w3: nn.Parameter = nn.Parameter(
            torch.empty(
                num_local_experts,
                dim,
                divide_exact(hidden_dim, divide_factor),
                dtype=dtype,
            )
        )

        self._register_load_state_dict_pre_hook(self.load_hook)

    def load_hook(
        self,
        state_dict: Dict[str, Any],
        prefix: str,
        local_metadata: Dict[str, Any],
        strict: bool,
        missing_keys: List[str],
        unexpected_keys: List[str],
        error_msgs: List[str],
    ) -> None:
        self.prefix = prefix
        # If checkpoint uses merged tensors, split them into w1, w2, w3
        if prefix + "moe_w_in_eD_F" in state_dict:
            e = self.num_local_experts
            D = self.dim
            state_dict[prefix + "w1"] = state_dict.pop(prefix + "moe_w_in_eD_F").view(e, D, -1)
            state_dict[prefix + "w2"] = state_dict.pop(prefix + "moe_w_out_eF_D").view(e, -1, D)
            state_dict[prefix + "w3"] = state_dict.pop(prefix + "moe_w_swiglu_eD_F").view(e, D, -1)

    def forward(
        self,
        routed_in_egD: torch.Tensor,  # [e*G, D]
    ) -> torch.Tensor:
        e = self.num_local_experts
        D = self.dim

        # Reshape to [e, G, D]
        x_egD = routed_in_egD.view(e, -1, D)

        # Apply Swiglu for each expert
        out_egD = self.batched_swiglu(x_egD, self.w1, self.w3, self.w2)
        # Flatten back to [e*G, D]
        out_egD = out_egD.view(-1, D)

        return out_egD

    def batched_swiglu(self, x: Tensor, w1: Tensor, w3: Tensor, w2: Tensor) -> Tensor:
        # x: [e, G, D]; w1: [e, D, F]; w3: [e, D, F]; w2: [e, F, D]
        middle_out_egF = F.silu(torch.bmm(x, w1)) * torch.bmm(x, w3)
        return torch.bmm(middle_out_egF, w2)


class MoE(nn.Module):
    """
    - x_bsD: [batch_size, seq_len, D]
    - router_DE: [D, E]
    """

    def __init__(
        self,
        dim: int,
        hidden_dim: int,
        ffn_dim_multiplier: float,
        multiple_of: int,
        moe_args: MoEArgs,
    ) -> None:
        super().__init__()

        self.moe_args = moe_args

        # Compute GMLP hidden dimension
        hidden_dim_denom: float = 1.0
        if moe_args.auto_scale_F:
            hidden_dim_denom = moe_args.capacity_factor + 1.0

        hidden_dim = int(2 * hidden_dim / 3)
        hidden_dim = int(ffn_dim_multiplier * hidden_dim)
        if moe_args.auto_scale_F:
            hidden_dim = int(hidden_dim / hidden_dim_denom)
        hidden_dim += -hidden_dim % multiple_of

        num_local_experts: int = moe_args.num_experts
        dtype: torch.dtype = torch.get_default_dtype()

        # Create Experts module (all experts local on this GPU)
        self.experts = Experts(
            num_local_experts,
            dim,
            hidden_dim,
        )

        # Router logits: [D, E]
        self.router_DE: nn.Parameter = nn.Parameter(torch.empty(dim, moe_args.num_experts, dtype=dtype))

        # Shared expert (dense FFN) — no reduce needed
        self.shared_expert = FeedForward(dim, hidden_dim, do_reduce=False)

        self._register_load_state_dict_pre_hook(self.load_hook)

    def load_hook(
        self,
        state_dict: Dict[str, Any],
        prefix: str,
        local_metadata: Dict[str, Any],
        strict: bool,
        missing_keys: List[str],
        unexpected_keys: List[str],
        error_msgs: List[str],
    ) -> None:
        # Remap shared expert weights if needed
        if prefix + "w_in_shared_FD.weight" in state_dict:
            state_dict[prefix + "shared_expert.w1.weight"] = state_dict.pop(prefix + "w_in_shared_FD.weight")
            state_dict[prefix + "shared_expert.w3.weight"] = state_dict.pop(prefix + "w_swiglu_FD.weight")
            state_dict[prefix + "shared_expert.w2.weight"] = state_dict.pop(prefix + "w_out_shared_DF.weight")

    def forward(self, x_bsD: Tensor) -> Tensor:
        bsz, slen, D = x_bsD.shape
        # Flatten tokens: [bsz * slen, D]
        x_aD = x_bsD.view(-1, D)
        a = x_aD.shape[0]

        # Compute router scores: [E, a]
        router_scores: Tensor = torch.matmul(x_aD, self.router_DE).transpose(0, 1)

        # Find top-k experts per token: router_indices_aK: [a, top_k], router_scores_aK: [a, top_k]
        router_scores_aK, router_indices_aK = torch.topk(router_scores.transpose(0, 1), self.moe_args.top_k, dim=1)

        # Zero out all but top-k scores
        mask_full = torch.full_like(router_scores.transpose(0, 1), float("-inf"))
        mask_full = mask_full.scatter_(1, router_indices_aK, router_scores_aK)
        router_scores = mask_full.transpose(0, 1)

        # Indices for gathering tokens: [E, a]
        router_indices = torch.arange(a, device=x_aD.device).view(1, -1).expand(router_scores.size(0), -1)

        # Normalize scores with sigmoid
        router_scores = torch.sigmoid(router_scores)

        # Gather routed inputs: [E * G, D]  (where G = number of tokens per expert)
        routed_in_EG_D: Tensor = torch.gather(
            x_aD,
            dim=0,
            index=router_indices.reshape(-1, 1).expand(-1, D),
        )
        routed_in_EG_D = routed_in_EG_D * router_scores.reshape(-1, 1)

        # Always apply shared expert first
        out_aD = self.shared_expert(x_aD)

        # Run local experts on detached routed inputs
        routed_out_eg_D = self.experts(routed_in_EG_D.detach())

        # Scatter-add expert outputs back into out_aD
        router_indices_EG_D = router_indices.reshape(-1, 1).expand(-1, D)
        out_aD.scatter_add_(dim=0, index=router_indices_EG_D, src=routed_out_eg_D.view(-1, D))

        # On a single GPU, no need to reduce across parallel regions
        # out_aD = reduce_from_model_parallel_region(out_aD)

        # Reshape back to [bsz, slen, D]
        return out_aD.view(bsz, slen, D)


In [13]:
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import einsum

from models.llama4.args import ModelArgs
from models.llama4.model import Attention


class LayerNorm(nn.LayerNorm):
    """Subclass torch's LayerNorm to handle fp16."""

    def forward(self, x: torch.Tensor):
        return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)


class ColumnParallelConv2dPatch(nn.Module):
    """Conv2D Patching layer (single‐GPU version).
    Arguments:
        in_channels: Input channels.
        out_channels: Output channels.
        kernel_size: Size of convolution kernel.
        stride (default 1): Stride for convolution.
        bias (default False): Use bias in Conv2d.
    Input: (bsz, in_channels, height, width)
    Output: (bsz, num_tokens, out_channels)
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, Tuple[int, int]],
        stride: Union[int, Tuple[int, int]],
        bias: Optional[bool] = False,
    ) -> None:
        super().__init__()
        if isinstance(kernel_size, int):
            kernel_size = (kernel_size, kernel_size)
        self._unfold = nn.Unfold(kernel_size=kernel_size, stride=stride)
        in_features = in_channels * kernel_size[0] * kernel_size[1]
        self._linear = nn.Linear(in_features, out_channels, bias=bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [bsz, in_channels, height, width]
        x = self._unfold(x)                       # [bsz, in_channels * k_h * k_w, num_tokens]
        x = x.permute(0, 2, 1)                    # [bsz, num_tokens, in_channels * k_h * k_w]
        x = self._linear(x)                       # [bsz, num_tokens, out_channels]
        return x


class _FeedForward(nn.Module):
    def __init__(
        self,
        dim: int,
        hidden_dim: int,
        dropout: float,
        act_layer: Callable = nn.GELU,
    ):
        super().__init__()
        # Single‐GPU linear layers
        self.c_fc = nn.Linear(dim, hidden_dim, bias=True)
        self.c_proj = nn.Linear(hidden_dim, dim, bias=True)
        self.non_linearity = act_layer()
        self.dropout = dropout

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        hidden = self.c_fc(x)                     # [*, hidden_dim]
        hidden = self.non_linearity(hidden)
        hidden = F.dropout(hidden, p=self.dropout, training=self.training)
        hidden = self.c_proj(hidden)              # [*, dim]
        return hidden


class _TransformerBlock(nn.Module):
    def __init__(
        self,
        d_model: int,
        n_head: int,
        max_batch_size: int,
        max_seq_len: int,
        mlp_ratio: float = 4.0,
        act_layer: Callable = nn.GELU,
        gated: bool = False,
    ):
        super().__init__()
        assert d_model % n_head == 0
        self.n_heads = n_head
        self.head_dim = d_model // self.n_heads

        attn_args = ModelArgs(
            dim=d_model,
            head_dim=self.head_dim,
            n_heads=self.n_heads,
            n_kv_heads=self.n_heads,
            max_batch_size=max_batch_size,
            max_seq_len=max_seq_len,
        )
        self.attn = Attention(attn_args, use_rope=True, use_qk_norm=False, add_bias=True)
        self.ln_1 = LayerNorm(d_model)
        self.mlp = _FeedForward(
            dim=d_model,
            hidden_dim=int(mlp_ratio * d_model),
            dropout=0.0,
            act_layer=act_layer,
        )
        self.ln_2 = LayerNorm(d_model)
        self.gated = gated
        if gated:
            self.gate_attn = nn.Parameter(torch.zeros(1))
            self.gate_ffn = nn.Parameter(torch.zeros(1))

    def attention(
        self,
        x: torch.Tensor,
        freq_cis: Optional[torch.Tensor] = None,
    ):
        return self.attn(x=x, start_pos=0, freqs_cis=freq_cis)

    def forward(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        freq_cis: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        _gate_attn = 1 if not self.gated else self.gate_attn.tanh()
        _gate_ffn = 1 if not self.gated else self.gate_ffn.tanh()

        x = x + _gate_attn * self.attention(self.ln_1(x), freq_cis=freq_cis)
        x = x + _gate_ffn * self.mlp(self.ln_2(x))
        return x


class _Transformer(nn.Module):
    def __init__(
        self,
        dim: int,
        layers: int,
        heads: int,
        max_batch_size: int,
        max_seq_len: int,
        mlp_ratio: float = 4.0,
        act_layer: Callable = nn.GELU,
        gated: bool = False,
    ):
        super().__init__()
        self.resblocks = nn.ModuleList(
            [
                _TransformerBlock(
                    d_model=dim,
                    n_head=heads,
                    mlp_ratio=mlp_ratio,
                    act_layer=act_layer,
                    gated=gated,
                    max_batch_size=max_batch_size,
                    max_seq_len=max_seq_len,
                )
                for _ in range(layers)
            ]
        )

    def forward(self, x: torch.Tensor, return_intermediate=None, mask=None, freq_cis=None):
        out = []
        for idx, r in enumerate(self.resblocks):
            if return_intermediate is not None and idx in return_intermediate:
                out.append(x)
            x = r(x, mask=mask, freq_cis=freq_cis)
        if return_intermediate is not None:
            return x, torch.stack(out, dim=-1)
        return x


class PackingIndex:
    Z = 0  # Z (time) coordinate of the token in the original sample
    Y = 1  # Y (height) coordinate of the token in the original sample
    X = 2  # X (width) coordinate of the token in the original sample
    TIME = 3  # Total number of time units (frames) in the original sample
    HEIGHT = 4  # Height of the original sample
    WIDTH = 5  # Width of the original sample
    IDX = 6  # Full index of the token in the original sample (x + y * w + z * w * h)
    BATCH_IDX = 7  # Which batch element this token belongs to
    NUM_METADATA = 8
    ID_CLS_TOKEN = -2
    ID_PAD_TOKEN = -1


ENCODER_MAX_BATCH_SIZE = 32
ENCODER_MAX_SEQ_LEN = 1024


class VisionEncoder(nn.Module):
    def __init__(
        self,
        image_size: Tuple[int, int],
        patch_size: Tuple[int, int],
        dim: int,
        layers: int,
        heads: int,
        mlp_ratio: float,
    ):
        super().__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.grid_size = (
            self.image_size[0] // self.patch_size[0],
            self.image_size[1] // self.patch_size[1],
        )

        # Replace ColumnParallelConv2dPatch with single‐GPU version
        self.conv1 = ColumnParallelConv2dPatch(
            in_channels=3,
            out_channels=dim,
            kernel_size=patch_size,
            stride=patch_size,
            bias=False,
        )

        scale = dim**-0.5
        self.class_embedding = nn.Parameter(scale * torch.randn(dim))
        self.positional_embedding_vlm = nn.Parameter(
            scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, dim)
        )

        self.ln_pre = LayerNorm(dim)
        self.ln_post = LayerNorm(dim)

        # Build transformer (single‐GPU)
        self.transformer = _Transformer(
            dim,
            layers,
            heads,
            ENCODER_MAX_BATCH_SIZE,
            ENCODER_MAX_SEQ_LEN,
            mlp_ratio,
            act_layer=nn.GELU,
        )

        # Compute packed indices for positional embedding
        image_h, image_w = self.image_size
        patch_h, patch_w = self.patch_size
        idx_h, idx_w = image_h // patch_h, image_w // patch_w

        img_idx = torch.arange(image_h * image_w // (patch_h * patch_w), dtype=torch.int32)
        img_idx = img_idx.reshape(idx_h * idx_w, 1)
        img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
        img_idx[-1, -1] = PackingIndex.ID_CLS_TOKEN

        packed_img_idx = torch.empty(
            img_idx.shape[0],
            img_idx.shape[1],
            PackingIndex.NUM_METADATA - 1,
            dtype=torch.int32,
        )
        packed_img_idx[:, :, PackingIndex.Y] = img_idx // idx_w
        packed_img_idx[:, :, PackingIndex.X] = img_idx % idx_w
        packed_img_idx[:, :, PackingIndex.HEIGHT].fill_(idx_h)
        packed_img_idx[:, :, PackingIndex.WIDTH].fill_(idx_w)
        packed_img_idx[:, :, PackingIndex.IDX] = img_idx
        packed_img_idx = packed_img_idx.reshape(1, -1, PackingIndex.NUM_METADATA - 1)
        self.packed_img_idx = packed_img_idx  # for positional embedding load hook

        # Compute RoPE frequencies
        rope_freq = self.get_rope_freqs(dim // heads // 2)
        freqs_x = self.compute_rope_freqs(rope_freq, packed_img_idx[:, :, PackingIndex.X] + 1)
        freqs_y = self.compute_rope_freqs(rope_freq, packed_img_idx[:, :, PackingIndex.Y] + 1)
        freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2]
        freqs = freqs.masked_fill(packed_img_idx[:, :, PackingIndex.IDX, None] < 0, 0)
        self.freq_cis = torch.view_as_complex(torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1))
        self.freq_cis = self.freq_cis.squeeze(0)

        # On a single GPU, n_heads is just heads (no world_size divide)
        self.n_heads = heads

        self._register_load_state_dict_pre_hook(self.load_hook)

    def get_rope_freqs(self, dim, theta=10000):
        freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
        return freqs

    @torch.amp.autocast("cuda", enabled=False)
    def compute_rope_freqs(self, freqs, t):
        freqs = einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
        freqs = freqs.repeat_interleave(2, dim=-1)
        return freqs

    def load_hook(
        self,
        state_dict: Dict[str, Any],
        prefix: str,
        local_metadata: Dict[str, Any],
        strict: bool = True,
        missing_keys: List[str] = None,
        unexpected_keys: List[str] = None,
        error_msgs: List[str] = None,
        return_state_dict: bool = False,
    ) -> None:
        orig_pos_embed = state_dict.get(prefix + "positional_embedding")
        if orig_pos_embed is not None and orig_pos_embed.shape[-2:] != self.positional_embedding_vlm.shape[-2:]:
            raise ValueError(
                f"Positional embedding shape {orig_pos_embed.shape} does not match expected shape {self.positional_embedding_vlm.shape}"
            )

        batch_size, token_per_image, _ = self.packed_img_idx.shape
        idx = self.packed_img_idx.reshape(batch_size * token_per_image, 1, -1)
        total_windows, window_size, _ = idx.shape

        grid = (
            (idx[:, :, [PackingIndex.X, PackingIndex.Y]] / idx[:, :, [PackingIndex.WIDTH, PackingIndex.HEIGHT]]) * 2 - 1
        )[None, ...]

        if orig_pos_embed is not None:
            posemb = (
                orig_pos_embed[1:].view(1, self.grid_size[0], self.grid_size[1], -1)
                .permute(0, 3, 1, 2)
                .contiguous()
            )
            posemb = posemb.to(device=grid.device, dtype=grid.dtype)
            sample = F.grid_sample(
                posemb, grid, padding_mode="zeros"
            )
            sample = sample.view(-1, total_windows, window_size).permute(1, 2, 0).contiguous()
            sample = torch.where(
                idx[:, :, PackingIndex.IDX, None] == PackingIndex.ID_CLS_TOKEN,
                orig_pos_embed[0].view(1, 1, -1).to(device=sample.device, dtype=sample.dtype),
                sample,
            )
            new_pos_embed = sample.reshape(batch_size, token_per_image, -1)
            state_dict[prefix + "positional_embedding_vlm"] = new_pos_embed.squeeze(0)

        if return_state_dict:
            return state_dict

    def apply_class_embedding(self, x: torch.Tensor) -> torch.Tensor:
        cls = self.class_embedding.to(x.dtype) + torch.zeros(
            x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
        )
        return torch.cat([x, cls], dim=1)

    def forward(self, images: torch.Tensor) -> torch.Tensor:
        if images.ndim == 5:
            num_concurrent_media = 1
            bsz, num_chunks, nch, h, w = images.shape
        else:
            bsz, num_concurrent_media, num_chunks, nch, h, w = images.shape

        images = images.reshape(bsz * num_concurrent_media * num_chunks, nch, h, w)
        x = self.conv1(images)  # [*, num_patches, dim]
        _, ntok, dim = x.shape
        x = x.reshape(bsz * num_concurrent_media * num_chunks, ntok, dim)

        x = self.apply_class_embedding(x)
        ntok += 1

        if self.positional_embedding_vlm is not None:
            x = x + self.positional_embedding_vlm.to(x.dtype)

        x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok, dim)
        x = self.ln_pre(x)
        x = x.view(bsz * num_concurrent_media, -1, dim)
        freq_cis = self.freq_cis.to(images.device)

        tf_output = self.transformer(x, freq_cis=freq_cis)

        int_x = None
        if isinstance(tf_output, tuple):
            x, int_x = tf_output
        else:
            x = tf_output
        x = self.ln_post(x)

        x = x[:, :-1, :]  # remove cls token

        if int_x is not None:
            int_x = int_x[:, :-1, :, :].reshape(bsz * num_concurrent_media, ntok - 1, -1)
            x = torch.cat([x, int_x], dim=-1)

        return x


In [14]:
import math
from typing import Any, Callable, Dict, List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

from models.llama4.args import VisionArgs
# from .encoder import VisionEncoder


class PixelShuffle(nn.Module):
    def __init__(self, ps_ratio):
        super().__init__()
        self.ps_ratio = ps_ratio

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B, N, C], N = number of patches
        assert self.ps_ratio is not None, "ps_ratio is required for pixel shuffle"
        assert x.dim() == 3, "pixel shuffle requires encoded patches [B, N, C]"
        hh = ww = int(math.sqrt(x.shape[1]))
        x = x.reshape(x.shape[0], hh, ww, -1)  # [B, H, W, C]
        x = pixel_shuffle_op(x, ps_ratio=self.ps_ratio)
        pixel_shuffle_patches = x.reshape(x.shape[0], -1, x.shape[-1])
        return pixel_shuffle_patches


def pixel_shuffle_op(input_x: torch.Tensor, ps_ratio: float) -> torch.Tensor:
    # input_x: [N, W, H, C]
    n, w, h, c = input_x.size()
    input_x = input_x.view(n, w, int(h * ps_ratio), int(c / ps_ratio))
    input_x = input_x.permute(0, 2, 1, 3).contiguous()
    input_x = input_x.view(
        n,
        int(h * ps_ratio),
        int(w * ps_ratio),
        int(c / (ps_ratio * ps_ratio)),
    )
    input_x = input_x.permute(0, 2, 1, 3).contiguous()
    return input_x


class SimpleMLP(nn.Module):
    def __init__(
        self,
        dim: int,
        hidden_dim: int,
        bias: bool = True,
        dropout: float = 0.0,
        act_layer: Callable = nn.GELU,
    ):
        super().__init__()
        # Use nn.Linear instead of ColumnParallelLinear/RowParallelLinear
        self.c_fc = nn.Linear(dim, hidden_dim, bias=bias)
        self.c_proj = nn.Linear(hidden_dim, hidden_dim, bias=bias)
        self.non_linearity = act_layer()
        self.dropout = dropout

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        hidden = self.c_fc(x)                     # [*, hidden_dim]
        hidden = self.non_linearity(hidden)
        hidden = F.dropout(hidden, p=self.dropout, training=self.training)
        hidden = self.c_proj(hidden)              # [*, hidden_dim]
        return self.non_linearity(hidden)


class PixelShuffleMLP(nn.Module):
    def __init__(
        self,
        ps_ratio: float,
        input_dim: int,
        output_dim: int = 4096,
        add_fc: bool = False,
    ):
        super().__init__()
        self.pixel_shuffle = PixelShuffle(ps_ratio)
        self.mlp = SimpleMLP(
            int(input_dim // (ps_ratio**2)),
            output_dim,
            bias=False,
            dropout=0.0,
            act_layer=nn.GELU,
        )
        self.fc = nn.Identity()
        if add_fc:
            # Replace ColumnParallelLinear with nn.Linear
            self.fc = nn.Linear(output_dim, output_dim, bias=False)

    def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
        # encoded_patches: [B, N, C]
        encoded_patches = self.pixel_shuffle(encoded_patches)  # [B, N', C]
        return self.fc(self.mlp(encoded_patches))               # [B, N', output_dim]


class VisionEmbeddings(nn.Module):
    def __init__(self, args: VisionArgs):
        super().__init__()
        self.args = args

        image_size = args.image_size
        patch_size = args.patch_size
        self.vision_encoder = VisionEncoder(
            image_size=(image_size.height, image_size.width),
            patch_size=(patch_size.height, patch_size.width),
            dim=args.dim,
            layers=args.n_layers,
            heads=args.n_heads,
            mlp_ratio=args.mlp_ratio,
        )
        self.vision_encoder = self.vision_encoder.to(torch.bfloat16)
        self.vision_adapter = PixelShuffleMLP(
            ps_ratio=args.pixel_shuffle_ratio,
            input_dim=args.dim,
            output_dim=args.output_dim,
        )

        self.output_dim = args.output_dim
        self._register_load_state_dict_pre_hook(self.load_hook)

    def load_hook(
        self,
        state_dict: Dict[str, Any],
        prefix: str,
        local_metadata: Dict[str, Any],
        strict: bool = True,
        missing_keys: List[str] = None,
        unexpected_keys: List[str] = None,
        error_msgs: List[str] = None,
        return_state_dict: bool = False,
    ) -> None:
        original_sd = self.state_dict()
        for k in state_dict:
            if (
                k.startswith(prefix)
                and len(state_dict[k].shape) == 1
                and state_dict[k].shape[0] == 0
            ):
                state_dict[k] = state_dict[k].reshape(
                    original_sd[k[len(prefix) :]].shape
                )

    def _get_empty_sequence(self, h: torch.Tensor) -> torch.Tensor:
        return torch.zeros(
            h.shape[0],
            h.shape[1],
            self.output_dim,
            device=h.device,
            dtype=h.dtype,
        )

    def forward(
        self,
        image_batch: List[List[torch.Tensor]],
        image_mask: torch.Tensor,
        h_ref: torch.Tensor,
    ) -> torch.Tensor:
        # Flatten all images in batch
        images_flattened = [image for sample in image_batch for image in sample]
        images_flattened = torch.vstack(images_flattened).unsqueeze(1).to(h_ref.dtype).to(h_ref.device)
        # Encode patches
        embedding = self.vision_encoder(images_flattened)           # [sum_chunks, num_patches, dim]
        # Project via pixel‐shuffle + MLP
        projected_embedding = self.vision_adapter(embedding)       # [sum_chunks, num_tokens, output_dim]

        h_image = self._get_empty_sequence(h_ref)                  # [B, T, output_dim]
        return scatter_embeddings(image_batch, image_mask, h_image, projected_embedding)


def scatter_embeddings(
    image_batch: List[List[torch.Tensor]],
    image_mask: torch.Tensor,
    h_image: torch.Tensor,
    encoded_patches_proj: torch.Tensor,
) -> torch.Tensor:
    # Determine number of chunks per sample
    num_images_per_sequence = [
        sum(image.size(0) for image in sample_images) for sample_images in image_batch
    ]

    assert not torch.isnan(encoded_patches_proj).any()
    assert sum(num_images_per_sequence) == encoded_patches_proj.size(0), (
        f"{sum(num_images_per_sequence)=} != {encoded_patches_proj.shape=}"
    )

    encoded_patches_list = encoded_patches_proj.split(num_images_per_sequence, dim=0)
    for index in range(h_image.size(0)):
        encoded_patches_per_sample = encoded_patches_list[index]
        sample_image_mask = image_mask[index]

        if encoded_patches_per_sample.numel() == 0:
            continue
        encoded_patches_per_sample = encoded_patches_per_sample.contiguous().view(
            -1, encoded_patches_per_sample.size(-1)
        )

        n_tokens_to_fill = sample_image_mask.sum()
        assert n_tokens_to_fill <= encoded_patches_per_sample.size(0)

        h_image[index].masked_scatter_(
            sample_image_mask.expand(-1, h_image.size(-1)),
            encoded_patches_per_sample[:n_tokens_to_fill],
        )

    return h_image


In [15]:
import math
from typing import Any, Dict, List, Optional, Tuple

# from models.llama4.vision.embedding import VisionEmbeddings


import torch
import torch.nn.functional as F
from torch import nn

from models.llama4.args import ModelArgs
from models.llama4.datatypes import TransformerInput, TransformerOutput
# from models.llama4.ffn import FeedForward
# from models.llama4.moe import MoE


def rmsnorm(x, eps):
    def _norm(y):
        return y * torch.rsqrt(y.pow(2).mean(-1, keepdim=True) + eps)

    return _norm(x.float()).type_as(x)


class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        return rmsnorm(x, self.eps) * self.weight


def apply_scaling(freqs: torch.Tensor, scale_factor: float, high_freq_factor: float):
    low_freq_factor = 1
    old_context_len = 8192  # original llama3 length

    low_freq_wavelen = old_context_len / low_freq_factor
    high_freq_wavelen = old_context_len / high_freq_factor
    new_freqs = []
    for freq in freqs:
        wavelen = 2 * math.pi / freq
        if wavelen < high_freq_wavelen:
            new_freqs.append(freq)
        elif wavelen > low_freq_wavelen:
            new_freqs.append(freq / scale_factor)
        else:
            assert low_freq_wavelen != high_freq_wavelen
            smooth = (old_context_len / wavelen - low_freq_factor) / (
                high_freq_factor - low_freq_factor
            )
            new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
    return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)


def precompute_freqs_cis(
    dim: int,
    end: int,
    theta: float,
    use_scaled: bool,
    scale_factor: float,
    high_freq_factor: float,
):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device, dtype=torch.float32)
    if use_scaled:
        freqs = apply_scaling(freqs, scale_factor, high_freq_factor)
    freqs = torch.outer(t, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis


def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)


def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)


class Attention(nn.Module):
    # This module is now single‐GPU/CPU only (no model parallelism).
    def __init__(
        self,
        args: ModelArgs,
        use_qk_norm: bool,
        use_rope: bool,
        add_bias: bool = False,
    ):
        super().__init__()
        self.use_rope = use_rope
        self.use_qk_norm = use_qk_norm
        # For attention temperature tuning
        self.attn_temperature_tuning = args.attn_temperature_tuning
        self.floor_scale = args.floor_scale
        self.attn_scale = args.attn_scale

        self.n_heads = args.n_heads
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        self.n_local_heads = self.n_heads  # no model parallel split
        self.n_local_kv_heads = self.n_kv_heads
        self.n_rep = self.n_local_heads // self.n_local_kv_heads
        self.head_dim = args.dim // args.n_heads

        # Replaced ColumnParallelLinear with nn.Linear
        self.wq = nn.Linear(args.dim, self.n_heads * self.head_dim, bias=add_bias)
        self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=add_bias)
        self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=add_bias)

        # Replaced RowParallelLinear with nn.Linear
        self.wo = nn.Linear(self.n_heads * self.head_dim, args.dim, bias=add_bias)

        # Caching buffers remain the same (will reside on whichever device the model is on)
        self.cache_k = torch.zeros(
            (args.max_batch_size, args.max_seq_len, self.n_local_kv_heads, self.head_dim)
        ).cuda()
        self.cache_v = torch.zeros(
            (args.max_batch_size, args.max_seq_len, self.n_local_kv_heads, self.head_dim)
        ).cuda()

        self.norm_eps = args.norm_eps
        self._register_load_state_dict_pre_hook(self.load_hook)

    def load_hook(
        self,
        state_dict: Dict[str, Any],
        prefix: str,
        local_metadata: Dict[str, Any],
        strict: bool,
        missing_keys: List[str],
        unexpected_keys: List[str],
        error_msgs: List[str],
    ) -> None:
        # Preserve any pretrained wqkv splitting logic if needed
        if prefix + "wqkv.weight" in state_dict:
            wqkv = state_dict.pop(prefix + "wqkv.weight")
            d, r = divmod(wqkv.shape[0], self.n_heads + 2 * self.n_kv_heads)
            if r != 0:
                raise ValueError(
                    f"shape={tuple(wqkv.shape)} is not divisible by "
                    f"n_heads ({self.n_heads}) + 2 * n_kv_heads ({self.n_kv_heads})"
                )
            wq, wk, wv = wqkv.split(
                [d * self.n_heads, d * self.n_kv_heads, d * self.n_kv_heads], dim=0
            )
            state_dict[prefix + "wq.weight"] = wq
            state_dict[prefix + "wk.weight"] = wk
            state_dict[prefix + "wv.weight"] = wv

    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
    ):
        bsz, seqlen, _ = x.shape
        xq = self.wq(x)
        xk = self.wk(x)
        xv = self.wv(x)

        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

        if self.use_rope:
            xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        if self.use_qk_norm:
            xq = rmsnorm(xq, self.norm_eps)
            xk = rmsnorm(xk, self.norm_eps)

        # Attention temperature tuning (NoPE layers) remains unchanged
        if self.attn_temperature_tuning and not self.use_rope:
            seq_positions = torch.arange(
                start_pos, start_pos + seqlen, device=xq.device, dtype=torch.float32
            )
            attn_scales = (
                torch.log(torch.floor((seq_positions + 1.0) / self.floor_scale) + 1.0)
                * self.attn_scale
                + 1.0
            )
            attn_scales = attn_scales.view(1, seqlen, 1, 1)
            xq = xq * attn_scales

        # Ensure cache is on correct device
        self.cache_k = self.cache_k.to(xq.device)
        self.cache_v = self.cache_v.to(xq.device)

        # Write into cache
        self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
        self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv

        xk = self.cache_k[:bsz, : start_pos + seqlen]
        xv = self.cache_v[:bsz, : start_pos + seqlen]

        xq, xk, xv = [t.transpose(1, 2) for t in (xq, xk, xv)]
        xk = xk.repeat_interleave(self.n_rep, dim=1)
        xv = xv.repeat_interleave(self.n_rep, dim=1)

        attn_output = F.scaled_dot_product_attention(
            xq, xk, xv, attn_mask=mask, dropout_p=0.0
        )
        attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
        output = self.wo(attn_output)
        return output


class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // args.n_heads if args.head_dim is None else args.head_dim

        self.is_nope_layer = (
            args.nope_layer_interval is not None
            and (layer_id + 1) % args.nope_layer_interval == 0
        )

        use_rope = not self.is_nope_layer
        use_qk_norm = args.use_qk_norm and not self.is_nope_layer

        self.attention = Attention(args, use_rope=use_rope, use_qk_norm=use_qk_norm)

        if args.moe_args and (layer_id + 1) % args.moe_args.interleave_moe_layer_step == 0:
            self.feed_forward = MoE(
                dim=args.dim,
                hidden_dim=int(args.ffn_exp * args.dim),
                ffn_dim_multiplier=args.ffn_dim_multiplier,
                multiple_of=args.multiple_of,
                moe_args=args.moe_args,
            )
        else:
            hidden_dim = int(4 * args.dim)
            hidden_dim = int(2 * hidden_dim / 3)
            if args.ffn_dim_multiplier is not None:
                hidden_dim = int(args.ffn_dim_multiplier * hidden_dim)
            hidden_dim = args.multiple_of * ((hidden_dim + args.multiple_of - 1) // args.multiple_of)

            self.feed_forward = FeedForward(dim=args.dim, hidden_dim=hidden_dim)

        self.layer_id = layer_id
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

        self._register_load_state_dict_pre_hook(self.load_hook)

    def load_hook(
        self,
        state_dict: Dict[str, Any],
        prefix: str,
        local_metadata: Dict[str, Any],
        strict: bool,
        missing_keys: List[str],
        unexpected_keys: List[str],
        error_msgs: List[str],
    ) -> None:
        if prefix + "attention.wqkv.layer_norm_weight" in state_dict:
            state_dict[prefix + "attention_norm.weight"] = state_dict.pop(
                prefix + "attention.wqkv.layer_norm_weight"
            )

        if prefix + "feed_forward.mlp.layer_norm_weight" in state_dict:
            state_dict[prefix + "ffn_norm.weight"] = state_dict.pop(
                prefix + "feed_forward.mlp.layer_norm_weight"
            )
        elif prefix + "feed_forward.norm.weight" in state_dict:
            state_dict[prefix + "ffn_norm.weight"] = state_dict.pop(
                prefix + "feed_forward.norm.weight"
            )

        for k in ("feed_forward.experts.mlp", "feed_forward.mlp_shared", "attention.wo", "attention.wqkv"):
            if prefix + k + "._extra_state" in state_dict:
                state_dict.pop(prefix + k + "._extra_state")

    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        global_attn_mask: Optional[torch.Tensor],
        local_attn_mask: Optional[torch.Tensor],
    ):
        # Use global mask if NoPE or if chunked local attention is disabled
        if self.is_nope_layer or local_attn_mask is None:
            mask = global_attn_mask
        else:
            mask = local_attn_mask

        h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
        out = h + self.feed_forward(self.ffn_norm(h))
        return out


class Transformer(nn.Module):
    def __init__(self, args: ModelArgs, **kwargs) -> None:
        super().__init__()
        self.args = args

        self.vocab_size = args.vocab_size
        self.n_layers = args.n_layers

        # Replaced VocabParallelEmbedding with nn.Embedding
        self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)

        self.layers = nn.ModuleList()
        for layer_id in range(args.n_layers):
            self.layers.append(TransformerBlock(layer_id, args))

        self.norm = RMSNorm(args.dim, eps=args.norm_eps)
        # Replaced ColumnParallelLinear with nn.Linear
        self.output = nn.Linear(args.dim, args.vocab_size, bias=False)

        self.freqs_cis = precompute_freqs_cis(
            args.dim // args.n_heads,
            args.max_seq_len * 2,
            args.rope_theta,
            args.use_scaled_rope,
            args.rope_scaling_factor,
            args.rope_high_freq_factor,
        )

        vision_args = self.args.vision_args
        if vision_args:
            # circular import otherwise until we refactor out Attention
            # from .vision.embedding import VisionEmbeddings

            self.vision_embeddings = VisionEmbeddings(vision_args)
            # Replaced ColumnParallelLinear with nn.Linear
            self.vision_projection = nn.Linear(vision_args.output_dim, args.dim, bias=False)

        self._register_load_state_dict_pre_hook(self.load_hook)

    def load_hook(
        self,
        state_dict: Dict[str, Any],
        prefix: str,
        local_metadata: Dict[str, Any],
        strict: bool,
        missing_keys: List[str],
        unexpected_keys: List[str],
        error_msgs: List[str],
    ) -> None:
        if prefix + "rope.freqs" in state_dict:
            state_dict.pop(prefix + "rope.freqs")

    @torch.inference_mode()
    def forward(self, model_input: TransformerInput) -> TransformerOutput:
        tokens = model_input.tokens
        start_pos = model_input.tokens_position
        assert isinstance(start_pos, int), (
            "This implementation does not support different start positions per batch item"
        )

        bsz, seqlen = tokens.shape
        h = self.tok_embeddings(tokens)

        if image_embedding := model_input.image_embedding:
            h_image = self.vision_projection(image_embedding.embedding)
            h = h * ~image_embedding.mask + h_image * image_embedding.mask

        self.freqs_cis = self.freqs_cis.to(h.device)
        freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]

        global_attn_mask, local_attn_mask = None, None
        if seqlen > 1:
            global_attn_mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
            global_attn_mask = torch.triu(global_attn_mask, diagonal=1).type_as(h)

            # Handle MPS bug where triu produces NaNs instead of 0
            if global_attn_mask.device.type == torch.device("mps").type:
                global_attn_mask = torch.nan_to_num(global_attn_mask, nan=0.0)

            if chunk_size := self.args.attention_chunk_size:
                local_attn_mask = create_chunked_attention_mask(seqlen, chunk_size, tokens.device)

        for layer in self.layers:
            h = layer(h, start_pos, freqs_cis, global_attn_mask, local_attn_mask)
        h = self.norm(h)
        output = self.output(h).float()

        return TransformerOutput(logits=output)


def create_chunked_attention_mask(seq_len: int, attention_chunk_size: int, device: torch.device) -> torch.Tensor:
    block_pos = torch.abs(
        (torch.arange(seq_len).unsqueeze(0) // attention_chunk_size)
        - (torch.arange(seq_len).unsqueeze(1) // attention_chunk_size)
    )
    token_pos = torch.arange(seq_len).unsqueeze(0) - torch.arange(seq_len).unsqueeze(1)
    mask = (block_pos == 0) & (token_pos <= 0)
    return mask.to(device)


In [16]:
moe_num_experts = 16
state_dicts = [convert_moe_weights(d, moe_num_experts) for d in state_dicts]

In [17]:
from models.checkpoint import reshard_mp

In [18]:
state_dict = reshard_mp(
            state_dicts,
            size=1,
            rank=0
            # moe_num_experts=model_args.moe_args.num_experts,
        )

In [19]:
model = Transformer(model_args)
model.load_state_dict(state_dict, strict=False)

_IncompatibleKeys(missing_keys=[], unexpected_keys=['layers.0.feed_forward.global_gate_stats_3E', 'layers.0.feed_forward.running_gate_stats_3E', 'layers.1.feed_forward.running_gate_stats_3E', 'layers.1.feed_forward.global_gate_stats_3E', 'layers.2.feed_forward.running_gate_stats_3E', 'layers.2.feed_forward.global_gate_stats_3E', 'layers.3.feed_forward.running_gate_stats_3E', 'layers.3.feed_forward.global_gate_stats_3E', 'layers.4.feed_forward.global_gate_stats_3E', 'layers.4.feed_forward.running_gate_stats_3E', 'layers.5.feed_forward.running_gate_stats_3E', 'layers.5.feed_forward.global_gate_stats_3E', 'layers.6.feed_forward.global_gate_stats_3E', 'layers.6.feed_forward.running_gate_stats_3E', 'layers.7.feed_forward.running_gate_stats_3E', 'layers.7.feed_forward.global_gate_stats_3E', 'layers.8.feed_forward.global_gate_stats_3E', 'layers.8.feed_forward.running_gate_stats_3E', 'layers.9.feed_forward.running_gate_stats_3E', 'layers.9.feed_forward.global_gate_stats_3E', 'layers.10.feed_fo

In [20]:
from models.datatypes import GenerationResult, QuantizationMode

In [21]:
quantization_mode = QuantizationMode.int4_mixed

In [22]:
from models.llama4.quantization.loader import convert_to_quantized_model

In [23]:
from models.quantize_impls import (
    # Fp8ScaledWeights,
    Int4ScaledWeights,
    # load_fp8,
    load_int4,
    # quantize_fp8,
    quantize_int4,
)

In [24]:
rank = 0

In [25]:
int4_scales_path = os.path.join(ckpt_dir, f"int4_scales_{rank}.pt")


def apply_quantization(_, weight):
    return quantize_int4(weight, output_device=torch.device("cuda"))

In [26]:
    def should_quantize_block(block: nn.Module) -> bool:
        if not isinstance(block, TransformerBlock):
            return False

        is_moe = isinstance(block.feed_forward, MoE)

In [28]:
# for _, block in model.named_modules():
#     if not should_quantize_block(block):
#         continue



In [30]:
        for _, block in model.named_modules():
            if not should_quantize_block(block):
                continue

            update_status(f"Rank {rank} - Layer {block.layer_id}")

            # Quantize only routed experts, not shared
            prefix = f"layers.{block.layer_id}.feed_forward"
            moe = block.feed_forward
            moe.experts.batched_swiglu = experts_batched_swiglu_wrapper.__get__(moe.experts)

            for key in ("w1", "w3", "w2"):
                param = getattr(moe.experts, key)
                update_status(f"Rank {rank} - Layer {block.layer_id} - MoE {key}")
                setattr(
                    moe.experts,
                    key,
                    apply_quantization(
                        f"{prefix}.experts.{key}",
                        param.transpose(1, 2).contiguous(),
                    ),
                )

            if True: #quantization_mode == QuantizationMode.int4_mixed:
                # Quantize shared experts
                moe.shared_expert.forward = swiglu_wrapper_no_reduce.__get__(moe.shared_expert)
                for key in ("w1", "w3", "w2"):
                    param = getattr(moe.shared_expert, key)
                    update_status(f"Rank {rank} - Layer {block.layer_id} - MoE shared expert {key}")
                    param.weight = apply_quantization(f"{prefix}.shared_expert.{key}", param.weight)

            processed_blocks += 1


In [29]:
free -h

Exception ignored in: <function _releaseLock at 0x7939eac40ea0>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.12/logging/__init__.py", line 243, in _releaseLock
    def _releaseLock():
    
KeyboardInterrupt: 


               total        used        free      shared  buff/cache   available
Mem:           2.0Ti       656Gi       585Gi       7.1Gi       787Gi       1.3Ti
Swap:             0B          0B          0B


In [None]:
device = 'cuda:0'
model = model.to(device)

In [None]:
def should_quantize_block(block: nn.Module) -> bool:
    if not isinstance(block, TransformerBlock):
        return False

    is_moe = isinstance(block.feed_forward, MoE)
    if quantization_mode == QuantizationMode.fp8_mixed:
        # skip quantization on first and last layers
        return is_moe and not (block.layer_id == 0 or block.layer_id == (model.n_layers - 1))

In [None]:
model = convert_to_quantized_model(model, ckpt_dir, quantization_mode)

In [None]:
# !pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/cu126/
# !pip install --pre fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/cu126/
# !pip install fbgemm-gpu==1.2.0 --index-url https://download.pytorch.org/whl/cu126
!pip install fbgemm-gpu-genai

In [None]:
!pip uninstall -y fbgemm-gpu

In [None]:
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)

In [None]:
device = torch.device(type='cuda', index=0)

In [None]:
x = "How are you?"

tokens = torch.Tensor(tokenizer.encode(x, bos=False, eos=False)).unsqueeze(dim=0).to(device)

In [None]:
model = model.to(device)