In [1]:
import gc
import torch
from coreml_llama import *

In [2]:
M_path = "/Volumes/무제/llama3_2/Llama3.2-1B-Instruct"

In [3]:
import json
with open(f"{M_path}/params.json", "r") as st_json:
    params = json.load(st_json)
params

args = ModelArgs(**params)
transformer = Transformer(args)

model_pth = torch.load(f"{M_path}/consolidated.00.pth", map_location="cpu", weights_only=True)
transformer.load_state_dict(model_pth, strict=False)
transformer.eval()

Transformer(
  (tok_embeddings): Embedding(128256, 2048)
  (layers): ModuleList(
    (0-15): 16 x TransformerBlock(
      (attention): Attention(
        (wq): Linear(in_features=2048, out_features=2048, bias=False)
        (wk): Linear(in_features=2048, out_features=512, bias=False)
        (wv): Linear(in_features=2048, out_features=512, bias=False)
        (wo): Linear(in_features=2048, out_features=2048, bias=False)
        (rope): RoPE()
      )
      (feed_forward): FeedForward(
        (w1): Linear(in_features=2048, out_features=8192, bias=False)
        (w2): Linear(in_features=8192, out_features=2048, bias=False)
        (w3): Linear(in_features=2048, out_features=8192, bias=False)
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
    )
  )
  (norm): RMSNorm()
  (output): Linear(in_features=2048, out_features=128256, bias=False)
)

In [4]:
from tokenizer import Tokenizer, ChatFormat
tok = Tokenizer(f"{M_path}/tokenizer.model")
formatter = ChatFormat(tok)

In [5]:
transformer = transformer.to(device= "mps")

In [6]:
dialogs = [
    [{"role": "user", "content": "hello!😆"}],
]

prompt_tokens = [
    formatter.encode_dialog_prompt(dialog) for dialog in dialogs
]
# prompt = torch.tensor(np.array(tok.encode("hello world!", bos= True, eos= False))[None, :])
prompt = torch.tensor(prompt_tokens, device= "mps")

In [7]:

pad_id = tok.pad_id
tokens = torch.full((1, 1000), pad_id, dtype=torch.long, device= "mps")

for k, t in enumerate(prompt):
    tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device= "mps")
token_logprobs = torch.zeros_like(tokens, dtype=torch.float, device= "mps")

prev_pos = 0
eos_reached = torch.tensor([False] * 1, device= "mps")
input_text_mask = tokens != pad_id

temperature = 0
stop_tokens = torch.tensor(list(tok.stop_tokens), device= "mps")
prev_pos = 0

  tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device= "mps")


In [8]:
cache = KVCache(
    transformer.caches_shape, dtype=torch.float16, device= "mps"
)

for cur_pos in range(len(prompt[0]), 800):
    seqlen = tokens[:, prev_pos:cur_pos].size(1)

    mask = torch.full((seqlen, seqlen), -1e9, device= "mps")
    mask = torch.triu(mask, diagonal=1)
    mask = torch.hstack(
        [torch.zeros((seqlen, prev_pos), device= "mps"), mask]
    )[None, None, :, :]

    logits = transformer.forward(tokens[:, prev_pos:cur_pos], mask, cache)
    next_token = torch.argmax(logits[..., -1, :], dim=-1)

    next_token = next_token.reshape(-1)
    next_token = torch.where(
        input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
    )
    tokens[:, cur_pos] = next_token
    
    eos_reached |= (~input_text_mask[:, cur_pos]) & (
        torch.isin(next_token, stop_tokens)
    )
    prev_pos = cur_pos
    if all(eos_reached):
        break
    # break

In [11]:
print(tok.decode(tokens[0, :prev_pos].tolist()))

<|begin_of_text|><|start_header_id|>user<|end_header_id|>

hello!😆<|eot_id|><|start_header_id|>assistant<|end_header_id|>

😊 Hello! How's your day going so far?


In [12]:
Trained_Transformer(args).load_state_dict(transformer.state_dict())

<All keys matched successfully>