From 81abb6640349a29d8d0b6e86445696136ed8d1f6 Mon Sep 17 00:00:00 2001 From: bkb2135 Date: Wed, 8 May 2024 17:35:54 +0000 Subject: [PATCH 1/8] Use llama 3 prompt template --- prompting/llms/vllm_llm.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/prompting/llms/vllm_llm.py b/prompting/llms/vllm_llm.py index 97f9496b1..f723093d4 100644 --- a/prompting/llms/vllm_llm.py +++ b/prompting/llms/vllm_llm.py @@ -44,15 +44,17 @@ def load_vllm_pipeline(model_id: str, device: str, mock=False): return MockPipeline(model_id) # Calculates the gpu memory utilization required to run the model within 20GB of GPU - max_allowed_memory_in_gb = 20 - max_allowed_memory_allocation_in_bytes = max_allowed_memory_in_gb * 1e9 - gpu_mem_utilization = calculate_gpu_requirements( - device, max_allowed_memory_allocation_in_bytes - ) + max_allowed_memory_in_gb = 160 + # max_allowed_memory_allocation_in_bytes = max_allowed_memory_in_gb * 1e9 + # gpu_mem_utilization = calculate_gpu_requirements( + # device, max_allowed_memory_allocation_in_bytes + # ) try: # Attempt to initialize the LLM - return LLM(model=model_id, gpu_memory_utilization=gpu_mem_utilization) + _ = LLM(model=model_id, gpu_memory_utilization = 1, tensor_parallel_size = 4, ) + _.llm_engine.tokenizer.eos_token_id = 128009 + return _ except ValueError as e: bt.logging.error( f"Error loading the VLLM pipeline within {max_allowed_memory_in_gb}GB: {e}" @@ -155,17 +157,17 @@ def _make_prompt(self, messages: List[Dict[str, str]]): for message in messages: if message["role"] == "system": composed_prompt += ( - f'<|im_start|>system\n{message["content"]} <|im_end|>' + f'<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{{{{ {message["content"]} }}}}<|eot_id|>' ) elif message["role"] == "user": - composed_prompt += f'<|im_start|>user\n{message["content"]} <|im_end|>' + composed_prompt += f'<|start_header_id|>user<|end_header_id|>\n{{{{ {message["content"]} }}}}<|eot_id|>' elif message["role"] == "assistant": composed_prompt += ( - f'<|im_start|>assistant\n{message["content"]} <|im_end|>' + f'<|start_header_id|>assistant<|end_header_id|>\n{{{{ {message["content"]} }}}}<|eot_id|>' ) # Adds final tag indicating the assistant's turn - composed_prompt += "<|im_start|>assistant\n" + composed_prompt += "<|start_header_id|>assistant<|end_header_id|>" return composed_prompt def forward(self, messages: List[Dict[str, str]]): From 22ac11b92fa4519ca46e2e532586bcf1d1a7c971 Mon Sep 17 00:00:00 2001 From: p-ferreira Date: Mon, 13 May 2024 19:27:40 +0000 Subject: [PATCH 2/8] set vllm restrictions for llama --- prompting/llms/vllm_llm.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/prompting/llms/vllm_llm.py b/prompting/llms/vllm_llm.py index f723093d4..1f8b39e8e 100644 --- a/prompting/llms/vllm_llm.py +++ b/prompting/llms/vllm_llm.py @@ -44,15 +44,15 @@ def load_vllm_pipeline(model_id: str, device: str, mock=False): return MockPipeline(model_id) # Calculates the gpu memory utilization required to run the model within 20GB of GPU - max_allowed_memory_in_gb = 160 - # max_allowed_memory_allocation_in_bytes = max_allowed_memory_in_gb * 1e9 - # gpu_mem_utilization = calculate_gpu_requirements( - # device, max_allowed_memory_allocation_in_bytes - # ) + max_allowed_memory_in_gb = 60 + max_allowed_memory_allocation_in_bytes = max_allowed_memory_in_gb * 1e9 + gpu_mem_utilization = calculate_gpu_requirements( + device, max_allowed_memory_allocation_in_bytes + ) try: # Attempt to initialize the LLM - _ = LLM(model=model_id, gpu_memory_utilization = 1, tensor_parallel_size = 4, ) + _ = LLM(model=model_id, gpu_memory_utilization = gpu_mem_utilization, quantization="AWQ") _.llm_engine.tokenizer.eos_token_id = 128009 return _ except ValueError as e: @@ -69,7 +69,7 @@ def load_vllm_pipeline(model_id: str, device: str, mock=False): clean_gpu_cache() # Increase the memory allocation for the second attempt - max_allowed_memory_in_gb_second_attempt = 24 + max_allowed_memory_in_gb_second_attempt = 70 max_allowed_memory_allocation_in_bytes = ( max_allowed_memory_in_gb_second_attempt * 1e9 ) From ba1d98a03c8f2890ec79fc25e89a1acc15e66861 Mon Sep 17 00:00:00 2001 From: p-ferreira Date: Mon, 13 May 2024 20:40:59 +0000 Subject: [PATCH 3/8] updates default model neuron id --- prompting/utils/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prompting/utils/config.py b/prompting/utils/config.py index 5db083b4a..84784f360 100644 --- a/prompting/utils/config.py +++ b/prompting/utils/config.py @@ -270,7 +270,7 @@ def add_validator_args(cls, parser): "--neuron.model_id", type=str, help="The model to use for the validator.", - default="NousResearch/Nous-Hermes-2-SOLAR-10.7B", + default="casperhansen/llama-3-70b-instruct-awq", ) parser.add_argument( From 1a3db2cc86590d1237c0912ddf3416df99601af4 Mon Sep 17 00:00:00 2001 From: p-ferreira Date: Tue, 14 May 2024 18:47:38 +0000 Subject: [PATCH 4/8] adds initial implemention of handling multiple gpus --- prompting/llms/utils.py | 50 +++++++++++++++++++++++++++++++------- prompting/llms/vllm_llm.py | 19 ++++++++------- prompting/utils/config.py | 7 ++++++ prompting/validator.py | 1 + 4 files changed, 59 insertions(+), 18 deletions(-) diff --git a/prompting/llms/utils.py b/prompting/llms/utils.py index 75ce1fb02..92317e1f8 100644 --- a/prompting/llms/utils.py +++ b/prompting/llms/utils.py @@ -7,15 +7,7 @@ def contains_gpu_index_in_device(device: str) -> bool: pattern = r"^cuda:\d+$" return bool(re.match(pattern, device)) - -def calculate_gpu_requirements( - device: str, max_allowed_memory_allocation_in_bytes: int = 20e9 -) -> float: - """Calculates the memory utilization requirements for the model to be loaded on the device. - Args: - device (str): The device to load the model to. - max_allowed_memory_allocation_in_bytes (int, optional): The maximum allowed memory allocation in bytes. Defaults to 20e9 (20GB). - """ +def calculate_single_gpu_requirements(device: str, max_allowed_memory_allocation_in_bytes: int): if contains_gpu_index_in_device(device): device_with_gpu_index = device else: @@ -39,4 +31,44 @@ def calculate_gpu_requirements( f'{gpu_utilization * 100}% of the GPU memory will be utilized for loading the model to device "{device}".' ) + return gpu_utilization + +def calculate_multiple_gpu_requirements(device: str, gpus: int, max_allowed_memory_allocation_in_bytes: int): + torch.cuda.synchronize() + total_free_memory = 0 + total_gpu_memory = 0 + + for i in range(gpus): + gpu_device = f"cuda:{i}" + global_free, total_memory = torch.cuda.mem_get_info(device=gpu_device) + total_free_memory += global_free + total_gpu_memory += total_memory + + bt.logging.info(f"Total available free memory across all visible {gpus} GPUs: {round(total_free_memory / 10e8, 2)} GB") + bt.logging.info(f"Total GPU memory across all visible GPUs: {gpus} {round(total_gpu_memory / 10e8, 2)} GB") + + if total_free_memory < max_allowed_memory_allocation_in_bytes: + raise torch.cuda.CudaError( + f"Not enough memory across all specified {gpus} GPUs to allocate for the model. Please ensure you have at least {max_allowed_memory_allocation_in_bytes / 10e8} GB of free GPU memory." + ) + + gpu_utilization = round(max_allowed_memory_allocation_in_bytes / total_free_memory, 2) + bt.logging.info( + f"{gpu_utilization * 100}% of the total GPU memory across all GPUs will be utilized for loading the model." + ) + return gpu_utilization + + +def calculate_gpu_requirements( + device: str, gpus: int, max_allowed_memory_allocation_in_bytes: int = 60e9, +) -> float: + """Calculates the memory utilization requirements for the model to be loaded on the device. + Args: + device (str): The device to load the model to. + max_allowed_memory_allocation_in_bytes (int, optional): The maximum allowed memory allocation in bytes. Defaults to 20e9 (20GB). + """ + if gpus == 1: + return calculate_single_gpu_requirements(device, max_allowed_memory_allocation_in_bytes) + else: + return calculate_multiple_gpu_requirements(device, gpus, max_allowed_memory_allocation_in_bytes=max_allowed_memory_allocation_in_bytes) diff --git a/prompting/llms/vllm_llm.py b/prompting/llms/vllm_llm.py index 1f8b39e8e..99c057a8c 100644 --- a/prompting/llms/vllm_llm.py +++ b/prompting/llms/vllm_llm.py @@ -38,7 +38,7 @@ def clean_gpu_cache(): torch.cuda.synchronize() -def load_vllm_pipeline(model_id: str, device: str, mock=False): +def load_vllm_pipeline(model_id: str, device: str, gpus: int, mock=False): """Loads the VLLM pipeline for the LLM, or a mock pipeline if mock=True""" if mock or model_id == "mock": return MockPipeline(model_id) @@ -47,14 +47,14 @@ def load_vllm_pipeline(model_id: str, device: str, mock=False): max_allowed_memory_in_gb = 60 max_allowed_memory_allocation_in_bytes = max_allowed_memory_in_gb * 1e9 gpu_mem_utilization = calculate_gpu_requirements( - device, max_allowed_memory_allocation_in_bytes + device, gpus, max_allowed_memory_allocation_in_bytes ) try: # Attempt to initialize the LLM - _ = LLM(model=model_id, gpu_memory_utilization = gpu_mem_utilization, quantization="AWQ") - _.llm_engine.tokenizer.eos_token_id = 128009 - return _ + llm = LLM(model=model_id, gpu_memory_utilization = gpu_mem_utilization, quantization="AWQ", tensor_parallel_size=gpus) + llm.llm_engine.tokenizer.eos_token_id = 128009 + return llm except ValueError as e: bt.logging.error( f"Error loading the VLLM pipeline within {max_allowed_memory_in_gb}GB: {e}" @@ -77,11 +77,11 @@ def load_vllm_pipeline(model_id: str, device: str, mock=False): f"Retrying to load with {max_allowed_memory_in_gb_second_attempt}GB..." ) gpu_mem_utilization = calculate_gpu_requirements( - device, max_allowed_memory_allocation_in_bytes + device, gpus, max_allowed_memory_allocation_in_bytes, gpus ) # Attempt to initialize the LLM again with increased memory allocation - return LLM(model=model_id, gpu_memory_utilization=gpu_mem_utilization) + return LLM(model=model_id, gpu_memory_utilization=gpu_mem_utilization, tensor_parallel_size=gpus) except Exception as e: bt.logging.error( f"Error loading the VLLM pipeline within {max_allowed_memory_in_gb_second_attempt}GB: {e}" @@ -90,10 +90,11 @@ def load_vllm_pipeline(model_id: str, device: str, mock=False): class vLLMPipeline(BasePipeline): - def __init__(self, model_id: str, device: str = None, mock=False): + def __init__(self, model_id: str, device: str = None, gpus: int = 1, mock=False): super().__init__() - self.llm = load_vllm_pipeline(model_id, device, mock) + self.llm = load_vllm_pipeline(model_id, device, gpus, mock) self.mock = mock + self.gpus = gpus def __call__(self, composed_prompt: str, **model_kwargs: Dict) -> str: if self.mock: diff --git a/prompting/utils/config.py b/prompting/utils/config.py index 84784f360..220a16e26 100644 --- a/prompting/utils/config.py +++ b/prompting/utils/config.py @@ -71,6 +71,13 @@ def add_args(cls, parser): help="Device to run on.", default="cuda" if torch.cuda.is_available() else "cpu", ) + + parser.add_argument( + "--neuron.gpus", + type=int, + help="The number of visible GPUs to be considered in the llm initialization", + default=1, + ) parser.add_argument( "--neuron.epoch_length", diff --git a/prompting/validator.py b/prompting/validator.py index 4a54930af..91ee6840c 100644 --- a/prompting/validator.py +++ b/prompting/validator.py @@ -19,6 +19,7 @@ def __init__(self, config=None): self.llm_pipeline = vLLMPipeline( model_id=self.config.neuron.model_id, + gpus=self.config.gpus, device=self.device, mock=self.config.mock, ) From 25f37076d7d70498e95225be3bcf144eaa7ee6b1 Mon Sep 17 00:00:00 2001 From: p-ferreira Date: Tue, 14 May 2024 19:48:06 +0000 Subject: [PATCH 5/8] adds llm_max_allowed_memory_in_gb param --- prompting/llms/vllm_llm.py | 43 ++++++++------------------------------ prompting/utils/config.py | 9 +++++++- prompting/validator.py | 3 ++- 3 files changed, 19 insertions(+), 36 deletions(-) diff --git a/prompting/llms/vllm_llm.py b/prompting/llms/vllm_llm.py index 99c057a8c..a63ce8dd0 100644 --- a/prompting/llms/vllm_llm.py +++ b/prompting/llms/vllm_llm.py @@ -31,20 +31,20 @@ def clean_gpu_cache(): destroy_model_parallel() gc.collect() torch.cuda.empty_cache() - torch.distributed.destroy_process_group() + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() # Wait for the GPU to clean up time.sleep(10) torch.cuda.synchronize() -def load_vllm_pipeline(model_id: str, device: str, gpus: int, mock=False): +def load_vllm_pipeline(model_id: str, device: str, gpus: int, max_allowed_memory_in_gb: int, mock=False): """Loads the VLLM pipeline for the LLM, or a mock pipeline if mock=True""" if mock or model_id == "mock": return MockPipeline(model_id) - # Calculates the gpu memory utilization required to run the model within 20GB of GPU - max_allowed_memory_in_gb = 60 + # Calculates the gpu memory utilization required to run the model within 20GB of GPU max_allowed_memory_allocation_in_bytes = max_allowed_memory_in_gb * 1e9 gpu_mem_utilization = calculate_gpu_requirements( device, gpus, max_allowed_memory_allocation_in_bytes @@ -55,44 +55,19 @@ def load_vllm_pipeline(model_id: str, device: str, gpus: int, mock=False): llm = LLM(model=model_id, gpu_memory_utilization = gpu_mem_utilization, quantization="AWQ", tensor_parallel_size=gpus) llm.llm_engine.tokenizer.eos_token_id = 128009 return llm - except ValueError as e: - bt.logging.error( - f"Error loading the VLLM pipeline within {max_allowed_memory_in_gb}GB: {e}" - ) - - # If the first attempt fails, retry with increased memory allocation - try: - bt.logging.info( - "Trying to cleanup GPU and retrying to load the model with extra allocation..." - ) - # Clean the GPU from memory before retrying - clean_gpu_cache() - - # Increase the memory allocation for the second attempt - max_allowed_memory_in_gb_second_attempt = 70 - max_allowed_memory_allocation_in_bytes = ( - max_allowed_memory_in_gb_second_attempt * 1e9 - ) - bt.logging.warning( - f"Retrying to load with {max_allowed_memory_in_gb_second_attempt}GB..." - ) - gpu_mem_utilization = calculate_gpu_requirements( - device, gpus, max_allowed_memory_allocation_in_bytes, gpus - ) - - # Attempt to initialize the LLM again with increased memory allocation - return LLM(model=model_id, gpu_memory_utilization=gpu_mem_utilization, tensor_parallel_size=gpus) except Exception as e: bt.logging.error( - f"Error loading the VLLM pipeline within {max_allowed_memory_in_gb_second_attempt}GB: {e}" + f"Error loading the VLLM pipeline within {max_allowed_memory_in_gb}GB: {e}" ) + raise e + class vLLMPipeline(BasePipeline): - def __init__(self, model_id: str, device: str = None, gpus: int = 1, mock=False): + def __init__(self, model_id: str, llm_max_allowed_memory_in_gb:int, device: str = None, gpus: int = 1, mock=False): super().__init__() - self.llm = load_vllm_pipeline(model_id, device, gpus, mock) + self.llm = load_vllm_pipeline(model_id, device, gpus, llm_max_allowed_memory_in_gb, mock) self.mock = mock self.gpus = gpus diff --git a/prompting/utils/config.py b/prompting/utils/config.py index 220a16e26..2f4c148f6 100644 --- a/prompting/utils/config.py +++ b/prompting/utils/config.py @@ -75,9 +75,16 @@ def add_args(cls, parser): parser.add_argument( "--neuron.gpus", type=int, - help="The number of visible GPUs to be considered in the llm initialization", + help="The number of visible GPUs to be considered in the llm initialization. This parameter currently reflects on the property `tensor_parallel_size` of vllm", default=1, ) + + parser.add_argument( + "--neuron.llm_max_allowed_memory_in_gb", + type=int, + help="The max gpu memory utilization set for initializing the model. This parameter currently reflects on the property `gpu_memory_utilization` of vllm", + default=60, + ) parser.add_argument( "--neuron.epoch_length", diff --git a/prompting/validator.py b/prompting/validator.py index 91ee6840c..030bbd44f 100644 --- a/prompting/validator.py +++ b/prompting/validator.py @@ -19,7 +19,8 @@ def __init__(self, config=None): self.llm_pipeline = vLLMPipeline( model_id=self.config.neuron.model_id, - gpus=self.config.gpus, + gpus=self.config.neuron.gpus, + llm_max_allowed_memory_in_gb=self.config.neuron.llm_max_allowed_memory_in_gb, device=self.device, mock=self.config.mock, ) From 3fcafe72a0c930ec949ec4f1fae8d5ebc84f2365 Mon Sep 17 00:00:00 2001 From: p-ferreira Date: Tue, 14 May 2024 20:03:23 +0000 Subject: [PATCH 6/8] properly document eos token overwrite --- prompting/llms/vllm_llm.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/prompting/llms/vllm_llm.py b/prompting/llms/vllm_llm.py index a63ce8dd0..2c88ba37e 100644 --- a/prompting/llms/vllm_llm.py +++ b/prompting/llms/vllm_llm.py @@ -52,14 +52,17 @@ def load_vllm_pipeline(model_id: str, device: str, gpus: int, max_allowed_memory try: # Attempt to initialize the LLM - llm = LLM(model=model_id, gpu_memory_utilization = gpu_mem_utilization, quantization="AWQ", tensor_parallel_size=gpus) + llm = LLM(model=model_id, gpu_memory_utilization = gpu_mem_utilization, quantization="AWQ", tensor_parallel_size=gpus) + # This solution implemented by @bkb2135 sets the eos_token_id directly for efficiency in vLLM usage. + # This approach avoids the overhead of loading a tokenizer each time the custom eos token is needed. + # Using the Hugging Face pipeline, the eos token specific to llama models was fetched and saved (128009). + # This method provides a straightforward solution, though there may be more optimal ways to manage custom tokens. llm.llm_engine.tokenizer.eos_token_id = 128009 return llm except Exception as e: bt.logging.error( f"Error loading the VLLM pipeline within {max_allowed_memory_in_gb}GB: {e}" - ) - + ) raise e From f500bfb34298f841a7097f676617fc1eefce8b4f Mon Sep 17 00:00:00 2001 From: p-ferreira Date: Tue, 14 May 2024 20:39:24 +0000 Subject: [PATCH 7/8] fix unit tests --- prompting/llms/hf.py | 2 ++ prompting/llms/utils.py | 2 +- tests/test_llm.py | 68 +++++++++++++---------------------------- 3 files changed, 25 insertions(+), 47 deletions(-) diff --git a/prompting/llms/hf.py b/prompting/llms/hf.py index a6f0234e2..869b32411 100644 --- a/prompting/llms/hf.py +++ b/prompting/llms/hf.py @@ -106,6 +106,8 @@ def __init__( mock=False, model_kwargs: dict = None, return_streamer: bool = False, + gpus: int = 1, + llm_max_allowed_memory_in_gb: int = 0 ): super().__init__() self.model = model_id diff --git a/prompting/llms/utils.py b/prompting/llms/utils.py index 92317e1f8..8de32ba81 100644 --- a/prompting/llms/utils.py +++ b/prompting/llms/utils.py @@ -61,7 +61,7 @@ def calculate_multiple_gpu_requirements(device: str, gpus: int, max_allowed_memo def calculate_gpu_requirements( - device: str, gpus: int, max_allowed_memory_allocation_in_bytes: int = 60e9, + device: str, gpus: int, max_allowed_memory_allocation_in_bytes: float, ) -> float: """Calculates the memory utilization requirements for the model to be loaded on the device. Args: diff --git a/tests/test_llm.py b/tests/test_llm.py index 92653d41b..4d26c2d5f 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -44,7 +44,7 @@ def test_llm_clean_response( def test_load_pipeline_mock(pipeline: BasePipeline): # Note that the model_id will be used internally as static response for the mock pipeline model_id = "gpt2" - pipeline_instance = pipeline(model_id=model_id, device="cpu", mock=True) + pipeline_instance = pipeline(model_id=model_id, device="cpu", gpus=1, llm_max_allowed_memory_in_gb=0, mock=True) pipeline_message = pipeline_instance("") mock_message = MockPipeline(model_id).forward(messages=[]) @@ -110,7 +110,7 @@ def test_calculate_gpu_requirements( mock_mem_get_info.return_value = (available_memory, available_memory) # Mock current_device to return a default device index if needed mock_current_device.return_value = 0 - result = calculate_gpu_requirements(device, max_allowed_memory_allocation_in_bytes) + result = calculate_gpu_requirements(device=device, gpus=1, max_allowed_memory_allocation_in_bytes=max_allowed_memory_allocation_in_bytes) assert result == expected_result @@ -133,55 +133,31 @@ def test_calulate_gpu_requirements_raises_cuda_error( # Test 1: Success on first attempt @patch("prompting.llms.vllm_llm.calculate_gpu_requirements") @patch("prompting.llms.vllm_llm.LLM") -def test_load_vllm_pipeline_success_first_try( +def test_load_vllm_pipeline_success( mock_llm, mock_calculate_gpu_requirements ): # Mocking calculate_gpu_requirements to return a fixed value mock_calculate_gpu_requirements.return_value = 5e9 # Example value - # Mocking LLM to return a mock LLM object without raising an exception - mock_llm.return_value = MagicMock(spec=LLM) - result = load_vllm_pipeline(model_id="test_name", device="cuda") - assert isinstance(result, MagicMock) # or any other assertion you find suitable - mock_llm.assert_called_once() # Ensures LLM was called exactly once - - -# # Test 2: Success on second attempt with larger memory allocation -@patch("prompting.llms.vllm_llm.clean_gpu_cache") -@patch("prompting.llms.vllm_llm.calculate_gpu_requirements") -@patch( - "prompting.llms.vllm_llm.LLM", - side_effect=[ValueError("First attempt failed"), MagicMock(spec=LLM)], -) -def test_load_vllm_pipeline_success_second_try( - mock_llm, mock_calculate_gpu_requirements, mock_clean_gpu_cache -): - mock_calculate_gpu_requirements.return_value = 5e9 # Example value for both calls + # Creating a mock for the tokenizer with the desired eos_token_id + mock_tokenizer = MagicMock() + mock_tokenizer.eos_token_id = 12345 - result = load_vllm_pipeline(model_id="test", device="cuda") - assert isinstance(result, MagicMock) - assert mock_llm.call_count == 2 # LLM is called twice - mock_clean_gpu_cache.assert_called_once() # Ensures clean_gpu_cache was called + # Creating a mock for llm_engine and setting its tokenizer + mock_llm_engine = MagicMock() + mock_llm_engine.tokenizer = mock_tokenizer + # Creating the main mock LLM object and setting its llm_engine + mock_llm_instance = MagicMock() + mock_llm_instance.llm_engine = mock_llm_engine -# # Test 3: Exception on second attempt -@patch("prompting.llms.vllm_llm.clean_gpu_cache") -@patch("prompting.llms.vllm_llm.calculate_gpu_requirements") -@patch( - "prompting.llms.vllm_llm.LLM", - side_effect=[ - ValueError("First attempt failed"), - Exception("Second attempt failed"), - ], -) -def test_load_vllm_pipeline_exception_second_try( - mock_llm, mock_calculate_gpu_requirements, mock_clean_gpu_cache -): - mock_calculate_gpu_requirements.return_value = ( - 5e9 # Example value for both attempts - ) - - with pytest.raises(Exception, match="Second attempt failed"): - load_vllm_pipeline(model_id="HuggingFaceH4/zephyr-7b-beta", device="gpu0") - assert mock_llm.call_count == 2 # LLM is called twice - mock_clean_gpu_cache.assert_called_once() # Ensures clean_gpu_cache was called + # Setting the return value of the LLM mock to the mock LLM instance + mock_llm.return_value = mock_llm_instance + + + result = load_vllm_pipeline(model_id="test_name", device="cuda", gpus=1, max_allowed_memory_in_gb=0) + assert isinstance(result, MagicMock) # or any other assertion you find suitable + mock_llm.assert_called_once() # Ensures LLM was called exactly once + + # Verify the nested property (Specific assert for llama3) + assert result.llm_engine.tokenizer.eos_token_id == 128009 From 291a7efc195fa7f8ee4ecb9397b1d6e6aedbd057 Mon Sep 17 00:00:00 2001 From: p-ferreira <38992619+p-ferreira@users.noreply.github.com> Date: Thu, 16 May 2024 11:29:58 -0400 Subject: [PATCH 8/8] drops python 3.9 from gh action --- .github/workflows/python-package.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 41e1fbc61..84b5dd82e 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -13,7 +13,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.9", "3.10"] + python-version: ["3.10"] steps: - uses: actions/checkout@v3