Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def _eval_batch(self,
summed_loss = self.loss_fn(
label_batch=batch['targets'], logits_batch=logits,
mask_batch=weights)['summed']
return summed_loss
return summed_loss.to(dtype=torch.float64)


class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload):
Expand Down
4 changes: 2 additions & 2 deletions algorithmic_efficiency/workloads/criteo1tb/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ def num_eval_train_examples(self) -> int:

@property
def num_validation_examples(self) -> int:
return 89_000_000
return 83_274_637

@property
def num_test_examples(self) -> int:
return 89_274_637
return 95_000_000

@property
def train_mean(self):
Expand Down