In [1]:
import torch
import ttnn
from llama2.model import ModelArgs
from typing import Tuple
import json

# just a super simple forward without host overhead
from ttcode.lightweightmodule import LightweightModule
checkpoint = "llama2/configs/stories260K.pth"

2024-12-11 11:37:45.723 | DEBUG    | ttnn:<module>:82 - Initial ttnn.CONFIG:
Config{cache_path=/home/bach/.cache/ttnn,model_cache_path=/home/bach/.cache/ttnn/models,tmp_dir=/tmp/ttnn,enable_model_cache=false,enable_fast_runtime_mode=true,throw_exception_on_fallback=false,enable_logging=false,enable_graph_report=false,enable_detailed_buffer_report=false,enable_detailed_tensor_report=false,enable_comparison_mode=false,comparison_mode_pcc=0.9999,root_report_path=generated/ttnn/reports,report_name=std::nullopt,std::nullopt}


In [2]:
# RMS Norm imported from a common
from ttcode import rmsnorm

In [3]:
def apply_rotatory_emb(self, xqkv_fused, rot_mat, current_pos, page_table=None):
    ##
    # Reshape and rotary embeddings
    ###
    (
        q_heads_pre_rot_1BQD,
        k_heads_pre_rot_1BKD,
        v_heads_1BKD,
    ) = ttnn.experimental.nlp_create_qkv_heads_decode(
        xqkv_fused,
        num_heads=self.n_local_heads,
        num_kv_heads=self.n_local_kv_heads,
        memory_config=ttnn.L1_HEIGHT_SHARDED_MEMORY_CONFIG,
    )

    ttnn.deallocate(xqkv_fused)

    q_heads_1BQD = ttnn.linear(
        q_heads_pre_rot_1BQD,
        rot_mat,
        program_config=self.model_config["ROT_MAT_BMM_PROGCFG"](
            q_heads_pre_rot_1BQD.shape[-2], q_heads_pre_rot_1BQD.shape[-1], rot_mat.shape[-1]
        ),
        memory_config=ttnn.DRAM_MEMORY_CONFIG,
        compute_kernel_config=self.compute_kernel_config_hifi2,
        dtype=ttnn.bfloat16,
    )

    k_heads_1BKD = ttnn.linear(
        k_heads_pre_rot_1BKD,
        rot_mat,
        program_config=self.model_config["ROT_MAT_BMM_PROGCFG"](
            k_heads_pre_rot_1BKD.shape[-2], k_heads_pre_rot_1BKD.shape[-1], rot_mat.shape[-1]
        ),
        memory_config=k_heads_pre_rot_1BKD.memory_config(),
        compute_kernel_config=self.compute_kernel_config_hifi2,
        dtype=ttnn.bfloat16,
    )

    ttnn.deallocate(q_heads_pre_rot_1BQD)
    ttnn.deallocate(k_heads_pre_rot_1BKD)

    return q_heads_1BQD, k_heads_1BKD

