Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Models] Llama2 fix #333

Merged
merged 12 commits into from
Jul 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions python/hidet/testing/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from tqdm import tqdm

import torch
import transformers
from transformers import LlamaConfig, LlamaTokenizer
from transformers import LlamaForCausalLM as hfLm

Expand All @@ -34,10 +35,12 @@ def copy_weights(torch_model, hidet_model):
print(type(mod))
raise ValueError(f"hidet/hf mismatch at {name}")

src = hidet.from_torch(tensor).to(mod.dtype, mod.device)
if len(src.shape) != len(mod.shape) or any(a != b for a, b in zip(src.shape, mod.shape)):
print(transformers.__version__)
raise RuntimeError(f"hidet/hf shape mismatch at {name}, hidet: {mod.shape}, torch: {src.shape}")
found_tensors.append(mod)
mod.copy_(hidet.from_torch(tensor).to(mod.dtype, mod.device))
if mod.shape != tensor.shape:
print(f"hidet/hf shape mismatch at {name}, hidet: {mod.shape}, torch: {tensor.shape}")
mod.copy_(src)

buffer_names = set(name for name, _ in torch_model.named_buffers())

Expand Down Expand Up @@ -145,7 +148,7 @@ def repeat_kv(hidden_states: hidet.Tensor, n_rep: int) -> hidet.Tensor:
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
return hidden_states.reshape((batch, num_key_value_heads * n_rep, slen, head_dim))


class LlamaAttention(nn.Module):
Expand Down Expand Up @@ -188,8 +191,12 @@ def forward(
raise RuntimeError("Pretraining TP > 1 is not supported yet")

query_states = self.q_proj(hidden_states).reshape([bsz, q_len, self.num_heads, self.head_dim]).transpose(1, 2)
key_states = self.k_proj(hidden_states).reshape([bsz, q_len, self.num_heads, self.head_dim]).transpose(1, 2)
value_states = self.v_proj(hidden_states).reshape([bsz, q_len, self.num_heads, self.head_dim]).transpose(1, 2)
key_states = (
self.k_proj(hidden_states).reshape([bsz, q_len, self.num_key_value_heads, self.head_dim]).transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states).reshape([bsz, q_len, self.num_key_value_heads, self.head_dim]).transpose(1, 2)
)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
Expand Down Expand Up @@ -388,7 +395,7 @@ def build_flow_graph(model, batch_size=1, device='cuda', dtype='float16'):
position_ids = hidet.symbol([batch_size, config.max_position_embeddings], dtype=hidet.int32, device=device)

get_sym = lambda: hidet.symbol(
[batch_size, config.num_attention_heads, "prev_seq_len", config.hidden_size // config.num_attention_heads],
[batch_size, config.num_key_value_heads, "prev_seq_len", config.hidden_size // config.num_key_value_heads],
device=device,
dtype=dtype,
)
Expand All @@ -415,7 +422,7 @@ def generate(text: str, model, tokenizer, config, num_tokens=20, device='cuda',
position_ids = hidet.arange(0, config.max_position_embeddings, dtype=hidet.int32, device=device).unsqueeze(0)

make_past = lambda: hidet.zeros(
[1, config.num_attention_heads, 0, config.hidden_size // config.num_attention_heads], device=device, dtype=dtype
[1, config.num_key_value_heads, 0, config.hidden_size // config.num_key_value_heads], device=device, dtype=dtype
)
past_keys_values = [make_past() for _ in range(config.num_hidden_layers * 2)]

Expand All @@ -438,7 +445,7 @@ def generate_torch(input_ids: str, tokenizer, torch_model, num_tokens, device='c
attention_mask = torch.ones([1, config.max_position_embeddings]).to(device=device, dtype=dtype)
# position_ids = torch.arange(0, config.max_position_embeddings, device='cuda').unsqueeze(0)
make_past = lambda: torch.zeros(
[1, config.num_attention_heads, 0, config.hidden_size // config.num_attention_heads]
[1, config.num_key_value_heads, 0, config.hidden_size // config.num_key_value_heads]
).to(device=device, dtype=dtype)
key_value_cache = [(make_past(), make_past()) for _ in range(config.num_hidden_layers)]
outputs = []
Expand Down
2 changes: 1 addition & 1 deletion tests/frontends/torch/models/test_torch_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
@pytest.mark.parametrize('batch_size', [1])
@pytest.mark.parametrize('seq_length', [128])
@pytest.mark.parametrize('use_fp16,use_tensor_core', [(False, False), (False, True), (True, True)])
@pytest.mark.parametrize('dynamic', [False, True])
@pytest.mark.parametrize('dynamic', [False]) # TODO: enable dynamic when torch dynamo is fixed
def test_bert(batch_size: int, seq_length: int, use_fp16, use_tensor_core, dynamic):
tokens_tensor = torch.zeros((batch_size, seq_length), dtype=torch.long, device='cuda')
segments_tensors = torch.zeros((batch_size, seq_length), dtype=torch.long, device='cuda')
Expand Down
60 changes: 59 additions & 1 deletion tests/models/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# limitations under the License.
# %%
import pytest
from hidet.testing.models.llama import get_compiled_model, generate
from hidet.testing.models.llama import get_compiled_model, generate, convert_model
from hidet.runtime.storage import current_memory_pool


Expand Down Expand Up @@ -63,3 +63,61 @@ def test_llama2(device, opt):
print(current_memory_pool("cuda"))
print(current_memory_pool("cpu"))
print(current_memory_pool("vcuda"))


def test_model_architecture():
import torch
import hidet
from transformers.models.llama import LlamaForCausalLM as hfLm, LlamaConfig

config = LlamaConfig(
**{
"architectures": ["LlamaForCausalLM"],
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 11008,
"max_position_embeddings": 4096,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 2,
"num_key_value_heads": 1,
"pad_token_id": 0,
"pretraining_tp": 1,
"rms_norm_eps": 1e-05,
"rope_scaling": None,
"tie_word_embeddings": False,
"torch_dtype": "float16",
"use_cache": True,
"vocab_size": 32000,
}
)

with torch.device("cuda"):
hf_model = hfLm(config).eval()

model = convert_model(hf_model, device='cuda', dtype=hidet.float32)

def build_flow_graph(model, batch_size=1, device='cuda', dtype='float16'):
config = model.config
input_ids = hidet.symbol([batch_size, 'seq_len'], dtype=hidet.int32, device=device)
position_ids = hidet.symbol([batch_size, config.max_position_embeddings], dtype=hidet.int32, device=device)

y = model(input_ids, position_ids=position_ids, past_key_values=None) # key_value_cache)
inputs = [input_ids, position_ids]

outputs = [y['logits']]
return hidet.trace_from(outputs, inputs)

cmodel = build_flow_graph(model, batch_size=1, device='cuda', dtype=hidet.float32)

x = torch.randint(0, 32000, (1, 512), dtype=torch.int32).cuda()
pos_ids = torch.arange(0, config.max_position_embeddings, dtype=torch.int32).reshape(1, -1).cuda()
res1 = hf_model(x)
res2 = cmodel(hidet.from_torch(x), hidet.from_torch(pos_ids))

logits1 = res1.logits
logits2 = res2.torch()
assert torch.allclose(logits1, logits2, rtol=1e-3, atol=1e-3)