Skip to content

Commit

Permalink
[Sync] Bump version 0.2.3 (#957)
Browse files Browse the repository at this point in the history
  • Loading branch information
Leymore committed Mar 12, 2024
1 parent 64fde73 commit ab6cdb2
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
2 changes: 1 addition & 1 deletion opencompass/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.2.2'
__version__ = '0.2.3'
18 changes: 10 additions & 8 deletions opencompass/openicl/icl_inferencer/icl_chat_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def __init__(
temperature: Optional[float] = 0.0,
do_sample: Optional[bool] = False,
infer_mode: str = 'last',
max_out_len: int = 512,
**kwargs) -> None:
super().__init__(
model=model,
Expand All @@ -193,6 +194,7 @@ def __init__(
save_every = 1
self.save_every = save_every
self.dialogue_mode = False
self.max_out_len = max_out_len

def _set_meta_template(self, model):
origin = model.template_parser
Expand Down Expand Up @@ -334,8 +336,8 @@ def infer_last(self, chat: List[dict], index: int, output_handler):
]

history = chat[:assistant_indices[-1]]
output = self.model.generate_from_template([history],
max_out_len=512)[0]
output = self.model.generate_from_template(
[history], max_out_len=self.max_out_len)[0]
output_handler.save_results(
origin_prompt=history,
prediction=output,
Expand All @@ -356,11 +358,11 @@ def infer_every(self, chat: List[dict], index: int, output_handler):
[history],
do_sample=self.do_sample,
temperature=self.temperature,
max_out_len=512)[0]
max_out_len=self.max_out_len)[0]
else:
output = self.model.generate_from_template([history],
do_sample=False,
max_out_len=512)[0]
output = self.model.generate_from_template(
[history], do_sample=False,
max_out_len=self.max_out_len)[0]
chat[i]['content'] = output
if not self.dialogue_mode:
output_handler.save_multiround_results(
Expand Down Expand Up @@ -397,8 +399,8 @@ def infer_every_with_gt(self, chat: List[dict], index: int,

for i in assistant_indices:
history = chat[:i]
output = self.model.generate_from_template([history],
max_out_len=512)[0]
output = self.model.generate_from_template(
[history], max_out_len=self.max_out_len)[0]
output_handler.save_multiround_results(
origin_prompt=history[-1]['content'],
prediction=output,
Expand Down

0 comments on commit ab6cdb2

Please sign in to comment.