In [7]:
class Attention(LightweightModule):

    def kv_cache(self, q_heads_1BQD, k_heads_1BKD, v_heads_1BKD, current_pos):
        ###
        # KV update
        ###
        keys = self.layer_past[0]
        values = self.layer_past[1]

        # k_heads, [seqlen, n_kv_heads, bsz, head_dim]
        # v_heads [seqlen, n_kv_heads, bsz, head_dim]
        # keys, [max_batch_size, n_kv_heads // configuration.num_devices, sliding_window, head_dim]
        ttnn.experimental.paged_update_cache(
            keys, k_heads_1BKD, update_idxs_tensor=current_pos
            )
        
        ttnn.experimental.paged_update_cache(
            values, v_heads_1BKD, update_idxs_tensor=current_pos
            )
        
        self.layer_past[0] = keys
        self.layer_past[1] = values

        ttnn.deallocate(k_heads_1BKD)
        ttnn.deallocate(v_heads_1BKD)

        attn_output_1G4D = ttnn.transformer.scaled_dot_product_attention_decode(
            q_heads_1BQD,
            keys,
            values,
            cur_pos_tensor=current_pos,
            scale=self.scale,
            # program_config=self.model_config["SDPA_DECODE_PROGCFG"],
            # compute_kernel_config=self.model_config["SDPA_DECODE_COMPUTE_PROGCFG"],
            memory_config=ttnn.DRAM_MEMORY_CONFIG,  # FIXME: why not L1 height sharded e.g. SCORES_BATCHED_MM_OUTPUT_MEMCFG?
        )

        ttnn.deallocate(q_heads_1BQD)

        attn_output_11BH = ttnn.to_memory_config(
            attn_output_1G4D, 
            # memory_config=self.model_config["SCORES_BATCHED_MM_OUTPUT_MEMCFG"]
        )
        attn_output_cat = ttnn.experimental.nlp_concat_heads_decode(
            attn_output_11BH,
            num_heads=self.n_local_heads,
        )
        ttnn.deallocate(attn_output_11BH)
        ttnn.deallocate(attn_output_1G4D)

        return attn_output_cat

    def __init__(self, args: ModelArgs, state_dict: dict, layer_num, device):
        super().__init__()
        self.state_dict = state_dict
        self.max_batch_size = 1
        self.device = device
        # Indicates the number of heads for the Keys and Values
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        # Indicates the number of heads for the Queries
        self.n_q_heads = args.n_heads
        assert args.n_heads % self.n_kv_heads == 0
        # Indicates the dimension of each head, that is, the part of the embedding that each head will be responsible for
        self.head_dim = args.dim // args.n_heads        
        # Indicates how many times the Keys and Values should be repeated        
        self.n_rep = args.n_heads // self.n_kv_heads
        
        prefix = f"layers.{layer_num}.attention."

        ## Concatenated weights
        self.wqkv = ttnn.as_tensor(
                torch.concat([
                        torch.transpose(self.state_dict[f"{prefix}wq.weight"], -2, -1,),
                        torch.transpose(self.state_dict[f"{prefix}wk.weight"], -2, -1,),
                        torch.transpose(self.state_dict[f"{prefix}wv.weight"], -2, -1,),
                    ],
                    dim=-1
                ),
                # device=self.device,
                # mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1),
                # dtype=self.dtype,
                # memory_config=wqkv_mem_config,
                # layout=self.model_config["ATTN_W_LAYOUT_TILE"],
                # cache_file_name=cache_name("wqkv_sharded"),
        )
        
        ## Output weights
        self.wo = ttnn.as_tensor(
            torch.transpose(self.state_dict[f"{prefix}wo.weight"], -2, -1,),
            # device=self.device,
            # mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-2),
            # dtype=self.dtype,
            # memory_config=wo_mem_config,
            # layout=self.model_config["ATTN_W_LAYOUT_TILE"],
            # cache_file_name=cache_name("wo_height_sharded"),
            )

        ## Cache KV 
        
        cache_k = torch.zeros((
            self.paged_attention_config.max_num_blocks,
            self.n_kv_heads,
            self.paged_attention_config.block_size,
            self.head_dim,)
        )
        cache_v = torch.zeros((
            self.paged_attention_config.max_num_blocks,
            self.n_kv_heads,
            self.paged_attention_config.block_size,
            self.head_dim,)
        )

        self.layer_past = [
            ttnn.as_tensor(
                k_or_v,
                device=self.mesh_device,
                # mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=1),
                # layout=self.model_config["ATTN_W_LAYOUT_TILE"],
                dtype=self.dtype,
                # cache_file_name=f"{weight_cache_path}/kvcache_{k_or_v.shape}"
                # if weight_cache_path and not configuration.dummy_weights
                # else None,
                memory_config=ttnn.DRAM_MEMORY_CONFIG,
            )
            for k_or_v in [cache_k, cache_v]
        ]

        self.scale = self.head_dim**-0.5
        
    
    def forward(self, x, current_pos, rot_mats=None, transformation_mats=None, user_id=0, mode="decode", page_table=None
    ):
        # Two inference phase, prefil and decode
        # Prefil: generate KV cache for subsequent tokens generations
        # Decode: generate tokens
        if mode == "prefill":
            return self.forward_prefill(x, rot_mats, transformation_mats, user_id, page_table)
        else:
            return self.forward_decode(x, current_pos, rot_mats, page_table)

    def forward_prefill(self, x_11SH, rot_mats, transformation_mats, user_id: int = 0, page_table=None):
        # todo
        ...

    def forward_decode(self, x: ttnn.Tensor, current_pos, rot_mat=None, page_table=None,) -> ttnn.Tensor:
        """
        x: (seq_len, 1, batch, dim)
        current_pos: (batch_size), current token position in the sequence for each user
        """
        assert self.max_batch_size * self.n_kv_heads < 64
        ###
        # QKV matmuls
        # Use HiFi2 for DRAM-sharded matmuls as they are otherwise flop-bound. Loses 1 bit of activation precision.
        ###
        
        self.compute_kernel_config_hifi2 = ttnn.GrayskullComputeKernelConfig(
                math_fidelity=ttnn.MathFidelity.HiFi2,
            )
        
        
        ## Linear transformation with qkv concatened
        xqkv_fused_sharded = ttnn.linear(
            x,
            self.wqkv,
            memory_config=ttnn.L1_MEMORY_CONFIG,
            # memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG,
            # program_config=self.model_config["XQKV_DECODE_PROGCFG"],
            compute_kernel_config=self.compute_kernel_config_hifi2,
            device=self.device,
            dtype=ttnn.bfloat16,
        )
        ttnn.deallocate(x)

        ## ??? I suppose to join sharded qkv but sharded where? devices or cores?
        xqkv_fused = ttnn.sharded_to_interleaved(xqkv_fused_sharded, ttnn.L1_MEMORY_CONFIG)
        ttnn.deallocate(xqkv_fused_sharded)

        # Reshape such that true unpadded batch is tracked in shape
        fqkv_shape = xqkv_fused.shape
        xqkv_fused = ttnn.reshape(
            xqkv_fused, ttnn.Shape((1, 1, self.max_batch_size, fqkv_shape[3]), (1, 1, 32, fqkv_shape[3]))
        )

        # RoPE relative positional embeddings
        attn_output_cat = self.apply_rotatory_emb(xqkv_fused)        

        # program config matched to output of nlp_concat_heads_decode
        dense_out_sharded = ttnn.linear(
            attn_output_cat,
            self.wo,
            # program_config=self.model_config["ATTN_OUTPUT_PROGCFG"],
            compute_kernel_config=self.compute_kernel_config_hifi2,
            memory_config=attn_output_cat.memory_config(),
        )  # seqlen, 1, batch, hidden_size

        ttnn.deallocate(attn_output_cat)

        dense_out_sharded = ttnn.to_memory_config(dense_out_sharded, self.model_config["DECODE_RESIDUAL_MEMCFG"])
        return dense_out_sharded

In [8]:
# Load checkpoint dict
import os
checkpoint_dict = torch.load(checkpoint)
model_args = checkpoint_dict['model_args']
print(model_args)
state_dict = checkpoint_dict['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)

{'dim': 64, 'n_layers': 5, 'n_heads': 8, 'n_kv_heads': 4, 'vocab_size': 512, 'multiple_of': 4, 'max_seq_len': 512, 'dropout': 0.05}


In [9]:
device = ttnn.open_device(device_id=0)
args = ModelArgs()
layer_num = 0
attention = Attention(args, state_dict, layer_num, device)
attention(ttnn.from_torch(torch.rand((1,64)), layout=ttnn.TILE_LAYOUT), 1)

[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Profiler started on device 0


AttributeError: 'Attention' object has no attribute 'paged_attention_config'