In [1]:
import math
import time
import torch
import ttnn
from llama2.model import ModelArgs
from typing import Tuple

# just a super simple forward without host overhead
from ttcode.lightweightmodule import LightweightModule
import random
random.seed(42)
torch.manual_seed(42)

2024-12-13 17:52:08.811 | DEBUG    | ttnn:<module>:82 - Initial ttnn.CONFIG:
Config{cache_path=/home/bach/.cache/ttnn,model_cache_path=/home/bach/.cache/ttnn/models,tmp_dir=/tmp/ttnn,enable_model_cache=false,enable_fast_runtime_mode=true,throw_exception_on_fallback=false,enable_logging=false,enable_graph_report=false,enable_detailed_buffer_report=false,enable_detailed_tensor_report=false,enable_comparison_mode=false,comparison_mode_pcc=0.9999,root_report_path=generated/ttnn/reports,report_name=std::nullopt,std::nullopt}


<torch._C.Generator at 0x7f92db17fbb0>

In [2]:
device_id = 0
device = ttnn.open_device(device_id=device_id)

[38;2;000;128;000m                 Device[0m | [1m[38;2;100;149;237mINFO    [0m | Opening user mode device driver

[32m2024-12-13 17:52:08.923[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Detected 1 PCI device : [0]
[32m2024-12-13 17:52:08.923[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Opened PCI device 0; KMD version: 1.30.0
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Initializing device 0. Program cache is NOT enabled
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | AI CLK for device 0 is:   1000 MHz
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Profiler started on device 0


In [3]:
# Load checkpoint dict
import os
checkpoint_dict = torch.load("llama2/configs/stories260K.pth")
model_args = checkpoint_dict['model_args']
state_dict = checkpoint_dict['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)

args = ModelArgs()
# args.dim = 4
args.max_seq_len = 512
print(args)

ModelArgs(dim=64, n_layers=5, n_heads=8, n_kv_heads=4, vocab_size=512, hidden_dim=None, multiple_of=4, norm_eps=1e-05, max_seq_len=512, dropout=0.5)


In [4]:
batch_size = 1
tokens_num = args.max_seq_len

In [5]:
from llama2.model import precompute_freqs_cis
torch_freqs_cos, torch_freqs_sin = precompute_freqs_cis(args.dim // args.n_heads, args.max_seq_len)
# freqs_cos = ttnn.from_torch(torch_freqs_cos, device=device)
#freqs_sin = ttnn.from_torch(torch_freqs_sin, device=device)
freqs_cos = torch_freqs_cos.bfloat16()
freqs_sin = torch_freqs_sin.bfloat16()

In [6]:
def reshape_for_broadcast(freqs_cis: ttnn.Tensor, x: ttnn.Tensor):
    ndim = len(x.shape)
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1]) 
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return ttnn.reshape(freqs_cis, shape)

In [7]:
def apply_rotary_emb(
    xq: ttnn.Tensor,
    xk: ttnn.Tensor,
    freqs_cos: ttnn.Tensor,
    freqs_sin: ttnn.Tensor
) -> Tuple[ttnn.Tensor, ttnn.Tensor]:
    print("xq:", xq.shape)
    assert tuple(xq.shape)[0] == 1, "Only works with batch 1 :-C"
    xq = ttnn.reshape(xq, (tuple(xq.shape)[:-1] + (-1,2)))
    # Cannot unbind, cannot slice with [:..], must use ttnn.slice
    # xq_r, xq_i = xq.unbind(-1)
    # Squeeze because to_layour only supports 4D max tensor
    xq = ttnn.squeeze(xq, 0)
    xq = ttnn.to_layout(xq, layout = ttnn.ROW_MAJOR_LAYOUT)
    xq = ttnn.unsqueeze(xq, 0)
    xq_r = ttnn.slice(xq, [0,0,0,0,0], list(tuple(xq.shape)[:-1] + (1,)))
    xq_r = ttnn.squeeze(xq_r, -1)
    print("xq_r:", xq_r.shape)
    # ttnn.deallocate(xq_r)
    xq_i = ttnn.slice(xq, [0,0,0,0,1], list(tuple(xq.shape)[:-1] + (2,)))
    xq_i = ttnn.squeeze(xq_i, -1)    
    print("xq_i:", xq_i.shape)

    xk = ttnn.squeeze(xk, 0)
    xk = ttnn.to_layout(xk, layout = ttnn.ROW_MAJOR_LAYOUT)
    xk = ttnn.unsqueeze(xk, 0)
    xk = ttnn.unsqueeze(xk, 0)
    xk_r = ttnn.slice(xk, [0,0,0,0,0],tuple(xk.shape)[:-1] + (1,))
    xk_r = ttnn.squeeze(xk_r, -1)
    # ttnn.deallocate(xk_r)
    xk_i = ttnn.slice(xk, [0,0,0,0,1], tuple(xk.shape)[:-1] + (2,))
    xk_i = ttnn.squeeze(xk_i, -1)  

    freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
    freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)

    freqs_sin = ttnn.to_layout(freqs_sin, layout = ttnn.TILE_LAYOUT)
    freqs_cos = ttnn.to_layout(freqs_cos, layout = ttnn.TILE_LAYOUT)

    # apply rotation using real numbers
    xq_r = ttnn.to_layout(xq_r, layout = ttnn.TILE_LAYOUT)
    xq_i = ttnn.to_layout(xq_i, layout = ttnn.TILE_LAYOUT)
    xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
    xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos

    xk_r = ttnn.to_layout(xk_r, layout = ttnn.TILE_LAYOUT)
    xk_i = ttnn.to_layout(xk_i, layout = ttnn.TILE_LAYOUT)
    xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
    xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
    
    # there's no ttnn.stack nor ttnn.flatten :-)
    xq_out_r = ttnn.to_layout(xq_out_r, layout = ttnn.ROW_MAJOR_LAYOUT)
    xq_out_i = ttnn.to_layout(xq_out_i, layout = ttnn.ROW_MAJOR_LAYOUT)
    print(xq_out_r.shape, xq_out_i.shape)
    # Create new dimension
    xq_out_r = ttnn.unsqueeze(xq_out_r, -1)
    # Concatenate along the new dimension
    xq_out = torch.concatenate([xq_out_r, xq_out_i], dim=-1)
    # todo: implement flatten
    print(xq_out.shape)

    # xq_out = ttnn.flatten(xq_out, 3)
    return 

In [8]:
from llama2.model import apply_rotary_emb as apply_rotary_emb_torch

def apply_rotary_emb_host(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cos: torch.Tensor,
    freqs_sin: torch.Tensor,
    dtype=ttnn.bfloat16
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq = ttnn.to_torch(xq)
    xk = ttnn.to_torch(xk) 
    xq_out, xk_out = apply_rotary_emb_torch(xq, xk, freqs_cos, freqs_sin)

    return (
        ttnn.from_torch(xq_out, device=device, dtype=ttnn.bfloat16),#layout=ttnn.TILE_LAYOUT),
        ttnn.from_torch(xk_out, device=device, dtype=ttnn.bfloat16)#layout=ttnn.TILE_LAYOUT)
    )


In [9]:
def repeat_kv(x: ttnn.Tensor, n_rep: int) -> ttnn.Tensor:
    return ttnn.repeat_interleave(x, dim=2, repeats=n_rep)

In [10]:
class Attention(LightweightModule):
    def __init__(self, args: ModelArgs, state_dict: dict, layer_num, device, dtype=ttnn.bfloat16):
        super().__init__()
        self.dtype = dtype
        self.max_batch_size = 1
        self.device = device
        # Indicates the number of heads for the Keys and Values
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        # Indicates the number of heads for the Queries
        self.n_q_heads = args.n_heads
        assert args.n_heads % self.n_kv_heads == 0
        # Indicates the dimension of each head, that is, the part of the embedding that each head will be responsible for
        self.head_dim = args.dim // args.n_heads        
        # Indicates how many times the Keys and Values should be repeated        
        self.n_rep = args.n_heads // self.n_kv_heads
        
        prefix = f"layers.{layer_num}.attention."

        self.wq = ttnn.as_tensor(
            torch.transpose(state_dict[f"{prefix}wq.weight"], -2, -1,),
            memory_config=ttnn.L1_MEMORY_CONFIG,
            dtype=self.dtype,
            device=self.device
        )
        self.wk = ttnn.as_tensor(
            torch.transpose(state_dict[f"{prefix}wk.weight"], -2, -1,),
            memory_config=ttnn.L1_MEMORY_CONFIG,
            dtype=self.dtype,
            device=self.device
        )
        self.wv = ttnn.as_tensor(
            torch.transpose(state_dict[f"{prefix}wv.weight"], -2, -1,),
            memory_config=ttnn.L1_MEMORY_CONFIG,
            dtype=self.dtype,
            device=self.device
        )
        self.wo = ttnn.as_tensor(
            torch.transpose(state_dict[f"{prefix}wo.weight"], -2, -1,),
            memory_config=ttnn.L1_MEMORY_CONFIG,
            dtype=self.dtype,
            device=self.device
        )

        self.wq = ttnn.to_layout(self.wq, ttnn.TILE_LAYOUT)
        self.wk = ttnn.to_layout(self.wk, ttnn.TILE_LAYOUT)
        self.wv = ttnn.to_layout(self.wv, ttnn.TILE_LAYOUT)
        self.wo = ttnn.to_layout(self.wo, ttnn.TILE_LAYOUT)

        mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
        self.mask = torch.triu(mask, diagonal=1).bfloat16()
        self.mask = ttnn.from_torch(self.mask, device=device)
        self.mask = ttnn.to_layout(self.mask, ttnn.TILE_LAYOUT)


    def forward_opt(self, x: ttnn.Tensor, freqs_cos:torch.Tensor, freqs_sin: torch.Tensor):
        ... 

    def forward(self, x: ttnn.Tensor, freqs_cos:torch.Tensor, freqs_sin: torch.Tensor):
        bsz, seqlen, _ = x.shape
        xq = ttnn.linear(
            x,
            self.wq,
            bias=None,
            memory_config=ttnn.L1_MEMORY_CONFIG,
            dtype=self.dtype,
        )

        xq = ttnn.to_layout(xq, layout=ttnn.ROW_MAJOR_LAYOUT)
        xq = ttnn.reshape(xq, (bsz, seqlen, self.n_q_heads, self.head_dim))
        xq = ttnn.to_layout(xq, layout=ttnn.TILE_LAYOUT)

        xk = ttnn.linear(
            x,
            self.wk,
            bias=None,
            memory_config=ttnn.L1_MEMORY_CONFIG,
            dtype=self.dtype,
        )
        xk = ttnn.to_layout(xk, layout=ttnn.ROW_MAJOR_LAYOUT)
        xk = ttnn.reshape(xk, (bsz, seqlen, self.n_kv_heads, self.head_dim))
        xk = ttnn.to_layout(xk, layout=ttnn.TILE_LAYOUT)


        xv = ttnn.linear(
            x,
            self.wv,
            bias=None,
            memory_config=ttnn.L1_MEMORY_CONFIG,
            dtype=self.dtype,
        )
        xv = ttnn.to_layout(xv, layout=ttnn.ROW_MAJOR_LAYOUT)
        xv = ttnn.reshape(xv, (bsz, seqlen, self.n_kv_heads, self.head_dim))
        xv = ttnn.to_layout(xv, layout=ttnn.TILE_LAYOUT)
        
        # Apply RoPE
        xq, xk = apply_rotary_emb_host(xq, xk, freqs_cos, freqs_sin)
        xk = repeat_kv(xk, self.n_rep)
        xv = repeat_kv(xv, self.n_rep)

        # premute instead of transpose 
        # (B, 1, H_Q, Head_Dim) -> (B, H_Q, 1, Head_Dim)
        xq = ttnn.permute(xq, (0, 2, 1, 3))
        # (B, Seq_Len_KV, H_Q, Head_Dim) -> (B, H_Q, Seq_Len_KV, Head_Dim)
        xk = ttnn.permute(xk, (0, 2, 1, 3))
        xv = ttnn.permute(xv, (0, 2, 1, 3))

        xq = ttnn.to_layout(
            xq, 
            layout=ttnn.TILE_LAYOUT, 
            memory_config=ttnn.DRAM_MEMORY_CONFIG)
        xk = ttnn.to_layout(
            xk, 
            layout=ttnn.TILE_LAYOUT, 
            memory_config=ttnn.DRAM_MEMORY_CONFIG)
        xv = ttnn.to_layout(
            xv, 
            layout=ttnn.TILE_LAYOUT, 
            memory_config=ttnn.DRAM_MEMORY_CONFIG)
    
        # use flash attention, shape problem
        # xk = ttnn.permute(xk, (0, 1, 3, 2))
        if False:
            output = ttnn.transformer.scaled_dot_product_attention(
                xq, 
                xk, 
                xv, 
                attn_mask=None, 
                is_causal=True
            )
        attention_scores = ttnn.matmul(
            xq,
            ttnn.permute(xk, (0, 1, 3, 2)),
            memory_config=ttnn.L1_MEMORY_CONFIG,
            dtype=self.dtype,
            # core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x),
        )
        # attention_scores = xq @ ttnn.permute(xk, (0, 2, 1, 3))
        attention_scores = ttnn.div(attention_scores, math.sqrt(self.head_dim))
        attention_scores = attention_scores + self.mask[:, :, :seqlen, :seqlen]
        attention_scores = ttnn.softmax(attention_scores, dim=-1)

        output = ttnn.matmul(
            attention_scores,
            xv,
            memory_config=ttnn.L1_MEMORY_CONFIG,
            dtype=self.dtype
        )
        output = ttnn.permute(output, (0, 2, 1, 3))
        output = ttnn.to_layout(output, layout=ttnn.ROW_MAJOR_LAYOUT)
        output = ttnn.reshape(output, (bsz, seqlen, -1))
        output = ttnn.to_layout(output, layout=ttnn.TILE_LAYOUT)

        output = ttnn.linear(
            output,
            self.wo,
            bias=None,
            memory_config=ttnn.L1_MEMORY_CONFIG,
            dtype=self.dtype,
        )
        return output
        

In [11]:
# torch.set_printoptions(linewidth=100, precision=3,profile='full')
torch.set_printoptions(profile='short')
def check_close(a, b, atol=0.001):
    b = torch.Tensor(ttnn.to_torch(b))
    # print(a, b)
    equals = torch.sum(torch.isclose(a, b, atol=atol))
    print(f"{equals}/{a.numel()}")


In [12]:
from llama2.model import Attention as llama2Attention
if False:
    # Compare Attention layer
    layer_num = 0
    attention = Attention(args, state_dict, layer_num, device)
    # attention(ttnn.from_torch(torch.rand((1,64)), layout=ttnn.TILE_LAYOUT), 1)

    x_torch = torch.rand((batch_size, tokens_num, args.dim), dtype=torch.bfloat16) 

    x = ttnn.from_torch(
        x_torch,
        layout=ttnn.TILE_LAYOUT, 
        device=device,
        dtype=ttnn.bfloat16
    )

    # output = attention.forward(x, freqs_cos, freqs_sin)

    torch_attention = llama2Attention(args, state_dict, layer_num)
    # torch_attention.load_state_dict(state_dict, strict=False)
    # torch_output = torch_attention.forward(x_torch, torch_freqs_cos, torch_freqs_sin)

In [13]:
class FeedForward(LightweightModule):
    def __init__(self, dim: int, hidden_dim: int, multiple_of: int, layer_num, device, dtype=ttnn.bfloat16):
        super().__init__()
        self.dtype = dtype
        self.device = device
        if hidden_dim is None:
            hidden_dim = 4 * dim
            hidden_dim = int(2 * hidden_dim / 3)
            # Round the hidden_dim to the nearest multiple of the multiple_of parameter
            hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
        prefix = f"layers.{layer_num}.feed_forward."

        self.w1 = ttnn.as_tensor(
            torch.transpose(state_dict[f"{prefix}w1.weight"], -2, -1,),
            memory_config=ttnn.L1_MEMORY_CONFIG,
            dtype=self.dtype,
            device=self.device
        )
        self.w2 = ttnn.as_tensor(
            torch.transpose(state_dict[f"{prefix}w2.weight"], -2, -1,),
            memory_config=ttnn.L1_MEMORY_CONFIG,
            dtype=self.dtype,
            device=self.device
        )
        self.w3 = ttnn.as_tensor(
            torch.transpose(state_dict[f"{prefix}w3.weight"], -2, -1,),
            memory_config=ttnn.L1_MEMORY_CONFIG,
            dtype=self.dtype,
            device=self.device
        )

        self.w1 = ttnn.to_layout(self.w1, ttnn.TILE_LAYOUT)
        self.w2 = ttnn.to_layout(self.w2, ttnn.TILE_LAYOUT)
        self.w3 = ttnn.to_layout(self.w3, ttnn.TILE_LAYOUT)


    def forward(self, x):
        # (B, Seq_Len, Dim) --> (B, Seq_Len, Hidden_Dim)
        x1 = ttnn.linear(
            x,
            self.w1,
            bias=None,
            memory_config=ttnn.L1_MEMORY_CONFIG,
            dtype=self.dtype,
        ) 
        swish = ttnn.silu(x1)
        # (B, Seq_Len, Dim) --> (B, Seq_Len, Hidden_Dim)
        x_V = ttnn.linear(
            x,
            self.w3,
            bias=None,
            memory_config=ttnn.L1_MEMORY_CONFIG,
            dtype=self.dtype,
        ) 
        # (B, Seq_Len, Hidden_Dim) * (B, Seq_Len, Hidden_Dim) --> (B, Seq_Len, Hidden_Dim)
        x = ttnn.mul(swish, x_V)
        # (B, Seq_Len, Hidden_Dim) --> (B, Seq_Len, Dim)
        x = ttnn.linear(
            x,
            self.w2,
            bias=None,
            memory_config=ttnn.L1_MEMORY_CONFIG,
            dtype=self.dtype,
        ) 
        return x

In [14]:
from llama2.model import FeedForward as llamaFF
# Compare FeedForward layer
layer_num = 0

x_torch = torch.ones((batch_size, tokens_num, args.dim), dtype=torch.bfloat16) 

x = ttnn.from_torch(
    x_torch,
    layout=ttnn.TILE_LAYOUT, 
    device=device,
    dtype=ttnn.bfloat16
)

tt_ff = FeedForward(args.dim, args.hidden_dim, args.multiple_of, layer_num, device)
torch_ff = llamaFF(args.dim, args.hidden_dim, args.multiple_of, 0, state_dict, layer_num)

# ttnn_output = tt_ff.forward(x)
# torch_output = torch_ff.forward(x_torch)

In [15]:
# RMS Norm imported from a common
class RMSNorm(LightweightModule):
    def __init__(self, device, dim, eps: float, state_dict, prefix, dtype=ttnn.bfloat16, add_prefix=True):
        super().__init__()
        self.eps = eps
        self.device = device
        self.dtype=dtype
        # The gamma parameter
        if add_prefix:
            prefix = f"layers.{prefix}weight"
        torch_weight = state_dict[prefix].unsqueeze(0).view(1, 1, dim)
        self.weight = ttnn.as_tensor(
            # torch.transpose(state_dict[prefix], -2, -1,),
            torch_weight,
            device=device,
            memory_config=ttnn.L1_MEMORY_CONFIG,
            dtype=dtype,
        )
        self.weight = ttnn.to_layout(self.weight, ttnn.TILE_LAYOUT)

    def forward(self, x):
        # (Dim) * (B, Seq_Len, Dim) = (B, Seq_Len, Dim)
        norm_x = ttnn.mul(x,ttnn.rsqrt(ttnn.mean(ttnn.pow(x,2), -1)) + self.eps)
        return ttnn.mul(norm_x, self.weight)

In [16]:
class TransformerBlock(LightweightModule):
    def __init__(self, layer_id: int, args: ModelArgs, state_dict, device=device):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // args.n_heads
        self.attention = Attention(args, state_dict, layer_id, device)
        self.feed_forward = FeedForward(
            args.dim,
            args.hidden_dim,
            args.multiple_of,
            layer_id,
            device
        )
        self.layer_id = layer_id
        # Normalization BEFORE the attention block
        self.attention_norm = RMSNorm(device, args.dim, args.norm_eps, state_dict, f"{layer_id}.attention_norm.")
        # Normalization BEFORE the feed forward block
        self.ffn_norm = RMSNorm(device, args.dim, args.norm_eps, state_dict, f"{layer_id}.ffn_norm.")

    def forward(self, x, freqs_cos, freqs_sin):
        # (B, Seq_Len, Dim) + (B, Seq_Len, Dim) --> (B, Seq_Len, Dim)
        h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin)
        # (B, Seq_Len, Dim) + (B, Seq_Len, Dim) --> (B, Seq_Len, Dim)
        out = h + self.feed_forward.forward(self.ffn_norm(h))
        return out

In [17]:
print(freqs_cos.shape)

torch.Size([512, 4])


In [None]:
from llama2.model import TransformerBlock as llamaTransformerBlock
# Compare Transformer Block
if False:
    layer_num = 0

    x_torch = torch.rand((batch_size, tokens_num, args.dim), dtype=torch.bfloat16)
    # x_torch = torch_trans_output
    print(x_torch.shape)
    _, seq_len, _ = x_torch.shape
        
    x = ttnn.from_torch(
        x_torch,
        device=device,
        dtype=ttnn.bfloat16
    )

    x = ttnn.to_layout(x, ttnn.TILE_LAYOUT)

    tt_tb = TransformerBlock(layer_num,args, state_dict, device)
    torch_tb = llamaTransformerBlock(layer_num, args, state_dict)
    start = time.time()
    ttnn_tb_output = tt_tb.forward(x, freqs_cos[:seq_len], freqs_sin[:seq_len])
    print(time.time() - start)
    start = time.time()
    torch_tb_output = torch_tb.forward(x_torch, freqs_cos[:seq_len], freqs_sin[:seq_len])
    print(time.time() - start)
    check_close(torch_tb_output.bfloat16(), ttnn_tb_output, atol=0.05) 

torch.Size([1, 512, 64])
10.076623439788818
0.04893326759338379
32742/32768


In [19]:
class Transformer(LightweightModule):
    # last_loss: Optional[torch.Tensor]

    def __init__(self, params: ModelArgs, state_dict, device, dtype=ttnn.bfloat16):
        super().__init__()
        self.params = params
        self.vocab_size = params.vocab_size
        self.n_layers = params.n_layers

        self.dtype = dtype
        self.device = device
        # self.dropout = nn.Dropout(params.dropout)
        self.layers = [
            TransformerBlock(i, params, state_dict, device)
            for i in range(self.n_layers)
        ]
        self.norm = RMSNorm(device, params.dim, params.norm_eps, state_dict, prefix="norm.weight", add_prefix=False)
        self.output = ttnn.as_tensor(
            torch.transpose(state_dict[f"output.weight"], -2, -1,),
            layout=ttnn.TILE_LAYOUT,
            memory_config=ttnn.L1_MEMORY_CONFIG,
            dtype=self.dtype,
            device=self.device
        )
        
        self.tok_embeddings = ttnn.as_tensor(
            #state_dict[f"output.weight"],
            state_dict[f"tok_embeddings.weight"], # .unsqueeze(0).unsqueeze(0),
            # torch.transpose(state_dict[f"tok_embeddings.weight"], -2, -1,),
            layout=ttnn.ROW_MAJOR_LAYOUT,
            memory_config=ttnn.L1_MEMORY_CONFIG,
            dtype=self.dtype,
            device=self.device
        )

        # some useful precompute for the RoPE relative positional embeddings
        self.freqs_cos, self.freqs_sin = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len)

    from typing import Optional

    def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None) -> torch.Tensor:
        # (B, Seq_Len)
        _bsz, seqlen = tokens.shape
        # (B, Seq_Len) -> (B, Seq_Len, Dim)
        # h = self.tok_embeddings(tokens)
        h = ttnn.embedding(
            tokens, 
            self.tok_embeddings,
            layout=ttnn.TILE_LAYOUT)
        # h = self.dropout(h)
        freqs_cos = self.freqs_cos[:seqlen]
        freqs_sin = self.freqs_sin[:seqlen]
        # Consecutively apply all the encoder layers
        # return self.layers[0](h, freqs_cos, freqs_sin)
        for layer in self.layers:
            h = layer(h, freqs_cos, freqs_sin)

        h = self.norm(h)
        print(h.shape)
        # inference-time mini-optimization: only forward the output on the very last position
        h = ttnn.to_layout(h, ttnn.ROW_MAJOR_LAYOUT)
        h = ttnn.slice(h, [0, seqlen-1, 0], [_bsz, seqlen, args.dim])
        h = ttnn.to_layout(h, ttnn.TILE_LAYOUT, device=self.device)

        logits = ttnn.linear(
            h,
            self.output,
            bias=None,
            memory_config=ttnn.L1_MEMORY_CONFIG,
            dtype=self.dtype,
        )

        return logits

In [21]:
from llama2.model import Transformer as llamaTransformer
# Compare Transformer
if False:
    x_torch = torch.randint(1,args.vocab_size-1, (1, 152))
    seq_len = 1

    x = ttnn.from_torch(
        x_torch,
        device=device,
        dtype=ttnn.bfloat16
    )
    x = ttnn.to_layout(x, layout=ttnn.ROW_MAJOR_LAYOUT)
    tt_trans = Transformer(args, state_dict, device)
    torch_trans = llamaTransformer(args, state_dict)
    start = time.time()
    ttnn_trans_output = tt_trans.forward(x)
    print(time.time() - start)
    start = time.time()
    torch_trans_output = torch_trans.forward(x_torch)
    print(time.time() - start)

    check_close(torch_trans_output.bfloat16(), ttnn_trans_output, atol=0.1)

## WIP: Speaking really bad, need to adjust something somewhere

In [28]:
from llama2.tokenizer import Tokenizer

tt_trans = Transformer(args, state_dict, device)
enc = Tokenizer(tokenizel_path="./llama2/tokenizer.model")
temperature = 1.0
top_k = 300

start = "Dream comes true!"
start_ids = enc.encode(start, bos=True, eos=False)
x = (torch.tensor(start_ids, dtype=torch.long)[None, ...])
tt_x = ttnn.from_torch(
        x_torch,
        device=device,
        dtype=ttnn.bfloat16
    )
tt_x = ttnn.to_layout(tt_x, layout=ttnn.ROW_MAJOR_LAYOUT)

for _ in range(100):
    # if the sequence context is growing too long we must crop it at block_size
    # x = x if idx.size(1) <= args.max_seq_len else idx[:, -args.max_seq_len:]
    # forward the model to get the logits for the index in the sequence
    logits = tt_trans.forward(tt_x)
    logits = ttnn.to_torch(logits)
    logits = logits[:, -1, :] # crop to just the final time step
    if temperature == 0.0:
        # "sample" the single most likely index
        _, idx_next = torch.topk(logits, k=1, dim=-1)
    else:
        # pluck the logits at the final step and scale by desired temperature
        logits = logits / temperature
        # optionally crop the logits to only the top k options
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = -float('Inf')
        # apply softmax to convert logits to (normalized) probabilities
        probs = torch.functional.F.softmax(logits, dim=-1)
        idx_next = torch.multinomial(probs, num_samples=1)
    # append sampled index to the running sequence and continue
    x = torch.cat((x, idx_next), dim=1)

print(enc.decode(x[0].tolist()))

ttnn.Shape([1, 152[160], 64])
ttnn.Shape([1, 152[160], 64])
ttnn.Shape([1, 152[160], 64])
ttnn.Shape([1, 152[160], 64])
ttnn.Shape([1, 152[160], 64])
ttnn.Shape([1, 152[160], 64])
ttnn.Shape([1, 152[160], 64])
ttnn.Shape([1, 152[160], 64])
ttnn.Shape([1, 152[160], 64])
ttnn.Shape([1, 152[160], 64])
ttnn.Shape([1, 152[160], 64])
ttnn.Shape([1, 152[160], 64])
ttnn.Shape([1, 152[160], 64])
ttnn.Shape([1, 152[160], 64])
ttnn.Shape([1, 152[160], 64])
ttnn.Shape([1, 152[160], 64])
ttnn.Shape([1, 152[160], 64])
ttnn.Shape([1, 152[160], 64])
ttnn.Shape([1, 152[160], 64])
ttnn.Shape([1, 152[160], 64])
ttnn.Shape([1, 152[160], 64])
ttnn.Shape([1, 152[160], 64])
ttnn.Shape([1, 152[160], 64])
ttnn.Shape([1, 152[160], 64])
ttnn.Shape([1, 152[160], 64])
ttnn.Shape([1, 152[160], 64])
ttnn.Shape([1, 152[160], 64])
ttnn.Shape([1, 152[160], 64])
ttnn.Shape([1, 152[160], 64])
ttnn.Shape([1, 152[160], 64])
ttnn.Shape([1, 152[160], 64])
ttnn.Shape([1, 152[160], 64])
ttnn.Shape([1, 152[160], 64])
ttnn.Shape