Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make CodeEval respect device_eval_batch_size #2969

Merged
merged 29 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
33 changes: 25 additions & 8 deletions composer/datasets/in_context_learning_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1234,11 +1234,13 @@ class InContextLearningCodeEvalDataset(InContextLearningDataset):
def __init__(
self,
generations_per_sample: int,
pass_at_k: int = 1,
pass_at_k: Union[int, list[int]] = 1,
josejg marked this conversation as resolved.
Show resolved Hide resolved
*args,
**kwargs,
):
if generations_per_sample < pass_at_k:
if isinstance(pass_at_k, int):
pass_at_k = [pass_at_k]
if generations_per_sample < max(pass_at_k):
raise ValueError(
f'generations_per_sample ({generations_per_sample}) must be greater than or equal to pass_at_k ({pass_at_k}) for code evaluation.'
)
Expand All @@ -1250,13 +1252,14 @@ def __init__(
'entry_points': 'entry_point',
'test_inputs': 'test_inputs',
'test_outputs': 'test_outputs',
'languages': 'language'
'languages': 'language',
'sample_id': 'sample_id',
}
# Linting complains if these are not set in init
self.max_prompt_length = 0
self.max_answer_length = 0
static_keys = ['mode', 'pass_at_k', 'generation_length', 'generation_kwargs']
list_keys = ['prompts', 'tests', 'entry_points', 'test_inputs', 'test_outputs', 'languages', 'labels']
static_keys = ['mode', 'pass_at_k', 'generation_length', 'generation_kwargs', 'generations_per_sample', 'dataset_size']
list_keys = ['prompts', 'tests', 'entry_points', 'test_inputs', 'test_outputs', 'languages', 'labels', 'sample_id']
tensor_keys = ['input_ids', 'attention_mask']
super().__init__(
context_key='prompt',
Expand All @@ -1272,7 +1275,9 @@ def __init__(
**kwargs,
)
self._set_max_prompt_and_answer_lengths()
dataset_size = len(self.dataset)
self.dataset = self.dataset.map(self._trim_padding)
self.dataset = self.repeat_dataset(self.dataset, generations_per_sample)
self.base_batch = {
'input_ids': [],
'mode': 'generate',
Expand All @@ -1288,15 +1293,27 @@ def __init__(
'generation_kwargs': {
'pad_token_id': self.pad_tok_id,
'num_beams': 1, # single beam
'num_return_sequences': generations_per_sample,
josejg marked this conversation as resolved.
Show resolved Hide resolved
'do_sample': True,
'temperature': 0.2, # good default for code
'use_cache': True,
'eos_token_id': self.tokenizer.eos_token_id
}
},
'sample_id': [],
'pass_at_k': list(pass_at_k),
'generations_per_sample': generations_per_sample,
'dataset_size': dataset_size,
}
if 'generation_kwargs' in kwargs:
self.update_generation_kwargs(kwargs['generation_kwargs'])

def repeat_dataset(self, dataset: HFDataset, repetitions: int) -> HFDataset:
def repeated_dataset():
for i, sample in enumerate(dataset):
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
for _ in range(repetitions):
yield {'sample_id':i, **sample}
from datasets import Dataset as HFDataset
return HFDataset.from_generator(repeated_dataset)
josejg marked this conversation as resolved.
Show resolved Hide resolved

def _set_max_prompt_and_answer_lengths(self):
"""
Iterates through the dataset and finds the maximum prompt length and sequence lengths
Expand Down Expand Up @@ -1371,7 +1388,7 @@ def build_icl_dataloader(
prelimiter: str, # e.g. 'Question: '
cot_delimiter: str, # e.g. ' ### '
fewshot_random_seed: int,
pass_at_k: int,
pass_at_k: Union[int, list[int]],
generations_per_sample: int,
generation_kwargs: Dict,
early_stopping_criteria: Optional[List[str]] = None,
Expand Down
95 changes: 55 additions & 40 deletions composer/metrics/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torchmetrics import Metric

from composer.utils.eval_client import EvalClient, LambdaEvalClient, LocalEvalClient, MosaicMLLambdaEvalClient
from composer.utils import dist

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -535,8 +536,8 @@ class InContextLearningCodeEvalAccuracy(InContextLearningMetric):
def __init__(self, dist_sync_on_step: bool = False):
# state from multiple processes
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.add_state('correct', default=torch.tensor(0.), dist_reduce_fx='sum')
self.add_state('total', default=torch.tensor(0.), dist_reduce_fx='sum')

self._initialized = False

self.eval_device = os.environ.get('CODE_EVAL_DEVICE', None)
if self.eval_device is not None:
Expand Down Expand Up @@ -580,6 +581,17 @@ def estimator(self, n: int, c: int, k: int) -> float:
return 1.0
return 1.0 - float(np.prod(1.0 - k / np.arange(n - c + 1, n + 1)))

def _initialize_state(self, batch: dict[str, Any]):
self.dataset_size = batch['dataset_size']
self.pass_at_k = batch['pass_at_k']
self.num_generations = batch['generations_per_sample']

# We need to defer the accumulator initialization because it depends on dataset size
self.add_state('correct', default=torch.zeros(self.dataset_size), dist_reduce_fx='sum')
self.add_state('total', default=torch.zeros(self.dataset_size), dist_reduce_fx='sum')
dist.barrier()
self._initialized = True

def update(self, batch: Dict[str, Any], outputs: List[str], labels: List[str]):
"""Updates the pass@k accuracy of code generation.

Expand All @@ -604,51 +616,54 @@ def update(self, batch: Dict[str, Any], outputs: List[str], labels: List[str]):
labels (List[str]): A list of the correct code generations, for compatibility with existing HF generate
functionalities. This is not used.
"""
if not self._initialized:
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
self._initialize_state(batch)

del labels # never used
client = self.get_client()

pass_at_k = batch['pass_at_k']
num_generations = batch['generation_kwargs']['num_return_sequences']
processed_outputs = [
outputs[i * num_generations:(i + 1) * num_generations] for i in range(len(batch['prompts']))
]
payloads = []
for sample_outputs, sample_prompt, test_inputs, test_outputs, entry_point, language in zip(
processed_outputs, batch['prompts'], batch['test_inputs'], batch['test_outputs'], batch['entry_points'],
for sample_id, code_gen, sample_prompt, test_inputs, test_outputs, entry_point, language in zip(
batch['sample_id'], outputs, batch['prompts'], batch['test_inputs'], batch['test_outputs'], batch['entry_points'],
batch['languages']):
self.total += torch.tensor(1.0)
prompt_payload = []
for code_gen in sample_outputs:
code_gen = re.split(r'\n[A-Za-z0-9#`]', code_gen)[0] # remove everything after function ends
final_code = sample_prompt + code_gen # combine prompt with the code generation
generation_payload = []
for test_input, test_output in zip(test_inputs, test_outputs):
payload = {
'code': final_code,
'input': test_input,
'output': test_output,
'entry_point': entry_point,
'language': language,
}
generation_payload.append(payload)

prompt_payload.append(generation_payload)
payloads.append(prompt_payload)

results = client.invoke(payloads)
for prompt in results:
num_correct = 0
for generation in prompt:
correct = all(generation)
if correct:
num_correct += 1

pass_at_k_rate = self.estimator(num_generations, num_correct, pass_at_k)
self.correct += torch.tensor(pass_at_k_rate)

idx = sample_id
self.total[idx] += 1.0

code_gen = re.split(r'\n[A-Za-z0-9#`]', code_gen)[0] # remove everything after function ends
final_code = sample_prompt + code_gen # combine prompt with the code generation

test_results = []
for test_input, test_output in zip(test_inputs, test_outputs):
payload = {
'code': final_code,
'input': test_input,
'output': test_output,
'entry_point': entry_point,
'language': language,
}

# TODO: refactor client.invoke API
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
result = client.invoke([[[payload]]])[0][0][0]
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
test_results.append(result)

if all(test_results):
self.correct[idx] += 1.0

client.close() # pyright: ignore [reportOptionalMemberAccess]

def compute(self):
assert isinstance(self.correct, Tensor)
assert isinstance(self.total, Tensor)
return self.correct / self.total
if not (self.total == self.num_generations).all().item():
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(f"Some samples in the dataset have less than {self.num_generations} generations")

results = {}
n = self.num_generations

for k in self.pass_at_k:
results[f'pass@{k}'] = sum([self.estimator(n, c.item(), k) for c in self.correct]) / self.dataset_size

if len(results) == 1: # backwards compatibility
return list(results.values())[0]

return results