In [1]:
import time
import torch
import ttnn
from llama2.model import ModelArgs
import random
from models.utility_functions import comp_pcc

random.seed(42)
torch.manual_seed(42)

2025-02-04 18:01:21.680 | DEBUG    | ttnn:<module>:82 - Initial ttnn.CONFIG:
Config{cache_path=/home/bach/.cache/ttnn,model_cache_path=/home/bach/.cache/ttnn/models,tmp_dir=/tmp/ttnn,enable_model_cache=false,enable_fast_runtime_mode=true,throw_exception_on_fallback=false,enable_logging=false,enable_graph_report=false,enable_detailed_buffer_report=false,enable_detailed_tensor_report=false,enable_comparison_mode=false,comparison_mode_pcc=0.9999,root_report_path=generated/ttnn/reports,report_name=std::nullopt,std::nullopt}


<torch._C.Generator at 0x7f7a5ffb42f0>

In [2]:
TEST = {
    "Attention": True,
    "RMSNorm": True,
    "FeedForward": True,
    "TransformerBlock": True,
    "Transformer": True,
    "Generation": True
}
TEST_LAYER_NUM = 0

In [3]:
device_id = 0
device = ttnn.open_device(device_id=device_id)
# device.enable_program_cache()

[38;2;000;128;000m                 Device[0m | [1m[38;2;100;149;237mINFO    [0m | Opening user mode device driver

[32m2025-02-04 18:01:21.894[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Opened PCI device 0; KMD version: 1.30.0, IOMMU: disabled
[32m2025-02-04 18:01:21.895[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Detected PCI devices: [0]
[32m2025-02-04 18:01:21.895[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Using local chip ids: {0} and remote chip ids {}
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Initializing device 0. Program cache is NOT enabled
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | AI CLK for device 0 is:   1000 MHz
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Profiler started on device 0


In [4]:
# Load checkpoint dict
checkpoint_dict = torch.load("llama2/configs/stories260K.pth")
model_args = checkpoint_dict['model_args']
state_dict = checkpoint_dict['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)

args = ModelArgs()
print(args)

ModelArgs(dim=64, n_layers=5, n_heads=8, n_kv_heads=4, vocab_size=512, hidden_dim=None, multiple_of=4, norm_eps=1e-05, max_seq_len=512, dropout=0.5)


In [5]:
batch_size = 1
tokens_num = args.max_seq_len

In [6]:
from llama2.model import precompute_freqs_cis
torch_freqs_cos, torch_freqs_sin = precompute_freqs_cis(args.dim // args.n_heads, args.max_seq_len)
freqs_cos = torch_freqs_cos.bfloat16()
freqs_sin = torch_freqs_sin.bfloat16()

In [7]:
# torch.set_printoptions(linewidth=100, precision=3,profile='full')
torch.set_printoptions(profile='short')
def check_close(torch_tensor, ttnn_tensor, atol=0.01):
    ttnn_tensor = torch.Tensor(ttnn.to_torch(ttnn_tensor))
    equals = torch.sum(torch.isclose(torch_tensor, ttnn_tensor, atol=atol))
    perc = equals/torch_tensor.numel() * 100
    print(f"Close values: {perc:.3f}% ({torch_tensor.numel()})")
    print(f"PCC: {comp_pcc(torch_tensor, ttnn_tensor)[1]}")


### Compare Attention layer


In [13]:
from llama2.model import Attention as torchAttention
from ttllama2 import Attention

if TEST['Attention']:
    # ttnn.close_device(device)
    # device = ttnn.open_device(device_id=device_id)
    # ttnn.enable_cache(device)
    x_torch = torch.rand((batch_size, tokens_num, args.dim), dtype=torch.bfloat16) 

    x = ttnn.from_torch(
        x_torch,
        layout=ttnn.TILE_LAYOUT, 
        device=device,
        dtype=ttnn.bfloat16
    )

    print("Attention init:")
    start = time.time()
    attention = Attention(args, state_dict, TEST_LAYER_NUM, device)
    print(f"ttnn: {time.time() - start:.3f}")

    print("Attention forward:")
    
    start = time.time()
    o_attention = attention.forward(x, freqs_cos, freqs_sin)
    print(f"ttnn: {time.time() - start:.3f}")
    
    start = time.time()
    torch_attention = torchAttention(args, state_dict, TEST_LAYER_NUM)
    print(f"torch: {time.time() - start:.3f}")

    to_attention = torch_attention.forward(x_torch, torch_freqs_cos, torch_freqs_sin)
    
    check_close(to_attention, o_attention)

Attention init:
ttnn: 0.016
Attention forward:
1°: 0.020
2°: 0.045
3°: 0.014
4°: 0.047
ttnn: 0.126
torch: 0.039
Close values: 89.059% (32768)
PCC: 0.9998018312999007


### Compare FeedForward layer

In [41]:
from llama2.model import FeedForward as torchFeedForward
from ttllama2 import FeedForward

if TEST['FeedForward']:

    ttnn.close_device(device)
    device = ttnn.open_device(device_id=device_id)
    x_torch = torch.rand((batch_size, tokens_num, args.dim), dtype=torch.bfloat16) 
    x = ttnn.from_torch(
        x_torch,
        layout=ttnn.TILE_LAYOUT, 
        device=device,
        dtype=ttnn.bfloat16
    )
    tt_ff = FeedForward(args.dim, args.hidden_dim, args.multiple_of, TEST_LAYER_NUM, state_dict, device)
    torch_ff = torchFeedForward(args.dim, args.hidden_dim, args.multiple_of, 0, state_dict, TEST_LAYER_NUM)
    
    start = time.time()
    o_ff = tt_ff.forward(x)
    print("Transformer:")
    print(f"ttnn: {time.time() - start:.3f}")
    start = time.time()
    to_ff = torch_ff.forward(x_torch)
    print(f"torch: {time.time() - start:.3f}")

    check_close(to_ff, o_ff)

[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Closing device 0
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Disabling and clearing program cache on device 0
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Initializing device 0. Program cache is NOT enabled
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | AI CLK for device 0 is:   1000 MHz
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Profiler started on device 0
Transformer:
ttnn: 0.014
torch: 0.013
Close values: 37.268% (32768)
PCC: 0.9997631032750476


### Compare Transformer Block


In [42]:
from llama2.model import TransformerBlock as torchTransformerBlock
from ttllama2 import TransformerBlock 


if TEST['TransformerBlock']:
    ttnn.close_device(device)
    device = ttnn.open_device(device_id=device_id)
    x_torch = torch.rand((batch_size, tokens_num, args.dim), dtype=torch.bfloat16)
        
    x = ttnn.from_torch(
        x_torch,
        device=device,
        dtype=ttnn.bfloat16
    )
    x = ttnn.to_layout(x, ttnn.TILE_LAYOUT)

    tt_tb = TransformerBlock(TEST_LAYER_NUM, args, state_dict, device)
    torch_tb = torchTransformerBlock(TEST_LAYER_NUM, args, state_dict)
    _, seq_len, _ = x_torch.shape

    start = time.time()
    o_tb = tt_tb.forward(x, freqs_cos[:seq_len], freqs_sin[:seq_len])
    print("TransformerBlock:")
    print(f"ttnn: {time.time() - start:.3f}")
    start = time.time()
    to_tb = torch_tb.forward(x_torch, freqs_cos[:seq_len], freqs_sin[:seq_len])
    print(f"torch: {time.time() - start:.3f}")
    check_close(to_tb.bfloat16(), o_tb) 

[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Closing device 0
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Disabling and clearing program cache on device 0
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Initializing device 0. Program cache is NOT enabled
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | AI CLK for device 0 is:   1000 MHz
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Profiler started on device 0
TransformerBlock:
ttnn: 2.479
torch: 0.048
Close values: 22.971% (32768)
PCC: 0.9992009902669368


### Compare Transformer

Close values drops hard here...

In [89]:
from llama2.model import Transformer as torchTransformer
from ttllama2 import Transformer
from llama2.tokenizer import Tokenizer

if TEST['Transformer']:
    ttnn.close_device(device)
    device = ttnn.open_device(device_id=device_id)
    
    x_torch = torch.randint(1, args.vocab_size-1, (1, 152))
    seq_len = 1

    x = ttnn.from_torch(x_torch, device=device)
    x = ttnn.to_layout(x, layout=ttnn.ROW_MAJOR_LAYOUT)
    
    tt_trans = Transformer(args, state_dict, device)
    torch_trans = torchTransformer(args, state_dict)
    start = time.time()
    o_trans = tt_trans.forward(x)
    print("Transformer:")
    print(f"ttnn: {time.time() - start:.3f}")
    start = time.time()
    to_trans = torch_trans.forward(x_torch)
    print(f"torch: {time.time() - start:.3f}")

    check_close(to_trans.bfloat16(), o_trans)

[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Closing device 0
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Disabling and clearing program cache on device 0
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Initializing device 0. Program cache is NOT enabled
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | AI CLK for device 0 is:   1000 MHz
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Profiler started on device 0
Transformer:
ttnn: 0.820
torch: 0.058
Close values: 2.344% (512)
PCC: 0.9999211092957268


### Generation
#### WIP: Speaking really bad, need to adjust something somewhere


In [45]:
def get_next_token(logits, temperature, top_k):
    logits = logits[:, -1, :] # crop to just the final time step

    if temperature == 0.0:
        # "sample" the single most likely index
        _, idx_next = torch.topk(logits, k=1, dim=-1)
    else:
        # pluck the logits at the final step and scale by desired temperature
        logits = logits / temperature
        # optionally crop the logits to only the top k options
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = -float('Inf')
        # apply softmax to convert logits to (normalized) probabilities
        probs = torch.functional.F.softmax(logits, dim=-1)
        idx_next = torch.multinomial(probs, num_samples=1)
    
    return idx_next

In [55]:
from llama2.tokenizer import Tokenizer

enc = Tokenizer(tokenizel_path="./llama2/tokenizer.model")
temperature = 1.0
top_k = 300

start = "Dream comes true!"
start_ids = enc.encode(start, bos=True, eos=False)
x = (torch.tensor(start_ids, dtype=torch.long)[None, ...])


In [86]:
torch_trans = torchTransformer(args, state_dict)

# [1,13]
x = (torch.tensor(start_ids, dtype=torch.long)[None, ...])

for _ in range(100):
    # if the sequence context is growing too long we must crop it at block_size
    # x = x if idx.size(1) <= args.max_seq_len else idx[:, -args.max_seq_len:]
    
    # forward the model to get the logits for the index in the sequence
    logits = torch_trans.forward(x)

    idx_next = get_next_token(logits, temperature, top_k)
    # append sampled index to the running sequence and continue
    x = torch.cat((x, idx_next), dim=1)


print(enc.decode(x[0].tolist()))

Dream comes true! Daddy was very excited. He explained it was something yummy.
One day, Dizzy found a stubborn box on a hole. It was coming from the box and Benny was helpless. Don't worry,


In [87]:
ttnn.close_device(device)
device = ttnn.open_device(device_id=device_id)
# device.enable_program_cache()

tt_trans = Transformer(args, state_dict, device)

x = (torch.tensor(start_ids, dtype=torch.long)[None, ...])

time_forward = 0
for _ in range(100):
    tt_x = ttnn.from_torch(x, device=device)
    tt_x = ttnn.to_layout(tt_x, layout=ttnn.ROW_MAJOR_LAYOUT)

    start = time.time()
    logits = tt_trans.forward(tt_x)
    time_forward += time.time() - start

    logits = ttnn.to_torch(logits)
    idx_next = get_next_token(logits, temperature, top_k)    
    x = torch.cat((x, idx_next), dim=1)


print(enc.decode(x[0].tolist()))

print(f"Execution time: {time_forward}")

[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Closing device 0
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Disabling and clearing program cache on device 0
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Initializing device 0. Program cache is NOT enabled
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | AI CLK for device 0 is:   1000 MHz
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Profiler started on device 0
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Enabling program cache on device 0
Dream comes true! "edadadit ofout herenaesestb LilyeflfendJeliforon niin linll ankow brieiverke itedlMeom hisveldganIy akireherarotayitan Heendll are” butadltilannstoanv time.or ofs. toomary”Blenot
Execution time: 14.889288902282715
