In [1]:
from PIL import Image
import torch
from transformers import CLIPImageProcessor, AutoTokenizer
import torch
import json
from safetensors.torch import load_file 

device = torch.device('mps')
model = load_file("./model/model.safetensors", device=device.type)

  from .autonotebook import tqdm as notebook_tqdm


### Vision Tower

In [2]:
# PREPARE INPUT
image = Image.open('./lion.jpg').convert("RGB")

img_tensor = torch.tensor(CLIPImageProcessor(
    crop_size={"height": 1024, "width": 1024},
    image_mean=[0.0, 0.0, 0.0],
    image_std=[1.0, 1.0, 1.0],
    size={"shortest_edge": 1024},
    return_tensors='pt'
)(image)['pixel_values'], device=torch.device('mps'), dtype=torch.bfloat16) 

# [1, 3, 1024, 1024]

  img_tensor = torch.tensor(CLIPImageProcessor(


In [3]:
# STEM
x = torch.nn.functional.gelu(
    torch.nn.functional.conv2d(
        img_tensor,
        model[f'model.vision_tower.vision_tower.model.patch_embed.0.reparam_conv.weight'], 
        bias=model[f'model.vision_tower.vision_tower.model.patch_embed.0.reparam_conv.bias'],
        stride=2,
        padding=1,
        groups=1,
        dilation=1
    )
).to(device, dtype=torch.bfloat16)

x = torch.nn.functional.gelu(
    torch.nn.functional.conv2d(
    x,
    model[f'model.vision_tower.vision_tower.model.patch_embed.1.reparam_conv.weight'], 
    bias=model[f'model.vision_tower.vision_tower.model.patch_embed.1.reparam_conv.bias'], 
    stride=2,
    padding=1,
    groups=96,
    dilation=1
)).to(device, dtype=torch.bfloat16)

x = torch.nn.functional.gelu(
    torch.nn.functional.conv2d(
        x,
        model[f'model.vision_tower.vision_tower.model.patch_embed.2.reparam_conv.weight'], 
        bias=model['model.vision_tower.vision_tower.model.patch_embed.2.reparam_conv.bias'], 
        stride=1,
        padding=0,
        groups=1,
        dilation=1
    )
)

# [1, 96, 256, 256]

In [4]:
# STAGE 1
for i in range(2):
    # TOKENMIXER - RepMixer
    x = torch.nn.functional.conv2d(
        x,
        weight=model[f'model.vision_tower.vision_tower.model.network.0.{i}.token_mixer.reparam_conv.weight'], 
        bias=model[f'model.vision_tower.vision_tower.model.network.0.{i}.token_mixer.reparam_conv.bias'],
        stride=1,
        padding=1,
        groups=96,
        dilation=1
    )

    # CONVFFN
    x_c = torch.nn.functional.conv2d(
        x,
        weight=model[f'model.vision_tower.vision_tower.model.network.0.{i}.convffn.conv.conv.weight'],
        stride=1,
        padding=3,
        groups=96,
        dilation=1
    )
    x_c = torch.nn.functional.batch_norm(
        x_c,
        running_mean=model[f'model.vision_tower.vision_tower.model.network.0.{i}.convffn.conv.bn.running_mean'],
        running_var=model[f'model.vision_tower.vision_tower.model.network.0.{i}.convffn.conv.bn.running_var'],
        weight=model[f'model.vision_tower.vision_tower.model.network.0.{i}.convffn.conv.bn.weight'],
        bias=model[f'model.vision_tower.vision_tower.model.network.0.{i}.convffn.conv.bn.bias'],
        training=False,
    )
    x_c = torch.nn.functional.gelu(
        torch.nn.functional.conv2d(
            x_c,
            weight=model[f'model.vision_tower.vision_tower.model.network.0.{i}.convffn.fc1.weight'],
            bias=model[f'model.vision_tower.vision_tower.model.network.0.{i}.convffn.fc1.bias'],
            stride=1,
            padding=0,
            groups=1,
            dilation=1
        )
    )
    x_c = torch.nn.functional.conv2d(
        x_c,
        weight=model[f'model.vision_tower.vision_tower.model.network.0.{i}.convffn.fc2.weight'],
        bias=model[f'model.vision_tower.vision_tower.model.network.0.{i}.convffn.fc2.bias'],
        stride=1,
        padding=0,
        groups=1,
        dilation=1
    )
    x = x + model[f'model.vision_tower.vision_tower.model.network.0.{i}.layer_scale'].view(1, -1, 1, 1) * x_c

# [1, 96, 256, 256])

In [5]:
# PATCH EMBED
x = torch.nn.functional.gelu(
    torch.nn.functional.conv2d(
        x,
        weight=model['model.vision_tower.vision_tower.model.network.1.proj.0.lkb_reparam.weight'],
        bias=model['model.vision_tower.vision_tower.model.network.1.proj.0.lkb_reparam.bias'],
        stride=2,
        padding=3,
        groups=96,
        dilation=1
    )
)
x = torch.nn.functional.gelu(
    torch.nn.functional.conv2d(
        x,
        weight=model['model.vision_tower.vision_tower.model.network.1.proj.1.reparam_conv.weight'],
        bias=model['model.vision_tower.vision_tower.model.network.1.proj.1.reparam_conv.bias'],
        stride=1,
        padding=0,
        groups=1,
        dilation=1
    )
)

# [1, 192, 128, 128]

In [6]:
# STAGE 2
layer_idx = 2
for i in range(12):
    # TOKENMIXER - RepMixer
    x = torch.nn.functional.conv2d(
        x,
        weight=model[f'model.vision_tower.vision_tower.model.network.2.{i}.token_mixer.reparam_conv.weight'], 
        bias=model[f'model.vision_tower.vision_tower.model.network.2.{i}.token_mixer.reparam_conv.bias'],
        stride=1,
        padding=1,
        groups=192,
        dilation=1
    )

    # CONVFFN
    x_c = torch.nn.functional.conv2d(
        x,
        weight=model[f'model.vision_tower.vision_tower.model.network.2.{i}.convffn.conv.conv.weight'],
        stride=1,
        padding=3,
        groups=192,
        dilation=1
    )
    x_c = torch.nn.functional.batch_norm(
        x_c,
        running_mean=model[f'model.vision_tower.vision_tower.model.network.2.{i}.convffn.conv.bn.running_mean'],
        running_var=model[f'model.vision_tower.vision_tower.model.network.2.{i}.convffn.conv.bn.running_var'],
        weight=model[f'model.vision_tower.vision_tower.model.network.2.{i}.convffn.conv.bn.weight'],
        bias=model[f'model.vision_tower.vision_tower.model.network.2.{i}.convffn.conv.bn.bias'],
        training=False,
    )
    x_c = torch.nn.functional.gelu(
        torch.nn.functional.conv2d(
            x_c,
            weight=model[f'model.vision_tower.vision_tower.model.network.2.{i}.convffn.fc1.weight'],
            bias=model[f'model.vision_tower.vision_tower.model.network.2.{i}.convffn.fc1.bias'],
            stride=1,
            padding=0,
            groups=1,
            dilation=1
        )
    )
    x_c = torch.nn.functional.conv2d(
        x_c,
        weight=model[f'model.vision_tower.vision_tower.model.network.2.{i}.convffn.fc2.weight'],
        bias=model[f'model.vision_tower.vision_tower.model.network.2.{i}.convffn.fc2.bias'],
        stride=1,
        padding=0,
        groups=1,
        dilation=1
    )
    x = x + model[f'model.vision_tower.vision_tower.model.network.2.{i}.layer_scale'].view(1, -1, 1, 1) * x_c

# [1, 192, 128, 128]

In [7]:
# PATCH EMBED
x = torch.nn.functional.gelu(
    torch.nn.functional.conv2d(
        x,
        weight=model['model.vision_tower.vision_tower.model.network.3.proj.0.lkb_reparam.weight'],
        bias=model['model.vision_tower.vision_tower.model.network.3.proj.0.lkb_reparam.bias'],
        stride=2,
        padding=3,
        groups=192,
        dilation=1
    )
)

x = torch.nn.functional.gelu(
    torch.nn.functional.conv2d(
        x,
        weight=model['model.vision_tower.vision_tower.model.network.3.proj.1.reparam_conv.weight'],
        bias=model['model.vision_tower.vision_tower.model.network.3.proj.1.reparam_conv.bias'],
        stride=1,
        padding=0,
        groups=1,
        dilation=1
    )
)

# [1, 384, 64, 64]

In [8]:
# STAGE 3
layer_idx = 4
for i in range(24):
    # TOKENMIXER - RepMixer
    x = torch.nn.functional.conv2d(
        x,
        weight=model[f'model.vision_tower.vision_tower.model.network.4.{i}.token_mixer.reparam_conv.weight'], 
        bias=model[f'model.vision_tower.vision_tower.model.network.4.{i}.token_mixer.reparam_conv.bias'],
        stride=1,
        padding=1,
        groups=384,
        dilation=1
    )

    # CONVFFN
    x_c = torch.nn.functional.conv2d(
        x,
        weight=model[f'model.vision_tower.vision_tower.model.network.4.{i}.convffn.conv.conv.weight'],
        stride=1,
        padding=3,
        groups=384,
        dilation=1
    )
    x_c = torch.nn.functional.batch_norm(
        x_c,
        running_mean=model[f'model.vision_tower.vision_tower.model.network.4.{i}.convffn.conv.bn.running_mean'],
        running_var=model[f'model.vision_tower.vision_tower.model.network.4.{i}.convffn.conv.bn.running_var'],
        weight=model[f'model.vision_tower.vision_tower.model.network.4.{i}.convffn.conv.bn.weight'],
        bias=model[f'model.vision_tower.vision_tower.model.network.4.{i}.convffn.conv.bn.bias'],
        training=False,
    )
    x_c = torch.nn.functional.gelu(
        torch.nn.functional.conv2d(
            x_c,
            weight=model[f'model.vision_tower.vision_tower.model.network.4.{i}.convffn.fc1.weight'],
            bias=model[f'model.vision_tower.vision_tower.model.network.4.{i}.convffn.fc1.bias'],
            stride=1,
            padding=0,
            groups=1,
            dilation=1
        )
    )
    x_c = torch.nn.functional.conv2d(
        x_c,
        weight=model[f'model.vision_tower.vision_tower.model.network.4.{i}.convffn.fc2.weight'],
        bias=model[f'model.vision_tower.vision_tower.model.network.4.{i}.convffn.fc2.bias'],
        stride=1,
        padding=0,
        groups=1,
        dilation=1
    )
    x = x + model[f'model.vision_tower.vision_tower.model.network.4.{i}.layer_scale'].view(1, -1, 1, 1) * x_c

# [1, 384, 64, 64]

In [9]:
# PATCH EMBED
x = torch.nn.functional.gelu(
    torch.nn.functional.conv2d(
        x,
        weight=model['model.vision_tower.vision_tower.model.network.5.proj.0.lkb_reparam.weight'],
        bias=model['model.vision_tower.vision_tower.model.network.5.proj.0.lkb_reparam.bias'],
        stride=2,
        padding=3,
        groups=384,
        dilation=1
    )
)
x = torch.nn.functional.gelu(
    torch.nn.functional.conv2d(
        x,
        weight=model['model.vision_tower.vision_tower.model.network.5.proj.1.reparam_conv.weight'],
        bias=model['model.vision_tower.vision_tower.model.network.5.proj.1.reparam_conv.bias'],
        stride=1,
        padding=0,
        groups=1,
        dilation=1
    )
)

# [1, 768, 32, 32]

In [10]:
# Conditional Positional Embeddings (CPE)
x = torch.nn.functional.conv2d(
    x,
    weight=model[f'model.vision_tower.vision_tower.model.network.6.reparam_conv.weight'],
    bias=model[f'model.vision_tower.vision_tower.model.network.6.reparam_conv.bias'],
    stride=1,
    padding=3,
    groups=768,
    dilation=1
)

# [1, 768, 32, 32]

In [11]:
# STAGE 4
layer_idx = 7
for i in range(4):    
    # TOKENMIXER - Attention

    # Since format is (B, C, H, W) and I don't want to reshape to put C in the end, we roll our own layernorm2d
    mean = x.mean(dim=1, keepdim=True)
    var = x.var(dim=1, unbiased=False, keepdim=True)
    x_hat = (x - mean) / torch.sqrt(var + 1e-5)
    x_norm = model[f'model.vision_tower.vision_tower.model.network.7.{i}.norm.weight'][:, None, None] \
        * x_hat \
        + model[f'model.vision_tower.vision_tower.model.network.7.{i}.norm.bias'][:, None, None] # [:, None, None] makes (768) -> (768, 1, 1) making multiply with (1, 768, 32, 32) valid

    head_dim = 32
    B, C, H, W = x_norm.shape[:]
    n_heads = C // head_dim
    N = H * W # pixels
    x_norm = x_norm.flatten(2).transpose(-2, -1)
    qkv: torch.Tensor = (
        (x_norm @ model[f'model.vision_tower.vision_tower.model.network.7.{i}.token_mixer.qkv.weight'].T)
        .reshape(B, N, 3, n_heads, head_dim) # B, N, C, H, D
        .permute(2, 0, 3, 1, 4) # C, B, H, N, D = 3, 1, 24, 1024, 32
    )
    q, k, v = qkv.unbind(0) # (B, 24, 1024, 32)
    score = (q @ k.transpose(-1,-2)) / (head_dim ** 0.5) # (B, 24, 1024, 1024)
    score = score.softmax(dim=-1)
    out = score @ v # (B, 24, 1024, 32)
    out = out.transpose(1, 2).reshape(B, N, C)
    x_attn = torch.nn.functional.linear(out, model[f'model.vision_tower.vision_tower.model.network.7.{i}.token_mixer.proj.weight'], model[f'model.vision_tower.vision_tower.model.network.7.{i}.token_mixer.proj.bias'])
    x_attn = x_attn.transpose(1, 2).reshape(B, C, H, W)
    x = x + model[f'model.vision_tower.vision_tower.model.network.7.{i}.layer_scale_1'].view(1, -1, 1, 1) * x_attn

    # CONVFFN
    x_c = torch.nn.functional.conv2d(
        x,
        weight=model[f'model.vision_tower.vision_tower.model.network.7.{i}.convffn.conv.conv.weight'],
        stride=1,
        padding=3,
        groups=768,
        dilation=1
    )
    x_c = torch.nn.functional.batch_norm(
        x_c,
        running_mean=model[f'model.vision_tower.vision_tower.model.network.7.{i}.convffn.conv.bn.running_mean'],
        running_var=model[f'model.vision_tower.vision_tower.model.network.7.{i}.convffn.conv.bn.running_var'],
        weight=model[f'model.vision_tower.vision_tower.model.network.7.{i}.convffn.conv.bn.weight'],
        bias=model[f'model.vision_tower.vision_tower.model.network.7.{i}.convffn.conv.bn.bias'],
        training=False,
    )
    x_c = torch.nn.functional.gelu(
        torch.nn.functional.conv2d(
            x_c,
            weight=model[f'model.vision_tower.vision_tower.model.network.7.{i}.convffn.fc1.weight'],
            bias=model[f'model.vision_tower.vision_tower.model.network.7.{i}.convffn.fc1.bias'],
            stride=1,
            padding=0,
            groups=1,
            dilation=1
        )
    )
    x_c = torch.nn.functional.conv2d(
        x_c,
        weight=model[f'model.vision_tower.vision_tower.model.network.7.{i}.convffn.fc2.weight'],
        bias=model[f'model.vision_tower.vision_tower.model.network.7.{i}.convffn.fc2.bias'],
        stride=1,
        padding=0,
        groups=1,
        dilation=1
    )
    x = x + model[f'model.vision_tower.vision_tower.model.network.7.{i}.layer_scale_2'].view(1, -1, 1, 1) * x_c

# [1, 768, 32, 32]

In [12]:
# PATCH EMBED
x = torch.nn.functional.gelu(
    torch.nn.functional.conv2d(
        x,  # (B, 768, 32, 32)
        weight=model['model.vision_tower.vision_tower.model.network.8.proj.0.lkb_reparam.weight'],
        bias=model['model.vision_tower.vision_tower.model.network.8.proj.0.lkb_reparam.bias'],
        stride=2,
        padding=3,
        groups=768,
        dilation=1
    )
)
x = torch.nn.functional.gelu(
    torch.nn.functional.conv2d(
        x, # (B, 768, 16, 16)
        weight=model['model.vision_tower.vision_tower.model.network.8.proj.1.reparam_conv.weight'],
        bias=model['model.vision_tower.vision_tower.model.network.8.proj.1.reparam_conv.bias'],
        stride=1,
        padding=0,
        groups=1,
        dilation=1
    )
)

# [1, 1536, 16, 16]

In [13]:
# REPCPE
x = torch.nn.functional.conv2d(
    x,
    weight=model['model.vision_tower.vision_tower.model.network.9.reparam_conv.weight'],
    bias=model['model.vision_tower.vision_tower.model.network.9.reparam_conv.bias'],
    stride=1,
    padding=3,
    groups=1536,
    dilation=1
)

# [1, 1536, 16, 16]

In [14]:
# STAGE 5
for i in range(2):
    # TOKENMIXER - Attention

    mean = x.mean(dim=1, keepdim=True)
    var = x.var(dim=1, unbiased=False, keepdim=True)
    x_hat = (x - mean) / torch.sqrt(var + 1e-5)
    x_norm = (
        model[f'model.vision_tower.vision_tower.model.network.10.{i}.norm.weight'][:, None, None] * x_hat
        + model[f'model.vision_tower.vision_tower.model.network.10.{i}.norm.bias'][:, None, None]
    )

    head_dim = 32
    B, C, H, W = x_norm.shape
    n_heads = C // head_dim
    N = H * W

    x_tokens = x_norm.flatten(2).transpose(-2, -1) # [B, N, C]
    qkv = (
        (x_tokens @ model[f'model.vision_tower.vision_tower.model.network.10.{i}.token_mixer.qkv.weight'].T)
        .reshape(B, N, 3, n_heads, head_dim)
        .permute(2, 0, 3, 1, 4)
    )
    q, k, v = qkv.unbind(0) # [B, n_heads, N, head_dim]

    scores = (q @ k.transpose(-1, -2)) / (head_dim ** 0.5) # [B, n_heads, N, N]
    attn = scores.softmax(dim=-1)
    out = attn @ v # [B, n_heads, N, head_dim]

    out = out.transpose(1, 2).reshape(B, N, C) # [B, N, C]
    x_attn = torch.nn.functional.linear(
        out,
        model[f'model.vision_tower.vision_tower.model.network.10.{i}.token_mixer.proj.weight'],
        model[f'model.vision_tower.vision_tower.model.network.10.{i}.token_mixer.proj.bias'],
    )
    x_attn = x_attn.transpose(1, 2).reshape(B, C, H, W)

    x = x + model[f'model.vision_tower.vision_tower.model.network.10.{i}.layer_scale_1'].view(1, -1, 1, 1) * x_attn

    x_c = torch.nn.functional.conv2d(
        x,
        weight=model[f'model.vision_tower.vision_tower.model.network.10.{i}.convffn.conv.conv.weight'],
        padding=3,
        groups=1536,
    )
    x_c = torch.nn.functional.batch_norm(
        x_c,
        running_mean=model[f'model.vision_tower.vision_tower.model.network.10.{i}.convffn.conv.bn.running_mean'],
        running_var=model[f'model.vision_tower.vision_tower.model.network.10.{i}.convffn.conv.bn.running_var'],
        weight=model[f'model.vision_tower.vision_tower.model.network.10.{i}.convffn.conv.bn.weight'],
        bias=model[f'model.vision_tower.vision_tower.model.network.10.{i}.convffn.conv.bn.bias'],
        training=False,
    )
    x_c = torch.nn.functional.gelu(
        torch.nn.functional.conv2d(
            x_c,
            weight=model[f'model.vision_tower.vision_tower.model.network.10.{i}.convffn.fc1.weight'],
            bias=model[f'model.vision_tower.vision_tower.model.network.10.{i}.convffn.fc1.bias'],
            padding=0,
        )
    )
    x_c = torch.nn.functional.conv2d(
        x_c,
        weight=model[f'model.vision_tower.vision_tower.model.network.10.{i}.convffn.fc2.weight'],
        bias=model[f'model.vision_tower.vision_tower.model.network.10.{i}.convffn.fc2.bias'],
        padding=0,
    )
    x = x + model[f'model.vision_tower.vision_tower.model.network.10.{i}.layer_scale_2'].view(1, -1, 1, 1) * x_c

# [1, 1536, 16, 16]

In [15]:
# MOBILE BLOCK
x = torch.nn.functional.conv2d(
    x,
    weight=model['model.vision_tower.vision_tower.model.conv_exp.reparam_conv.weight'],
    bias=model['model.vision_tower.vision_tower.model.conv_exp.reparam_conv.bias'],
    stride=1,
    padding=1,
    dilation=1,
    groups=1536,
)

# SQUEEZE-AND-EXCITE MODULE
B, C, H, W = x.shape
x_se = torch.nn.functional.avg_pool2d(x, kernel_size=[16, 16]) # (B, C, 1, 1)
x_se = torch.nn.functional.conv2d(
    x_se,
    weight=model['model.vision_tower.vision_tower.model.conv_exp.se.reduce.weight'],
    bias=model['model.vision_tower.vision_tower.model.conv_exp.se.reduce.bias'],
    stride=1
)
x_se = torch.nn.functional.relu(x_se)
x_se = torch.nn.functional.conv2d(
    x_se,
    weight=model['model.vision_tower.vision_tower.model.conv_exp.se.expand.weight'],
    bias=model['model.vision_tower.vision_tower.model.conv_exp.se.expand.bias'],
    stride=1
)
x_se = torch.sigmoid(x_se)
x = torch.nn.functional.gelu(x * x_se)

# [1, 3072, 16, 16]

In [16]:
# PROJECTION
img_tokens = x.flatten(2).transpose(1, 2)
h = torch.nn.functional.gelu(torch.nn.functional.linear(img_tokens, model['model.mm_projector.0.weight'], model['model.mm_projector.0.bias']))
img_tokens = torch.nn.functional.linear(h, model['model.mm_projector.2.weight'], model['model.mm_projector.2.bias'])

# [1, 256, 1536] = 256 tokens of the llm's embedding size

## Language Model

In [17]:
# Tokenize and embed

tokenizer = AutoTokenizer.from_pretrained('./model')
with open("./model/config.json", "r") as f:
    config = json.load(f)
    
embd = torch.nn.Embedding(config["vocab_size"], config["hidden_size"], device='mps', dtype=torch.bfloat16)
embd.load_state_dict({"weight": model["model.embed_tokens.weight"]})

messages = [{"role": "user", "content": "<image>\nDescribe this image in detail."}]

rendered = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
pre, post = rendered.split("<image>", 1)
pre_ids = tokenizer(pre,  return_tensors="pt", add_special_tokens=False).input_ids.to('mps')
post_ids = tokenizer(post, return_tensors="pt", add_special_tokens=False).input_ids.to('mps')
emb_pre = embd(pre_ids)
emb_post = embd(post_ids)

x = torch.cat([emb_pre, img_tokens, emb_post], dim=1) # [1, S, H]

In [18]:
# RoPE

def rotate_half(t: torch.Tensor) -> torch.Tensor:
    d2 = t.shape[-1] // 2
    return torch.cat((-t[..., d2:], t[..., :d2]), dim=-1)

def apply_rope(qk: torch.Tensor, base: float) -> torch.Tensor:
    _, _, S, D = qk.shape
    assert D % 2 == 0
    inv = torch.arange(0, D, 2, device=qk.device, dtype=torch.float32) / D
    inv_freq = base ** (-inv) # (D/2,)
    t = torch.arange(S, device=qk.device, dtype=torch.float32) # (S,)
    freqs = torch.einsum("s,d->sd", t, inv_freq) # (S, D/2)
    emb = torch.cat([freqs, freqs], dim=-1) # (S, D)
    cos = emb.cos().to(qk.dtype)[None, None, :, :] # (1,1,S,D)
    sin = emb.sin().to(qk.dtype)[None, None, :, :]
    return (qk * cos) + (rotate_half(qk) * sin)

In [None]:
generated_id = None
with torch.no_grad():
    while generated_id == None or generated_id != tokenizer.eos_token_id:
        S = x.shape[1]
        hidden = config["hidden_size"]
        n_heads = config["num_attention_heads"]
        n_kv = config["num_key_value_heads"]
        head_dim = hidden // n_heads
        h = x
        for layer in range(config["num_hidden_layers"]):
            # RMSNorm
            h_rms = torch.nn.functional.rms_norm(
                h, normalized_shape=(hidden,),
                weight=model[f"model.layers.{layer}.input_layernorm.weight"],
                eps=config["rms_norm_eps"],
            ).to(torch.bfloat16)

            # QKV
            q = h_rms @ model[f"model.layers.{layer}.self_attn.q_proj.weight"].T + model[f"model.layers.{layer}.self_attn.q_proj.bias"]
            k = h_rms @ model[f"model.layers.{layer}.self_attn.k_proj.weight"].T + model[f"model.layers.{layer}.self_attn.k_proj.bias"]
            v = h_rms @ model[f"model.layers.{layer}.self_attn.v_proj.weight"].T + model[f"model.layers.{layer}.self_attn.v_proj.bias"]

            q = q.view(1, S, n_heads, head_dim).transpose(1, 2) # [1,H,S,D]
            k = k.view(1, S, n_kv,   head_dim).transpose(1, 2) # [1,KV,S,D]
            v = v.view(1, S, n_kv,   head_dim).transpose(1, 2)

            if n_heads != n_kv:
                reps = n_heads // n_kv
                k = k.repeat_interleave(reps, dim=1)
                v = v.repeat_interleave(reps, dim=1)

            # RoPE
            theta = config["rope_theta"]
            q = apply_rope(q, theta)
            k = apply_rope(k, theta)

            # Attn
            scores = (q @ k.transpose(-2, -1)) / (head_dim ** 0.5) # [1,H,S,S]
            causal = torch.triu(torch.full((S, S), float("-inf"), device='mps', dtype=scores.dtype), 1)
            attn = torch.softmax(scores + causal, dim=-1)
            attn_o = (attn @ v).transpose(1, 2).reshape(1, S, hidden) @ model[f"model.layers.{layer}.self_attn.o_proj.weight"].T

            h = h + attn_o

            # FFN
            h_rms = torch.nn.functional.rms_norm(
                h, normalized_shape=(hidden,),
                weight=model[f"model.layers.{layer}.post_attention_layernorm.weight"],
                eps=config["rms_norm_eps"],
            ).to(torch.bfloat16)

            gate = h_rms @ model[f"model.layers.{layer}.mlp.gate_proj.weight"].T
            up = h_rms @ model[f"model.layers.{layer}.mlp.up_proj.weight"].T
            ffn = (torch.nn.functional.silu(gate) * up) @ model[f"model.layers.{layer}.mlp.down_proj.weight"].T

            h = h + ffn

        # Final norm + logits
        h_norm = torch.nn.functional.rms_norm(
            h, normalized_shape=(hidden,),
            weight=model["model.norm.weight"],
            eps=config["rms_norm_eps"],
        ).to(torch.bfloat16)

        logits = h_norm @ model["lm_head.weight"].T
        generated_id = int(logits[:, -1, :].argmax(dim=-1))
        if generated_id == tokenizer.eos_token_id:
            break

        next_embed = embd(torch.tensor([[generated_id]], device='mps'))
        x = torch.cat([x, next_embed], dim=1)
        
        print(tokenizer.decode(generated_id, skip_special_tokens=True), end='', flush=True)

The image depicts a detailed bronze statue of a lion, positioned on a stone pedestal. The lion is lying down with its front paws extended forward and its head turned to the right, showcasing its open mouth and visible teeth. The statue is highly realistic, capturing the lion's muscular build and intricate details, including its mane and facial features. The lion is situated outdoors, with a backdrop of a clear blue sky adorned with fluffy white clouds. In the distance, there is a prominent building featuring a large dome and a clock tower, which appears to be a church or a significant historical structure. The building is surrounded by lush green trees, adding to the serene and majestic atmosphere of the scene. The overall composition of the image highlights the grandeur and artistic craftsmanship of the lion statue against the picturesque and tranquil setting.