From 09de0e69073134378b5c5203c46a99b12403c57a Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Sat, 18 Nov 2023 00:00:41 -0800 Subject: [PATCH] fix mnist weights bug --- .../workloads/mnist/mnist_pytorch/workload.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py b/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py index b7f33b94b..e638df078 100644 --- a/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py @@ -80,9 +80,7 @@ def _build_input_queue( weights = torch.as_tensor( batch['weights'], dtype=torch.bool, device=DEVICE) else: - weights = torch.ones((batch['targets'].shape[-1],), - dtype=torch.bool, - device=DEVICE) + weights = torch.ones_like(targets, dtype=torch.bool, device=DEVICE) # Send batch to other devices when using DDP. if USE_PYTORCH_DDP: dist.broadcast(inputs, src=0)