In [1]:
import torch
import ttnn
from llama2.model import ModelArgs
from typing import Tuple

# just a super simple forward without host overhead
from ttcode.lightweightmodule import LightweightModule

2024-12-11 16:20:51.398 | DEBUG    | ttnn:<module>:82 - Initial ttnn.CONFIG:
Config{cache_path=/home/bach/.cache/ttnn,model_cache_path=/home/bach/.cache/ttnn/models,tmp_dir=/tmp/ttnn,enable_model_cache=false,enable_fast_runtime_mode=true,throw_exception_on_fallback=false,enable_logging=false,enable_graph_report=false,enable_detailed_buffer_report=false,enable_detailed_tensor_report=false,enable_comparison_mode=false,comparison_mode_pcc=0.9999,root_report_path=generated/ttnn/reports,report_name=std::nullopt,std::nullopt}


In [2]:
device_id = 0
device = ttnn.open_device(device_id=device_id)

[38;2;000;128;000m                 Device[0m | [1m[38;2;100;149;237mINFO    [0m | Opening user mode device driver

[32m2024-12-11 16:20:51.528[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Detected 1 PCI device : [0]
[32m2024-12-11 16:20:51.529[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Opened PCI device 0; KMD version: 1.30.0
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Initializing device 0. Program cache is NOT enabled
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | AI CLK for device 0 is:   1000 MHz
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Profiler started on device 0


In [3]:
# RMS Norm imported from a common
from ttcode import rmsnorm

In [4]:
# Load checkpoint dict
import os
checkpoint_dict = torch.load("llama2/configs/stories260K.pth")
model_args = checkpoint_dict['model_args']
print(model_args)
state_dict = checkpoint_dict['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)

args = ModelArgs()

{'dim': 64, 'n_layers': 5, 'n_heads': 8, 'n_kv_heads': 4, 'vocab_size': 512, 'multiple_of': 4, 'max_seq_len': 512, 'dropout': 0.05}


In [5]:
batch_size = 1
tokens_num = args.max_seq_len

In [6]:
from llama2.model import precompute_freqs_cis
torch_freqs_cos, torch_freqs_sin = precompute_freqs_cis(args.dim // args.n_heads, args.max_seq_len)
# freqs_cos = ttnn.from_torch(torch_freqs_cos, device=device)
#freqs_sin = ttnn.from_torch(torch_freqs_sin, device=device)
freqs_cos = torch_freqs_cos
freqs_sin = torch_freqs_sin

tensor([0.0000, 0.2500, 0.5000, 0.7500])
10000.0


In [7]:
def reshape_for_broadcast(freqs_cis: ttnn.Tensor, x: ttnn.Tensor):
    ndim = len(x.shape)
    assert 0 <= 1 < ndim
    print(f"{freqs_cis.shape}, {x.shape[1]},{x.shape[-1]}")
    assert freqs_cis.shape == (x.shape[1], x.shape[-1]) 
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    print(freqs_cis.shape)
    print(shape)
    return ttnn.reshape(freqs_cis, shape)

In [8]:
def apply_rotary_emb(
    xq: ttnn.Tensor,
    xk: ttnn.Tensor,
    freqs_cos: ttnn.Tensor,
    freqs_sin: ttnn.Tensor
) -> Tuple[ttnn.Tensor, ttnn.Tensor]:
    print("xq:", xq.shape)
    assert tuple(xq.shape)[0] == 1, "Only works with batch 1 :-C"
    xq = ttnn.reshape(xq, (tuple(xq.shape)[:-1] + (-1,2)))
    # Cannot unbind, cannot slice with [:..], must use ttnn.slice
    # xq_r, xq_i = xq.unbind(-1)
    # Squeeze because to_layour only supports 4D max tensor
    xq = ttnn.squeeze(xq, 0)
    xq = ttnn.to_layout(xq, layout = ttnn.ROW_MAJOR_LAYOUT)
    xq = ttnn.unsqueeze(xq, 0)
    xq_r = ttnn.slice(xq, [0,0,0,0,0], list(tuple(xq.shape)[:-1] + (1,)))
    xq_r = ttnn.squeeze(xq_r, -1)
    print("xq_r:", xq_r.shape)
    # ttnn.deallocate(xq_r)
    xq_i = ttnn.slice(xq, [0,0,0,0,1], list(tuple(xq.shape)[:-1] + (2,)))
    xq_i = ttnn.squeeze(xq_i, -1)    
    print("xq_i:", xq_i.shape)

    xk = ttnn.squeeze(xk, 0)
    xk = ttnn.to_layout(xk, layout = ttnn.ROW_MAJOR_LAYOUT)
    xk = ttnn.unsqueeze(xk, 0)
    xk = ttnn.unsqueeze(xk, 0)
    xk_r = ttnn.slice(xk, [0,0,0,0,0],tuple(xk.shape)[:-1] + (1,))
    xk_r = ttnn.squeeze(xk_r, -1)
    # ttnn.deallocate(xk_r)
    xk_i = ttnn.slice(xk, [0,0,0,0,1], tuple(xk.shape)[:-1] + (2,))
    xk_i = ttnn.squeeze(xk_i, -1)  

    freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
    freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)

    freqs_sin = ttnn.to_layout(freqs_sin, layout = ttnn.TILE_LAYOUT)
    freqs_cos = ttnn.to_layout(freqs_cos, layout = ttnn.TILE_LAYOUT)

    # apply rotation using real numbers
    xq_r = ttnn.to_layout(xq_r, layout = ttnn.TILE_LAYOUT)
    xq_i = ttnn.to_layout(xq_i, layout = ttnn.TILE_LAYOUT)
    xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
    xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos

    xk_r = ttnn.to_layout(xk_r, layout = ttnn.TILE_LAYOUT)
    xk_i = ttnn.to_layout(xk_i, layout = ttnn.TILE_LAYOUT)
    xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
    xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
    
    # there's no ttnn.stack nor ttnn.flatten :-)
    xq_out_r = ttnn.to_layout(xq_out_r, layout = ttnn.ROW_MAJOR_LAYOUT)
    xq_out_i = ttnn.to_layout(xq_out_i, layout = ttnn.ROW_MAJOR_LAYOUT)
    print(xq_out_r.shape, xq_out_i.shape)
    # Create new dimension
    xq_out_r = ttnn.unsqueeze(xq_out_r, -1)
    # Concatenate along the new dimension
    xq_out = torch.concatenate([xq_out_r, xq_out_i], dim=-1)
    # todo: implement flatten
    print(xq_out.shape)

    # xq_out = ttnn.flatten(xq_out, 3)
    return 

In [9]:
from llama2.model import apply_rotary_emb as apply_rotary_emb_torch

def apply_rotary_emb_host(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cos: torch.Tensor,
    freqs_sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq = ttnn.to_torch(xq)
    xk = ttnn.to_torch(xk) 
    xq_out, xk_out = apply_rotary_emb_torch(xq, xk, freqs_cos, freqs_sin)

    return ttnn.from_torch(xq_out, device=device), ttnn.from_torch(xk_out, device=device)

In [10]:
def repeat_kv(x: ttnn.Tensor, n_rep: int) -> ttnn.Tensor:
    return ttnn.repeat_interleave(x, dim=2, repeats=n_rep)

In [None]:
from ttnn.model_preprocessing import (
    preprocess_linear_bias,
    preprocess_linear_weight,
)

class Attention(LightweightModule):
    def __init__(self, args: ModelArgs, state_dict: dict, layer_num, device):
        super().__init__()
        self.state_dict = state_dict
        self.max_batch_size = 1
        self.device = device
        # Indicates the number of heads for the Keys and Values
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        # Indicates the number of heads for the Queries
        self.n_q_heads = args.n_heads
        assert args.n_heads % self.n_kv_heads == 0
        # Indicates the dimension of each head, that is, the part of the embedding that each head will be responsible for
        self.head_dim = args.dim // args.n_heads        
        # Indicates how many times the Keys and Values should be repeated        
        self.n_rep = args.n_heads // self.n_kv_heads
        
        prefix = f"layers.{layer_num}.attention."

        wq = torch.transpose(self.state_dict[f"{prefix}wq.weight"], -2, -1,)
        wk = torch.transpose(self.state_dict[f"{prefix}wk.weight"], -2, -1,)
        wv = torch.transpose(self.state_dict[f"{prefix}wv.weight"], -2, -1,)

        torch_wqkv = torch.cat([wq, wk, wv], dim=-1)
        self.wqkv = preprocess_linear_weight(torch_wqkv.T, dtype=ttnn.bfloat16)
        self.wqkv = ttnn.to_device(self.wqkv, device=self.device)

        print(self.wqkv.shape)

        self.wo = ttnn.as_tensor(
            torch.transpose(self.state_dict[f"{prefix}wo.weight"], -2, -1,),
            layout=ttnn.TILE_LAYOUT,
            memory_config=ttnn.L1_MEMORY_CONFIG,
            device=self.device,
            dtype=ttnn.bfloat8_b
        )

    def forward(self, x: ttnn.Tensor, freqs_cos:torch.Tensor, freqs_sin: torch.Tensor):
        bsz, seqlen, _ = x.shape
        fused_qkv_output = ttnn.linear(
            x,
            self.wqkv,
            bias=None,
            memory_config=ttnn.L1_MEMORY_CONFIG,
            dtype=ttnn.bfloat8_b,
            # core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x),
        )
        print(fused_qkv_output.shape)
        
        assert self.n_kv_heads == self.n_q_heads

        # Always | FATAL    | 
        # Invalid head size: 16. The head size must be a multiple of the tile width (32). Please adjust the dimensions accordingly.
        # 32??
        xq, xk, xv = ttnn.transformer.split_query_key_value_and_split_heads(
            fused_qkv_output,
            memory_config=ttnn.L1_MEMORY_CONFIG,
            num_heads=self.n_kv_heads,
        )

        # Apply RoPE
        xq, xk = apply_rotary_emb_host(xq, xk, freqs_cos, freqs_sin)

        print(xk.shape)
        xk = repeat_kv(xk, self.n_rep)
        xv = repeat_kv(xv, self.n_rep)

        print(xk.shape)

        # premute instead of transpose 
        # (B, 1, H_Q, Head_Dim) -> (B, H_Q, 1, Head_Dim)
        xq = ttnn.permute(xq, (0, 2, 1, 3))
        # (B, Seq_Len_KV, H_Q, Head_Dim) -> (B, H_Q, Seq_Len_KV, Head_Dim)
        xk = ttnn.permute(xk, (0, 2, 1, 3))
        xv = ttnn.permute(xv, (0, 2, 1, 3))

        xq = ttnn.to_layout(xq, layout=ttnn.TILE_LAYOUT)
        xk = ttnn.to_layout(xk, layout=ttnn.TILE_LAYOUT)
        xv = ttnn.to_layout(xv, layout=ttnn.TILE_LAYOUT)

        print(xq.get_dtype(), xq.shape)
        print(xk.get_dtype(), xk.shape)
        print(xv.get_dtype(), xv.shape)

        """
        output = ttnn.transformer.scaled_dot_product_attention(
            xq, 
            xk, 
            xv, 
            attn_mask=None, 
            is_causal=True
        )
        """
        attention_scores = ttnn.matmul(
            xq,
            xk,
            memory_config=ttnn.L1_MEMORY_CONFIG,
            dtype=ttnn.bfloat16,
            # core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x),
        )
        ttnn.deallocate(xq)
        ttnn.deallocate(xk)

        return xq

In [12]:
layer_num = 0
attention = Attention(args, state_dict, layer_num, device)
# attention(ttnn.from_torch(torch.rand((1,64)), layout=ttnn.TILE_LAYOUT), 1)

x_torch = torch.rand((batch_size,tokens_num,args.dim)) 

x = ttnn.from_torch(
    x_torch,
    layout=ttnn.TILE_LAYOUT, 
    device=device
)

xq = attention.forward(x, freqs_cos, freqs_sin)

ttnn.Shape([64, 128])
ttnn.Shape([1, 512, 128])
[38;2;000;128;000m                 Always[0m | [1m[38;2;255;000;000mFATAL   [0m | Invalid head size: 16. The head size must be a multiple of the tile width (32). Please adjust the dimensions accordingly.


RuntimeError: TT_FATAL @ ../ttnn/cpp/ttnn/operations/transformer/split_query_key_value_and_split_heads/split_query_key_value_and_split_heads.cpp:156: head_size % tt::constants::TILE_WIDTH == 0
info:
Invalid head size: 16. The head size must be a multiple of the tile width (32). Please adjust the dimensions accordingly.
backtrace:
 --- /home/bach/tt-install/tt-metal/ttnn/ttnn/_ttnn.so(+0x691748) [0x7f3f999f4748]
 --- ttnn::operations::transformer::SplitQueryKeyValueAndSplitHeadsOperation::invoke(tt::tt_metal::Tensor const&, std::__1::optional<tt::tt_metal::Tensor> const&, unsigned int, std::__1::optional<unsigned int>, bool, std::__1::optional<tt::tt_metal::MemoryConfig> const&)
 --- /home/bach/tt-install/tt-metal/ttnn/ttnn/_ttnn.so(+0x2057079) [0x7f3f9b3ba079]
 --- /home/bach/tt-install/tt-metal/ttnn/ttnn/_ttnn.so(+0x20587d3) [0x7f3f9b3bb7d3]
 --- /home/bach/tt-install/tt-metal/ttnn/ttnn/_ttnn.so(+0x2058e79) [0x7f3f9b3bbe79]
 --- /home/bach/tt-install/tt-metal/ttnn/ttnn/_ttnn.so(+0x20567b5) [0x7f3f9b3b97b5]
 --- /home/bach/tt-install/tt-metal/ttnn/ttnn/_ttnn.so(+0x2055277) [0x7f3f9b3b8277]
 --- /home/bach/tt-install/tt-metal/ttnn/ttnn/_ttnn.so(+0x2054ae6) [0x7f3f9b3b7ae6]
 --- /home/bach/tt-install/tt-metal/ttnn/ttnn/_ttnn.so(+0x20538b0) [0x7f3f9b3b68b0]
 --- /home/bach/tt-install/tt-metal/ttnn/ttnn/_ttnn.so(+0x166d1e6) [0x7f3f9a9d01e6]
 --- /home/bach/tt-menv/bin/python(PyCFunction_Call+0x59) [0x5f53d9]
 --- /home/bach/tt-menv/bin/python(_PyObject_MakeTpCall+0x29e) [0x5f5fae]
 --- /home/bach/tt-menv/bin/python() [0x50b0b8]
 --- /home/bach/tt-menv/bin/python(PyObject_Call+0x1f7) [0x5f4ce7]
 --- /home/bach/tt-menv/bin/python() [0x59c1cc]
 --- /home/bach/tt-menv/bin/python(PyObject_Call+0x27e) [0x5f4d6e]
 --- /home/bach/tt-menv/bin/python(_PyEval_EvalFrameDefault+0x1f35) [0x56c645]
 --- /home/bach/tt-menv/bin/python(_PyEval_EvalCodeWithName+0x26a) [0x56910a]
 --- /home/bach/tt-menv/bin/python(_PyFunction_Vectorcall+0x393) [0x5f5963]
 --- /home/bach/tt-menv/bin/python() [0x59bebf]
 --- /home/bach/tt-menv/bin/python(_PyObject_MakeTpCall+0x29e) [0x5f5fae]
 --- /home/bach/tt-menv/bin/python(_PyEval_EvalFrameDefault+0x630c) [0x570a1c]
 --- /home/bach/tt-menv/bin/python(_PyFunction_Vectorcall+0x1b6) [0x5f5786]
 --- /home/bach/tt-menv/bin/python(_PyEval_EvalFrameDefault+0x858) [0x56af68]
 --- /home/bach/tt-menv/bin/python(_PyEval_EvalCodeWithName+0x26a) [0x56910a]
 --- /home/bach/tt-menv/bin/python(PyEval_EvalCode+0x27) [0x68d2b7]
 --- /home/bach/tt-menv/bin/python() [0x6001d4]
 --- /home/bach/tt-menv/bin/python() [0x5c3a90]
 --- /home/bach/tt-menv/bin/python(_PyEval_EvalFrameDefault+0x72d) [0x56ae3d]
 --- /home/bach/tt-menv/bin/python() [0x500328]
 --- /home/bach/tt-menv/bin/python(_PyEval_EvalFrameDefault+0x213a) [0x56c84a]
 --- /home/bach/tt-menv/bin/python() [0x500328]
 --- /home/bach/tt-menv/bin/python(_PyEval_EvalFrameDefault+0x213a) [0x56c84a]
 --- /home/bach/tt-menv/bin/python() [0x500328]
 --- /home/bach/tt-menv/bin/python() [0x503f46]
 --- /home/bach/tt-menv/bin/python(_PyEval_EvalFrameDefault+0x858) [0x56af68]
 --- /home/bach/tt-menv/bin/python(_PyFunction_Vectorcall+0x1b6) [0x5f5786]
 --- /home/bach/tt-menv/bin/python(_PyEval_EvalFrameDefault+0x72d) [0x56ae3d]
 --- /home/bach/tt-menv/bin/python(_PyFunction_Vectorcall+0x1b6) [0x5f5786]
 --- /home/bach/tt-menv/bin/python(_PyEval_EvalFrameDefault+0x858) [0x56af68]
 --- /home/bach/tt-menv/bin/python(_PyEval_EvalCodeWithName+0x26a) [0x56910a]
 --- /home/bach/tt-menv/bin/python(_PyFunction_Vectorcall+0x393) [0x5f5963]
 --- /home/bach/tt-menv/bin/python() [0x50a98c]
 --- /home/bach/tt-menv/bin/python(PyObject_Call+0x1f7) [0x5f4ce7]
 --- /home/bach/tt-menv/bin/python(_PyEval_EvalFrameDefault+0x1f35) [0x56c645]
 --- /home/bach/tt-menv/bin/python(_PyEval_EvalCodeWithName+0x26a) [0x56910a]
 --- /home/bach/tt-menv/bin/python() [0x50aa00]
 --- /home/bach/tt-menv/bin/python(_PyEval_EvalFrameDefault+0x1882) [0x56bf92]
 --- /home/bach/tt-menv/bin/python() [0x500328]
 --- /home/bach/tt-menv/bin/python(_PyEval_EvalFrameDefault+0x213a) [0x56c84a]
 --- /home/bach/tt-menv/bin/python() [0x500328]
 --- /home/bach/tt-menv/bin/python(_PyEval_EvalFrameDefault+0x213a) [0x56c84a]
 --- /home/bach/tt-menv/bin/python() [0x500328]
 --- /home/bach/tt-menv/bin/python(_PyEval_EvalFrameDefault+0x213a) [0x56c84a]
 --- /home/bach/tt-menv/bin/python() [0x500328]
 --- /home/bach/tt-menv/bin/python(_PyEval_EvalFrameDefault+0x213a) [0x56c84a]
 --- /home/bach/tt-menv/bin/python() [0x500328]
 --- /home/bach/tt-menv/bin/python(_PyEval_EvalFrameDefault+0x213a) [0x56c84a]
 --- /home/bach/tt-menv/bin/python() [0x500328]
 --- /usr/lib/python3.8/lib-dynload/_asyncio.cpython-38-x86_64-linux-gnu.so(+0x7ef9) [0x7f3ff2ca3ef9]
 --- /usr/lib/python3.8/lib-dynload/_asyncio.cpython-38-x86_64-linux-gnu.so(+0x9083) [0x7f3ff2ca5083]
 --- /home/bach/tt-menv/bin/python(_PyObject_MakeTpCall+0x29e) [0x5f5fae]
 --- /home/bach/tt-menv/bin/python() [0x5fed13]
 --- /home/bach/tt-menv/bin/python() [0x5c3977]
 --- /home/bach/tt-menv/bin/python(PyVectorcall_Call+0x18d) [0x5f525d]
 --- /home/bach/tt-menv/bin/python(_PyEval_EvalFrameDefault+0x6a6a) [0x57117a]
 --- /home/bach/tt-menv/bin/python(_PyFunction_Vectorcall+0x1b6) [0x5f5786]
 --- /home/bach/tt-menv/bin/python(_PyEval_EvalFrameDefault+0x858) [0x56af68]
 --- /home/bach/tt-menv/bin/python(_PyFunction_Vectorcall+0x1b6) [0x5f5786]
 --- /home/bach/tt-menv/bin/python(_PyEval_EvalFrameDefault+0x858) [0x56af68]
 --- /home/bach/tt-menv/bin/python(_PyFunction_Vectorcall+0x1b6) [0x5f5786]
 --- /home/bach/tt-menv/bin/python(_PyEval_EvalFrameDefault+0x858) [0x56af68]
 --- /home/bach/tt-menv/bin/python(_PyFunction_Vectorcall+0x1b6) [0x5f5786]
 --- /home/bach/tt-menv/bin/python(_PyEval_EvalFrameDefault+0x858) [0x56af68]
 --- /home/bach/tt-menv/bin/python(_PyFunction_Vectorcall+0x1b6) [0x5f5786]
 --- /home/bach/tt-menv/bin/python(_PyEval_EvalFrameDefault+0x858) [0x56af68]
 --- /home/bach/tt-menv/bin/python(_PyEval_EvalCodeWithName+0x26a) [0x56910a]
 --- /home/bach/tt-menv/bin/python() [0x50aa00]
 --- /home/bach/tt-menv/bin/python(_PyEval_EvalFrameDefault+0x5809) [0x56ff19]
 --- /home/bach/tt-menv/bin/python(_PyEval_EvalCodeWithName+0x26a) [0x56910a]
 --- /home/bach/tt-menv/bin/python(PyEval_EvalCode+0x27) [0x68d2b7]
 --- /home/bach/tt-menv/bin/python() [0x6001d4]
 --- /home/bach/tt-menv/bin/python() [0x5c3a90]
 --- /home/bach/tt-menv/bin/python(_PyEval_EvalFrameDefault+0x72d) [0x56ae3d]
 --- /home/bach/tt-menv/bin/python(_PyEval_EvalCodeWithName+0x26a) [0x56910a]
 --- /home/bach/tt-menv/bin/python(_PyFunction_Vectorcall+0x393) [0x5f5963]
 --- /home/bach/tt-menv/bin/python(_PyEval_EvalFrameDefault+0x72d) [0x56ae3d]
 --- /home/bach/tt-menv/bin/python(_PyEval_EvalCodeWithName+0x26a) [0x56910a]
 --- /home/bach/tt-menv/bin/python(_PyFunction_Vectorcall+0x393) [0x5f5963]
 --- /home/bach/tt-menv/bin/python(PyObject_Call+0x1f7) [0x5f4ce7]
 --- /home/bach/tt-menv/bin/python() [0x6b6a92]
 --- /home/bach/tt-menv/bin/python(Py_RunMain+0x379) [0x6b6e99]
 --- /home/bach/tt-menv/bin/python(Py_BytesMain+0x2d) [0x6b70bd]
 --- /lib/x86_64-linux-gnu/libc.so.6(__libc_start_main+0xf3) [0x7f3ff3554083]
 --- /home/bach/tt-menv/bin/python(_start+0x2e) [0x5fa3ae]
