In [31]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
#https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py

model_id = 'meta-llama/Llama-3.2-1B-Instruct'
device = 'cpu'

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
).to(device)


In [None]:
from llama.configuration_llama import LlamaConfig

config = LlamaConfig()
print(config)
print(model.config)

my_config = model.config
my_config 

In [None]:
input_text = 'Hi, I am a dog and I like to wo'
input = tokenizer([input_text], return_tensors = 'pt')

print(input)

In [None]:
output = model.generate(
    **input,  # Unpack the dictionary to pass input_ids and attention_mask
    max_length=50,      # Maximum length of generated text (adjust as needed)
    do_sample=False,          # Whether to sample for more diverse generations
)
tokenizer.decode(output[0], skip_special_tokens=True)

In [None]:
with torch.no_grad():
    r = model(input['input_ids'], input['attention_mask'])
    output_id = torch.argmax(r['logits'][:, -1, :], dim = -1)
    print(tokenizer.decode(output_id))

In [None]:
with torch.no_grad():
    r = model(input['input_ids'], input['attention_mask'])
    
    
    output_id = torch.argmax(r['logits'][:, -1, :], dim = -1)
    print(tokenizer.decode(output_id))


In [None]:
with torch.no_grad():
    base_output = model.model(input['input_ids'], input['attention_mask'])
    logits = model.lm_head(base_output[0])

    output_id = torch.argmax(logits[:, -1, :], dim = -1)
    print(tokenizer.decode(output_id))

In [None]:
with torch.no_grad():
    # base_output = model.model(input['input_ids'], input['attention_mask'])
    embeds = model.model.embed_tokens(input['input_ids'])

    cache_position = torch.arange(0, 0 + embeds.shape[1], device=embeds.device)
    position_ids = cache_position.unsqueeze(0)
    causal_mask = model.model._update_causal_mask(input['attention_mask'], embeds, cache_position, None, False)
    print(causal_mask)

    # Create position embeddings to be shared across transformer blocks
    position_embeddings = model.model.rotary_emb(embeds, position_ids)

    hidden_state = embeds
    for i, layer in enumerate(model.model.layers):            
        layer_output = layer(
            hidden_state,
            causal_mask,
            position_ids,
            position_embeddings = position_embeddings
        )

        hidden_state = layer_output[0]


    hidden_states = model.model.norm(hidden_state)
    
    logits = model.lm_head(hidden_states)

    output_id = torch.argmax(logits[:, -1, :], dim = -1)
    print(tokenizer.decode(output_id))

In [None]:
with torch.no_grad():
    # base_output = model.model(input['input_ids'], input['attention_mask'])
    embeds = model.model.embed_tokens(input['input_ids'])

    cache_position = torch.arange(0, 0 + embeds.shape[1], device=embeds.device)
    position_ids = cache_position.unsqueeze(0)
    causal_mask = model.model._update_causal_mask(input['attention_mask'], embeds, cache_position, None, False)
    print(causal_mask)

    # Create position embeddings to be shared across transformer blocks
    position_embeddings = model.model.rotary_emb(embeds, position_ids)

    hidden_state = embeds
    for i, layer in enumerate(model.model.layers):      

        residual = hidden_state
        hidden_state = layer.input_layernorm(hidden_state)   
        
        sa_output, _, _ = layer.self_attn(hidden_state, causal_mask, position_ids, position_embeddings = position_embeddings)
        hidden_state = residual + sa_output

        residual = hidden_state
        hidden_state = layer.post_attention_layernorm(hidden_state)

        mlp_output = layer.mlp(hidden_state)
        hidden_state = residual + mlp_output

    hidden_states = model.model.norm(hidden_state)
    
    logits = model.lm_head(hidden_states)

    output_id = torch.argmax(logits[:, -1, :], dim = -1)
    print(tokenizer.decode(output_id))

In [87]:
position_embeddings = model.model.rotary_emb(embeds, position_ids)
