Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10"]
python-version: ["3.10"]

steps:
- uses: actions/checkout@v3
Expand Down
2 changes: 2 additions & 0 deletions prompting/llms/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ def __init__(
mock=False,
model_kwargs: dict = None,
return_streamer: bool = False,
gpus: int = 1,
llm_max_allowed_memory_in_gb: int = 0
):
super().__init__()
self.model = model_id
Expand Down
50 changes: 41 additions & 9 deletions prompting/llms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,7 @@ def contains_gpu_index_in_device(device: str) -> bool:
pattern = r"^cuda:\d+$"
return bool(re.match(pattern, device))


def calculate_gpu_requirements(
device: str, max_allowed_memory_allocation_in_bytes: int = 20e9
) -> float:
"""Calculates the memory utilization requirements for the model to be loaded on the device.
Args:
device (str): The device to load the model to.
max_allowed_memory_allocation_in_bytes (int, optional): The maximum allowed memory allocation in bytes. Defaults to 20e9 (20GB).
"""
def calculate_single_gpu_requirements(device: str, max_allowed_memory_allocation_in_bytes: int):
if contains_gpu_index_in_device(device):
device_with_gpu_index = device
else:
Expand All @@ -39,4 +31,44 @@ def calculate_gpu_requirements(
f'{gpu_utilization * 100}% of the GPU memory will be utilized for loading the model to device "{device}".'
)

return gpu_utilization

def calculate_multiple_gpu_requirements(device: str, gpus: int, max_allowed_memory_allocation_in_bytes: int):
torch.cuda.synchronize()
total_free_memory = 0
total_gpu_memory = 0

for i in range(gpus):
gpu_device = f"cuda:{i}"
global_free, total_memory = torch.cuda.mem_get_info(device=gpu_device)
total_free_memory += global_free
total_gpu_memory += total_memory

bt.logging.info(f"Total available free memory across all visible {gpus} GPUs: {round(total_free_memory / 10e8, 2)} GB")
bt.logging.info(f"Total GPU memory across all visible GPUs: {gpus} {round(total_gpu_memory / 10e8, 2)} GB")

if total_free_memory < max_allowed_memory_allocation_in_bytes:
raise torch.cuda.CudaError(
f"Not enough memory across all specified {gpus} GPUs to allocate for the model. Please ensure you have at least {max_allowed_memory_allocation_in_bytes / 10e8} GB of free GPU memory."
)

gpu_utilization = round(max_allowed_memory_allocation_in_bytes / total_free_memory, 2)
bt.logging.info(
f"{gpu_utilization * 100}% of the total GPU memory across all GPUs will be utilized for loading the model."
)

return gpu_utilization


def calculate_gpu_requirements(
device: str, gpus: int, max_allowed_memory_allocation_in_bytes: float,
) -> float:
"""Calculates the memory utilization requirements for the model to be loaded on the device.
Args:
device (str): The device to load the model to.
max_allowed_memory_allocation_in_bytes (int, optional): The maximum allowed memory allocation in bytes. Defaults to 20e9 (20GB).
"""
if gpus == 1:
return calculate_single_gpu_requirements(device, max_allowed_memory_allocation_in_bytes)
else:
return calculate_multiple_gpu_requirements(device, gpus, max_allowed_memory_allocation_in_bytes=max_allowed_memory_allocation_in_bytes)
63 changes: 22 additions & 41 deletions prompting/llms/vllm_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,67 +31,48 @@ def clean_gpu_cache():
destroy_model_parallel()
gc.collect()
torch.cuda.empty_cache()
torch.distributed.destroy_process_group()
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()

# Wait for the GPU to clean up
time.sleep(10)
torch.cuda.synchronize()


def load_vllm_pipeline(model_id: str, device: str, mock=False):
def load_vllm_pipeline(model_id: str, device: str, gpus: int, max_allowed_memory_in_gb: int, mock=False):
"""Loads the VLLM pipeline for the LLM, or a mock pipeline if mock=True"""
if mock or model_id == "mock":
return MockPipeline(model_id)

# Calculates the gpu memory utilization required to run the model within 20GB of GPU
max_allowed_memory_in_gb = 20
# Calculates the gpu memory utilization required to run the model within 20GB of GPU
max_allowed_memory_allocation_in_bytes = max_allowed_memory_in_gb * 1e9
gpu_mem_utilization = calculate_gpu_requirements(
device, max_allowed_memory_allocation_in_bytes
device, gpus, max_allowed_memory_allocation_in_bytes
)

try:
# Attempt to initialize the LLM
return LLM(model=model_id, gpu_memory_utilization=gpu_mem_utilization)
except ValueError as e:
bt.logging.error(
f"Error loading the VLLM pipeline within {max_allowed_memory_in_gb}GB: {e}"
)

# If the first attempt fails, retry with increased memory allocation
try:
bt.logging.info(
"Trying to cleanup GPU and retrying to load the model with extra allocation..."
)
# Clean the GPU from memory before retrying
clean_gpu_cache()

# Increase the memory allocation for the second attempt
max_allowed_memory_in_gb_second_attempt = 24
max_allowed_memory_allocation_in_bytes = (
max_allowed_memory_in_gb_second_attempt * 1e9
)
bt.logging.warning(
f"Retrying to load with {max_allowed_memory_in_gb_second_attempt}GB..."
)
gpu_mem_utilization = calculate_gpu_requirements(
device, max_allowed_memory_allocation_in_bytes
)

# Attempt to initialize the LLM again with increased memory allocation
return LLM(model=model_id, gpu_memory_utilization=gpu_mem_utilization)
llm = LLM(model=model_id, gpu_memory_utilization = gpu_mem_utilization, quantization="AWQ", tensor_parallel_size=gpus)
# This solution implemented by @bkb2135 sets the eos_token_id directly for efficiency in vLLM usage.
# This approach avoids the overhead of loading a tokenizer each time the custom eos token is needed.
# Using the Hugging Face pipeline, the eos token specific to llama models was fetched and saved (128009).
# This method provides a straightforward solution, though there may be more optimal ways to manage custom tokens.
llm.llm_engine.tokenizer.eos_token_id = 128009
return llm
except Exception as e:
bt.logging.error(
f"Error loading the VLLM pipeline within {max_allowed_memory_in_gb_second_attempt}GB: {e}"
)
f"Error loading the VLLM pipeline within {max_allowed_memory_in_gb}GB: {e}"
)
raise e



class vLLMPipeline(BasePipeline):
def __init__(self, model_id: str, device: str = None, mock=False):
def __init__(self, model_id: str, llm_max_allowed_memory_in_gb:int, device: str = None, gpus: int = 1, mock=False):
super().__init__()
self.llm = load_vllm_pipeline(model_id, device, mock)
self.llm = load_vllm_pipeline(model_id, device, gpus, llm_max_allowed_memory_in_gb, mock)
self.mock = mock
self.gpus = gpus

def __call__(self, composed_prompt: str, **model_kwargs: Dict) -> str:
if self.mock:
Expand Down Expand Up @@ -155,17 +136,17 @@ def _make_prompt(self, messages: List[Dict[str, str]]):
for message in messages:
if message["role"] == "system":
composed_prompt += (
f'<|im_start|>system\n{message["content"]} <|im_end|>'
f'<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{{{{ {message["content"]} }}}}<|eot_id|>'
)
elif message["role"] == "user":
composed_prompt += f'<|im_start|>user\n{message["content"]} <|im_end|>'
composed_prompt += f'<|start_header_id|>user<|end_header_id|>\n{{{{ {message["content"]} }}}}<|eot_id|>'
elif message["role"] == "assistant":
composed_prompt += (
f'<|im_start|>assistant\n{message["content"]} <|im_end|>'
f'<|start_header_id|>assistant<|end_header_id|>\n{{{{ {message["content"]} }}}}<|eot_id|>'
)

# Adds final tag indicating the assistant's turn
composed_prompt += "<|im_start|>assistant\n"
composed_prompt += "<|start_header_id|>assistant<|end_header_id|>"
return composed_prompt

def forward(self, messages: List[Dict[str, str]]):
Expand Down
16 changes: 15 additions & 1 deletion prompting/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,20 @@ def add_args(cls, parser):
help="Device to run on.",
default="cuda" if torch.cuda.is_available() else "cpu",
)

parser.add_argument(
"--neuron.gpus",
type=int,
help="The number of visible GPUs to be considered in the llm initialization. This parameter currently reflects on the property `tensor_parallel_size` of vllm",
default=1,
)

parser.add_argument(
"--neuron.llm_max_allowed_memory_in_gb",
type=int,
help="The max gpu memory utilization set for initializing the model. This parameter currently reflects on the property `gpu_memory_utilization` of vllm",
default=60,
)

parser.add_argument(
"--neuron.epoch_length",
Expand Down Expand Up @@ -270,7 +284,7 @@ def add_validator_args(cls, parser):
"--neuron.model_id",
type=str,
help="The model to use for the validator.",
default="NousResearch/Nous-Hermes-2-SOLAR-10.7B",
default="casperhansen/llama-3-70b-instruct-awq",
)

