Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 58 additions & 17 deletions src/instructlab/generator/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@

# Local
from ..config import DEFAULT_MULTIPROCESSING_START_METHOD, get_model_family
from ..utils import chunk_document, read_taxonomy
from ..utils import (
chunk_document,
max_seed_example_tokens,
num_chars_from_tokens,
read_taxonomy,
)
from . import utils
from .utils import GenerateException

Expand Down Expand Up @@ -298,18 +303,35 @@ def get_instructions_from_model(
stop=["* Task 5"],
)
request_start = time.time()
results = utils.openai_completion(
api_base=api_base,
api_key=api_key,
prompts=batch_inputs,
model_name=model_name,
tls_insecure=tls_insecure,
tls_client_cert=tls_client_cert,
tls_client_key=tls_client_key,
tls_client_passwd=tls_client_passwd,
batch_size=request_batch_size,
decoding_args=decoding_args,
)
try:
results = utils.openai_completion(
api_base=api_base,
api_key=api_key,
prompts=batch_inputs,
model_name=model_name,
tls_insecure=tls_insecure,
tls_client_cert=tls_client_cert,
tls_client_key=tls_client_key,
tls_client_passwd=tls_client_passwd,
batch_size=request_batch_size,
decoding_args=decoding_args,
)
except GenerateException as exc:
# Attempt to log and gracefully recover from exceeding the server's
# maximum context length. This won't work for all servers.
#
# Both llama_cpp_python and vllm use this exact string in their error
# responses when exceeding the model's max content length. Other
# OpenAI-compatible servers may as well, but no guarantees.
if "model's maximum context length" in str(exc):
logger.warn(
"Generated prompt exceeded the server's maximum context length. "
"If you see this warning many times during generation, lower "
"the length of your example question and answers or raise the "
"server's maximum context size using `max_ctx_size`."
)
return [], 0
raise exc
request_duration = time.time() - request_start

post_process_start = time.time()
Expand Down Expand Up @@ -362,6 +384,7 @@ def generate_data(
tls_client_passwd: Optional[str] = None,
):
seed_instruction_data = []
machine_seed_instruction_data = []
generate_start = time.time()

if not os.path.exists(output_dir):
Expand All @@ -377,6 +400,22 @@ def generate_data(
else:
raise SystemExit(f"Error: taxonomy ({taxonomy}) does not exist.")

prompt_template = check_prompt_file(
prompt_file_path, get_model_family(model_family, model_name)
)
max_seed_tokens = max_seed_example_tokens(server_ctx_size, len(prompt_template))
max_seed_chars = num_chars_from_tokens(max_seed_tokens)
for seed_example in seed_instruction_data:
if (
len(seed_example["instruction"])
+ len(seed_example["input"])
+ len(seed_example["output"])
>= max_seed_chars
):
raise SystemExit(
f"Error: An example in the taxonomy path {seed_example['taxonomy_path']} is too long for the server context size of {server_ctx_size}. Ensure the total number of characters across the combined question, answer, and context is less than {max_seed_chars} for each example or use a server with a larger context size."
)

seeds = len(seed_instruction_data)
logger.debug(f"Loaded {seeds} human-written seed instructions from {taxonomy}")
if not seeds:
Expand Down Expand Up @@ -449,9 +488,6 @@ def unescape(s):
scorer._tokenizer.tokenize(inst) for inst in all_instructions
]

prompt_template = check_prompt_file(
prompt_file_path, get_model_family(model_family, model_name)
)
if console_output:
print(
"Synthesizing new instructions. If you aren't satisfied with the generated instructions, interrupt training (Ctrl-C) and try adjusting your YAML files. Adding more examples may help."
Expand All @@ -471,7 +507,7 @@ def unescape(s):
# Filter the pool
instruction_data_pool = [
e
for e in seed_instruction_data + machine_instruction_data
for e in seed_instruction_data + machine_seed_instruction_data
if e["taxonomy_path"] == selected_taxonomy
]
instruction_data, discarded = get_instructions_from_model(
Expand Down Expand Up @@ -520,6 +556,11 @@ def unescape(s):
# Comment out extra info not currently being used:
# instruction_data_entry["most_similar_instructions"] = most_similar_instructions
# instruction_data_entry["avg_similarity_score"] = float(np.mean(rouge_scores))

# Only add sufficiently small instructions to our machine seeds
if len(new_instruction_tokens) <= max_seed_tokens:
machine_seed_instruction_data.append(instruction_data_entry)

machine_instruction_data.append(instruction_data_entry)
all_instructions.append(instruction_data_entry["instruction"])
all_instruction_tokens.append(new_instruction_tokens)
Expand Down
41 changes: 41 additions & 0 deletions src/instructlab/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,47 @@ def num_chars_from_tokens(num_tokens) -> int:
return int(num_tokens * 4) # 1 token ~ 4 English character


def num_tokens_from_chars(num_chars) -> int:
return int(num_chars / 4) # 1 token ~ 4 English character


def max_seed_example_tokens(server_ctx_size, prompt_num_chars) -> int:
"""
Estimates the maximum number of tokens any seed example can have based
on the server context size and number of characters in the selected prompt.

A lot has to fit into the given server context size:
- The prompt itself, which can vary in size a bit based on model family and knowledge vs skill
- Two seed examples, which we append to the prompt template.
- A knowledge document chunk, if this is a knowledge example.
- The generated completion, which can vary substantially in length.

This is an attempt to roughly estimate the maximum size any seed example
(question + answer + context values from the yaml) should be to even have
a hope of not often exceeding the server's maximum context size.

NOTE: This does not take into account knowledge document chunks. It's meant
to calculate the maximum size that any seed example should be, whether knowledge
or skill. Knowledge seed examples will want to stay well below this limit.

NOTE: This is a very simplistic calculation, and examples with lots of numbers
or punctuation may have quite a different token count than the estimates here,
depending on the model (and thus tokenizer) in use. That's ok, as it's only
meant to be a rough estimate.

Args:
server_ctx_size (int): Size of the server context, in tokens.
prompt_num_chars (int): Number of characters in the prompt (not including the examples)
"""
# Ensure we have at least 1024 tokens available for a response.
max_seed_tokens = server_ctx_size - 1024
# Subtract the number of tokens in our prompt template
max_seed_tokens = max_seed_tokens - num_tokens_from_chars(prompt_num_chars)
# Divide number of characters by 2, since we insert 2 examples
max_seed_tokens = int(max_seed_tokens / 2)
return max_seed_tokens


def chunk_document(documents: List, server_ctx_size, chunk_word_count) -> List[str]:
"""
Iterates over the documents and splits them into chunks based on the word count provided by the user.
Expand Down
36 changes: 36 additions & 0 deletions tests/test_lab_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

# First Party
from instructlab import lab
from instructlab.config import get_default_config, write_config
from instructlab.generator.generate_data import generate_data
from instructlab.generator.utils import GenerateException

Expand Down Expand Up @@ -161,6 +162,41 @@ def test_open_ai_server_error(self, get_instructions_from_model):
get_instructions_from_model.assert_called_once()
mt.teardown()

def test_new_data_too_long(self):
runner = CliRunner()
with open("tests/testdata/skill_too_long_answer.yaml", "rb") as qnafile:
with runner.isolated_filesystem():
cfg_file = "small_ctx_config.yaml"
smaller_ctx = 3072
config = get_default_config()
config.serve.max_ctx_size = smaller_ctx
write_config(config, config_file=cfg_file)
mt = MockTaxonomy(pathlib.Path("taxonomy"))
mt.create_untracked(
"compositional_skills/tracked/qna.yaml", qnafile.read()
)
result = runner.invoke(
lab.cli,
[
"--config",
cfg_file,
"generate",
"--taxonomy-base",
"main",
"--taxonomy-path",
mt.root,
"--endpoint-url",
"localhost:8000",
"--server-ctx-size",
smaller_ctx,
],
)
assert (
result.exit_code == 1
), "command finished with an unexpected exit code"
assert "too long for the server context size" in result.output
mt.teardown()

@patch(
"instructlab.generator.generate_data.get_instructions_from_model",
return_value=(testdata.generate_data_return_value, 0),
Expand Down
Loading