diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py index 993d82c9d..55b68fb2f 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -233,7 +233,7 @@ def _eval_batch(self, summed_loss = self.loss_fn( label_batch=batch['targets'], logits_batch=logits, mask_batch=weights)['summed'] - return summed_loss + return summed_loss.to(dtype=torch.float64) class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload): diff --git a/algorithmic_efficiency/workloads/criteo1tb/workload.py b/algorithmic_efficiency/workloads/criteo1tb/workload.py index 801716de7..ef971bb75 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/workload.py @@ -63,11 +63,11 @@ def num_eval_train_examples(self) -> int: @property def num_validation_examples(self) -> int: - return 89_000_000 + return 83_274_637 @property def num_test_examples(self) -> int: - return 89_274_637 + return 95_000_000 @property def train_mean(self):