Skip to content

Commit

Permalink
Support static batching in rank_instruct
Browse files Browse the repository at this point in the history
  • Loading branch information
jncraton committed Dec 28, 2023
1 parent e656d15 commit bd0a339
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 15 deletions.
5 changes: 3 additions & 2 deletions languagemodels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,10 +257,11 @@ def classify(doc: str, label1: str, label2: str) -> str:
"""

results = rank_instruct(
f"Classify as {label1} or {label2}: {doc}\n\nClassification:", [label1, label2]
[f"Classify as {label1} or {label2}: {doc}\n\nClassification:"],
[label1, label2]
)

return results[0]
return results[0][0]


def store_doc(doc: str, name: str = "") -> None:
Expand Down
36 changes: 23 additions & 13 deletions languagemodels/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,33 +175,43 @@ def generate(
return [tokenizer.decode(i, skip_special_tokens=True).lstrip() for i in outputs_ids]


def rank_instruct(input, targets):
def rank_instruct(inputs, targets):
"""Sorts a list of targets by their probabilities
>>> rank_instruct("Classify positive or negative: I love python. Classification:",
>>> rank_instruct(["Classify positive or negative: I love python. Classification:"],
... ['positive', 'negative'])
['positive', 'negative']
[['positive', 'negative']]
>>> rank_instruct("Classify fantasy or documentary: "
... "The wizard raised their want. Classification:",
>>> rank_instruct(["Classify fantasy or documentary: "
... "The wizard raised their want. Classification:"],
... ['fantasy', 'documentary'])
['fantasy', 'documentary']
[['fantasy', 'documentary']]
>>> rank_instruct(["Say six", "Say seven"], ["six", "seven"])
[['six', 'seven'], ['seven', 'six']]
"""
tokenizer, model = get_model("instruct")

in_tok = tokenizer.encode(input, add_special_tokens=False).tokens
targ_tok = [tokenizer.encode(t, add_special_tokens=False).tokens for t in targets]
targ_tok *= len(inputs)

in_tok = []
for input in inputs:
toks = [tokenizer.encode(input, add_special_tokens=False).tokens]
in_tok += toks * len(targets)

if "Generator" in str(type(model)):
scores = model.score_batch([in_tok + t for t in targ_tok])
scores = model.score_batch([i+t for i, t in zip(in_tok, targ_tok)])
else:
scores = model.score_batch([in_tok] * len(targ_tok), target=targ_tok)

logprobs = [sum(r.log_probs) for r in scores]
scores = model.score_batch(in_tok, target=targ_tok)

results = sorted(zip(targets, logprobs), key=lambda r: -r[1])
ret = []
for i in range(0, len(inputs) * len(targets), len(targets)):
logprobs = [sum(r.log_probs) for r in scores[i:i+len(targets)]]
results = sorted(zip(targets, logprobs), key=lambda r: -r[1])
ret.append([r[0] for r in results])

return [r[0] for r in results]
return ret


def parse_chat(prompt):
Expand Down

0 comments on commit bd0a339

Please sign in to comment.