Skip to content
Merged
2 changes: 1 addition & 1 deletion prompting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# DEALINGS IN THE SOFTWARE.

# Define the version of the template module.
__version__ = "2.2.0"
__version__ = "2.3.0"
version_split = __version__.split(".")
__spec_version__ = (
(10000 * int(version_split[0]))
Expand Down
2 changes: 2 additions & 0 deletions prompting/llms/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ def __init__(
mock=False,
model_kwargs: dict = None,
return_streamer: bool = False,
gpus: int = 1,
llm_max_allowed_memory_in_gb: int = 0
):
super().__init__()
self.model = model_id
Expand Down
50 changes: 41 additions & 9 deletions prompting/llms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,7 @@ def contains_gpu_index_in_device(device: str) -> bool:
pattern = r"^cuda:\d+$"
return bool(re.match(pattern, device))


def calculate_gpu_requirements(
device: str, max_allowed_memory_allocation_in_bytes: int = 20e9
) -> float:
"""Calculates the memory utilization requirements for the model to be loaded on the device.
Args:
device (str): The device to load the model to.
max_allowed_memory_allocation_in_bytes (int, optional): The maximum allowed memory allocation in bytes. Defaults to 20e9 (20GB).
"""
def calculate_single_gpu_requirements(device: str, max_allowed_memory_allocation_in_bytes: int):
if contains_gpu_index_in_device(device):
device_with_gpu_index = device
else:
Expand All @@ -39,4 +31,44 @@ def calculate_gpu_requirements(
f'{gpu_utilization * 100}% of the GPU memory will be utilized for loading the model to device "{device}".'
)

return gpu_utilization

def calculate_multiple_gpu_requirements(device: str, gpus: int, max_allowed_memory_allocation_in_bytes: int):
torch.cuda.synchronize()
total_free_memory = 0
total_gpu_memory = 0

for i in range(gpus):
gpu_device = f"cuda:{i}"
global_free, total_memory = torch.cuda.mem_get_info(device=gpu_device)
total_free_memory += global_free
total_gpu_memory += total_memory

bt.logging.info(f"Total available free memory across all visible {gpus} GPUs: {round(total_free_memory / 10e8, 2)} GB")
bt.logging.info(f"Total GPU memory across all visible GPUs: {gpus} {round(total_gpu_memory / 10e8, 2)} GB")

if total_free_memory < max_allowed_memory_allocation_in_bytes:
raise torch.cuda.CudaError(
f"Not enough memory across all specified {gpus} GPUs to allocate for the model. Please ensure you have at least {max_allowed_memory_allocation_in_bytes / 10e8} GB of free GPU memory."
)

gpu_utilization = round(max_allowed_memory_allocation_in_bytes / total_free_memory, 2)
bt.logging.info(
f"{gpu_utilization * 100}% of the total GPU memory across all GPUs will be utilized for loading the model."
)

return gpu_utilization


def calculate_gpu_requirements(
device: str, gpus: int, max_allowed_memory_allocation_in_bytes: float,
) -> float:
"""Calculates the memory utilization requirements for the model to be loaded on the device.
Args:
device (str): The device to load the model to.
max_allowed_memory_allocation_in_bytes (int, optional): The maximum allowed memory allocation in bytes. Defaults to 20e9 (20GB).
"""
if gpus == 1:
return calculate_single_gpu_requirements(device, max_allowed_memory_allocation_in_bytes)
else:
return calculate_multiple_gpu_requirements(device, gpus, max_allowed_memory_allocation_in_bytes=max_allowed_memory_allocation_in_bytes)
61 changes: 21 additions & 40 deletions prompting/llms/vllm_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,67 +31,48 @@ def clean_gpu_cache():
destroy_model_parallel()
gc.collect()
torch.cuda.empty_cache()
torch.distributed.destroy_process_group()
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()

# Wait for the GPU to clean up
time.sleep(10)
torch.cuda.synchronize()


def load_vllm_pipeline(model_id: str, device: str, mock=False):
def load_vllm_pipeline(model_id: str, device: str, gpus: int, max_allowed_memory_in_gb: int, mock=False):
"""Loads the VLLM pipeline for the LLM, or a mock pipeline if mock=True"""
if mock or model_id == "mock":
return MockPipeline(model_id)

# Calculates the gpu memory utilization required to run the model within 20GB of GPU
max_allowed_memory_in_gb = 20
# Calculates the gpu memory utilization required to run the model within 20GB of GPU
max_allowed_memory_allocation_in_bytes = max_allowed_memory_in_gb * 1e9
gpu_mem_utilization = calculate_gpu_requirements(
device, max_allowed_memory_allocation_in_bytes
device, gpus, max_allowed_memory_allocation_in_bytes
)

try:
# Attempt to initialize the LLM
return LLM(model=model_id, gpu_memory_utilization=gpu_mem_utilization)
except ValueError as e:
bt.logging.error(
f"Error loading the VLLM pipeline within {max_allowed_memory_in_gb}GB: {e}"
)

# If the first attempt fails, retry with increased memory allocation
try:
bt.logging.info(
"Trying to cleanup GPU and retrying to load the model with extra allocation..."
)
# Clean the GPU from memory before retrying
clean_gpu_cache()

# Increase the memory allocation for the second attempt
max_allowed_memory_in_gb_second_attempt = 24
max_allowed_memory_allocation_in_bytes = (
max_allowed_memory_in_gb_second_attempt * 1e9
)
bt.logging.warning(
f"Retrying to load with {max_allowed_memory_in_gb_second_attempt}GB..."
)
gpu_mem_utilization = calculate_gpu_requirements(
device, max_allowed_memory_allocation_in_bytes
)

# Attempt to initialize the LLM again with increased memory allocation
return LLM(model=model_id, gpu_memory_utilization=gpu_mem_utilization)
llm = LLM(model=model_id, gpu_memory_utilization = gpu_mem_utilization, quantization="AWQ", tensor_parallel_size=gpus)
# This solution implemented by @bkb2135 sets the eos_token_id directly for efficiency in vLLM usage.
# This approach avoids the overhead of loading a tokenizer each time the custom eos token is needed.
# Using the Hugging Face pipeline, the eos token specific to llama models was fetched and saved (128009).
# This method provides a straightforward solution, though there may be more optimal ways to manage custom tokens.
llm.llm_engine.tokenizer.eos_token_id = 128009
return llm
except Exception as e:
bt.logging.error(
f"Error loading the VLLM pipeline within {max_allowed_memory_in_gb_second_attempt}GB: {e}"
f"Error loading the VLLM pipeline within {max_allowed_memory_in_gb}GB: {e}"
)
raise e



class vLLMPipeline(BasePipeline):
def __init__(self, model_id: str, device: str = None, mock=False):
def __init__(self, model_id: str, llm_max_allowed_memory_in_gb: int, device: str = None, gpus: int = 1, mock: bool = False):
super().__init__()
self.llm = load_vllm_pipeline(model_id, device, mock)
self.llm = load_vllm_pipeline(model_id, device, gpus, llm_max_allowed_memory_in_gb, mock)
self.mock = mock
self.gpus = gpus

def __call__(self, composed_prompt: str, **model_kwargs: Dict) -> str:
if self.mock:
Expand Down Expand Up @@ -155,17 +136,17 @@ def _make_prompt(self, messages: List[Dict[str, str]]):
for message in messages:
if message["role"] == "system":
composed_prompt += (
f'<|im_start|>system\n{message["content"]} <|im_end|>'
f'<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{{{{ {message["content"]} }}}}<|eot_id|>'
)
elif message["role"] == "user":
composed_prompt += f'<|im_start|>user\n{message["content"]} <|im_end|>'
composed_prompt += f'<|start_header_id|>user<|end_header_id|>\n{{{{ {message["content"]} }}}}<|eot_id|>'
elif message["role"] == "assistant":
composed_prompt += (
f'<|im_start|>assistant\n{message["content"]} <|im_end|>'
f'<|start_header_id|>assistant<|end_header_id|>\n{{{{ {message["content"]} }}}}<|eot_id|>'
)

# Adds final tag indicating the assistant's turn
composed_prompt += "<|im_start|>assistant\n"
composed_prompt += "<|start_header_id|>assistant<|end_header_id|>"
return composed_prompt

def forward(self, messages: List[Dict[str, str]]):
Expand Down
16 changes: 15 additions & 1 deletion prompting/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,20 @@ def add_args(cls, parser):
help="Device to run on.",
default="cuda" if torch.cuda.is_available() else "cpu",
)

parser.add_argument(
"--neuron.gpus",
type=int,
help="The number of visible GPUs to be considered in the llm initialization. This parameter currently reflects on the property `tensor_parallel_size` of vllm",
default=1,
)

parser.add_argument(
"--neuron.llm_max_allowed_memory_in_gb",
type=int,
help="The max gpu memory utilization set for initializing the model. This parameter currently reflects on the property `gpu_memory_utilization` of vllm",
default=60,
)

parser.add_argument(
"--neuron.epoch_length",
Expand Down Expand Up @@ -270,7 +284,7 @@ def add_validator_args(cls, parser):
"--neuron.model_id",
type=str,
help="The model to use for the validator.",
default="NousResearch/Nous-Hermes-2-SOLAR-10.7B",
default="casperhansen/llama-3-70b-instruct-awq",
)

parser.add_argument(
Expand Down
2 changes: 2 additions & 0 deletions prompting/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def __init__(self, config=None):

self.llm_pipeline = vLLMPipeline(
model_id=self.config.neuron.model_id,
gpus=self.config.neuron.gpus,
llm_max_allowed_memory_in_gb=self.config.neuron.llm_max_allowed_memory_in_gb,
device=self.device,
mock=self.config.mock,
)
Expand Down
68 changes: 22 additions & 46 deletions tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_llm_clean_response(
def test_load_pipeline_mock(pipeline: BasePipeline):
# Note that the model_id will be used internally as static response for the mock pipeline
model_id = "gpt2"
pipeline_instance = pipeline(model_id=model_id, device="cpu", mock=True)
pipeline_instance = pipeline(model_id=model_id, device="cpu", gpus=1, llm_max_allowed_memory_in_gb=0, mock=True)
pipeline_message = pipeline_instance("")

mock_message = MockPipeline(model_id).forward(messages=[])
Expand Down Expand Up @@ -110,7 +110,7 @@ def test_calculate_gpu_requirements(
mock_mem_get_info.return_value = (available_memory, available_memory)
# Mock current_device to return a default device index if needed
mock_current_device.return_value = 0
result = calculate_gpu_requirements(device, max_allowed_memory_allocation_in_bytes)
result = calculate_gpu_requirements(device=device, gpus=1, max_allowed_memory_allocation_in_bytes=max_allowed_memory_allocation_in_bytes)
assert result == expected_result


Expand All @@ -133,55 +133,31 @@ def test_calulate_gpu_requirements_raises_cuda_error(
# Test 1: Success on first attempt
@patch("prompting.llms.vllm_llm.calculate_gpu_requirements")
@patch("prompting.llms.vllm_llm.LLM")
def test_load_vllm_pipeline_success_first_try(
def test_load_vllm_pipeline_success(
mock_llm, mock_calculate_gpu_requirements
):
# Mocking calculate_gpu_requirements to return a fixed value
mock_calculate_gpu_requirements.return_value = 5e9 # Example value
# Mocking LLM to return a mock LLM object without raising an exception
mock_llm.return_value = MagicMock(spec=LLM)

result = load_vllm_pipeline(model_id="test_name", device="cuda")
assert isinstance(result, MagicMock) # or any other assertion you find suitable
mock_llm.assert_called_once() # Ensures LLM was called exactly once


# # Test 2: Success on second attempt with larger memory allocation
@patch("prompting.llms.vllm_llm.clean_gpu_cache")
@patch("prompting.llms.vllm_llm.calculate_gpu_requirements")
@patch(
"prompting.llms.vllm_llm.LLM",
side_effect=[ValueError("First attempt failed"), MagicMock(spec=LLM)],
)
def test_load_vllm_pipeline_success_second_try(
mock_llm, mock_calculate_gpu_requirements, mock_clean_gpu_cache
):
mock_calculate_gpu_requirements.return_value = 5e9 # Example value for both calls
# Creating a mock for the tokenizer with the desired eos_token_id
mock_tokenizer = MagicMock()
mock_tokenizer.eos_token_id = 12345

result = load_vllm_pipeline(model_id="test", device="cuda")
assert isinstance(result, MagicMock)
assert mock_llm.call_count == 2 # LLM is called twice
mock_clean_gpu_cache.assert_called_once() # Ensures clean_gpu_cache was called
# Creating a mock for llm_engine and setting its tokenizer
mock_llm_engine = MagicMock()
mock_llm_engine.tokenizer = mock_tokenizer

# Creating the main mock LLM object and setting its llm_engine
mock_llm_instance = MagicMock()
mock_llm_instance.llm_engine = mock_llm_engine

# # Test 3: Exception on second attempt
@patch("prompting.llms.vllm_llm.clean_gpu_cache")
@patch("prompting.llms.vllm_llm.calculate_gpu_requirements")
@patch(
"prompting.llms.vllm_llm.LLM",
side_effect=[
ValueError("First attempt failed"),
Exception("Second attempt failed"),
],
)
def test_load_vllm_pipeline_exception_second_try(
mock_llm, mock_calculate_gpu_requirements, mock_clean_gpu_cache
):
mock_calculate_gpu_requirements.return_value = (
5e9 # Example value for both attempts
)

with pytest.raises(Exception, match="Second attempt failed"):
load_vllm_pipeline(model_id="HuggingFaceH4/zephyr-7b-beta", device="gpu0")
assert mock_llm.call_count == 2 # LLM is called twice
mock_clean_gpu_cache.assert_called_once() # Ensures clean_gpu_cache was called
# Setting the return value of the LLM mock to the mock LLM instance
mock_llm.return_value = mock_llm_instance


result = load_vllm_pipeline(model_id="test_name", device="cuda", gpus=1, max_allowed_memory_in_gb=0)
assert isinstance(result, MagicMock) # or any other assertion you find suitable
mock_llm.assert_called_once() # Ensures LLM was called exactly once

# Verify the nested property (Specific assert for llama3)
assert result.llm_engine.tokenizer.eos_token_id == 128009