Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

feat(model): return logprobs for prompt (also for max_tokens=0) #84

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
116 changes: 77 additions & 39 deletions basaran/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def __call__(
else:
raise TypeError("prompt must be a string or a 1-d tensor")

# check if max_tokens==0
max_tokens_is_zero = (max_tokens == 0)

# Ensure arguments are non-negative.
min_tokens = max(min_tokens, 0)
max_tokens = max(max_tokens, 1)
Expand All @@ -60,22 +63,14 @@ def __call__(
for i in range(n):
detokenizers.append(StreamTokenizer(self.tokenizer))

# Echo prompt tokens if required.
for token in input_ids:
samples = self._sample(token, 0, [], []) if logprobs > 0 else {}
for i in range(n):
text = detokenizers[i].decode(token)
offset = detokenizers[i].start
if echo:
yield map_choice(text, i, text_offset=offset, **samples)

# Generate completion tokens.
for (
tokens,
token_logprobs,
top_tokens,
top_logprobs,
status,
prompt_logprobs
) in self.generate(
input_ids[None, :].repeat(n, 1),
logprobs=logprobs,
Expand All @@ -84,37 +79,55 @@ def __call__(
temperature=temperature,
top_p=top_p,
):
for i in range(n):
# Check and update the finish status of the sequence.
if finish_reasons[i]:
continue
if status[i] == 0:
finish_reasons[i] = "stop"
elif status[i] == -1:
finish_reasons[i] = "length"

# Collect samples of the most likely tokens if required.
samples = (
self._sample(
token=tokens[i],
token_logprob=token_logprobs[i],
top_tokens=top_tokens[i],
top_logprobs=top_logprobs[i],

# Echo prompt tokens if required.
if echo and prompt_logprobs is not None:

for i, token in enumerate(input_ids):

token_logprob = 0
if i > 0 and not self.model.config.is_encoder_decoder:
token_logprob = prompt_logprobs[i-1,0]

samples = self._sample(token, token_logprob, [], []) if logprobs > 0 else {}
for i in range(n):
text = detokenizers[i].decode(token)
offset = detokenizers[i].start
yield map_choice(text, i, text_offset=offset, **samples)

# check if max_tokens was set to zero by user
if not max_tokens_is_zero:
for i in range(n):
# Check and update the finish status of the sequence.
if finish_reasons[i]:
continue
if status[i] == 0:
finish_reasons[i] = "stop"
elif status[i] == -1:
finish_reasons[i] = "length"

# Collect samples of the most likely tokens if required.
samples = (
self._sample(
token=tokens[i],
token_logprob=token_logprobs[i],
top_tokens=top_tokens[i],
top_logprobs=top_logprobs[i],
)
if logprobs > 0
else {}
)
if logprobs > 0
else {}
)

# Yield predicted tokens.
text = detokenizers[i].decode(tokens[i])
offset = detokenizers[i].start
yield map_choice(
text,
i,
text_offset=offset,
finish_reason=finish_reasons[i],
**samples,
)
# Yield predicted tokens.
text = detokenizers[i].decode(tokens[i])
offset = detokenizers[i].start
yield map_choice(
text,
i,
text_offset=offset,
finish_reason=finish_reasons[i],
**samples,
)

def _sample(self, token, token_logprob, top_tokens, top_logprobs):
"""Sample log probabilities of the most likely tokens."""
Expand Down Expand Up @@ -228,6 +241,9 @@ def generate(self, input_ids, logprobs=0, **kwargs):
# Keep track of which sequences are already finished.
unfinished = input_ids.new_ones(batch_size)

# track generation steps
generation_step = 0

# Start auto-regressive generation.
while True:
inputs = self.model.prepare_inputs_for_generation(
Expand All @@ -241,6 +257,28 @@ def generate(self, input_ids, logprobs=0, **kwargs):
output_hidden_states=False,
)

# get logprobs for prompt (only required at first step)
prompt_logprobs = None
if generation_step == 0:
input_probs = torch.nn.functional.softmax(outputs.logits[:, :-1, :], dim=-1)
input_logprobs = []

# collect probs across batch, starting with second token
for sequence_index in range(input_ids.size(1)-1):
probs_at_step = input_probs[:, sequence_index, :]
batch_probs = []
for batch_id in range(input_ids.size(0)):
probs_for_b = probs_at_step[batch_id]
token_prob = probs_for_b[input_ids[batch_id, sequence_index+1]]
batch_probs.append(token_prob)
input_logprobs.append(batch_probs)

# convert to logprobs
prompt_logprobs = torch.log(torch.tensor(input_logprobs).to(self.model.device) + 1e-7)

# increment generation step
generation_step += 1

# Pre-process the probability distribution of the next tokens.
logits = outputs.logits[:, -1, :]
with torch.inference_mode():
Expand Down Expand Up @@ -290,7 +328,7 @@ def generate(self, input_ids, logprobs=0, **kwargs):
status = 0 - status

# Yield predictions and status.
yield tokens, token_logprobs, top_tokens, top_logprobs, status
yield tokens, token_logprobs, top_tokens, top_logprobs, status, prompt_logprobs

# Stop when finished or exceeded the max length.
if status.max() <= 0:
Expand Down