From 7f7891df9e7e29eb5482f85317d57790b08594c1 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Mon, 2 Oct 2023 14:17:19 -0400 Subject: [PATCH 1/5] minor --- .../workloads/criteo1tb/criteo1tb_jax/workload.py | 1 + 1 file changed, 1 insertion(+) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index ba8db9ced..ae5da49bc 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -140,6 +140,7 @@ def _eval_batch_pmapped(self, summed_loss = self.loss_fn( label_batch=batch['targets'], logits_batch=logits, mask_batch=weights)['summed'] + print(summed_loss) return summed_loss def _eval_batch(self, From f081dd10dfe968dc43cba0bfbaa7b4d3aa49786e Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Mon, 2 Oct 2023 14:28:03 -0400 Subject: [PATCH 2/5] minor --- .../workloads/criteo1tb/criteo1tb_jax/workload.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index ae5da49bc..1ae818764 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -6,7 +6,7 @@ from flax import jax_utils import jax import jax.numpy as jnp - +import numpy as np from algorithmic_efficiency import param_utils from algorithmic_efficiency import spec from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax import models @@ -140,8 +140,7 @@ def _eval_batch_pmapped(self, summed_loss = self.loss_fn( label_batch=batch['targets'], logits_batch=logits, mask_batch=weights)['summed'] - print(summed_loss) - return summed_loss + return np.array(summed_loss, dtype=np.float64) def _eval_batch(self, params: spec.ParameterContainer, From 62d9ad746623d3eafe7d5248cdb375b232960a26 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Mon, 2 Oct 2023 14:31:12 -0400 Subject: [PATCH 3/5] minor --- .../workloads/criteo1tb/criteo1tb_jax/workload.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index 1ae818764..f3dc0c66d 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -140,14 +140,14 @@ def _eval_batch_pmapped(self, summed_loss = self.loss_fn( label_batch=batch['targets'], logits_batch=logits, mask_batch=weights)['summed'] - return np.array(summed_loss, dtype=np.float64) + return summed_loss def _eval_batch(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor]) -> spec.Tensor: # We do NOT psum inside of _eval_batch_pmapped, so the returned tensor of # shape (local_device_count,) will all be different values. - return self._eval_batch_pmapped(params, batch).sum() + return np.array(self._eval_batch_pmapped(params, batch).sum(), dtype=np.float64) class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload): From 19accc709a35c18132e73751b7f5763148d6946f Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Mon, 2 Oct 2023 17:10:17 -0400 Subject: [PATCH 4/5] Lint fix --- .../workloads/criteo1tb/criteo1tb_jax/workload.py | 1 + 1 file changed, 1 insertion(+) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index f3dc0c66d..dc1696bd0 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -7,6 +7,7 @@ import jax import jax.numpy as jnp import numpy as np + from algorithmic_efficiency import param_utils from algorithmic_efficiency import spec from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax import models From a1844c768d62ba3592bac897aa92cbe15b7ef9cd Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Mon, 2 Oct 2023 17:13:36 -0400 Subject: [PATCH 5/5] Lint fix --- .../workloads/criteo1tb/criteo1tb_jax/workload.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index dc1696bd0..a76a70289 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -148,7 +148,8 @@ def _eval_batch(self, batch: Dict[str, spec.Tensor]) -> spec.Tensor: # We do NOT psum inside of _eval_batch_pmapped, so the returned tensor of # shape (local_device_count,) will all be different values. - return np.array(self._eval_batch_pmapped(params, batch).sum(), dtype=np.float64) + return np.array( + self._eval_batch_pmapped(params, batch).sum(), dtype=np.float64) class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload):