Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] fix clp potential error and support bs>1 #439

Merged
merged 4 commits into from
Sep 27, 2023
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 52 additions & 24 deletions opencompass/openicl/icl_inferencer/icl_clp_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def inference(self,
if self.single_token:
index = 0
prompt_list = []
choice_target_ids = []
target_pos = []
# TODO: Hard code temperaily, need to modified here
choices = retriever.test_ds[0]['choices']
try:
Expand Down Expand Up @@ -149,7 +149,7 @@ def inference(self,
ice[idx],
ice_template=ice_template,
prompt_template=prompt_template)
prompt = self.model.parse_template(prompt, mode='ppl')
prompt = self.model.parse_template(prompt, mode='gen')
if self.max_seq_len is not None:
prompt_token_num = get_token_len(prompt)
# add one because additional token will be added in the end
Expand All @@ -165,15 +165,19 @@ def inference(self,
ice_template=ice_template,
prompt_template=prompt_template)
prompt_token_num = get_token_len(prompt)
# Add single token for prompt, this token can be any token
prompt += 'yes'
prompt_list.append(prompt)
# in case prompt token num reaches
# in case prompt token num reaches max
if self.max_seq_len is not None and \
prompt_token_num + 1 > self.max_seq_len:
prompt_token_num = self.max_seq_len - 1
# minus the bos token
choice_target_ids.append(prompt_token_num - 1)

# get the target position index
if self.model.tokenizer.padding_side == 'left':
yingfhu marked this conversation as resolved.
Show resolved Hide resolved
# always the last position
target_pos.append(-1)
else:
# the last position of the original prompt
target_pos.append(prompt_token_num - 1)

# 4.1 Fetch and zip prompt & gold answer if output column exists
ds_reader = retriever.dataset_reader
Expand All @@ -187,14 +191,31 @@ def inference(self,
len(prompt_list),
self.batch_size,
disable=not self.is_main_process):
# get batch data
sub_prompt_list = prompt_list[idx:idx + self.batch_size]
sub_golds = gold_ans[idx:idx + self.batch_size]
sub_choice_target_ids = choice_target_ids[idx:idx +
self.batch_size]
sub_res = self.__get_cond_prob(sub_prompt_list,
sub_choice_target_ids,
choice_ids)
sub_target_pos = target_pos[idx:idx + self.batch_size]

# get probability result
if hasattr(self.model, 'batch_padding'):
# get batch padding for huggingface model
batch_padding = self.model.batch_padding
else:
# defaults to True for internal model
batch_padding = True

if batch_padding and self.batch_size > 1:
sub_res = self._get_cond_prob(sub_prompt_list,
sub_target_pos, choice_ids)
else:
sub_res = []
for prompt, position in zip(sub_prompt_list,
sub_target_pos):
sub_res.extend(
self._get_cond_prob([prompt], [position],
choice_ids))

# save all the result
for res, prompt, gold in zip(sub_res, sub_prompt_list,
sub_golds):
example_input = prompt.replace(ice[idx], '')
Expand All @@ -217,22 +238,29 @@ def inference(self,
for sample in output_handler.results_dict.values()
]

def __get_cond_prob(self,
input_texts: List[str],
sub_choice_target_ids,
choice_ids,
mask_length=None):
# TODO: support multiple tokens
def _get_cond_prob(self, input_texts: List[str], target_pos: List[int],
choice_ids: List[int]):
"""Get the condition probability of next token.

Args:
input_texts (List[str]): All the input prompt to be tested.
target_pos (List[int]): Target position of next token.
choice_ids (List[int]): Choice ids of target tokens.
"""
if hasattr(self.model, 'generator'):
outputs, _ = self.model.generator.get_logits(input_texts)
get_logits = self.model.generator.get_logits
else:
outputs, _ = self.model.get_logits(input_texts)
get_logits = self.model.get_logits

outputs, _ = get_logits(input_texts)

shift_logits = outputs[..., :-1, :].contiguous().float()
# we want get the next token probability
# therefore no shift here
logits = outputs.contiguous().float()

shift_logits = F.log_softmax(shift_logits, dim=-1)
logits = F.log_softmax(logits, dim=-1)
log_probs = []
for logits, target_ids in zip(shift_logits, sub_choice_target_ids):
for logit, target_ids in zip(logits, target_pos):
log_probs.append(
F.softmax(logits[target_ids, choice_ids], dim=-1).tolist())
F.softmax(logit[target_ids, choice_ids], dim=-1).tolist())
return log_probs