Skip to content

Commit

Permalink
Again
Browse files Browse the repository at this point in the history
  • Loading branch information
mfuntowicz committed Dec 14, 2023
1 parent a7bf2d5 commit d878b56
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions tests/models/test_llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from unittest import TestCase

from parameterized import parameterized
from optimum.nvidia.utils.tests import requires_gpu
from optimum.nvidia.models.llama import LLamaForCausalLM as TrtLlamaForCausalLM



class LLamaForCausalLMTestCase(TestCase):

@requires_gpu
@parameterized.expand(["float16", "bfloat16"])
def test_build_engine_7b_with_tp(self, dtype: str):
model = TrtLlamaForCausalLM.from_pretrained("huggingface/llama-7b", dtype=dtype)
self.assertIsNotNone(model)

0 comments on commit d878b56

Please sign in to comment.