In [1]:
import torch
from time import perf_counter
import numpy as np
from transformers import LlamaForCausalLM, LlamaTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = '/dataset/crosspipe/llama-2-chat/Llama-2-7b-chat-hf'
model = LlamaForCausalLM.from_pretrained(model_name).cuda()
tokenizer = LlamaTokenizer.from_pretrained(model_name)

Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.20s/it]


In [3]:
model.eval()

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNo

In [6]:
text = "Hello, my name is Llama."
inputs = tokenizer(text, return_tensors='pt').to('cuda')

In [10]:
# 包装模型的前向传播，简化输出
class WrappedModel(torch.nn.Module):
    def __init__(self, model):
        super(WrappedModel, self).__init__()
        self.model = model

    def forward(self, input_ids):
        outputs = self.model(input_ids)
        return outputs.logits

wrapped_model = WrappedModel(model)

In [11]:
wrapped_model

WrappedModel(
  (model): LlamaForCausalLM(
    (model): LlamaModel(
      (embed_tokens): Embedding(32000, 4096, padding_idx=0)
      (layers): ModuleList(
        (0-31): 32 x LlamaDecoderLayer(
          (self_attn): LlamaAttention(
            (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (rotary_emb): LlamaRotaryEmbedding()
          )
          (mlp): LlamaMLP(
            (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
            (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
            (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
            (act_fn): SiLUActivation()
          )
          (input_layernorm): LlamaRMSNorm()
          (post_attention_

In [26]:
traced_model = torch.jit.trace(wrapped_model, (inputs.input_ids,))

  if input_shape[-1] > 1:
  if seq_len > self.max_seq_len_cached:
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
  if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):


In [12]:
traced_model

WrappedModel(
  original_name=WrappedModel
  (model): LlamaForCausalLM(
    original_name=LlamaForCausalLM
    (model): LlamaModel(
      original_name=LlamaModel
      (embed_tokens): Embedding(original_name=Embedding)
      (layers): ModuleList(
        original_name=ModuleList
        (0): LlamaDecoderLayer(
          original_name=LlamaDecoderLayer
          (self_attn): LlamaAttention(
            original_name=LlamaAttention
            (q_proj): Linear(original_name=Linear)
            (k_proj): Linear(original_name=Linear)
            (v_proj): Linear(original_name=Linear)
            (o_proj): Linear(original_name=Linear)
            (rotary_emb): LlamaRotaryEmbedding(original_name=LlamaRotaryEmbedding)
          )
          (mlp): LlamaMLP(
            original_name=LlamaMLP
            (gate_proj): Linear(original_name=Linear)
            (up_proj): Linear(original_name=Linear)
            (down_proj): Linear(original_name=Linear)
            (act_fn): SiLUActivation(origina

In [13]:
def timer(f, *args):
    torch.cuda.synchronize()  # 确保之前所有的GPU操作完成
    start = perf_counter()
    f(*args)
    torch.cuda.synchronize()  # 确保当前所有的GPU操作完成
    return 1000 * (perf_counter() - start)

# 测量traced_model的执行时间
times = [timer(traced_model, inputs.input_ids) for _ in range(100)]
print(np.mean(times))

570.3120026789838


In [None]:
inputs = tokenizer(text, return_tensors='pt').to('cuda')

In [14]:
import time


In [15]:

def timer(f, *args):
    torch.cuda.synchronize()  # 确保之前所有的GPU操作完成
    start = perf_counter()
    f(*args)
    torch.cuda.synchronize()  # 确保当前所有的GPU操作完成
    return 1000 * (perf_counter() - start)

# 定义包装函数以简化模型输出
def model_forward(input_ids):
    outputs = model(input_ids)
    return outputs.logits

# 测量模型的执行时间
times = []
for _ in range(100):
    elapsed_time = timer(model_forward, inputs.input_ids)
    times.append(elapsed_time)

# 计算平均时间
average_time = np.mean(times)
print(f"Average inference time over 100 runs: {average_time:.2f} ms")

Average inference time over 100 runs: 55.32 ms


In [25]:
time1=time.time()
timer(model_forward, inputs.input_ids)
time2 = time.time()
time_elapsed = time2-time1
print(time_elapsed)

0.06358981132507324


In [31]:
import torch 
 
class Model(torch.nn.Module): 
    def __init__(self, n): 
        super().__init__() 
        self.n = n 
        self.conv = torch.nn.Conv2d(3, 3, 3) 
 
    def forward(self, x): 
        for i in range(self.n): 
            x = self.conv(x) 
        return x 

In [28]:
models = [Model(2), Model(3)] 
model_names = ['model_2', 'model_3'] 

In [34]:
for model, model_name in zip(models, model_names): 
    dummy_input = torch.rand(1, 3, 10, 10) 
    dummy_output = model(dummy_input) 
    model_trace = torch.jit.trace(model, dummy_input) 
    model_script = torch.jit.script(model) 
    # 跟踪法与直接 torch.onnx.export(model, ...)等价 
    torch.onnx.export(model_trace, dummy_input, f'/home/Graph_module/oonx/{model_name}_trace.onnx') 
    # 记录法必须先调用 torch.jit.sciprt 
    torch.onnx.export(model_script, dummy_input, f'/home/Graph_module/oonx/{model_name}_script.onnx') 

