Skip to content

Commit

Permalink
Combine internal generation functions
Browse files Browse the repository at this point in the history
  • Loading branch information
jncraton committed Dec 28, 2023
1 parent 11b286f commit 79ebb32
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 41 deletions.
13 changes: 6 additions & 7 deletions languagemodels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@

from languagemodels.config import config
from languagemodels.inference import (
generate_instruct,
generate_code,
generate,
rank_instruct,
parse_chat,
list_tokens,
Expand Down Expand Up @@ -34,7 +33,7 @@ def complete(prompt: str) -> str:
'she was sure she was safe'
"""

result = generate_instruct(
result = generate(
"Write a sentence", prefix=prompt,
max_tokens=config["max_tokens"], temperature=0.7, topk=40
)
Expand Down Expand Up @@ -63,7 +62,7 @@ def do(prompt: str) -> str:
>>> do("Is the following positive or negative: I love Star Trek.")
'Positive.'
"""
result = generate_instruct(prompt, max_tokens=config["max_tokens"], topk=1)
result = generate(prompt, max_tokens=config["max_tokens"], topk=1)

if len(result.split()) == 1:
result = result.title()
Expand Down Expand Up @@ -153,7 +152,7 @@ def chat(prompt: str) -> str:
if prompt.startswith("System:"):
prompt = prompt[7:].strip()

response = generate_instruct(
response = generate(
prompt,
max_tokens=config["max_tokens"],
repetition_penalty=1.3,
Expand Down Expand Up @@ -187,7 +186,7 @@ def code(prompt: str) -> str:
>>> code("def return_4():")
'...return 4...'
"""
result = generate_code(prompt, max_tokens=config["max_tokens"], topk=1)
result = generate(prompt, max_tokens=config["max_tokens"], topk=1, model="code")

return result

Expand All @@ -213,7 +212,7 @@ def extract_answer(question: str, context: str) -> str:
'...Guido van Rossum...'
"""

return generate_instruct(f"{context}\n\n{question}")
return generate(f"{context}\n\n{question}")


def classify(doc: str, label1: str, label2: str) -> str:
Expand Down
37 changes: 3 additions & 34 deletions languagemodels/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,15 @@ def chat_oa(engine, prompt, max_tokens=200, temperature=0):
raise InferenceException(f"OpenAI error: {resp}")


def generate_instruct(
def generate(
instruction,
max_tokens=200,
temperature=0.1,
topk=1,
repetition_penalty=1.3,
prefix="",
suppress=[],
model="instruct",
):
"""Generates one completion for a prompt using an instruction-tuned model
Expand All @@ -132,7 +133,7 @@ def generate_instruct(
if os.environ.get("LANGUAGEMODELS_OA_KEY"):
return chat_oa("gpt-3.5-turbo", instruction, max_tokens).strip()

tokenizer, model = get_model("instruct")
tokenizer, model = get_model(model)

suppress = [tokenizer.encode(s, add_special_tokens=False).tokens for s in suppress]

Expand Down Expand Up @@ -173,38 +174,6 @@ def generate_instruct(
return text


def generate_code(
prompt,
max_tokens=200,
temperature=0.1,
topk=1,
repetition_penalty=1.2,
prefix="",
):
"""Generates one completion for a prompt using a code-tuned model
>>> generate_code("# Print Hello, World!\\n")
'print("Hello, World!")\\n'
"""

tokenizer, model = get_model("code")

results = model.translate_batch(
[tokenizer.encode(prompt).tokens],
target_prefix=[tokenizer.encode(prefix, add_special_tokens=False).tokens],
repetition_penalty=repetition_penalty,
max_decoding_length=max_tokens,
sampling_temperature=temperature,
sampling_topk=topk,
beam_size=1,
)
output_tokens = results[0].hypotheses[0]
output_ids = [tokenizer.token_to_id(t) for t in output_tokens]
text = tokenizer.decode(output_ids, skip_special_tokens=True)

return text


def rank_instruct(input, targets):
"""Sorts a list of targets by their probabilities
Expand Down

0 comments on commit 79ebb32

Please sign in to comment.