In [1]:
import time
import torch
import ttnn
from llama2.model import ModelArgs
import random
from models.utility_functions import comp_pcc

random.seed(42)
torch.manual_seed(42)

2025-02-14 15:48:10.972 | DEBUG    | ttnn:<module>:82 - Initial ttnn.CONFIG:
Config{cache_path=/home/bach/.cache/ttnn,model_cache_path=/home/bach/.cache/ttnn/models,tmp_dir=/tmp/ttnn,enable_model_cache=false,enable_fast_runtime_mode=true,throw_exception_on_fallback=false,enable_logging=false,enable_graph_report=false,enable_detailed_buffer_report=false,enable_detailed_tensor_report=false,enable_comparison_mode=false,comparison_mode_pcc=0.9999,root_report_path=generated/ttnn/reports,report_name=std::nullopt,std::nullopt}


<torch._C.Generator at 0x7ff6c9ebb2f0>

In [2]:
from llama2.model import precompute_freqs_cis
from llama2.tokenizer import Tokenizer

# Load checkpoint dict
checkpoint_dict = torch.load("llama2/configs/stories42M.pt")
enc = Tokenizer(tokenizel_path="./llama2/tokenizerM.model")


model_args = checkpoint_dict['model_args']
state_dict = checkpoint_dict['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)

args = ModelArgs(**model_args)
print(args)



torch_freqs_cos, torch_freqs_sin = precompute_freqs_cis(args.dim // args.n_heads, args.max_seq_len)
freqs_cos = torch_freqs_cos.bfloat16()
freqs_sin = torch_freqs_sin.bfloat16()


ModelArgs(dim=512, n_layers=8, n_heads=8, n_kv_heads=8, vocab_size=32000, hidden_dim=None, multiple_of=32, norm_eps=1e-05, max_seq_len=1024, dropout=0.5)


In [3]:
batch_size = 1
tokens_num = args.max_seq_len

In [4]:
device_id = 0
device = ttnn.open_device(device_id=device_id)
device.enable_program_cache()

[38;2;000;128;000m                 Device[0m | [1m[38;2;100;149;237mINFO    [0m | Opening user mode device driver

[32m2025-02-14 15:48:11.302[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Opened PCI device 0; KMD version: 1.30.0, IOMMU: disabled
[32m2025-02-14 15:48:11.303[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Detected PCI devices: [0]
[32m2025-02-14 15:48:11.303[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Using local chip ids: {0} and remote chip ids {}
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Initializing device 0. Program cache is NOT enabled
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | AI CLK for device 0 is:   1000 MHz
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Profiler started on device 0
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;23

### Generation


In [5]:
def get_next_token(logits, temperature, top_k):
    logits = logits[:, -1, :] # crop to just the final time step

    if temperature == 0.0:
        # "sample" the single most likely index
        _, idx_next = torch.topk(logits, k=1, dim=-1)
    else:
        # pluck the logits at the final step and scale by desired temperature
        logits = logits / temperature
        # optionally crop the logits to only the top k options
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = -float('Inf')
        # apply softmax to convert logits to (normalized) probabilities
        probs = torch.functional.F.softmax(logits, dim=-1)
        idx_next = torch.multinomial(probs, num_samples=1)
    
    return idx_next

In [6]:
temperature = 1.0
top_k = 300
token_to_generate = 100

start = "hiari is trying to make this damn thing work"
start_ids = enc.encode(start, bos=True, eos=False)


## TTNN llama2

In [7]:
from ttllama2 import Transformer

tt_trans = Transformer(args, state_dict, device)

In [8]:
from tqdm import tqdm

x = (torch.tensor(start_ids, dtype=torch.long)[None, ...])

text = ""

first_token = 0
time_forward = 0
start_generation = time.time()

for _ in tqdm(range(token_to_generate)):
    
    x = x if x.size(1) <= args.max_seq_len else x[:, -args.max_seq_len:]

    tt_x = ttnn.from_torch(x, device=device)
    tt_x = ttnn.to_layout(tt_x, layout=ttnn.ROW_MAJOR_LAYOUT)

    start_forward = time.time()
    logits = tt_trans.forward(tt_x)
    time_forward += time.time() - start_forward

    if first_token == 0:
        first_token = time_forward

    logits = ttnn.to_torch(logits)
    idx_next = get_next_token(logits, temperature, top_k)    
    x = torch.cat((x, idx_next), dim=1)
    
print(enc.decode(x[0].tolist()))

tot_generation = time.time() - start_generation


100%|██████████| 100/100 [02:23<00:00,  1.43s/it]

hiari is trying to make this damn thing work. He had worked so hard but his mom was not very happy. She said it was too hard. So, his dad took him on an airplane to make it there.
When they arrived, his dad took him in and got in the airplane. It was so much fun! He was amazed by the sky and the wind was blowing throughout his faces. Everything was so big and beautiful.
When it was over, his dad handed him a small hammer





In [12]:
print(f"Tot generation time: {tot_generation:.3f}")
print(f"Tot forward time: {time_forward:.3f}")
print(f"Time To First Token: {first_token:.3f}")
print(f"Tokens per Second: {token_to_generate / tot_generation:.3f}")

Tot generation time: 143.449
Tot forward time: 136.042
Time To First Token: 10.972
Tokens per Second: 0.697


## Torch CPU

In [10]:
from llama2.model import Transformer as torchTransformer

torch_trans = torchTransformer(args, state_dict)

# [1,13]
x = (torch.tensor(start_ids, dtype=torch.long)[None, ...])

first_token_torch = 0
time_forward_torch = 0
start_generation = time.time()

for _ in tqdm(range(token_to_generate)):
    # if the sequence context is growing too long we must crop it at block_size
    x = x if x.size(1) <= args.max_seq_len else x[:, -args.max_seq_len:]
    
    start_forward = time.time()
    logits = torch_trans.forward(x)   
    time_forward_torch += time.time() - start_forward

    if first_token_torch == 0:
        first_token_torch = time_forward_torch
    # forward the model to get the logits for the index in the sequence

    idx_next = get_next_token(logits, temperature, top_k)
    
    # append sampled index to the running sequence and continue
    x = torch.cat((x, idx_next), dim=1)

tot_generation_torch = time.time() - start_generation

print(enc.decode(x[0].tolist()))

100%|██████████| 100/100 [02:17<00:00,  1.38s/it]

hiari is trying to make this damn thing work all her others.
She wasnby the and a,inter still playing. She had so long, little in each place.
One. particular was not a fierce fire, the can was the M true.
She leaned down to the first, then the night, in the back savally howling, afraid to find himself. Come. In your jail or your friends were und with such the preby of nature of a novel. 10 roo?"
H





In [11]:
print(f"Tot generation time: {tot_generation_torch}")
print(f"Tot forward time: {time_forward_torch}")
print(f"Time To First Token: {first_token_torch}")
print(f"Tokens per Second: {token_to_generate / tot_generation_torch}")

Tot generation time: 137.9171919822693
Tot forward time: 137.38418817520142
Time To First Token: 0.3134133815765381
Tokens per Second: 0.7250727669459516