parser.add_argument(
Expand Down
2 changes: 2 additions & 0 deletions prompting/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def __init__(self, config=None):

self.llm_pipeline = vLLMPipeline(
model_id=self.config.neuron.model_id,
gpus=self.config.neuron.gpus,
llm_max_allowed_memory_in_gb=self.config.neuron.llm_max_allowed_memory_in_gb,
device=self.device,
mock=self.config.mock,
)
Expand Down
68 changes: 22 additions & 46 deletions tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_llm_clean_response(
def test_load_pipeline_mock(pipeline: BasePipeline):
# Note that the model_id will be used internally as static response for the mock pipeline
model_id = "gpt2"
pipeline_instance = pipeline(model_id=model_id, device="cpu", mock=True)
pipeline_instance = pipeline(model_id=model_id, device="cpu", gpus=1, llm_max_allowed_memory_in_gb=0, mock=True)
pipeline_message = pipeline_instance("")

mock_message = MockPipeline(model_id).forward(messages=[])
Expand Down Expand Up @@ -110,7 +110,7 @@ def test_calculate_gpu_requirements(
mock_mem_get_info.return_value = (available_memory, available_memory)
# Mock current_device to return a default device index if needed
mock_current_device.return_value = 0
result = calculate_gpu_requirements(device, max_allowed_memory_allocation_in_bytes)
result = calculate_gpu_requirements(device=device, gpus=1, max_allowed_memory_allocation_in_bytes=max_allowed_memory_allocation_in_bytes)
assert result == expected_result


Expand All @@ -133,55 +133,31 @@ def test_calulate_gpu_requirements_raises_cuda_error(
# Test 1: Success on first attempt
@patch("prompting.llms.vllm_llm.calculate_gpu_requirements")
@patch("prompting.llms.vllm_llm.LLM")
def test_load_vllm_pipeline_success_first_try(
def test_load_vllm_pipeline_success(
mock_llm, mock_calculate_gpu_requirements
):
# Mocking calculate_gpu_requirements to return a fixed value
mock_calculate_gpu_requirements.return_value = 5e9 # Example value
# Mocking LLM to return a mock LLM object without raising an exception
mock_llm.return_value = MagicMock(spec=LLM)

result = load_vllm_pipeline(model_id="test_name", device="cuda")
assert isinstance(result, MagicMock) # or any other assertion you find suitable
mock_llm.assert_called_once() # Ensures LLM was called exactly once


# # Test 2: Success on second attempt with larger memory allocation
@patch("prompting.llms.vllm_llm.clean_gpu_cache")
@patch("prompting.llms.vllm_llm.calculate_gpu_requirements")
@patch(
"prompting.llms.vllm_llm.LLM",
side_effect=[ValueError("First attempt failed"), MagicMock(spec=LLM)],
)
def test_load_vllm_pipeline_success_second_try(
mock_llm, mock_calculate_gpu_requirements, mock_clean_gpu_cache
):
mock_calculate_gpu_requirements.return_value = 5e9 # Example value for both calls
# Creating a mock for the tokenizer with the desired eos_token_id
mock_tokenizer = MagicMock()
mock_tokenizer.eos_token_id = 12345

result = load_vllm_pipeline(model_id="test", device="cuda")
assert isinstance(result, MagicMock)
assert mock_llm.call_count == 2 # LLM is called twice
mock_clean_gpu_cache.assert_called_once() # Ensures clean_gpu_cache was called
# Creating a mock for llm_engine and setting its tokenizer
mock_llm_engine = MagicMock()
mock_llm_engine.tokenizer = mock_tokenizer

# Creating the main mock LLM object and setting its llm_engine
mock_llm_instance = MagicMock()
mock_llm_instance.llm_engine = mock_llm_engine

# # Test 3: Exception on second attempt
@patch("prompting.llms.vllm_llm.clean_gpu_cache")
@patch("prompting.llms.vllm_llm.calculate_gpu_requirements")
@patch(
"prompting.llms.vllm_llm.LLM",
side_effect=[
ValueError("First attempt failed"),
Exception("Second attempt failed"),
],
)
def test_load_vllm_pipeline_exception_second_try(
mock_llm, mock_calculate_gpu_requirements, mock_clean_gpu_cache
):
mock_calculate_gpu_requirements.return_value = (
5e9 # Example value for both attempts
)

with pytest.raises(Exception, match="Second attempt failed"):
load_vllm_pipeline(model_id="HuggingFaceH4/zephyr-7b-beta", device="gpu0")
assert mock_llm.call_count == 2 # LLM is called twice
mock_clean_gpu_cache.assert_called_once() # Ensures clean_gpu_cache was called
# Setting the return value of the LLM mock to the mock LLM instance
mock_llm.return_value = mock_llm_instance


result = load_vllm_pipeline(model_id="test_name", device="cuda", gpus=1, max_allowed_memory_in_gb=0)
assert isinstance(result, MagicMock) # or any other assertion you find suitable
mock_llm.assert_called_once() # Ensures LLM was called exactly once

# Verify the nested property (Specific assert for llama3)
assert result.llm_engine.tokenizer.eos_token_id == 128009