Skip to content

Add ModelGenerator for LLama2, driver proxy and Vicuna #9

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion databricks/labs/doc_qa/evaluators/templated_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__.split(".")[0])


class ParameterType(Enum):
Expand Down
2 changes: 0 additions & 2 deletions databricks/labs/doc_qa/llm_providers/openai_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@


openai_token = os.getenv('OPENAI_API_KEY')
openai_org = os.getenv('OPENAI_ORGANIZATION')

class StatusCode429Error(Exception):
pass
Expand All @@ -24,7 +23,6 @@ def request_openai(messages, functions=[], temperature=0.0, model="gpt-4"):
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {openai_token}",
"OpenAI-Organization": openai_org,
}
data = {
"model": model,
Expand Down
222 changes: 212 additions & 10 deletions databricks/labs/doc_qa/model_generators/model_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import concurrent.futures

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Instead of using full name, only use the module name
logger = logging.getLogger(__name__.split(".")[0])

class RowGenerateResult:
"""
Expand Down Expand Up @@ -78,6 +78,7 @@ def __init__(
self._prompt_formatter = prompt_formatter
self._batch_size = batch_size
self._concurrency = concurrency
self.input_variables = prompt_formatter.variables

def _generate(
self, prompts: list, temperature: float, max_tokens=256, system_prompt=None
Expand All @@ -95,7 +96,7 @@ def run_tasks(
Returns:
EvalResult: the evaluation result
"""
prompt_batches = []
task_batches = []
# First, traverse the input dataframe using batch size
for i in range(0, len(input_df), self._batch_size):
# Get the current batch
Expand All @@ -107,9 +108,13 @@ def run_tasks(
# Format the input dataframe into prompts
prompt = self._prompt_formatter.format(**row)
prompts.append(prompt)
prompt_batches.append(prompts)
task = {
"prompts": prompts,
"df": batch_df,
}
task_batches.append(task)
logger.info(
f"Generated total number of batches for prompts: {len(prompt_batches)}"
f"Generated total number of batches for prompts: {len(task_batches)}"
)

# Call the _generate in parallel using multiple threads, each call with a batch of prompts
Expand All @@ -118,15 +123,28 @@ def run_tasks(
) as executor:
future_to_batch = {
executor.submit(
self._generate, prompts, temperature, max_tokens, system_prompt
): prompts
for prompts in prompt_batches
self._generate,
task["prompts"],
temperature,
max_tokens,
system_prompt,
): task
for task in task_batches
}
batch_generate_results = []
for future in concurrent.futures.as_completed(future_to_batch):
prompts = future_to_batch[future]
task = future_to_batch[future]
try:
result = future.result()
batch_df = task["df"]
# Add the columns from batch_df where the column name is in the input_variables, add as attribute and value to the RowEvalResult
for index, row in enumerate(result.rows):
for input_variable in self.input_variables:
setattr(
row,
input_variable,
batch_df[input_variable].iloc[index],
)
batch_generate_results.append(result)
except Exception as exc:
logger.error(f"Exception occurred when running the task: {exc}")
Expand Down Expand Up @@ -268,7 +286,7 @@ def __init__(
self._tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

def _format_prompt(self, message: str, system_prompt_opt: str) -> str:
if system_prompt_opt is None:
if system_prompt_opt is not None:
texts = [f"[INST] <<SYS>>\n{system_prompt_opt}\n<</SYS>>\n\n"]
texts.append(f"{message.strip()} [/INST]")
return "".join(texts)
Expand Down Expand Up @@ -323,3 +341,187 @@ def _generate(
is_successful=True,
error_msg=None,
)


class VicunaModelGenerator(BaseModelGenerator):
def __init__(
self,
prompt_formatter: PromptTemplate,
model_name_or_path: str,
batch_size: int = 1,
concurrency: int = 1,
) -> None:
"""
Args:
prompt_formatter (PromptTemplate): the prompt format to format the input dataframe into prompts row by row according to the column names
model_name (str): the model name
batch_size (int, optional): Batch size that will be used to run tasks. Defaults to 1, which means it's sequential.

Recommendations:
- for A100 80GB, use batch_size 1 for vicuna-33b
- for A100 80GB x 2, use batch_size 64 for vicuna-33b
"""
super().__init__(prompt_formatter, batch_size, concurrency)
# require the concurrency to be 1 to avoid race condition during inference
if concurrency != 1:
raise ValueError(
"VicunaModelGenerator currently only supports concurrency 1"
)
self._model_name_or_path = model_name_or_path
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
)

if torch.cuda.is_available():
self._model = AutoModelForCausalLM.from_pretrained(
model_name_or_path, torch_dtype=torch.float16, device_map="auto"
)
else:
raise ValueError("VicunaModelGenerator currently only supports GPU")
self._tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

def _format_prompt(self, message: str, system_prompt_opt: str) -> str:
if system_prompt_opt is not None:
return f"""{system_prompt_opt}

USER: {message}
ASSISTANT:
"""
else:
return f"""A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.

USER: {message}
ASSISTANT:
"""

def _generate(
self, prompts: list, temperature: float, max_tokens=256, system_prompt=None
) -> BatchGenerateResult:
from transformers import pipeline

all_formatted_prompts = [
self._format_prompt(message=message, system_prompt_opt=system_prompt)
for message in prompts
]

top_p = 0.95
repetition_penalty = 1.15
pipe = pipeline(
"text-generation",
model=self._model,
tokenizer=self._tokenizer,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
return_full_text=False,
)
responses = pipe(all_formatted_prompts)
rows = []
for index, response in enumerate(responses):
response_content = response[0]["generated_text"]
row_generate_result = RowGenerateResult(
is_successful=True,
error_msg=None,
answer=response_content,
temperature=temperature,
max_tokens=max_tokens,
model_name=self._model_name_or_path,
top_p=top_p,
repetition_penalty=repetition_penalty,
prompts=all_formatted_prompts[index],
)
rows.append(row_generate_result)

return BatchGenerateResult(
num_rows=len(rows),
num_successful_rows=len(rows),
rows=rows,
is_successful=True,
error_msg=None,
)


class DriverProxyModelGenerator(BaseModelGenerator):
def __init__(
self,
url: str,
pat_token: str,
prompt_formatter: PromptTemplate,
batch_size: int = 32,
concurrency: int = 1,
) -> None:
"""
Args:
prompt_formatter (PromptTemplate): the prompt format to format the input dataframe into prompts row by row according to the column names
model_name (str): the model name
batch_size (int, optional): Batch size that will be used to run tasks. Defaults to 1, which means it's sequential.

Recommendations:
- for A100 80GB, use batch_size 16 for llama-2-13b-chat
"""
super().__init__(prompt_formatter, batch_size, concurrency)
self._url = url
self._pat_token = pat_token

def _format_prompt(self, message: str, system_prompt_opt: str) -> str:
if system_prompt_opt is not None:
texts = [f"[INST] <<SYS>>\n{system_prompt_opt}\n<</SYS>>\n\n"]
texts.append(f"{message.strip()} [/INST]")
return "".join(texts)
else:
texts = [f"[INST] \n\n"]
texts.append(f"{message.strip()} [/INST]")
return "".join(texts)

def _generate(
self, prompts: list, temperature: float, max_tokens=256, system_prompt=None
) -> BatchGenerateResult:
top_p = 0.95

all_formatted_prompts = [
self._format_prompt(message=message, system_prompt_opt=system_prompt)
for message in prompts
]

import requests
import json

headers = {
"Authentication": f"Bearer {self._pat_token}",
"Content-Type": "application/json",
}

data = {
"prompts": all_formatted_prompts,
"temperature": temperature,
"max_tokens": max_tokens,
}

response = requests.post(self._url, headers=headers, data=json.dumps(data))

# Extract the "outputs" as a JSON array from the response
outputs = response.json()["outputs"]
rows = []
for index, response_content in enumerate(outputs):
row_generate_result = RowGenerateResult(
is_successful=True,
error_msg=None,
answer=response_content,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
prompts=all_formatted_prompts[index],
)
rows.append(row_generate_result)

return BatchGenerateResult(
num_rows=len(rows),
num_successful_rows=len(rows),
rows=rows,
is_successful=True,
error_msg=None,
)
Loading