diff --git a/basaran/model.py b/basaran/model.py index b681cf30..2a529ae4 100644 --- a/basaran/model.py +++ b/basaran/model.py @@ -46,6 +46,9 @@ def __call__( else: raise TypeError("prompt must be a string or a 1-d tensor") + # check if max_tokens==0 + max_tokens_is_zero = (max_tokens == 0) + # Ensure arguments are non-negative. min_tokens = max(min_tokens, 0) max_tokens = max(max_tokens, 1) @@ -60,15 +63,6 @@ def __call__( for i in range(n): detokenizers.append(StreamTokenizer(self.tokenizer)) - # Echo prompt tokens if required. - for token in input_ids: - samples = self._sample(token, 0, [], []) if logprobs > 0 else {} - for i in range(n): - text = detokenizers[i].decode(token) - offset = detokenizers[i].start - if echo: - yield map_choice(text, i, text_offset=offset, **samples) - # Generate completion tokens. for ( tokens, @@ -76,6 +70,7 @@ def __call__( top_tokens, top_logprobs, status, + prompt_logprobs ) in self.generate( input_ids[None, :].repeat(n, 1), logprobs=logprobs, @@ -84,37 +79,55 @@ def __call__( temperature=temperature, top_p=top_p, ): - for i in range(n): - # Check and update the finish status of the sequence. - if finish_reasons[i]: - continue - if status[i] == 0: - finish_reasons[i] = "stop" - elif status[i] == -1: - finish_reasons[i] = "length" - - # Collect samples of the most likely tokens if required. - samples = ( - self._sample( - token=tokens[i], - token_logprob=token_logprobs[i], - top_tokens=top_tokens[i], - top_logprobs=top_logprobs[i], + + # Echo prompt tokens if required. + if echo and prompt_logprobs is not None: + + for i, token in enumerate(input_ids): + + token_logprob = 0 + if i > 0 and not self.model.config.is_encoder_decoder: + token_logprob = prompt_logprobs[i-1,0] + + samples = self._sample(token, token_logprob, [], []) if logprobs > 0 else {} + for i in range(n): + text = detokenizers[i].decode(token) + offset = detokenizers[i].start + yield map_choice(text, i, text_offset=offset, **samples) + + # check if max_tokens was set to zero by user + if not max_tokens_is_zero: + for i in range(n): + # Check and update the finish status of the sequence. + if finish_reasons[i]: + continue + if status[i] == 0: + finish_reasons[i] = "stop" + elif status[i] == -1: + finish_reasons[i] = "length" + + # Collect samples of the most likely tokens if required. + samples = ( + self._sample( + token=tokens[i], + token_logprob=token_logprobs[i], + top_tokens=top_tokens[i], + top_logprobs=top_logprobs[i], + ) + if logprobs > 0 + else {} ) - if logprobs > 0 - else {} - ) - # Yield predicted tokens. - text = detokenizers[i].decode(tokens[i]) - offset = detokenizers[i].start - yield map_choice( - text, - i, - text_offset=offset, - finish_reason=finish_reasons[i], - **samples, - ) + # Yield predicted tokens. + text = detokenizers[i].decode(tokens[i]) + offset = detokenizers[i].start + yield map_choice( + text, + i, + text_offset=offset, + finish_reason=finish_reasons[i], + **samples, + ) def _sample(self, token, token_logprob, top_tokens, top_logprobs): """Sample log probabilities of the most likely tokens.""" @@ -228,6 +241,9 @@ def generate(self, input_ids, logprobs=0, **kwargs): # Keep track of which sequences are already finished. unfinished = input_ids.new_ones(batch_size) + # track generation steps + generation_step = 0 + # Start auto-regressive generation. while True: inputs = self.model.prepare_inputs_for_generation( @@ -241,6 +257,28 @@ def generate(self, input_ids, logprobs=0, **kwargs): output_hidden_states=False, ) + # get logprobs for prompt (only required at first step) + prompt_logprobs = None + if generation_step == 0: + input_probs = torch.nn.functional.softmax(outputs.logits[:, :-1, :], dim=-1) + input_logprobs = [] + + # collect probs across batch, starting with second token + for sequence_index in range(input_ids.size(1)-1): + probs_at_step = input_probs[:, sequence_index, :] + batch_probs = [] + for batch_id in range(input_ids.size(0)): + probs_for_b = probs_at_step[batch_id] + token_prob = probs_for_b[input_ids[batch_id, sequence_index+1]] + batch_probs.append(token_prob) + input_logprobs.append(batch_probs) + + # convert to logprobs + prompt_logprobs = torch.log(torch.tensor(input_logprobs).to(self.model.device) + 1e-7) + + # increment generation step + generation_step += 1 + # Pre-process the probability distribution of the next tokens. logits = outputs.logits[:, -1, :] with torch.inference_mode(): @@ -290,7 +328,7 @@ def generate(self, input_ids, logprobs=0, **kwargs): status = 0 - status # Yield predictions and status. - yield tokens, token_logprobs, top_tokens, top_logprobs, status + yield tokens, token_logprobs, top_tokens, top_logprobs, status, prompt_logprobs # Stop when finished or exceeded the max length. if status.max() <= 0: