Skip to content

Commit

Permalink
fix clp potential error and support bs>1 (open-compass#439)
Browse files Browse the repository at this point in the history
* [Fix] fix clp potential error and support bs>1

* [Fix] fix clp potential error and support bs>1

* minor fix

* minor fix
  • Loading branch information
yingfhu authored and BunnyRunnerX committed Sep 27, 2023
1 parent 9b995ec commit e2adcb5
Showing 1 changed file with 59 additions and 24 deletions.
83 changes: 59 additions & 24 deletions opencompass/openicl/icl_inferencer/icl_clp_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def inference(self,
if self.single_token:
index = 0
prompt_list = []
choice_target_ids = []
target_pos = []
# TODO: Hard code temperaily, need to modified here
choices = retriever.test_ds[0]['choices']
try:
Expand All @@ -146,14 +146,21 @@ def inference(self,
# INTERNAL_END
get_token_len = self.model.get_token_len

if hasattr(self.model.tokenizer, 'padding_side'):
# get padding_side for huggingface model
padding_side = self.model.tokenizer.padding_side
else:
# defaults to left for internal model
padding_side = 'left'

# prepare in context for each example and control the length
for idx in range(len(ice_idx_list)):
prompt = retriever.generate_prompt_for_generate_task(
idx,
ice[idx],
ice_template=ice_template,
prompt_template=prompt_template)
prompt = self.model.parse_template(prompt, mode='ppl')
prompt = self.model.parse_template(prompt, mode='gen')
if self.max_seq_len is not None:
prompt_token_num = get_token_len(prompt)
# add one because additional token will be added in the end
Expand All @@ -169,15 +176,19 @@ def inference(self,
ice_template=ice_template,
prompt_template=prompt_template)
prompt_token_num = get_token_len(prompt)
# Add single token for prompt, this token can be any token
prompt += 'yes'
prompt_list.append(prompt)
# in case prompt token num reaches
# in case prompt token num reaches max
if self.max_seq_len is not None and \
prompt_token_num + 1 > self.max_seq_len:
prompt_token_num = self.max_seq_len - 1
# minus the bos token
choice_target_ids.append(prompt_token_num - 1)

# get the target position index
if padding_side == 'left':
# always the last position
target_pos.append(-1)
else:
# the last position of the original prompt
target_pos.append(prompt_token_num - 1)

# 4.1 Fetch and zip prompt & gold answer if output column exists
ds_reader = retriever.dataset_reader
Expand All @@ -186,19 +197,36 @@ def inference(self,
else:
gold_ans = [None] * len(prompt_list)

if hasattr(self.model, 'batch_padding'):
# get batch padding for huggingface model
batch_padding = self.model.batch_padding
else:
# defaults to False for internal model
batch_padding = False

logger.info('Calculating conditional log probability for prompts.')
for idx in trange(0,
len(prompt_list),
self.batch_size,
disable=not self.is_main_process):
# get batch data
sub_prompt_list = prompt_list[idx:idx + self.batch_size]
sub_golds = gold_ans[idx:idx + self.batch_size]
sub_choice_target_ids = choice_target_ids[idx:idx +
self.batch_size]
sub_res = self.__get_cond_prob(sub_prompt_list,
sub_choice_target_ids,
choice_ids)
sub_target_pos = target_pos[idx:idx + self.batch_size]

# get probability result
if batch_padding and self.batch_size > 1:
sub_res = self._get_cond_prob(sub_prompt_list,
sub_target_pos, choice_ids)
else:
sub_res = []
for prompt, position in zip(sub_prompt_list,
sub_target_pos):
sub_res.extend(
self._get_cond_prob([prompt], [position],
choice_ids))

# save all the result
for res, prompt, gold in zip(sub_res, sub_prompt_list,
sub_golds):
example_input = prompt.replace(ice[idx], '')
Expand All @@ -221,22 +249,29 @@ def inference(self,
for sample in output_handler.results_dict.values()
]

def __get_cond_prob(self,
input_texts: List[str],
sub_choice_target_ids,
choice_ids,
mask_length=None):
# TODO: support multiple tokens
def _get_cond_prob(self, input_texts: List[str], target_pos: List[int],
choice_ids: List[int]):
"""Get the condition probability of next token.
Args:
input_texts (List[str]): All the input prompt to be tested.
target_pos (List[int]): Target position of next token.
choice_ids (List[int]): Choice ids of target tokens.
"""
if hasattr(self.model, 'generator'):
outputs, _ = self.model.generator.get_logits(input_texts)
get_logits = self.model.generator.get_logits
else:
outputs, _ = self.model.get_logits(input_texts)
get_logits = self.model.get_logits

outputs, _ = get_logits(input_texts)

shift_logits = outputs[..., :-1, :].contiguous().float()
# we want get the next token probability
# therefore no shift here
logits = outputs.contiguous().float()

shift_logits = F.log_softmax(shift_logits, dim=-1)
logits = F.log_softmax(logits, dim=-1)
log_probs = []
for logits, target_ids in zip(shift_logits, sub_choice_target_ids):
for logit, target_ids in zip(logits, target_pos):
log_probs.append(
F.softmax(logits[target_ids, choice_ids], dim=-1).tolist())
F.softmax(logit[target_ids, choice_ids], dim=-1).tolist())
return log_probs

0 comments on commit e2adcb5

Please sign in to comment.