Skip to content

Commit

Permalink
[Infer/Fix] Fix Dependency in test - RMSNorm kernel (#5399)
Browse files Browse the repository at this point in the history
fix dependency in pytest
  • Loading branch information
yuanheng-zhao committed Feb 26, 2024
1 parent bc1da87 commit 1906118
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions tests/test_infer/test_ops/triton/test_rmsnorm_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
import triton
from packaging import version
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from vllm.model_executor.layers.layernorm import RMSNorm

from colossalai.kernel.triton import rms_layernorm
from colossalai.testing.utils import parameterize

try:
pass
import triton # noqa

HAS_TRITON = True
except ImportError:
Expand Down Expand Up @@ -85,6 +84,11 @@ def benchmark_rms_layernorm(
SEQUENCE_TOTAL: int,
HIDDEN_SIZE: int,
):
try:
from vllm.model_executor.layers.layernorm import RMSNorm
except ImportError:
raise ImportError("Please install vllm from https://github.com/vllm-project/vllm")

warmup = 10
rep = 1000

Expand Down

0 comments on commit 1906118

Please sign in to comment.