In [1]:
import time
import torch
import ttnn
from llama2.model import ModelArgs
import random
from models.utility_functions import comp_pcc

random.seed(42)
torch.manual_seed(42)

2025-02-05 12:00:56.499 | DEBUG    | ttnn:<module>:82 - Initial ttnn.CONFIG:
Config{cache_path=/home/bach/.cache/ttnn,model_cache_path=/home/bach/.cache/ttnn/models,tmp_dir=/tmp/ttnn,enable_model_cache=false,enable_fast_runtime_mode=true,throw_exception_on_fallback=false,enable_logging=false,enable_graph_report=false,enable_detailed_buffer_report=false,enable_detailed_tensor_report=false,enable_comparison_mode=false,comparison_mode_pcc=0.9999,root_report_path=generated/ttnn/reports,report_name=std::nullopt,std::nullopt}


<torch._C.Generator at 0x7f7ee19ef2f0>

In [2]:
from llama2.model import precompute_freqs_cis

# Load checkpoint dict
checkpoint_dict = torch.load("llama2/configs/stories260K.pth")
model_args = checkpoint_dict['model_args']
state_dict = checkpoint_dict['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)

args = ModelArgs()
print(args)

torch_freqs_cos, torch_freqs_sin = precompute_freqs_cis(args.dim // args.n_heads, args.max_seq_len)
freqs_cos = torch_freqs_cos.bfloat16()
freqs_sin = torch_freqs_sin.bfloat16()


ModelArgs(dim=64, n_layers=5, n_heads=8, n_kv_heads=4, vocab_size=512, hidden_dim=None, multiple_of=4, norm_eps=1e-05, max_seq_len=512, dropout=0.5)


In [3]:
batch_size = 1
tokens_num = args.max_seq_len

In [4]:
device_id = 0
device = ttnn.open_device(device_id=device_id)
# device.enable_program_cache()

[38;2;000;128;000m                 Device[0m | [1m[38;2;100;149;237mINFO    [0m | Opening user mode device driver

[32m2025-02-05 12:00:56.749[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Opened PCI device 0; KMD version: 1.30.0, IOMMU: disabled
[32m2025-02-05 12:00:56.750[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Detected PCI devices: [0]
[32m2025-02-05 12:00:56.750[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Using local chip ids: {0} and remote chip ids {}
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Initializing device 0. Program cache is NOT enabled
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | AI CLK for device 0 is:   1000 MHz
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Profiler started on device 0


### Generation


In [5]:
def get_next_token(logits, temperature, top_k):
    logits = logits[:, -1, :] # crop to just the final time step

    if temperature == 0.0:
        # "sample" the single most likely index
        _, idx_next = torch.topk(logits, k=1, dim=-1)
    else:
        # pluck the logits at the final step and scale by desired temperature
        logits = logits / temperature
        # optionally crop the logits to only the top k options
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = -float('Inf')
        # apply softmax to convert logits to (normalized) probabilities
        probs = torch.functional.F.softmax(logits, dim=-1)
        idx_next = torch.multinomial(probs, num_samples=1)
    
    return idx_next

In [53]:
from llama2.tokenizer import Tokenizer

enc = Tokenizer(tokenizel_path="./llama2/tokenizer.model")
temperature = 1.0
top_k = 300
token_to_generate = 100

start = "Dake was walking"
start_ids = enc.encode(start, bos=True, eos=False)

## TTNN llama2

In [54]:
from ttllama2 import Transformer
from tqdm import tqdm

tt_trans = Transformer(args, state_dict, device)

x = (torch.tensor(start_ids, dtype=torch.long)[None, ...])

text = ""

first_token = 0
time_forward = 0
start_generation = time.time()

for _ in tqdm(range(token_to_generate)):
    
    tt_x = ttnn.from_torch(x, device=device)
    tt_x = ttnn.to_layout(tt_x, layout=ttnn.ROW_MAJOR_LAYOUT)

    start_forward = time.time()
    logits = tt_trans.forward(tt_x)
    time_forward += time.time() - start_forward

    if first_token == 0:
        first_token = start_forward

    logits = ttnn.to_torch(logits)
    idx_next = get_next_token(logits, temperature, top_k)    
    x = torch.cat((x, idx_next), dim=1)
    
print(enc.decode(x[0].tolist()))

tot_generation = time.time() - start_generation


100%|██████████| 100/100 [01:13<00:00,  1.37it/s]

Dake was walking in Wain. judge was fast and she hurt herself. She left a mysterious cone, a man with a vase.
"Be carefully, Nappy!" he said.
"Please, please hope this is a napk





In [59]:
print(f"Tot generation time: {tot_generation}")
print(f"Tot forward time: {time_forward}")
print(f"Time To First Token: {first_token}")
print(f"Tokens per Second: {token_to_generate / tot_generation}")

Tot generation time: 73.04002594947815
Tot forward time: 72.6094172000885
Time To First Token: 1738760259.4030492
Tokens per Second: 1.3691123284809632


## Torch CPU

In [49]:
from llama2.model import Transformer as torchTransformer

torch_trans = torchTransformer(args, state_dict)

# [1,13]
x = (torch.tensor(start_ids, dtype=torch.long)[None, ...])

first_token_torch = 0
time_forward_torch = 0
start_generation = time.time()

for _ in tqdm(range(token_to_generate)):
    # if the sequence context is growing too long we must crop it at block_size
    # x = x if idx.size(1) <= args.max_seq_len else idx[:, -args.max_seq_len:]
    
    start_forward = time.time()
    logits = torch_trans.forward(x)   
    time_forward_torch += time.time() - start_forward

    if first_token_torch == 0:
        first_token_torch = start_forward
    # forward the model to get the logits for the index in the sequence

    idx_next = get_next_token(logits, temperature, top_k)
    
    # append sampled index to the running sequence and continue
    x = torch.cat((x, idx_next), dim=1)

tot_generation_torch = time.time() - start_generation

print(enc.decode(x[0].tolist()))

100%|██████████| 100/100 [00:02<00:00, 44.66it/s]

Dream comes true! The duck loved to kick their favorite puppy. One day Diamble, a rich ant named Cuppy who went outside to play. Dino was very clapping and okay.
Max's monkey gave him a big hug. Finally,





In [52]:
print(f"Tot generation time: {tot_generation_torch}")
print(f"Tot forward time: {time_forward_torch}")
print(f"Time To First Token: {first_token_torch}")
print(f"Tokens per Second: {token_to_generate / tot_generation_torch}")

Tot generation time: 2.2409582138061523
Tot forward time: 2.187505006790161
Time To First Token: 1738754933.197594
Tokens per Second: 44.623768254096596
