In [1]:
import random

import numpy as np
import torch
from transformers import LlamaForCausalLM, LlamaConfig

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def set_seed(seed: int):
    """
    Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch` and/or `tf` (if installed).

    Args:
        seed (`int`): The seed to set.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

config = LlamaConfig(num_hidden_layers=1, hidden_size=1024, intermediate_size=3000)
set_seed(42)
hf_llama = LlamaForCausalLM(config)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from modeling_llama import LlamaForCausalLM

set_seed(42)

llama = LlamaForCausalLM(config)

In [3]:
with torch.no_grad():
    ref_state_dict = hf_llama.state_dict()
    model_state_dict = llama.state_dict()

    for k in ref_state_dict.keys():
        ref = ref_state_dict[k].float()
        current = model_state_dict[k].cpu().float()
        assert torch.allclose(
            ref, current, atol=1e-2
        ), f"Model state dict does not match the reference model state dict for key {k}. Difference: {(ref - current).abs().max()}"

    print("Model state dict matches the reference model state dict")

Model state dict matches the reference model state dict


In [4]:
torch.save(llama.state_dict(), "initial_llama.pth")

In [4]:
import os

from transformers import AutoTokenizer
from torch.optim import Adam

hf_token = os..environ['HF_TOKEN']

optimizer = Adam(llama.parameters(), lr=1e-1, betas=(0.9, 0.999), eps=1e-8)
tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-chat-hf', token=hf_token)

test_input = 'test llama forward backward'
inputs = tokenizer(test_input, return_tensors='pt')
inputs['labels'] = inputs['input_ids']

loss = llama(**inputs).loss
print("loss: ", loss)

loss.backward()

torch.save(llama.model.layers[0].self_attn.q_proj.weight.grad, 'q_proj_grad.pth')
torch.save(llama.lm_head.weight.grad, 'lm_head_grad.pth')
torch.save(llama.model.embed_tokens.weight.grad, 'embed_grad.pth')

optimizer.step()


In [7]:
with torch.no_grad():
    ref_state_dict = hf_llama.state_dict()
    model_state_dict = llama.state_dict()

    for k in ref_state_dict.keys():
        ref = ref_state_dict[k].float()
        current = model_state_dict[k].cpu().float()
        assert torch.allclose(
            ref, current, atol=1e-2
        ), f"Model state dict does not match the reference model state dict for key {k}. Difference: {(ref - current).abs().max()}"

    print("Model state dict matches the reference model state dict")

AssertionError: Model state dict does not match the reference model state dict for key model.embed_tokens.weight. Difference: 0.10000000894069672

In [8]:
torch.save(llama.state_dict(), "one_step_llama.pth")