Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 33 additions & 34 deletions prompting/llms/vllm_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,35 +24,22 @@
from prompting.llms import BasePipeline, BaseLLM
from prompting.mock import MockPipeline
from prompting.llms.utils import calculate_gpu_requirements
from vllm.model_executor.parallel_utils.parallel_state import destroy_model_parallel


def clean_gpu_cache():
destroy_model_parallel()
gc.collect()
torch.cuda.empty_cache()
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()

# Wait for the GPU to clean up
time.sleep(10)
torch.cuda.synchronize()


def load_vllm_pipeline(model_id: str, device: str, gpus: int, max_allowed_memory_in_gb: int, mock=False):
"""Loads the VLLM pipeline for the LLM, or a mock pipeline if mock=True"""
if mock or model_id == "mock":
return MockPipeline(model_id)

# Calculates the gpu memory utilization required to run the model within 20GB of GPU
# Calculates the gpu memory utilization required to run the model.
max_allowed_memory_allocation_in_bytes = max_allowed_memory_in_gb * 1e9
gpu_mem_utilization = calculate_gpu_requirements(
device, gpus, max_allowed_memory_allocation_in_bytes
)

try:
# Attempt to initialize the LLM
llm = LLM(model=model_id, gpu_memory_utilization = gpu_mem_utilization, quantization="AWQ", tensor_parallel_size=gpus)
llm = LLM(model=model_id, gpu_memory_utilization = gpu_mem_utilization, quantization="AWQ", tensor_parallel_size=gpus)
# This solution implemented by @bkb2135 sets the eos_token_id directly for efficiency in vLLM usage.
# This approach avoids the overhead of loading a tokenizer each time the custom eos token is needed.
# Using the Hugging Face pipeline, the eos token specific to llama models was fetched and saved (128009).
Expand All @@ -68,7 +55,14 @@ def load_vllm_pipeline(model_id: str, device: str, gpus: int, max_allowed_memory


class vLLMPipeline(BasePipeline):
def __init__(self, model_id: str, llm_max_allowed_memory_in_gb: int, device: str = None, gpus: int = 1, mock: bool = False):
def __init__(
self,
model_id: str,
llm_max_allowed_memory_in_gb: int,
device: str = None,
gpus: int = 1,
mock: bool = False
):
super().__init__()
self.llm = load_vllm_pipeline(model_id, device, gpus, llm_max_allowed_memory_in_gb, mock)
self.mock = mock
Expand Down Expand Up @@ -109,7 +103,13 @@ def __init__(

# Keep track of generation data using messages and times
self.messages = [{"content": self.system_prompt, "role": "system"}]
self.times = [0]
self.times: List[float] = [0]
self._role_template = {
"system": "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{{{{ {} }}}}<|eot_id|>",
"user": "<|start_header_id|>user<|end_header_id|>\n{{{{ {} }}}}<|eot_id|>",
"assistant": "<|start_header_id|>assistant<|end_header_id|>\n{{{{ {} }}}}<|eot_id|>",
"end": "<|start_header_id|>assistant<|end_header_id|>",
}

def query(
self,
Expand All @@ -125,29 +125,25 @@ def query(
response = self.forward(messages=messages)
response = self.clean_response(cleaner, response)

self.messages = messages + [{"content": response, "role": "assistant"}]
self.times = self.times + [0, time.time() - t0]
self.messages = messages
self.messages.append({"content": response, "role": "assistant"})
self.times.extend((0, time.time() - t0))

return response

def _make_prompt(self, messages: List[Dict[str, str]]):
composed_prompt = ""
def _make_prompt(self, messages: List[Dict[str, str]]) -> str:
composed_prompt: List[str] = []

for message in messages:
if message["role"] == "system":
composed_prompt += (
f'<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{{{{ {message["content"]} }}}}<|eot_id|>'
)
elif message["role"] == "user":
composed_prompt += f'<|start_header_id|>user<|end_header_id|>\n{{{{ {message["content"]} }}}}<|eot_id|>'
elif message["role"] == "assistant":
composed_prompt += (
f'<|start_header_id|>assistant<|end_header_id|>\n{{{{ {message["content"]} }}}}<|eot_id|>'
)
role = message["role"]
if role not in self._role_template:
continue
content = message["content"]
composed_prompt.append(self._role_template[role].format(content))

# Adds final tag indicating the assistant's turn
composed_prompt += "<|start_header_id|>assistant<|end_header_id|>"
return composed_prompt
composed_prompt.append(self._role_template["end"])
return "".join(composed_prompt)

def forward(self, messages: List[Dict[str, str]]):
# make composed prompt from messages
Expand All @@ -164,7 +160,10 @@ def forward(self, messages: List[Dict[str, str]]):
if __name__ == "__main__":
# Example usage
llm_pipeline = vLLMPipeline(
model_id="HuggingFaceH4/zephyr-7b-beta", device="cuda", mock=False
model_id="casperhansen/llama-3-70b-instruct-awq",
device="cuda",
llm_max_allowed_memory_in_gb=80,
gpus=1,
)
llm = vLLM_LLM(llm_pipeline, system_prompt="You are a helpful AI assistant")

Expand Down